19#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return arith::ConstantOp::create(rewriter, loc,
34 return arith::ConstantOp::create(rewriter, loc, attr);
41 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return arith::ConstantOp::create(rewriter, loc,
45 return arith::ConstantOp::create(rewriter, loc, attr);
52 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
53 return arith::ConstantOp::create(rewriter, loc,
57 return arith::ConstantOp::create(rewriter, loc, attr);
62 if (
auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
63 return shapedTy.clone(cloneTo);
74 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
75 PatternRewriter &rewriter)
const final {
76 Location loc = op.getLoc();
77 Value a = op.getLhs();
78 Value
b = op.getRhs();
81 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
83 Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
84 Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne,
b);
85 Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
86 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
100 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
101 PatternRewriter &rewriter)
const final {
102 Location loc = op.getLoc();
103 Type type = op.getType();
104 Value a = op.getLhs();
105 Value
b = op.getRhs();
110 Value quotient = arith::DivSIOp::create(rewriter, loc, a,
b);
111 Value
product = arith::MulIOp::create(rewriter, loc, quotient,
b);
112 Value notEqualDivisor = arith::CmpIOp::create(
113 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
115 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
117 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
120 Value signEqual = arith::CmpIOp::create(
121 rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
123 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
125 Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
127 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
140struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
142 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
143 PatternRewriter &rewriter)
const final {
144 Location loc = op.getLoc();
145 Type type = op.getType();
146 Value a = op.getLhs();
147 Value
b = op.getRhs();
149 Value quotient = arith::DivSIOp::create(rewriter, loc, a,
b);
150 Value
product = arith::MulIOp::create(rewriter, loc, quotient,
b);
151 Value notEqualDivisor = arith::CmpIOp::create(
152 rewriter, loc, arith::CmpIPredicate::ne, a,
product);
155 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
157 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
160 Value signOpposite = arith::CmpIOp::create(
161 rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
163 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
165 Value minusOne =
createConst(loc, type, -1, rewriter);
166 Value quotientMinusOne =
167 arith::AddIOp::create(rewriter, loc, quotient, minusOne);
169 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
175template <
typename OpTy, arith::CmpIPredicate pred>
178 using OpRewritePattern<OpTy>::OpRewritePattern;
180 LogicalResult matchAndRewrite(OpTy op,
181 PatternRewriter &rewriter)
const final {
182 Value
lhs = op.getLhs();
183 Value
rhs = op.getRhs();
185 Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred,
lhs,
rhs);
186 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp,
lhs,
rhs);
191template <
typename OpTy, arith::CmpFPredicate pred>
194 using OpRewritePattern<OpTy>::OpRewritePattern;
196 LogicalResult matchAndRewrite(OpTy op,
197 PatternRewriter &rewriter)
const final {
198 Value
lhs = op.getLhs();
199 Value
rhs = op.getRhs();
201 Location loc = op.getLoc();
203 static_assert(pred == arith::CmpFPredicate::UGT ||
204 pred == arith::CmpFPredicate::ULT,
205 "pred must be either UGT or ULT");
206 Value cmp = arith::CmpFOp::create(rewriter, loc, pred,
lhs,
rhs);
207 Value select = arith::SelectOp::create(rewriter, loc, cmp,
lhs,
rhs);
210 Value isNaN = arith::CmpFOp::create(rewriter, loc,
211 arith::CmpFPredicate::UNO,
rhs,
rhs);
212 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN,
rhs, select);
217template <
typename OpTy, arith::CmpFPredicate pred>
220 using OpRewritePattern<OpTy>::OpRewritePattern;
222 LogicalResult matchAndRewrite(OpTy op,
223 PatternRewriter &rewriter)
const final {
224 Value
lhs = op.getLhs();
225 Value
rhs = op.getRhs();
227 Location loc = op.getLoc();
229 static_assert(pred == arith::CmpFPredicate::UGT ||
230 pred == arith::CmpFPredicate::ULT,
231 "pred must be either UGT or ULT");
232 Value cmp = arith::CmpFOp::create(rewriter, loc, pred,
lhs,
rhs);
233 Value select = arith::SelectOp::create(rewriter, loc, cmp,
lhs,
rhs);
236 Value isNaN = arith::CmpFOp::create(rewriter, loc,
237 arith::CmpFPredicate::UNO,
lhs,
lhs);
238 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN,
rhs, select);
245 LogicalResult matchAndRewrite(arith::ExtFOp op,
246 PatternRewriter &rewriter)
const final {
247 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
248 auto operand = op.getOperand();
249 Type operandTy = operand.getType();
250 Type resultTy = op.getType();
255 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
261 Value bitcast = arith::BitcastOp::create(
b, i16Ty, operand);
262 Value exti = arith::ExtUIOp::create(
b, i32Ty, bitcast);
264 Value c16 =
createConst(op.getLoc(), i32Ty, 16, rewriter);
265 Value shl = arith::ShLIOp::create(
b, exti, c16);
266 Value
result = arith::BitcastOp::create(
b, resultTy, shl);
268 rewriter.replaceOp(op,
result);
273struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
275 LogicalResult matchAndRewrite(arith::TruncFOp op,
276 PatternRewriter &rewriter)
const final {
277 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
278 auto operand = op.getOperand();
279 Type operandTy = operand.getType();
280 Type resultTy = op.getType();
285 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
288 if (op.getRoundingmodeAttr()) {
289 return rewriter.notifyMatchFailure(
290 op,
"only applicable to default rounding mode.");
310 arith::CmpFOp::create(
b, arith::CmpFPredicate::UNE, operand, operand);
312 Value c7FFF =
createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
314 Value c7FC0I16 =
createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
316 Value c16 =
createConst(op.getLoc(), i32Ty, 16, rewriter);
317 Value c1 =
createConst(op.getLoc(), i32Ty, 1, rewriter);
319 Value bitcast = arith::BitcastOp::create(
b, i32Ty, operand);
322 arith::AndIOp::create(
b, arith::ShRUIOp::create(
b, bitcast, c16), c1);
325 Value roundingBias = arith::AddIOp::create(
b, bit16, c7FFF);
332 Value biased = arith::AddIOp::create(
b, bitcast, roundingBias);
335 Value biasedAndShifted = arith::ShRUIOp::create(
b, biased, c16);
336 Value normalCaseResultI16 =
337 arith::TruncIOp::create(
b, i16Ty, biasedAndShifted);
341 arith::SelectOp::create(
b, isNan, c7FC0I16, normalCaseResultI16);
342 Value
result = arith::BitcastOp::create(
b, resultTy, select);
343 rewriter.replaceOp(op,
result);
379 LogicalResult matchAndRewrite(arith::ExtFOp op,
380 PatternRewriter &rewriter)
const final {
381 Location loc = op.getLoc();
382 ImplicitLocOpBuilder
b(loc, rewriter);
383 Value operand = op.getOperand();
384 Type operandTy = operand.
getType();
385 Type resultTy = op.getType();
389 if (!isa<Float4E2M1FNType>(operandETy))
390 return rewriter.notifyMatchFailure(op,
"not a ext of F4E2M1FN");
395 Value i4Bits = arith::BitcastOp::create(
b, i4Ty, operand);
397 Value c0x0 =
createConst(loc, i4Ty, 0x0, rewriter);
398 Value c0x1 =
createConst(loc, i4Ty, 0x1, rewriter);
399 Value c0x2 =
createConst(loc, i4Ty, 0x2, rewriter);
400 Value c0x4 =
createConst(loc, i4Ty, 0x4, rewriter);
401 Value c0x7 =
createConst(loc, i4Ty, 0x7, rewriter);
403 Value i4BitsNoSign = arith::AndIOp::create(
b, i4Bits, c0x7);
406 Value c0x00000014 =
createConst(loc, i32Ty, 0x14, rewriter);
407 Value bits1To24 = arith::ShLIOp::create(
b, i4BitsNoSign, c0x2);
409 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
410 bits1To24 = arith::SelectOp::create(
b, isHalf, c0x0, bits1To24);
411 bits1To24 = arith::ExtUIOp::create(
b, i32Ty, bits1To24);
412 bits1To24 = arith::ShLIOp::create(
b, bits1To24, c0x00000014);
415 Value zeroExpBits =
createConst(loc, i32Ty, 0x00000000, rewriter);
416 Value highExpBits =
createConst(loc, i32Ty, 0x40000000, rewriter);
417 Value lowExpBits =
createConst(loc, i32Ty, 0x3f000000, rewriter);
419 arith::CmpIOp::create(
b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
421 arith::SelectOp::create(
b, useLargerExp, highExpBits, lowExpBits);
423 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
424 bits25To31 = arith::SelectOp::create(
b, zeroExp, zeroExpBits, bits25To31);
427 Value c0x80000000 =
createConst(loc, i32Ty, 0x80000000, rewriter);
428 Value c0x8 =
createConst(loc, i4Ty, 0x8, rewriter);
430 arith::CmpIOp::create(
b, arith::CmpIPredicate::uge, i4Bits, c0x8);
432 arith::SelectOp::create(
b, negative, c0x80000000, zeroExpBits);
435 Value bits1To31 = arith::AddIOp::create(
b, bits1To24, bits25To31);
436 Value bits1To32 = arith::AddIOp::create(
b, bits1To31, bit32);
437 Value
result = arith::BitcastOp::create(
b, f32Ty, bits1To32);
438 if (!isa<Float32Type>(resultETy))
441 rewriter.replaceOp(op,
result);
448 LogicalResult matchAndRewrite(arith::ExtFOp op,
449 PatternRewriter &rewriter)
const final {
450 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
451 Value operand = op.getOperand();
452 Type operandTy = operand.
getType();
453 Type resultTy = op.getType();
457 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
458 return rewriter.notifyMatchFailure(op,
"not a ext of F8E8M0FNU");
465 Value bitcast = arith::BitcastOp::create(
b, i8Ty, operand);
466 Value cF32MantissaWidth =
createConst(op->getLoc(), i32Ty, 23, rewriter);
467 Value exti = arith::ExtUIOp::create(
b, i32Ty, bitcast);
468 Value f32Bits = arith::ShLIOp::create(
b, exti, cF32MantissaWidth);
471 auto fastMath = op.getFastmathAttr();
472 bool NoNaN = fastMath
473 ? (fastMath.getValue() & arith::FastMathFlags::nnan) ==
474 arith::FastMathFlags::nnan
477 Value cF8NaN =
createConst(op.getLoc(), i8Ty, 0xff, rewriter);
478 Value cF32NaN =
createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
480 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
482 f32Bits = arith::SelectOp::create(
b, isNan, cF32NaN, f32Bits);
485 Value
result = arith::BitcastOp::create(
b, f32Ty, f32Bits);
487 result = arith::TruncFOp::create(
b, resultTy,
result,
nullptr,
488 op.getFastmathAttr());
490 result = arith::ExtFOp::create(
b, resultTy,
result, op.getFastmathAttr());
492 rewriter.replaceOp(op,
result);
527 LogicalResult matchAndRewrite(arith::TruncFOp op,
528 PatternRewriter &rewriter)
const final {
529 Location loc = op.getLoc();
530 ImplicitLocOpBuilder
b(loc, rewriter);
531 Value operand = op.getOperand();
532 Type operandTy = operand.
getType();
533 Type resultTy = op.getType();
542 if (!isa<Float4E2M1FNType>(resultETy))
543 return rewriter.notifyMatchFailure(op,
"not a trunc of F4E2M1FN");
544 if (!isa<Float32Type>(operandETy))
545 operand = arith::ExtFOp::create(
b, f32Ty, operand);
549 Value c0x00000016 =
createConst(loc, i32Ty, 22, rewriter);
550 Value c0x00 =
createConst(loc, i8Ty, 0x00, rewriter);
551 Value c0xff =
createConst(loc, i8Ty, 0xff, rewriter);
552 Value zeroExpBits =
createConst(loc, i32Ty, 0, rewriter);
557 Value operandClamped = arith::MinNumFOp::create(
b, cHigherBound, operand);
558 operandClamped = arith::MaxNumFOp::create(
b, cLowerBound, operandClamped);
559 Value f32Bits = arith::BitcastOp::create(
b, i32Ty, operandClamped);
562 Value cF32ExpManWidth =
createConst(loc, i32Ty, 31, rewriter);
563 Value f32Sign = arith::ShRUIOp::create(
b, f32Bits, cF32ExpManWidth);
564 Value f4Sign = arith::TruncIOp::create(
b, i4Ty, f32Sign);
565 Value f4Bits = arith::ShLIOp::create(
b, f4Sign, c0x3);
568 Value biasAdjustment =
createConst(loc, i32Ty, 0x7e, rewriter);
569 Value cF4MantissaWidth = c0x1;
570 Value cF32MantissaWidth =
createConst(loc, i32Ty, 23, rewriter);
571 Value f32SignExp = arith::ShRUIOp::create(
b, f32Bits, cF32MantissaWidth);
572 Value biasAdjustedSignExp =
573 arith::SubIOp::create(
b, f32SignExp, biasAdjustment);
574 Value f4Exp = arith::TruncIOp::create(
b, i4Ty, biasAdjustedSignExp);
575 f4Exp = arith::ShLIOp::create(
b, f4Exp, cF4MantissaWidth);
576 f4Bits = arith::AddIOp::create(
b, f4Bits, f4Exp);
579 Value cF32FirstBitMask =
createConst(loc, i32Ty, 0x400000, rewriter);
580 Value man1Bit = arith::AndIOp::create(
b, f32Bits, cF32FirstBitMask);
581 man1Bit = arith::ShRUIOp::create(
b, man1Bit, c0x00000016);
582 Value f4Man = arith::TruncIOp::create(
b, i4Ty, man1Bit);
583 f4Bits = arith::AddIOp::create(
b, f4Bits, f4Man);
586 Value cF32MantissaMask =
createConst(loc, i32Ty, 0x7fffff, rewriter);
587 Value f8Exp = arith::TruncIOp::create(
b, i8Ty, biasAdjustedSignExp);
589 arith::CmpIOp::create(
b, arith::CmpIPredicate::sle, f8Exp, c0x00);
591 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, f8Exp, c0xff);
592 Value man23Bits = arith::AndIOp::create(
b, f32Bits, cF32MantissaMask);
593 Value isNonZeroMan = arith::CmpIOp::create(
b, arith::CmpIPredicate::ugt,
594 man23Bits, zeroExpBits);
595 Value roundToHalf = arith::AndIOp::create(
b, isNegOneExp, isNonZeroMan);
597 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, f8Exp, c0x00);
598 Value subnormalF4Bits =
createConst(loc, i4Ty, 0xf, rewriter);
599 Value halfF4Bits =
createConst(loc, i4Ty, 0x0, rewriter);
601 arith::SelectOp::create(
b, isSubnormal, subnormalF4Bits, f4Bits);
602 subResult = arith::SelectOp::create(
b, roundToHalf, halfF4Bits, subResult);
603 f4Bits = arith::SelectOp::create(
b, isZeroExp, f4Bits, subResult);
606 Value cF32Last22BitMask =
createConst(loc, i32Ty, 0x3fffff, rewriter);
607 Value cRound =
createConst(loc, i32Ty, 0x200000, rewriter);
608 Value man22Bits = arith::AndIOp::create(
b, f32Bits, cF32Last22BitMask);
610 arith::CmpIOp::create(
b, arith::CmpIPredicate::uge, man22Bits, cRound);
611 shouldRound = arith::OrIOp::create(
b, shouldRound, isSubnormal);
612 Value roundedF4Bits = arith::AddIOp::create(
b, f4Bits, c0x1);
613 f4Bits = arith::SelectOp::create(
b, shouldRound, roundedF4Bits, f4Bits);
615 Value
result = arith::BitcastOp::create(
b, resultTy, f4Bits);
616 rewriter.replaceOp(op,
result);
628 LogicalResult matchAndRewrite(arith::TruncFOp op,
629 PatternRewriter &rewriter)
const final {
630 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
631 Value operand = op.getOperand();
632 Type operandTy = operand.
getType();
634 Type resultTy = op.getType();
636 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
637 return rewriter.notifyMatchFailure(op,
"not a truncf to f8E8M0FNU");
640 if (op.getRoundingmodeAttr()) {
641 return rewriter.notifyMatchFailure(
642 op,
"only applicable to default rounding mode.");
650 operand = arith::ExtFOp::create(
b, f32Ty, operand, op.getFastmathAttr());
652 operand = arith::TruncFOp::create(
653 b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
655 Value f32Bits = arith::BitcastOp::create(
b, i32Ty, operand);
656 Value cF32MantissaWidth =
createConst(op->getLoc(), i32Ty, 23, rewriter);
657 Value f32SignExp = arith::ShRUIOp::create(
b, f32Bits, cF32MantissaWidth);
658 Value exp8Bits = arith::TruncIOp::create(
b, i8Ty, f32SignExp);
659 Value
result = arith::BitcastOp::create(
b, resultTy, exp8Bits);
660 rewriter.replaceOp(op,
result);
665struct ScalingExtFOpConverter :
public OpRewritePattern<arith::ScalingExtFOp> {
667 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
668 PatternRewriter &rewriter)
const final {
669 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
670 Value inputOperand = op.getIn();
671 Value scaleOperand = op.getScale();
672 Type scaleTy = scaleOperand.
getType();
676 scaleETy =
b.getF8E8M0Type();
678 scaleOperand = arith::TruncFOp::create(
b, scaleTy, scaleOperand,
nullptr,
679 op.getFastmathAttr());
682 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
683 return rewriter.notifyMatchFailure(
684 op,
"scaling_extf is using scales of type which can not be converted "
687 Type resultTy = op.getType();
691 arith::ExtFOp::create(
b, resultTy, scaleOperand, op.getFastmathAttr());
693 arith::ExtFOp::create(
b, resultTy, inputOperand, op.getFastmathAttr());
695 arith::MulFOp::create(
b, inputExt, scaleExt, op.getFastmathAttr());
696 rewriter.replaceOp(op,
result);
706struct ScalingTruncFOpConverter
709 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
710 PatternRewriter &rewriter)
const final {
711 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
712 Value inputOperand = op.getIn();
713 Value scaleOperand = op.getScale();
714 Type scaleTy = scaleOperand.
getType();
718 scaleETy =
b.getF8E8M0Type();
720 scaleOperand = arith::TruncFOp::create(
b, scaleTy, scaleOperand,
nullptr,
721 op.getFastmathAttr());
723 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
724 return rewriter.notifyMatchFailure(
725 op,
"scaling_truncf is using scales type which can not be converted "
728 Type resultTy = op.getType();
729 Type inputTy = inputOperand.
getType();
733 arith::ExtFOp::create(
b, inputTy, scaleOperand, op.getFastmathAttr());
734 Value
result = arith::DivFOp::create(
b, inputOperand, scaleOperand,
735 op.getFastmathAttr());
736 Value resultCast = arith::TruncFOp::create(
737 b, resultTy,
result, op.getRoundingmodeAttr(), op.getFastmathAttr());
738 rewriter.replaceOp(op, resultCast);
759struct FlushDenormalsOpConverter
762 LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
763 PatternRewriter &rewriter)
const final {
764 Location loc = op.getLoc();
765 ImplicitLocOpBuilder
b(loc, rewriter);
766 Value operand = op.getOperand();
767 Type operandTy = operand.
getType();
770 return rewriter.notifyMatchFailure(op,
"operand is not a float type");
772 const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
775 if (!llvm::APFloatBase::isIEEELikeFP(sem))
776 return rewriter.notifyMatchFailure(
777 op,
"only IEEE-like floating-point types are supported");
779 unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
780 unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
783 if (precision < 1 || precision > totalBits)
784 return rewriter.notifyMatchFailure(op,
"unexpected float semantics");
785 unsigned mantissaBits = precision - 1;
786 unsigned expBits = totalBits - 1 - mantissaBits;
787 if (expBits == 0 || mantissaBits == 0)
788 return rewriter.notifyMatchFailure(
789 op,
"degenerate float encoding has no exponent or mantissa");
793 Value bits = arith::BitcastOp::create(
b, intTy, operand);
795 APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
796 APInt clearMantissaMaskVal = ~APInt::getLowBitsSet(totalBits, mantissaBits);
797 APInt zeroVal = APInt::getZero(totalBits);
799 Value clearMantissaMask =
804 Value expField = arith::AndIOp::create(
b, bits, expMask);
806 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, expField, zero);
809 Value cleared = arith::AndIOp::create(
b, bits, clearMantissaMask);
810 Value resultBits = arith::SelectOp::create(
b, expIsZero, cleared, bits);
811 Value
result = arith::BitcastOp::create(
b, operandTy, resultBits);
813 rewriter.replaceOp(op,
result);
818struct ArithExpandOpsPass
819 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
820 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
822 void runOnOperation()
override {
826 arith::populateArithExpandOpsPatterns(patterns);
828 target.addLegalDialect<arith::ArithDialect>();
829 target.addLegalDialect<vector::VectorDialect>();
844 arith::ScalingExtFOp,
845 arith::ScalingTruncFOp
849 arith::populateExpandBFloat16Patterns(patterns);
851 arith::populateExpandF8E8M0Patterns(patterns);
853 arith::populateExpandF4E2M1Patterns(patterns);
854 if (includeFlushDenormals) {
855 arith::populateExpandFlushDenormalsPatterns(patterns);
858 target.addDynamicallyLegalOp<arith::FlushDenormalsOp>(
859 [](arith::FlushDenormalsOp op) {
864 return !llvm::APFloatBase::isIEEELikeFP(
865 floatTy.getFloatSemantics());
869 target.addDynamicallyLegalOp<arith::ExtFOp>(
870 [=](arith::ExtFOp op) {
873 bool legalTypes =
true;
877 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
879 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
883 target.addDynamicallyLegalOp<arith::TruncFOp>(
884 [=](arith::TruncFOp op) {
887 bool legalTypes =
true;
891 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
893 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
898 if (
failed(applyPartialConversion(getOperation(),
target,
899 std::move(patterns))))
909 .
add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
914 patterns.
add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
919 patterns.
add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
924 patterns.
add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
930 patterns.
add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
936 patterns.
add<FlushDenormalsOpConverter>(patterns.
getContext());
944 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
945 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
946 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
947 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
948 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
949 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
950 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
951 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns)
Add patterns to expand arith.flush_denormals into integer arithmetic (bitcast + bit masks + compare +...
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...