19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return arith::ConstantOp::create(rewriter, loc,
34 return arith::ConstantOp::create(rewriter, loc, attr);
41 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return arith::ConstantOp::create(rewriter, loc,
46 return arith::ConstantOp::create(rewriter, loc, attr);
51 if (
auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
52 return shapedTy.clone(cloneTo);
63 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
66 Value a = op.getLhs();
67 Value b = op.getRhs();
70 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
72 Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
73 Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
74 Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
75 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
89 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
92 Type type = op.getType();
93 Value a = op.getLhs();
94 Value b = op.getRhs();
99 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
100 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
101 Value notEqualDivisor = arith::CmpIOp::create(
102 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
104 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
106 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
109 Value signEqual = arith::CmpIOp::create(
110 rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
112 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
114 Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
116 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
129 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
131 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
134 Type type = op.getType();
135 Value a = op.getLhs();
136 Value b = op.getRhs();
138 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
139 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
140 Value notEqualDivisor = arith::CmpIOp::create(
141 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
144 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
146 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
149 Value signOpposite = arith::CmpIOp::create(
150 rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
152 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
155 Value quotientMinusOne =
156 arith::AddIOp::create(rewriter, loc, quotient, minusOne);
158 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
164 template <
typename OpTy, arith::CmpIPredicate pred>
169 LogicalResult matchAndRewrite(OpTy op,
171 Value lhs = op.getLhs();
172 Value rhs = op.getRhs();
174 Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
175 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
180 template <
typename OpTy, arith::CmpFPredicate pred>
185 LogicalResult matchAndRewrite(OpTy op,
187 Value lhs = op.getLhs();
188 Value rhs = op.getRhs();
192 static_assert(pred == arith::CmpFPredicate::UGT ||
193 pred == arith::CmpFPredicate::ULT,
194 "pred must be either UGT or ULT");
195 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
196 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
199 Value isNaN = arith::CmpFOp::create(rewriter, loc,
200 arith::CmpFPredicate::UNO, rhs, rhs);
201 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
206 template <
typename OpTy, arith::CmpFPredicate pred>
211 LogicalResult matchAndRewrite(OpTy op,
213 Value lhs = op.getLhs();
214 Value rhs = op.getRhs();
218 static_assert(pred == arith::CmpFPredicate::UGT ||
219 pred == arith::CmpFPredicate::ULT,
220 "pred must be either UGT or ULT");
221 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
222 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
225 Value isNaN = arith::CmpFOp::create(rewriter, loc,
226 arith::CmpFPredicate::UNO, lhs, lhs);
227 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
234 LogicalResult matchAndRewrite(arith::ExtFOp op,
237 auto operand = op.getOperand();
238 Type operandTy = operand.getType();
239 Type resultTy = op.getType();
244 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
250 Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
251 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
254 Value shl = arith::ShLIOp::create(b, exti, c16);
255 Value result = arith::BitcastOp::create(b, resultTy, shl);
257 rewriter.replaceOp(op, result);
262 struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
264 LogicalResult matchAndRewrite(arith::TruncFOp op,
267 auto operand = op.getOperand();
268 Type operandTy = operand.getType();
269 Type resultTy = op.getType();
274 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
277 if (op.getRoundingmodeAttr()) {
278 return rewriter.notifyMatchFailure(
279 op,
"only applicable to default rounding mode.");
299 arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
308 Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
311 arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
314 Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
321 Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
324 Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
325 Value normalCaseResultI16 =
326 arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
330 arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
331 Value result = arith::BitcastOp::create(b, resultTy, select);
332 rewriter.replaceOp(op, result);
368 LogicalResult matchAndRewrite(arith::ExtFOp op,
372 Value operand = op.getOperand();
374 Type resultTy = op.getType();
378 if (!isa<Float4E2M1FNType>(operandETy))
379 return rewriter.notifyMatchFailure(op,
"not a ext of F4E2M1FN");
384 Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
392 Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
396 Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
398 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
399 bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
400 bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
401 bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
408 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
410 arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
412 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
413 bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
419 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
421 arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
424 Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
425 Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
426 Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
427 if (!isa<Float32Type>(resultETy))
428 result = arith::TruncFOp::create(b, resultTy, result);
430 rewriter.replaceOp(op, result);
437 LogicalResult matchAndRewrite(arith::ExtFOp op,
440 Value operand = op.getOperand();
442 Type resultTy = op.getType();
446 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
447 return rewriter.notifyMatchFailure(op,
"not a ext of F8E8M0FNU");
454 Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
460 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
461 Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
464 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
466 f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
467 Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
469 result = arith::TruncFOp::create(b, resultTy, result,
nullptr,
470 op.getFastmathAttr());
472 result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
474 rewriter.replaceOp(op, result);
509 LogicalResult matchAndRewrite(arith::TruncFOp op,
513 Value operand = op.getOperand();
515 Type resultTy = op.getType();
524 if (!isa<Float4E2M1FNType>(resultETy))
525 return rewriter.notifyMatchFailure(op,
"not a trunc of F4E2M1FN");
526 if (!isa<Float32Type>(operandETy))
527 operand = arith::ExtFOp::create(b, f32Ty, operand);
539 Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
540 operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
541 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
545 Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
546 Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
547 Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
551 Value cF4MantissaWidth = c0x1;
553 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
554 Value biasAdjustedSignExp =
555 arith::SubIOp::create(b, f32SignExp, biasAdjustment);
556 Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
557 f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
558 f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
562 Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
563 man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
564 Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
565 f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
569 Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
571 arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
573 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
574 Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
575 Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
576 man23Bits, zeroExpBits);
577 Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
579 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
583 arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
584 subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
585 f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
590 Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
592 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
593 shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
594 Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
595 f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
597 Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
598 rewriter.replaceOp(op, result);
610 LogicalResult matchAndRewrite(arith::TruncFOp op,
613 Value operand = op.getOperand();
616 Type resultTy = op.getType();
618 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
619 return rewriter.notifyMatchFailure(op,
"not a truncf to f8E8M0FNU");
622 if (op.getRoundingmodeAttr()) {
623 return rewriter.notifyMatchFailure(
624 op,
"only applicable to default rounding mode.");
632 operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
634 operand = arith::TruncFOp::create(
635 b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
637 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
639 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
640 Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
641 Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
642 rewriter.replaceOp(op, result);
647 struct ScalingExtFOpConverter :
public OpRewritePattern<arith::ScalingExtFOp> {
649 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
652 Value inputOperand = op.getIn();
653 Value scaleOperand = op.getScale();
658 scaleETy = b.getF8E8M0Type();
660 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand,
nullptr,
661 op.getFastmathAttr());
664 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
665 return rewriter.notifyMatchFailure(
666 op,
"scaling_extf is using scales of type which can not be converted "
669 Type resultTy = op.getType();
673 arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
675 arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
677 arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
678 rewriter.replaceOp(op, result);
688 struct ScalingTruncFOpConverter
691 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
694 Value inputOperand = op.getIn();
695 Value scaleOperand = op.getScale();
700 scaleETy = b.getF8E8M0Type();
702 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand,
nullptr,
703 op.getFastmathAttr());
705 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
706 return rewriter.notifyMatchFailure(
707 op,
"scaling_truncf is using scales type which can not be converted "
710 Type resultTy = op.getType();
715 arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
716 Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
717 op.getFastmathAttr());
718 Value resultCast = arith::TruncFOp::create(
719 b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
720 rewriter.replaceOp(op, resultCast);
725 struct ArithExpandOpsPass
726 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
727 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
729 void runOnOperation()
override {
735 target.addLegalDialect<arith::ArithDialect>();
736 target.addLegalDialect<vector::VectorDialect>();
751 arith::ScalingExtFOp,
752 arith::ScalingTruncFOp
762 target.addDynamicallyLegalOp<arith::ExtFOp>(
763 [=](arith::ExtFOp op) {
766 bool legalTypes =
true;
770 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
772 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
776 target.addDynamicallyLegalOp<arith::TruncFOp>(
777 [=](arith::TruncFOp op) {
780 bool legalTypes =
true;
784 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
786 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
802 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
807 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
812 patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
817 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
823 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
832 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
833 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
834 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
835 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
836 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
837 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
838 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
839 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...