27 bool losesInfo =
false;
30 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
31 APFloat::rmNearestTiesToEven, &losesInfo);
33 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
34 return arith::ConstantOp::create(b, loc,
38 return arith::ConstantOp::create(b, loc, attr);
50 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
51 return arith::ConstantOp::create(b, loc,
55 return arith::ConstantOp::create(b, loc, attr);
61 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
62 i64Ty = shapedTy.clone(i64Ty);
63 Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
64 Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
67 return math::CopySignOp::create(b, fpFixedConvert, operand);
73 Value operand = op.getOperand();
76 Value exp = math::ExpOp::create(b, operand);
77 Value neg = arith::NegFOp::create(b, operand);
78 Value nexp = math::ExpOp::create(b, neg);
79 Value sub = arith::SubFOp::create(b, exp, nexp);
81 Value res = arith::MulFOp::create(b, sub, half);
89 Value operand = op.getOperand();
92 Value exp = math::ExpOp::create(b, operand);
93 Value neg = arith::NegFOp::create(b, operand);
94 Value nexp = math::ExpOp::create(b, neg);
95 Value add = arith::AddFOp::create(b, exp, nexp);
97 Value res = arith::MulFOp::create(b,
add, half);
111 auto floatType = op.getOperand().getType();
118 Value isNegative = arith::CmpFOp::create(
119 rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
120 Value isNegativeFloat =
121 arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
122 Value isNegativeTimesNegTwo =
123 arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
124 Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
127 Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
130 Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
131 Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
132 Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
133 Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
134 Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
145 Value operand = op.getOperand();
147 Value sin = math::SinOp::create(b, type, operand);
148 Value cos = math::CosOp::create(b, type, operand);
149 Value div = arith::DivFOp::create(b, type, sin, cos);
158 Value operand = op.getOperand();
162 Value fma = math::FmaOp::create(b, operand, operand, one);
163 Value sqrt = math::SqrtOp::create(b, fma);
164 Value add = arith::AddFOp::create(b, operand, sqrt);
165 Value res = math::LogOp::create(b,
add);
174 Value operand = op.getOperand();
178 Value fma = math::FmaOp::create(b, operand, operand, negOne);
179 Value sqrt = math::SqrtOp::create(b, fma);
180 Value add = arith::AddFOp::create(b, operand, sqrt);
181 Value res = math::LogOp::create(b,
add);
190 Value operand = op.getOperand();
194 Value add = arith::AddFOp::create(b, operand, one);
195 Value neg = arith::NegFOp::create(b, operand);
196 Value sub = arith::AddFOp::create(b, neg, one);
197 Value div = arith::DivFOp::create(b,
add, sub);
198 Value log = math::LogOp::create(b, div);
200 Value res = arith::MulFOp::create(b, log, half);
207 Value operandA = op.getOperand(0);
208 Value operandB = op.getOperand(1);
209 Value operandC = op.getOperand(2);
210 Type type = op.getType();
211 Value mult = arith::MulFOp::create(b, type, operandA, operandB);
212 Value add = arith::AddFOp::create(b, type, mult, operandC);
224 auto shapedType = dyn_cast<ShapedType>(op.getType());
225 if (shapedType && !shapedType.hasStaticShape())
229 Value operand = op.getOperand();
237 Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
240 arith::SelectOp::create(b, op->
getLoc(), gtCheck, one, zero);
242 Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
254 Value base = op.getOperand(0);
255 Value power = op.getOperand(1);
258 auto convertFPowItoPowf = [&]() -> LogicalResult {
259 Value castPowerToFp =
260 arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
261 Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
269 return convertFPowItoPowf();
273 return convertFPowItoPowf();
275 int64_t powerInt = value.getSExtValue();
276 bool isNegative = powerInt < 0;
277 int64_t absPower =
std::abs(powerInt);
281 while (absPower > 0) {
283 res = arith::MulFOp::create(b, baseType, base, res);
285 base = arith::MulFOp::create(b, baseType, base, base);
291 .getFloatSemantics();
300 APFloat::getInf(sem,
false), rewriter);
303 APFloat::getInf(sem,
true), rewriter);
305 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
306 Value negZeroEqCheck =
307 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
308 res = arith::DivFOp::create(b, baseType, one, res);
310 arith::SelectOp::create(b, op->
getLoc(), zeroEqCheck, posInfinity, res);
311 res = arith::SelectOp::create(b, op->
getLoc(), negZeroEqCheck, negInfinity,
324 Value operandA = op.getOperand(0);
325 Value operandB = op.getOperand(1);
326 auto typeA = operandA.
getType();
327 auto typeB = operandB.
getType();
333 return arith::MulFOp::create(b, x, y);
336 if (valueB.isZero()) {
342 if (valueB.isExactlyValue(1.0)) {
347 if (valueB.isExactlyValue(-1.0)) {
350 Value div = arith::DivFOp::create(b, one, operandA);
354 if (valueB.isExactlyValue(0.5)) {
356 Value sqrt = math::SqrtOp::create(b, operandA);
360 if (valueB.isExactlyValue(-0.5)) {
362 Value rsqrt = math::RsqrtOp::create(b, operandA);
366 if (valueB.isExactlyValue(2.0)) {
368 rewriter.
replaceOp(op, mulf(operandA, operandA));
371 if (valueB.isExactlyValue(-2.0)) {
375 Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
379 if (valueB.isExactlyValue(3.0)) {
380 rewriter.
replaceOp(op, mulf(mulf(operandA, operandA), operandA));
385 Value logA = math::LogOp::create(b, operandA);
386 Value mult = arith::MulFOp::create(b, operandB, logA);
387 Value expResult = math::ExpOp::create(b, mult);
399 Value operand = op.getOperand();
402 Value mult = arith::MulFOp::create(b, opType, operand, ln2);
403 Value exp = math::ExpOp::create(b, op->
getLoc(), mult);
412 Value operand = op.getOperand();
416 if (!opEType.
isF32()) {
421 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
422 i32Ty = shapedTy.clone(i32Ty);
429 Value incrValue = math::CopySignOp::create(b, half, operand);
430 Value add = arith::AddFOp::create(b, opType, operand, incrValue);
453 Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
454 Value operandExp = arith::AndIOp::create(
455 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
456 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
457 Value isSpecialValOrLargeVal = arith::CmpIOp::create(
458 b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
460 Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
470 auto operand = op.getOperand();
471 auto operandTy = operand.getType();
475 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
479 uint64_t allbits = -1;
481 allbits = allbits >> (64 - bitwidth);
486 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
489 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
491 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
493 Value add = arith::AddIOp::create(rewriter, loc, count, bits);
494 Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
496 x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
497 count = arith::SelectOp::create(rewriter, loc, pred,
add, count);
501 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
505 Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
515 auto operand = op.getOperand();
516 Type operandTy = operand.getType();
517 Type resultTy = op.getType();
521 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
525 Type fTy = operandTy;
527 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
528 iTy = shapedTy.clone(iTy);
533 unsigned mantissaWidth =
534 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
535 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
552 Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
553 Value round = math::RoundOp::create(b, operand);
554 Value roundBitcast = arith::BitcastOp::create(b, iTy,
round);
557 Value operandExp = arith::AndIOp::create(
558 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
559 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
560 Value roundExp = arith::AndIOp::create(
561 b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
562 Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
566 Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
567 clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
568 return arith::ShRUIOp::create(b, x, clampedShift);
571 auto maskMantissa = [&](
Value mantissa,
573 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
574 return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
591 Value roundBiasedExpEq0 =
592 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
593 Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
594 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
595 Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
596 b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
597 roundIsNotEvenOrSpecialVal =
598 arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
607 Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
608 b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
609 Value expectedOperandMaskedMantissa = arith::SelectOp::create(
610 b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
611 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
612 Value operandIsHalfway =
613 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
614 expectedOperandMaskedMantissa);
616 Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
617 b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
618 Value operandBiasedExpLt23 = arith::CmpIOp::create(
619 b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
621 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
623 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
627 Value sign = math::CopySignOp::create(b, c1Float, operand);
628 Value roundShifted = arith::SubFOp::create(b,
round, sign);
632 arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
633 Value result = arith::SelectOp::create(b, needsShift, roundShifted,
round);
637 result = math::CopySignOp::create(b, result, operand);
646 auto operand = op.getOperand();
647 auto operandTy = operand.getType();
649 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
650 if (shapedOperandType && !shapedOperandType.hasStaticShape())
654 if (!isa<FloatType>(eTy))
659 auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...