17 #include "llvm/ADT/SmallVectorExtras.h"
22 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
23 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
33 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
34 return rewriter.
create<arith::ConstantOp>(
37 return rewriter.
create<arith::ConstantOp>(loc, attr);
44 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
45 return rewriter.
create<arith::ConstantOp>(
49 return rewriter.
create<arith::ConstantOp>(loc, attr);
54 if (
auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
55 return shapedTy.clone(cloneTo);
66 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
69 Value a = op.getLhs();
70 Value b = op.getRhs();
73 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
75 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
76 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
77 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
78 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
92 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
95 Type type = op.getType();
96 Value a = op.getLhs();
97 Value b = op.getRhs();
102 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
103 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
104 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
105 loc, arith::CmpIPredicate::ne, a,
product);
108 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
110 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
112 Value signEqual = rewriter.create<arith::CmpIOp>(
113 loc, arith::CmpIPredicate::eq, aNeg, bNeg);
115 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
117 Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
119 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
132 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
134 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
137 Type type = op.getType();
138 Value a = op.getLhs();
139 Value b = op.getRhs();
141 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
142 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
143 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
144 loc, arith::CmpIPredicate::ne, a,
product);
148 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
150 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
152 Value signOpposite = rewriter.create<arith::CmpIOp>(
153 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
155 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
158 Value quotientMinusOne =
159 rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
161 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
167 template <
typename OpTy, arith::CmpIPredicate pred>
172 LogicalResult matchAndRewrite(OpTy op,
174 Value lhs = op.getLhs();
175 Value rhs = op.getRhs();
177 Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
178 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
183 template <
typename OpTy, arith::CmpFPredicate pred>
188 LogicalResult matchAndRewrite(OpTy op,
190 Value lhs = op.getLhs();
191 Value rhs = op.getRhs();
195 static_assert(pred == arith::CmpFPredicate::UGT ||
196 pred == arith::CmpFPredicate::ULT,
197 "pred must be either UGT or ULT");
198 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
199 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
202 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
204 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
209 template <
typename OpTy, arith::CmpFPredicate pred>
214 LogicalResult matchAndRewrite(OpTy op,
216 Value lhs = op.getLhs();
217 Value rhs = op.getRhs();
221 static_assert(pred == arith::CmpFPredicate::UGT ||
222 pred == arith::CmpFPredicate::ULT,
223 "pred must be either UGT or ULT");
224 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
225 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
228 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
230 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
237 LogicalResult matchAndRewrite(arith::ExtFOp op,
240 auto operand = op.getOperand();
241 Type operandTy = operand.getType();
242 Type resultTy = op.getType();
247 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
253 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
254 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
257 Value shl = b.create<arith::ShLIOp>(exti, c16);
258 Value result = b.create<arith::BitcastOp>(resultTy, shl);
260 rewriter.replaceOp(op, result);
265 struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
267 LogicalResult matchAndRewrite(arith::TruncFOp op,
270 auto operand = op.getOperand();
271 Type operandTy = operand.getType();
272 Type resultTy = op.getType();
277 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
280 if (op.getRoundingmodeAttr()) {
281 return rewriter.notifyMatchFailure(
282 op,
"only applicable to default rounding mode.");
302 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
311 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
314 b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
317 Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
324 Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
327 Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
328 Value normalCaseResultI16 =
329 b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
333 b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
334 Value result = b.create<arith::BitcastOp>(resultTy, select);
335 rewriter.replaceOp(op, result);
371 LogicalResult matchAndRewrite(arith::ExtFOp op,
375 Value operand = op.getOperand();
377 Type resultTy = op.getType();
381 if (!isa<Float4E2M1FNType>(operandETy))
382 return rewriter.notifyMatchFailure(op,
"not a ext of F4E2M1FN");
387 Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
396 Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
398 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
399 bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
400 bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
401 bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
408 b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
410 b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
412 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
413 bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
419 b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
420 Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
423 Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
424 Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
425 Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
426 if (!isa<Float32Type>(resultETy))
427 result = b.create<arith::TruncFOp>(resultTy, result);
429 rewriter.replaceOp(op, result);
436 LogicalResult matchAndRewrite(arith::ExtFOp op,
439 Value operand = op.getOperand();
441 Type resultTy = op.getType();
445 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
446 return rewriter.notifyMatchFailure(op,
"not a ext of F8E8M0FNU");
453 Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
459 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
460 Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
463 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
465 f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
466 Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
468 result = b.create<arith::TruncFOp>(resultTy, result,
nullptr,
469 op.getFastmathAttr());
471 result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
473 rewriter.replaceOp(op, result);
508 LogicalResult matchAndRewrite(arith::TruncFOp op,
512 Value operand = op.getOperand();
514 Type resultTy = op.getType();
523 if (!isa<Float32Type>(operandETy))
524 operand = b.create<arith::ExtFOp>(f32Ty, operand);
525 if (!isa<Float4E2M1FNType>(resultETy))
526 return rewriter.notifyMatchFailure(op,
"not a trunc of F4E2M1FN");
538 Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
539 operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
540 Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
544 Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
545 Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
546 Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
550 Value cF4MantissaWidth = c0x1;
552 Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
553 Value biasAdjustedSignExp =
554 b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
555 Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
556 f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
557 f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
561 Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
562 man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
563 Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
564 f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
568 Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
570 b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
572 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
573 Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
574 Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
575 man23Bits, zeroExpBits);
576 Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
578 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
582 b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
583 subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
584 f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
589 Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
591 b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
592 shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
593 Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
594 f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
596 Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
597 rewriter.replaceOp(op, result);
609 LogicalResult matchAndRewrite(arith::TruncFOp op,
612 Value operand = op.getOperand();
615 Type resultTy = op.getType();
617 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
618 return rewriter.notifyMatchFailure(op,
"not a truncf to f8E8M0FNU");
621 if (op.getRoundingmodeAttr()) {
622 return rewriter.notifyMatchFailure(
623 op,
"only applicable to default rounding mode.");
631 operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
633 operand = b.create<arith::TruncFOp>(
634 f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
636 Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
638 Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
639 Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
640 Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
641 rewriter.replaceOp(op, result);
646 struct ScalingExtFOpConverter :
public OpRewritePattern<arith::ScalingExtFOp> {
648 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
651 Value inputOperand = op.getIn();
652 Value scaleOperand = op.getScale();
657 scaleETy = b.getF8E8M0Type();
659 scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand,
nullptr,
660 op.getFastmathAttr());
662 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
663 return rewriter.notifyMatchFailure(
664 op,
"scaling_extf is using scales of type which can not be converted "
667 Type resultTy = op.getType();
671 b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
673 b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
675 b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
676 rewriter.replaceOp(op, result);
686 struct ScalingTruncFOpConverter
689 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
692 Value inputOperand = op.getIn();
693 Value scaleOperand = op.getScale();
698 scaleETy = b.getF8E8M0Type();
700 scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand,
nullptr,
701 op.getFastmathAttr());
703 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
704 return rewriter.notifyMatchFailure(
705 op,
"scaling_truncf is using scales type which can not be converted "
708 Type resultTy = op.getType();
713 b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
714 Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
715 op.getFastmathAttr());
716 Value resultCast = b.create<arith::TruncFOp>(
717 resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
718 rewriter.replaceOp(op, resultCast);
723 struct ArithExpandOpsPass
724 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
725 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
727 void runOnOperation()
override {
733 target.addLegalDialect<arith::ArithDialect>();
734 target.addLegalDialect<vector::VectorDialect>();
749 arith::ScalingExtFOp,
750 arith::ScalingTruncFOp
760 target.addDynamicallyLegalOp<arith::ExtFOp>(
761 [=](arith::ExtFOp op) {
764 bool legalTypes =
true;
768 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
770 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
774 target.addDynamicallyLegalOp<arith::TruncFOp>(
775 [=](arith::TruncFOp op) {
778 bool legalTypes =
true;
782 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
784 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
800 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
805 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
810 patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
815 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
821 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
830 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
831 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
832 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
833 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
834 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
835 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
836 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
837 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...