24 #define GEN_PASS_DEF_MATHEXPANDOPSPASS
25 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
31 bool losesInfo =
false;
34 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
35 APFloat::rmNearestTiesToEven, &losesInfo);
37 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
38 return arith::ConstantOp::create(b, loc,
42 return arith::ConstantOp::create(b, loc, attr);
54 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
55 return arith::ConstantOp::create(b, loc,
59 return arith::ConstantOp::create(b, loc, attr);
65 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
66 i64Ty = shapedTy.clone(i64Ty);
67 Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
68 Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
71 return math::CopySignOp::create(b, fpFixedConvert, operand);
77 Value operand = op.getOperand();
80 Value exp = math::ExpOp::create(b, operand);
81 Value neg = arith::NegFOp::create(b, operand);
82 Value nexp = math::ExpOp::create(b, neg);
83 Value sub = arith::SubFOp::create(b, exp, nexp);
85 Value res = arith::MulFOp::create(b, sub, half);
93 Value operand = op.getOperand();
96 Value exp = math::ExpOp::create(b, operand);
97 Value neg = arith::NegFOp::create(b, operand);
98 Value nexp = math::ExpOp::create(b, neg);
99 Value add = arith::AddFOp::create(b, exp, nexp);
101 Value res = arith::MulFOp::create(b,
add, half);
115 auto floatType = op.getOperand().getType();
122 Value isNegative = arith::CmpFOp::create(
123 rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
124 Value isNegativeFloat =
125 arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
126 Value isNegativeTimesNegTwo =
127 arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
128 Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
131 Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
134 Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
135 Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
136 Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
137 Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
138 Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
149 Value operand = op.getOperand();
151 Value sin = math::SinOp::create(b, type, operand);
152 Value cos = math::CosOp::create(b, type, operand);
153 Value div = arith::DivFOp::create(b, type, sin, cos);
162 Value operand = op.getOperand();
166 Value fma = math::FmaOp::create(b, operand, operand, one);
167 Value sqrt = math::SqrtOp::create(b, fma);
168 Value add = arith::AddFOp::create(b, operand, sqrt);
169 Value res = math::LogOp::create(b,
add);
178 Value operand = op.getOperand();
182 Value fma = math::FmaOp::create(b, operand, operand, negOne);
183 Value sqrt = math::SqrtOp::create(b, fma);
184 Value add = arith::AddFOp::create(b, operand, sqrt);
185 Value res = math::LogOp::create(b,
add);
194 Value operand = op.getOperand();
198 Value add = arith::AddFOp::create(b, operand, one);
199 Value neg = arith::NegFOp::create(b, operand);
200 Value sub = arith::AddFOp::create(b, neg, one);
201 Value div = arith::DivFOp::create(b,
add, sub);
202 Value log = math::LogOp::create(b, div);
204 Value res = arith::MulFOp::create(b, log, half);
211 Value operandA = op.getOperand(0);
212 Value operandB = op.getOperand(1);
213 Value operandC = op.getOperand(2);
214 Type type = op.getType();
215 Value mult = arith::MulFOp::create(b, type, operandA, operandB);
216 Value add = arith::AddFOp::create(b, type, mult, operandC);
228 auto shapedType = dyn_cast<ShapedType>(op.getType());
229 if (shapedType && !shapedType.hasStaticShape())
233 Value operand = op.getOperand();
241 Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
244 arith::SelectOp::create(b, op->
getLoc(), gtCheck, one, zero);
246 Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
258 Value base = op.getOperand(0);
259 Value power = op.getOperand(1);
262 auto convertFPowItoPowf = [&]() -> LogicalResult {
263 Value castPowerToFp =
264 arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
265 Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
273 return convertFPowItoPowf();
277 return convertFPowItoPowf();
279 int64_t powerInt = value.getSExtValue();
280 bool isNegative = powerInt < 0;
281 int64_t absPower =
std::abs(powerInt);
285 while (absPower > 0) {
287 res = arith::MulFOp::create(b, baseType, base, res);
289 base = arith::MulFOp::create(b, baseType, base, base);
295 .getFloatSemantics();
304 APFloat::getInf(sem,
false), rewriter);
307 APFloat::getInf(sem,
true), rewriter);
309 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
310 Value negZeroEqCheck =
311 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
312 res = arith::DivFOp::create(b, baseType, one, res);
314 arith::SelectOp::create(b, op->
getLoc(), zeroEqCheck, posInfinity, res);
315 res = arith::SelectOp::create(b, op->
getLoc(), negZeroEqCheck, negInfinity,
328 Value operandA = op.getOperand(0);
329 Value operandB = op.getOperand(1);
330 auto typeA = operandA.
getType();
331 auto typeB = operandB.
getType();
337 return arith::MulFOp::create(b, x, y);
340 if (valueB.isZero()) {
346 if (valueB.isExactlyValue(1.0)) {
351 if (valueB.isExactlyValue(-1.0)) {
354 Value div = arith::DivFOp::create(b, one, operandA);
358 if (valueB.isExactlyValue(0.5)) {
360 Value sqrt = math::SqrtOp::create(b, operandA);
364 if (valueB.isExactlyValue(-0.5)) {
366 Value rsqrt = math::RsqrtOp::create(b, operandA);
370 if (valueB.isExactlyValue(2.0)) {
372 rewriter.
replaceOp(op, mulf(operandA, operandA));
375 if (valueB.isExactlyValue(-2.0)) {
379 Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
383 if (valueB.isExactlyValue(3.0)) {
384 rewriter.
replaceOp(op, mulf(mulf(operandA, operandA), operandA));
389 Value logA = math::LogOp::create(b, operandA);
390 Value mult = arith::MulFOp::create(b, operandB, logA);
391 Value expResult = math::ExpOp::create(b, mult);
403 Value operand = op.getOperand();
406 Value mult = arith::MulFOp::create(b, opType, operand, ln2);
407 Value exp = math::ExpOp::create(b, op->
getLoc(), mult);
416 Value operand = op.getOperand();
420 if (!opEType.
isF32()) {
425 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
426 i32Ty = shapedTy.clone(i32Ty);
433 Value incrValue = math::CopySignOp::create(b, half, operand);
434 Value add = arith::AddFOp::create(b, opType, operand, incrValue);
457 Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
458 Value operandExp = arith::AndIOp::create(
459 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
460 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
461 Value isSpecialValOrLargeVal = arith::CmpIOp::create(
462 b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
464 Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
474 auto operand = op.getOperand();
475 auto operandTy = operand.getType();
479 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
483 uint64_t allbits = -1;
485 allbits = allbits >> (64 - bitwidth);
490 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
493 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
495 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
497 Value add = arith::AddIOp::create(rewriter, loc, count, bits);
498 Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
500 x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
501 count = arith::SelectOp::create(rewriter, loc, pred,
add, count);
505 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
509 Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
519 auto operand = op.getOperand();
520 Type operandTy = operand.getType();
521 Type resultTy = op.getType();
525 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
529 Type fTy = operandTy;
531 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
532 iTy = shapedTy.clone(iTy);
537 unsigned mantissaWidth =
538 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
539 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
556 Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
557 Value round = math::RoundOp::create(b, operand);
558 Value roundBitcast = arith::BitcastOp::create(b, iTy,
round);
561 Value operandExp = arith::AndIOp::create(
562 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
563 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
564 Value roundExp = arith::AndIOp::create(
565 b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
566 Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
570 Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
571 clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
572 return arith::ShRUIOp::create(b, x, clampedShift);
575 auto maskMantissa = [&](
Value mantissa,
577 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
578 return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
595 Value roundBiasedExpEq0 =
596 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
597 Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
598 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
599 Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
600 b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
601 roundIsNotEvenOrSpecialVal =
602 arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
611 Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
612 b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
613 Value expectedOperandMaskedMantissa = arith::SelectOp::create(
614 b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
615 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
616 Value operandIsHalfway =
617 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
618 expectedOperandMaskedMantissa);
620 Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
621 b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
622 Value operandBiasedExpLt23 = arith::CmpIOp::create(
623 b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
625 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
627 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
631 Value sign = math::CopySignOp::create(b, c1Float, operand);
632 Value roundShifted = arith::SubFOp::create(b,
round, sign);
636 arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
637 Value result = arith::SelectOp::create(b, needsShift, roundShifted,
round);
641 result = math::CopySignOp::create(b, result, operand);
650 auto operand = op.getOperand();
651 auto operandTy = operand.getType();
653 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
654 if (shapedOperandType && !shapedOperandType.hasStaticShape())
658 if (!isa<FloatType>(eTy))
663 auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
671 auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
672 op.getMin(), op.getFastmath());
680 auto filter = [&](StringRef name) {
684 assert(
"math" == MathDialect::getDialectNamespace());
685 name.consume_front(
"math.");
686 return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
688 if (filter(CountLeadingZerosOp::getOperationName()))
690 if (filter(SinhOp::getOperationName()))
692 if (filter(CoshOp::getOperationName()))
694 if (filter(TanOp::getOperationName()))
696 if (filter(TanhOp::getOperationName()))
698 if (filter(AsinhOp::getOperationName()))
700 if (filter(AcoshOp::getOperationName()))
702 if (filter(AtanhOp::getOperationName()))
704 if (filter(FmaOp::getOperationName()))
706 if (filter(CeilOp::getOperationName()))
708 if (filter(Exp2Op::getOperationName()))
710 if (filter(PowFOp::getOperationName()))
712 if (filter(FPowIOp::getOperationName()))
714 if (filter(RoundOp::getOperationName()))
716 if (filter(RoundEvenOp::getOperationName()))
718 if (filter(RsqrtOp::getOperationName()))
720 if (filter(ClampFOp::getOperationName()))
728 struct MathExpandOpsPass final
729 : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
730 using MathExpandOpsPassBase::MathExpandOpsPassBase;
732 void runOnOperation()
override {
735 llvm::to_vector_of<StringRef>(opMnemonics);
738 return signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef< StringRef > opMnemonics={})
Adds patterns to expand math operations into other more fundamental operations.
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...