24#define GEN_PASS_DEF_MATHEXPANDOPSPASS
25#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
31 bool losesInfo =
false;
34 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
35 APFloat::rmNearestTiesToEven, &losesInfo);
36 auto attr =
b.getFloatAttr(eltType, value);
37 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
38 return arith::ConstantOp::create(
b, loc,
42 return arith::ConstantOp::create(
b, loc, attr);
54 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
55 return arith::ConstantOp::create(
b, loc,
59 return arith::ConstantOp::create(
b, loc, attr);
64 Type i64Ty =
b.getI64Type();
65 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
66 i64Ty = shapedTy.clone(i64Ty);
67 Value fixedConvert = arith::FPToSIOp::create(
b, i64Ty, operand);
68 Value fpFixedConvert = arith::SIToFPOp::create(
b, opType, fixedConvert);
71 return math::CopySignOp::create(
b, fpFixedConvert, operand);
77 Value operand = op.getOperand();
80 Value exp = math::ExpOp::create(
b, operand);
81 Value neg = arith::NegFOp::create(
b, operand);
82 Value nexp = math::ExpOp::create(
b, neg);
83 Value sub = arith::SubFOp::create(
b, exp, nexp);
85 Value res = arith::MulFOp::create(
b, sub, half);
93 Value operand = op.getOperand();
96 Value exp = math::ExpOp::create(
b, operand);
97 Value neg = arith::NegFOp::create(
b, operand);
98 Value nexp = math::ExpOp::create(
b, neg);
99 Value add = arith::AddFOp::create(
b, exp, nexp);
101 Value res = arith::MulFOp::create(
b,
add, half);
115 auto floatType = op.getOperand().getType();
122 Value isNegative = arith::CmpFOp::create(
123 rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
124 Value isNegativeFloat =
125 arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
126 Value isNegativeTimesNegTwo =
127 arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
128 Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
131 Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
134 Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
135 Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
136 Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
137 Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
138 Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
149 Value operand = op.getOperand();
151 Value sin = math::SinOp::create(
b, type, operand);
152 Value cos = math::CosOp::create(
b, type, operand);
153 Value div = arith::DivFOp::create(
b, type, sin, cos);
162 Value operand = op.getOperand();
166 Value fma = math::FmaOp::create(
b, operand, operand, one);
167 Value sqrt = math::SqrtOp::create(
b, fma);
168 Value add = arith::AddFOp::create(
b, operand, sqrt);
178 Value operand = op.getOperand();
182 Value fma = math::FmaOp::create(
b, operand, operand, negOne);
183 Value sqrt = math::SqrtOp::create(
b, fma);
184 Value add = arith::AddFOp::create(
b, operand, sqrt);
194 Value operand = op.getOperand();
198 Value add = arith::AddFOp::create(
b, operand, one);
199 Value neg = arith::NegFOp::create(
b, operand);
200 Value sub = arith::AddFOp::create(
b, neg, one);
204 Value res = arith::MulFOp::create(
b, log, half);
211 Value operandA = op.getOperand(0);
212 Value operandB = op.getOperand(1);
213 Value operandC = op.getOperand(2);
214 Type type = op.getType();
215 Value mult = arith::MulFOp::create(
b, type, operandA, operandB);
216 Value add = arith::AddFOp::create(
b, type, mult, operandC);
228 auto shapedType = dyn_cast<ShapedType>(op.getType());
229 if (shapedType && !shapedType.hasStaticShape())
233 Value operand = op.getOperand();
236 FloatType floatTy = llvm::dyn_cast<FloatType>(operandETy);
237 const llvm::fltSemantics &semantics = floatTy.getFloatSemantics();
239 unsigned bitWidth = floatTy.getWidth();
240 unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
241 const int bias = (&semantics == &APFloat::Float8E8M0FNU())
242 ? -semantics.minExponent
243 : -(semantics.minExponent - 1);
244 bool hasNegativeZeroNaNEncoding =
245 (semantics.nanEncoding == llvm::fltNanEncoding::NegativeZero);
248 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
249 iTy = shapedTy.clone(iTy);
258 Value operandBitcast = arith::BitcastOp::create(
b, iTy, operand);
260 op->getLoc(), iTy,
static_cast<int64_t>((1ull << (bitWidth - 1)) - 1),
b);
261 Value unsignedBits = arith::AndIOp::create(
b, operandBitcast, cMask);
264 static_cast<int64_t>((uint64_t(bias + mantissaWidth)) << mantissaWidth),
266 Value isLargeExp = arith::CmpIOp::create(
b, arith::CmpIPredicate::uge,
267 unsignedBits, cThreshold);
268 Value isSpecialValOrLargeVal = isLargeExp;
272 if (hasNegativeZeroNaNEncoding) {
274 op->getLoc(), iTy,
static_cast<int64_t>(1ull << (bitWidth - 1)),
b);
275 Value isNegZeroEncoding = arith::CmpIOp::create(
276 b, arith::CmpIPredicate::eq, operandBitcast, cNegZeroBits);
277 isSpecialValOrLargeVal =
278 arith::OrIOp::create(
b, isLargeExp, isNegZeroEncoding);
287 Value gtCheck = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGT, operand,
290 arith::SelectOp::create(
b, op->getLoc(), gtCheck, one, zero);
292 Value add = arith::AddFOp::create(
b, opType, fpFixedConvert, incrValue);
293 Value ret = arith::SelectOp::create(
b, isSpecialValOrLargeVal, operand,
add);
305 Value base = op.getOperand(0);
306 Value power = op.getOperand(1);
309 auto convertFPowItoPowf = [&]() -> LogicalResult {
310 Value castPowerToFp =
311 arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
312 Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
320 return convertFPowItoPowf();
324 return convertFPowItoPowf();
326 int64_t powerInt = value.getSExtValue();
327 bool isNegative = powerInt < 0;
328 int64_t absPower = std::abs(powerInt);
332 while (absPower > 0) {
334 res = arith::MulFOp::create(
b, baseType, base, res);
336 base = arith::MulFOp::create(
b, baseType, base, base);
342 .getFloatSemantics();
345 APFloat::getZero(sem,
false), rewriter);
348 APFloat::getZero(sem,
true), rewriter);
351 APFloat::getInf(sem,
false), rewriter);
354 APFloat::getInf(sem,
true), rewriter);
356 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, res, zero);
357 Value negZeroEqCheck =
358 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, res, negZero);
359 res = arith::DivFOp::create(
b, baseType, one, res);
361 arith::SelectOp::create(
b, op->getLoc(), zeroEqCheck, posInfinity, res);
362 res = arith::SelectOp::create(
b, op->getLoc(), negZeroEqCheck, negInfinity,
375 Value operandA = op.getOperand(0);
376 Value operandB = op.getOperand(1);
377 auto typeA = operandA.
getType();
378 auto typeB = operandB.
getType();
384 return arith::MulFOp::create(
b, x, y);
387 if (valueB.isZero()) {
393 if (valueB.isExactlyValue(1.0)) {
398 if (valueB.isExactlyValue(-1.0)) {
401 Value div = arith::DivFOp::create(
b, one, operandA);
405 if (valueB.isExactlyValue(0.5)) {
407 Value sqrt = math::SqrtOp::create(
b, operandA);
411 if (valueB.isExactlyValue(-0.5)) {
413 Value rsqrt = math::RsqrtOp::create(
b, operandA);
417 if (valueB.isExactlyValue(2.0)) {
419 rewriter.
replaceOp(op, mulf(operandA, operandA));
422 if (valueB.isExactlyValue(-2.0)) {
426 Value div = arith::DivFOp::create(
b, one, mulf(operandA, operandA));
430 if (valueB.isExactlyValue(3.0)) {
431 rewriter.
replaceOp(op, mulf(mulf(operandA, operandA), operandA));
436 Value logA = math::LogOp::create(
b, operandA);
437 Value mult = arith::MulFOp::create(
b, operandB, logA);
438 Value expResult = math::ExpOp::create(
b, mult);
450 Value operand = op.getOperand();
453 Value mult = arith::MulFOp::create(
b, opType, operand, ln2);
454 Value exp = math::ExpOp::create(
b, op->getLoc(), mult);
463 Value operand = op.getOperand();
467 if (!opEType.
isF32()) {
471 Type i32Ty =
b.getI32Type();
472 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
473 i32Ty = shapedTy.clone(i32Ty);
480 Value incrValue = math::CopySignOp::create(
b, half, operand);
481 Value add = arith::AddFOp::create(
b, opType, operand, incrValue);
504 Value operandBitcast = arith::BitcastOp::create(
b, i32Ty, operand);
505 Value operandExp = arith::AndIOp::create(
506 b, arith::ShRUIOp::create(
b, operandBitcast, c23), expMask);
507 Value operandBiasedExp = arith::SubIOp::create(
b, operandExp, c127);
508 Value isSpecialValOrLargeVal = arith::CmpIOp::create(
509 b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
511 Value result = arith::SelectOp::create(
b, isSpecialValOrLargeVal, operand,
521 auto operand = op.getOperand();
522 auto operandTy = operand.getType();
527 if (!eTy.isIntOrFloat()) {
528 return rewriter.
notifyMatchFailure(op,
"ctlz expansion only supports int or float types");
531 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
535 uint64_t allbits = -1;
537 allbits = allbits >> (64 - bitwidth);
542 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
545 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
547 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
549 Value add = arith::AddIOp::create(rewriter, loc, count, bits);
550 Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
552 x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
553 count = arith::SelectOp::create(rewriter, loc, pred,
add, count);
557 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
561 Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
571 auto operand = op.getOperand();
572 Type operandTy = operand.getType();
573 Type resultTy = op.getType();
577 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
581 Type fTy = operandTy;
583 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
584 iTy = shapedTy.clone(iTy);
589 unsigned mantissaWidth =
590 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
591 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
608 Value operandBitcast = arith::BitcastOp::create(
b, iTy, operand);
609 Value round = math::RoundOp::create(
b, operand);
610 Value roundBitcast = arith::BitcastOp::create(
b, iTy, round);
613 Value operandExp = arith::AndIOp::create(
614 b, arith::ShRUIOp::create(
b, operandBitcast, c23), expMask);
615 Value operandBiasedExp = arith::SubIOp::create(
b, operandExp, c127);
616 Value roundExp = arith::AndIOp::create(
617 b, arith::ShRUIOp::create(
b, roundBitcast, c23), expMask);
618 Value roundBiasedExp = arith::SubIOp::create(
b, roundExp, c127);
622 Value clampedShift = arith::MaxSIOp::create(
b, shift, c0);
623 clampedShift = arith::MinSIOp::create(
b, clampedShift, c31);
624 return arith::ShRUIOp::create(
b, x, clampedShift);
627 auto maskMantissa = [&](
Value mantissa,
629 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
630 return arith::AndIOp::create(
b, mantissa, shiftedMantissaMask);
647 Value roundBiasedExpEq0 =
648 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
649 Value roundBiasedExpMinus1 = arith::SubIOp::create(
b, roundBiasedExp, c1);
650 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
651 Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
652 b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
653 roundIsNotEvenOrSpecialVal =
654 arith::OrIOp::create(
b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
663 Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
664 b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
665 Value expectedOperandMaskedMantissa = arith::SelectOp::create(
666 b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
667 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
668 Value operandIsHalfway =
669 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, operandMaskedMantissa,
670 expectedOperandMaskedMantissa);
672 Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
673 b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
674 Value operandBiasedExpLt23 = arith::CmpIOp::create(
675 b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
677 arith::AndIOp::create(
b, operandIsHalfway, operandBiasedExpLt23);
679 arith::AndIOp::create(
b, operandIsHalfway, operandBiasedExpGeNeg1);
683 Value sign = math::CopySignOp::create(
b, c1Float, operand);
684 Value roundShifted = arith::SubFOp::create(
b, round, sign);
688 arith::AndIOp::create(
b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
689 Value result = arith::SelectOp::create(
b, needsShift, roundShifted, round);
702 auto operand = op.getOperand();
703 auto operandTy = operand.getType();
705 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
706 if (shapedOperandType && !shapedOperandType.hasStaticShape())
710 if (!isa<FloatType>(eTy))
715 auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
723 auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
724 op.getMax(), op.getFastmath());
732 auto filter = [&](StringRef name) {
736 assert(
"math" == MathDialect::getDialectNamespace());
737 name.consume_front(
"math.");
738 return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
740 if (filter(CountLeadingZerosOp::getOperationName()))
742 if (filter(SinhOp::getOperationName()))
744 if (filter(CoshOp::getOperationName()))
746 if (filter(TanOp::getOperationName()))
748 if (filter(TanhOp::getOperationName()))
750 if (filter(AsinhOp::getOperationName()))
752 if (filter(AcoshOp::getOperationName()))
754 if (filter(AtanhOp::getOperationName()))
756 if (filter(FmaOp::getOperationName()))
758 if (filter(CeilOp::getOperationName()))
760 if (filter(Exp2Op::getOperationName()))
762 if (filter(PowFOp::getOperationName()))
764 if (filter(FPowIOp::getOperationName()))
766 if (filter(RoundOp::getOperationName()))
768 if (filter(RoundEvenOp::getOperationName()))
770 if (filter(RsqrtOp::getOperationName()))
772 if (filter(ClampFOp::getOperationName()))
780struct MathExpandOpsPass final
781 : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
782 using MathExpandOpsPassBase::MathExpandOpsPassBase;
784 void runOnOperation()
override {
787 llvm::to_vector_of<StringRef>(opMnemonics);
790 return signalPassFailure();
Attributes are known-constant values of operations.
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef< StringRef > opMnemonics={})
Adds patterns to expand math operations into other more fundamental operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...