19#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return arith::ConstantOp::create(rewriter, loc,
34 return arith::ConstantOp::create(rewriter, loc, attr);
41 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return arith::ConstantOp::create(rewriter, loc,
45 return arith::ConstantOp::create(rewriter, loc, attr);
52 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
53 return arith::ConstantOp::create(rewriter, loc,
57 return arith::ConstantOp::create(rewriter, loc, attr);
62 if (
auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
63 return shapedTy.clone(cloneTo);
74 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
75 PatternRewriter &rewriter)
const final {
76 Location loc = op.getLoc();
77 Value a = op.getLhs();
78 Value
b = op.getRhs();
81 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
83 Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
84 Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne,
b);
85 Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
86 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
100 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
101 PatternRewriter &rewriter)
const final {
102 Location loc = op.getLoc();
103 Type type = op.getType();
104 Value a = op.getLhs();
105 Value
b = op.getRhs();
110 Value quotient = arith::DivSIOp::create(rewriter, loc, a,
b);
111 Value
product = arith::MulIOp::create(rewriter, loc, quotient,
b);
112 Value notEqualDivisor = arith::CmpIOp::create(
113 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
115 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
117 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
120 Value signEqual = arith::CmpIOp::create(
121 rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
123 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
125 Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
127 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
140struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
142 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
143 PatternRewriter &rewriter)
const final {
144 Location loc = op.getLoc();
145 Type type = op.getType();
146 Value a = op.getLhs();
147 Value
b = op.getRhs();
149 Value quotient = arith::DivSIOp::create(rewriter, loc, a,
b);
150 Value
product = arith::MulIOp::create(rewriter, loc, quotient,
b);
151 Value notEqualDivisor = arith::CmpIOp::create(
152 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
155 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
157 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
160 Value signOpposite = arith::CmpIOp::create(
161 rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
163 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
165 Value minusOne =
createConst(loc, type, -1, rewriter);
166 Value quotientMinusOne =
167 arith::AddIOp::create(rewriter, loc, quotient, minusOne);
169 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
175template <
typename OpTy, arith::CmpIPredicate pred>
178 using OpRewritePattern<OpTy>::OpRewritePattern;
180 LogicalResult matchAndRewrite(OpTy op,
181 PatternRewriter &rewriter)
const final {
182 Value
lhs = op.getLhs();
183 Value
rhs = op.getRhs();
185 Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred,
lhs,
rhs);
186 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp,
lhs,
rhs);
191template <
typename OpTy, arith::CmpFPredicate pred>
194 using OpRewritePattern<OpTy>::OpRewritePattern;
196 LogicalResult matchAndRewrite(OpTy op,
197 PatternRewriter &rewriter)
const final {
198 Value
lhs = op.getLhs();
199 Value
rhs = op.getRhs();
201 Location loc = op.getLoc();
203 static_assert(pred == arith::CmpFPredicate::UGT ||
204 pred == arith::CmpFPredicate::ULT,
205 "pred must be either UGT or ULT");
206 Value cmp = arith::CmpFOp::create(rewriter, loc, pred,
lhs,
rhs);
207 Value select = arith::SelectOp::create(rewriter, loc, cmp,
lhs,
rhs);
210 Value isNaN = arith::CmpFOp::create(rewriter, loc,
211 arith::CmpFPredicate::UNO,
rhs,
rhs);
212 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN,
rhs, select);
217template <
typename OpTy, arith::CmpFPredicate pred>
220 using OpRewritePattern<OpTy>::OpRewritePattern;
222 LogicalResult matchAndRewrite(OpTy op,
223 PatternRewriter &rewriter)
const final {
224 Value
lhs = op.getLhs();
225 Value
rhs = op.getRhs();
227 Location loc = op.getLoc();
229 static_assert(pred == arith::CmpFPredicate::UGT ||
230 pred == arith::CmpFPredicate::ULT,
231 "pred must be either UGT or ULT");
232 Value cmp = arith::CmpFOp::create(rewriter, loc, pred,
lhs,
rhs);
233 Value select = arith::SelectOp::create(rewriter, loc, cmp,
lhs,
rhs);
236 Value isNaN = arith::CmpFOp::create(rewriter, loc,
237 arith::CmpFPredicate::UNO,
lhs,
lhs);
238 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN,
rhs, select);
245 LogicalResult matchAndRewrite(arith::ExtFOp op,
248 auto operand = op.getOperand();
249 Type operandTy = operand.getType();
250 Type resultTy = op.getType();
255 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
261 Value bitcast = arith::BitcastOp::create(
b, i16Ty, operand);
262 Value exti = arith::ExtUIOp::create(
b, i32Ty, bitcast);
265 Value shl = arith::ShLIOp::create(
b, exti, c16);
266 Value result = arith::BitcastOp::create(
b, resultTy, shl);
268 rewriter.replaceOp(op,
result);
273struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
275 LogicalResult matchAndRewrite(arith::TruncFOp op,
278 auto operand = op.getOperand();
279 Type operandTy = operand.getType();
280 Type resultTy = op.getType();
288 if (op.getRoundingmodeAttr()) {
290 op,
"only applicable to default rounding mode.");
310 arith::CmpFOp::create(
b, arith::CmpFPredicate::UNE, operand, operand);
319 Value bitcast = arith::BitcastOp::create(
b, i32Ty, operand);
322 arith::AndIOp::create(
b, arith::ShRUIOp::create(
b, bitcast, c16), c1);
325 Value roundingBias = arith::AddIOp::create(
b, bit16, c7FFF);
332 Value biased = arith::AddIOp::create(
b, bitcast, roundingBias);
335 Value biasedAndShifted = arith::ShRUIOp::create(
b, biased, c16);
336 Value normalCaseResultI16 =
337 arith::TruncIOp::create(
b, i16Ty, biasedAndShifted);
341 arith::SelectOp::create(
b, isNan, c7FC0I16, normalCaseResultI16);
342 Value result = arith::BitcastOp::create(
b, resultTy, select);
379 LogicalResult matchAndRewrite(arith::ExtFOp op,
383 Value operand = op.getOperand();
385 Type resultTy = op.getType();
389 if (!isa<Float4E2M1FNType>(operandETy))
395 Value i4Bits = arith::BitcastOp::create(
b, i4Ty, operand);
403 Value i4BitsNoSign = arith::AndIOp::create(
b, i4Bits, c0x7);
407 Value bits1To24 = arith::ShLIOp::create(
b, i4BitsNoSign, c0x2);
409 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
410 bits1To24 = arith::SelectOp::create(
b, isHalf, c0x0, bits1To24);
411 bits1To24 = arith::ExtUIOp::create(
b, i32Ty, bits1To24);
412 bits1To24 = arith::ShLIOp::create(
b, bits1To24, c0x00000014);
419 arith::CmpIOp::create(
b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
421 arith::SelectOp::create(
b, useLargerExp, highExpBits, lowExpBits);
423 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
424 bits25To31 = arith::SelectOp::create(
b, zeroExp, zeroExpBits, bits25To31);
430 arith::CmpIOp::create(
b, arith::CmpIPredicate::uge, i4Bits, c0x8);
432 arith::SelectOp::create(
b, negative, c0x80000000, zeroExpBits);
435 Value bits1To31 = arith::AddIOp::create(
b, bits1To24, bits25To31);
436 Value bits1To32 = arith::AddIOp::create(
b, bits1To31, bit32);
437 Value result = arith::BitcastOp::create(
b, f32Ty, bits1To32);
438 if (!isa<Float32Type>(resultETy))
448 LogicalResult matchAndRewrite(arith::ExtFOp op,
449 PatternRewriter &rewriter)
const final {
450 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
451 Value operand = op.getOperand();
452 Type operandTy = operand.
getType();
453 Type resultTy = op.getType();
457 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
465 Value bitcast = arith::BitcastOp::create(
b, i8Ty, operand);
466 Value cF32MantissaWidth =
createConst(op->getLoc(), i32Ty, 23, rewriter);
467 Value exti = arith::ExtUIOp::create(
b, i32Ty, bitcast);
468 Value f32Bits = arith::ShLIOp::create(
b, exti, cF32MantissaWidth);
471 auto fastMath = op.getFastmathAttr();
472 bool NoNaN = fastMath
473 ? (fastMath.getValue() & arith::FastMathFlags::nnan) ==
474 arith::FastMathFlags::nnan
477 Value cF8NaN =
createConst(op.getLoc(), i8Ty, 0xff, rewriter);
478 Value cF32NaN =
createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
480 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
482 f32Bits = arith::SelectOp::create(
b, isNan, cF32NaN, f32Bits);
485 Value
result = arith::BitcastOp::create(
b, f32Ty, f32Bits);
487 result = arith::TruncFOp::create(
b, resultTy,
result,
nullptr,
488 op.getFastmathAttr());
490 result = arith::ExtFOp::create(
b, resultTy,
result, op.getFastmathAttr());
527 LogicalResult matchAndRewrite(arith::TruncFOp op,
528 PatternRewriter &rewriter)
const final {
529 Location loc = op.getLoc();
530 ImplicitLocOpBuilder
b(loc, rewriter);
531 Value operand = op.getOperand();
532 Type operandTy = operand.
getType();
533 Type resultTy = op.getType();
542 if (!isa<Float4E2M1FNType>(resultETy))
544 if (!isa<Float32Type>(operandETy))
545 operand = arith::ExtFOp::create(
b, f32Ty, operand);
549 Value c0x00000016 =
createConst(loc, i32Ty, 22, rewriter);
550 Value c0x00 =
createConst(loc, i8Ty, 0x00, rewriter);
551 Value c0xff =
createConst(loc, i8Ty, 0xff, rewriter);
552 Value zeroExpBits =
createConst(loc, i32Ty, 0, rewriter);
557 Value operandClamped = arith::MinNumFOp::create(
b, cHigherBound, operand);
558 operandClamped = arith::MaxNumFOp::create(
b, cLowerBound, operandClamped);
559 Value f32Bits = arith::BitcastOp::create(
b, i32Ty, operandClamped);
562 Value cF32ExpManWidth =
createConst(loc, i32Ty, 31, rewriter);
563 Value f32Sign = arith::ShRUIOp::create(
b, f32Bits, cF32ExpManWidth);
564 Value f4Sign = arith::TruncIOp::create(
b, i4Ty, f32Sign);
565 Value f4Bits = arith::ShLIOp::create(
b, f4Sign, c0x3);
568 Value biasAdjustment =
createConst(loc, i32Ty, 0x7e, rewriter);
569 Value cF4MantissaWidth = c0x1;
570 Value cF32MantissaWidth =
createConst(loc, i32Ty, 23, rewriter);
571 Value f32SignExp = arith::ShRUIOp::create(
b, f32Bits, cF32MantissaWidth);
572 Value biasAdjustedSignExp =
573 arith::SubIOp::create(
b, f32SignExp, biasAdjustment);
574 Value f4Exp = arith::TruncIOp::create(
b, i4Ty, biasAdjustedSignExp);
575 f4Exp = arith::ShLIOp::create(
b, f4Exp, cF4MantissaWidth);
576 f4Bits = arith::AddIOp::create(
b, f4Bits, f4Exp);
579 Value cF32FirstBitMask =
createConst(loc, i32Ty, 0x400000, rewriter);
580 Value man1Bit = arith::AndIOp::create(
b, f32Bits, cF32FirstBitMask);
581 man1Bit = arith::ShRUIOp::create(
b, man1Bit, c0x00000016);
582 Value f4Man = arith::TruncIOp::create(
b, i4Ty, man1Bit);
583 f4Bits = arith::AddIOp::create(
b, f4Bits, f4Man);
586 Value cF32MantissaMask =
createConst(loc, i32Ty, 0x7fffff, rewriter);
587 Value f8Exp = arith::TruncIOp::create(
b, i8Ty, biasAdjustedSignExp);
589 arith::CmpIOp::create(
b, arith::CmpIPredicate::sle, f8Exp, c0x00);
591 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, f8Exp, c0xff);
592 Value man23Bits = arith::AndIOp::create(
b, f32Bits, cF32MantissaMask);
593 Value isNonZeroMan = arith::CmpIOp::create(
b, arith::CmpIPredicate::ugt,
594 man23Bits, zeroExpBits);
595 Value roundToHalf = arith::AndIOp::create(
b, isNegOneExp, isNonZeroMan);
597 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, f8Exp, c0x00);
598 Value subnormalF4Bits =
createConst(loc, i4Ty, 0xf, rewriter);
599 Value halfF4Bits =
createConst(loc, i4Ty, 0x0, rewriter);
601 arith::SelectOp::create(
b, isSubnormal, subnormalF4Bits, f4Bits);
602 subResult = arith::SelectOp::create(
b, roundToHalf, halfF4Bits, subResult);
603 f4Bits = arith::SelectOp::create(
b, isZeroExp, f4Bits, subResult);
606 Value cF32Last22BitMask =
createConst(loc, i32Ty, 0x3fffff, rewriter);
607 Value cRound =
createConst(loc, i32Ty, 0x200000, rewriter);
608 Value man22Bits = arith::AndIOp::create(
b, f32Bits, cF32Last22BitMask);
610 arith::CmpIOp::create(
b, arith::CmpIPredicate::uge, man22Bits, cRound);
611 shouldRound = arith::OrIOp::create(
b, shouldRound, isSubnormal);
612 Value roundedF4Bits = arith::AddIOp::create(
b, f4Bits, c0x1);
613 f4Bits = arith::SelectOp::create(
b, shouldRound, roundedF4Bits, f4Bits);
615 Value
result = arith::BitcastOp::create(
b, resultTy, f4Bits);
628 LogicalResult matchAndRewrite(arith::TruncFOp op,
629 PatternRewriter &rewriter)
const final {
630 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
631 Value operand = op.getOperand();
632 Type operandTy = operand.
getType();
634 Type resultTy = op.getType();
636 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
640 if (op.getRoundingmodeAttr()) {
642 op,
"only applicable to default rounding mode.");
650 operand = arith::ExtFOp::create(
b, f32Ty, operand, op.getFastmathAttr());
652 operand = arith::TruncFOp::create(
653 b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
655 Value f32Bits = arith::BitcastOp::create(
b, i32Ty, operand);
656 Value cF32MantissaWidth =
createConst(op->getLoc(), i32Ty, 23, rewriter);
657 Value f32SignExp = arith::ShRUIOp::create(
b, f32Bits, cF32MantissaWidth);
658 Value exp8Bits = arith::TruncIOp::create(
b, i8Ty, f32SignExp);
659 Value
result = arith::BitcastOp::create(
b, resultTy, exp8Bits);
665struct ScalingExtFOpConverter :
public OpRewritePattern<arith::ScalingExtFOp> {
667 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
668 PatternRewriter &rewriter)
const final {
669 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
670 Value inputOperand = op.getIn();
671 Value scaleOperand = op.getScale();
672 Type scaleTy = scaleOperand.
getType();
676 scaleETy =
b.getF8E8M0Type();
678 scaleOperand = arith::TruncFOp::create(
b, scaleTy, scaleOperand,
nullptr,
679 op.getFastmathAttr());
682 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
684 op,
"scaling_extf is using scales of type which can not be converted "
687 Type resultTy = op.getType();
691 arith::ExtFOp::create(
b, resultTy, scaleOperand, op.getFastmathAttr());
693 arith::ExtFOp::create(
b, resultTy, inputOperand, op.getFastmathAttr());
695 arith::MulFOp::create(
b, inputExt, scaleExt, op.getFastmathAttr());
706struct ScalingTruncFOpConverter
709 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
710 PatternRewriter &rewriter)
const final {
711 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
712 Value inputOperand = op.getIn();
713 Value scaleOperand = op.getScale();
714 Type scaleTy = scaleOperand.
getType();
718 scaleETy =
b.getF8E8M0Type();
720 scaleOperand = arith::TruncFOp::create(
b, scaleTy, scaleOperand,
nullptr,
721 op.getFastmathAttr());
723 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
725 op,
"scaling_truncf is using scales type which can not be converted "
728 Type resultTy = op.getType();
729 Type inputTy = inputOperand.
getType();
733 arith::ExtFOp::create(
b, inputTy, scaleOperand, op.getFastmathAttr());
734 Value
result = arith::DivFOp::create(
b, inputOperand, scaleOperand,
735 op.getFastmathAttr());
736 Value resultCast = arith::TruncFOp::create(
737 b, resultTy,
result, op.getRoundingmodeAttr(), op.getFastmathAttr());
759struct FlushDenormalsOpConverter
762 LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
763 PatternRewriter &rewriter)
const final {
764 Location loc = op.getLoc();
765 ImplicitLocOpBuilder
b(loc, rewriter);
766 Value operand = op.getOperand();
767 Type operandTy = operand.
getType();
772 const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
775 if (!llvm::APFloatBase::isIEEELikeFP(sem))
777 op,
"only IEEE-like floating-point types are supported");
779 unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
780 unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
783 if (precision < 1 || precision > totalBits)
785 unsigned mantissaBits = precision - 1;
786 unsigned expBits = totalBits - 1 - mantissaBits;
787 if (expBits == 0 || mantissaBits == 0)
789 op,
"degenerate float encoding has no exponent or mantissa");
793 Value bits = arith::BitcastOp::create(
b, intTy, operand);
795 APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
796 APInt clearMantissaMaskVal = ~APInt::getLowBitsSet(totalBits, mantissaBits);
797 APInt zeroVal = APInt::getZero(totalBits);
799 Value clearMantissaMask =
804 Value expField = arith::AndIOp::create(
b, bits, expMask);
806 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, expField, zero);
809 Value cleared = arith::AndIOp::create(
b, bits, clearMantissaMask);
810 Value resultBits = arith::SelectOp::create(
b, expIsZero, cleared, bits);
811 Value
result = arith::BitcastOp::create(
b, operandTy, resultBits);
818struct ArithExpandOpsPass
820 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
822 void runOnOperation()
override {
826 arith::populateArithExpandOpsPatterns(patterns);
828 target.addLegalDialect<arith::ArithDialect>();
829 target.addLegalDialect<vector::VectorDialect>();
844 arith::ScalingExtFOp,
845 arith::ScalingTruncFOp
849 arith::populateExpandBFloat16Patterns(patterns);
851 arith::populateExpandF8E8M0Patterns(patterns);
853 arith::populateExpandF4E2M1Patterns(patterns);
854 if (includeFlushDenormals) {
855 arith::populateExpandFlushDenormalsPatterns(patterns);
858 target.addDynamicallyLegalOp<arith::FlushDenormalsOp>(
859 [](arith::FlushDenormalsOp op) {
864 return !llvm::APFloatBase::isIEEELikeFP(
865 floatTy.getFloatSemantics());
869 target.addDynamicallyLegalOp<arith::ExtFOp>(
870 [=](arith::ExtFOp op) {
873 bool legalTypes =
true;
877 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
879 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
883 target.addDynamicallyLegalOp<arith::TruncFOp>(
884 [=](arith::TruncFOp op) {
887 bool legalTypes =
true;
891 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
893 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
898 if (
failed(applyPartialConversion(getOperation(),
target,
899 std::move(patterns))))
909 .
add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
914 patterns.
add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
919 patterns.
add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
924 patterns.
add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
930 patterns.
add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
936 patterns.
add<FlushDenormalsOpConverter>(patterns.
getContext());
944 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
945 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
946 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
947 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
948 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
949 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
950 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
951 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArithExpandOpsPassBase Base
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns)
Add patterns to expand arith.flush_denormals into integer arithmetic (bitcast + bit masks + compare +...
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...