19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return arith::ConstantOp::create(rewriter, loc,
34 return arith::ConstantOp::create(rewriter, loc, attr);
41 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return arith::ConstantOp::create(rewriter, loc,
46 return arith::ConstantOp::create(rewriter, loc, attr);
51 if (
auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
52 return shapedTy.clone(cloneTo);
63 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
66 Value a = op.getLhs();
67 Value b = op.getRhs();
70 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
72 Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
73 Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
74 Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
75 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
89 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
92 Type type = op.getType();
93 Value a = op.getLhs();
94 Value b = op.getRhs();
99 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
100 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
101 Value notEqualDivisor = arith::CmpIOp::create(
102 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
104 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
106 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
109 Value signEqual = arith::CmpIOp::create(
110 rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
112 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
114 Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
116 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
129 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
131 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
134 Type type = op.getType();
135 Value a = op.getLhs();
136 Value b = op.getRhs();
138 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
139 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
140 Value notEqualDivisor = arith::CmpIOp::create(
141 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
144 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
146 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
149 Value signOpposite = arith::CmpIOp::create(
150 rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
152 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
155 Value quotientMinusOne =
156 arith::AddIOp::create(rewriter, loc, quotient, minusOne);
158 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
164 template <
typename OpTy, arith::CmpIPredicate pred>
169 LogicalResult matchAndRewrite(OpTy op,
171 Value lhs = op.getLhs();
172 Value rhs = op.getRhs();
174 Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
175 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
180 template <
typename OpTy, arith::CmpFPredicate pred>
185 LogicalResult matchAndRewrite(OpTy op,
187 Value lhs = op.getLhs();
188 Value rhs = op.getRhs();
192 static_assert(pred == arith::CmpFPredicate::UGT ||
193 pred == arith::CmpFPredicate::ULT,
194 "pred must be either UGT or ULT");
195 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
196 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
199 Value isNaN = arith::CmpFOp::create(rewriter, loc,
200 arith::CmpFPredicate::UNO, rhs, rhs);
201 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
206 template <
typename OpTy, arith::CmpFPredicate pred>
211 LogicalResult matchAndRewrite(OpTy op,
213 Value lhs = op.getLhs();
214 Value rhs = op.getRhs();
218 static_assert(pred == arith::CmpFPredicate::UGT ||
219 pred == arith::CmpFPredicate::ULT,
220 "pred must be either UGT or ULT");
221 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
222 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
225 Value isNaN = arith::CmpFOp::create(rewriter, loc,
226 arith::CmpFPredicate::UNO, lhs, lhs);
227 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
234 LogicalResult matchAndRewrite(arith::ExtFOp op,
237 auto operand = op.getOperand();
238 Type operandTy = operand.getType();
239 Type resultTy = op.getType();
244 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
250 Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
251 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
254 Value shl = arith::ShLIOp::create(b, exti, c16);
255 Value result = arith::BitcastOp::create(b, resultTy, shl);
257 rewriter.replaceOp(op, result);
262 struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
264 LogicalResult matchAndRewrite(arith::TruncFOp op,
267 auto operand = op.getOperand();
268 Type operandTy = operand.getType();
269 Type resultTy = op.getType();
274 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
277 if (op.getRoundingmodeAttr()) {
278 return rewriter.notifyMatchFailure(
279 op,
"only applicable to default rounding mode.");
299 arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
308 Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
311 arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
314 Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
321 Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
324 Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
325 Value normalCaseResultI16 =
326 arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
330 arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
331 Value result = arith::BitcastOp::create(b, resultTy, select);
332 rewriter.replaceOp(op, result);
368 LogicalResult matchAndRewrite(arith::ExtFOp op,
372 Value operand = op.getOperand();
374 Type resultTy = op.getType();
378 if (!isa<Float4E2M1FNType>(operandETy))
379 return rewriter.notifyMatchFailure(op,
"not a ext of F4E2M1FN");
384 Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
393 Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
395 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
396 bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
397 bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
398 bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
405 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
407 arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
409 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
410 bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
416 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
418 arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
421 Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
422 Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
423 Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
424 if (!isa<Float32Type>(resultETy))
425 result = arith::TruncFOp::create(b, resultTy, result);
427 rewriter.replaceOp(op, result);
434 LogicalResult matchAndRewrite(arith::ExtFOp op,
437 Value operand = op.getOperand();
439 Type resultTy = op.getType();
443 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
444 return rewriter.notifyMatchFailure(op,
"not a ext of F8E8M0FNU");
451 Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
457 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
458 Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
461 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
463 f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
464 Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
466 result = arith::TruncFOp::create(b, resultTy, result,
nullptr,
467 op.getFastmathAttr());
469 result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
471 rewriter.replaceOp(op, result);
506 LogicalResult matchAndRewrite(arith::TruncFOp op,
510 Value operand = op.getOperand();
512 Type resultTy = op.getType();
521 if (!isa<Float4E2M1FNType>(resultETy))
522 return rewriter.notifyMatchFailure(op,
"not a trunc of F4E2M1FN");
523 if (!isa<Float32Type>(operandETy))
524 operand = arith::ExtFOp::create(b, f32Ty, operand);
536 Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
537 operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
538 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
542 Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
543 Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
544 Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
548 Value cF4MantissaWidth = c0x1;
550 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
551 Value biasAdjustedSignExp =
552 arith::SubIOp::create(b, f32SignExp, biasAdjustment);
553 Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
554 f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
555 f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
559 Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
560 man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
561 Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
562 f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
566 Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
568 arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
570 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
571 Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
572 Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
573 man23Bits, zeroExpBits);
574 Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
576 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
580 arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
581 subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
582 f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
587 Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
589 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
590 shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
591 Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
592 f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
594 Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
595 rewriter.replaceOp(op, result);
607 LogicalResult matchAndRewrite(arith::TruncFOp op,
610 Value operand = op.getOperand();
613 Type resultTy = op.getType();
615 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
616 return rewriter.notifyMatchFailure(op,
"not a truncf to f8E8M0FNU");
619 if (op.getRoundingmodeAttr()) {
620 return rewriter.notifyMatchFailure(
621 op,
"only applicable to default rounding mode.");
629 operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
631 operand = arith::TruncFOp::create(
632 b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
634 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
636 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
637 Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
638 Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
639 rewriter.replaceOp(op, result);
644 struct ScalingExtFOpConverter :
public OpRewritePattern<arith::ScalingExtFOp> {
646 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
649 Value inputOperand = op.getIn();
650 Value scaleOperand = op.getScale();
655 scaleETy = b.getF8E8M0Type();
657 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand,
nullptr,
658 op.getFastmathAttr());
661 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
662 return rewriter.notifyMatchFailure(
663 op,
"scaling_extf is using scales of type which can not be converted "
666 Type resultTy = op.getType();
670 arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
672 arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
674 arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
675 rewriter.replaceOp(op, result);
685 struct ScalingTruncFOpConverter
688 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
691 Value inputOperand = op.getIn();
692 Value scaleOperand = op.getScale();
697 scaleETy = b.getF8E8M0Type();
699 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand,
nullptr,
700 op.getFastmathAttr());
702 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
703 return rewriter.notifyMatchFailure(
704 op,
"scaling_truncf is using scales type which can not be converted "
707 Type resultTy = op.getType();
712 arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
713 Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
714 op.getFastmathAttr());
715 Value resultCast = arith::TruncFOp::create(
716 b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
717 rewriter.replaceOp(op, resultCast);
722 struct ArithExpandOpsPass
723 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
724 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
726 void runOnOperation()
override {
732 target.addLegalDialect<arith::ArithDialect>();
733 target.addLegalDialect<vector::VectorDialect>();
748 arith::ScalingExtFOp,
749 arith::ScalingTruncFOp
759 target.addDynamicallyLegalOp<arith::ExtFOp>(
760 [=](arith::ExtFOp op) {
763 bool legalTypes =
true;
767 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
769 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
773 target.addDynamicallyLegalOp<arith::TruncFOp>(
774 [=](arith::TruncFOp op) {
777 bool legalTypes =
true;
781 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
783 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
799 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
804 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
809 patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
814 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
820 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
829 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
830 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
831 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
832 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
833 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
834 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
835 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
836 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...