28 bool losesInfo =
false;
31 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
32 APFloat::rmNearestTiesToEven, &losesInfo);
34 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
35 return b.
create<arith::ConstantOp>(loc,
39 return b.
create<arith::ConstantOp>(loc, attr);
51 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
52 return b.
create<arith::ConstantOp>(loc,
56 return b.
create<arith::ConstantOp>(loc, attr);
62 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
63 i64Ty = shapedTy.clone(i64Ty);
64 Value fixedConvert = b.
create<arith::FPToSIOp>(i64Ty, operand);
65 Value fpFixedConvert = b.
create<arith::SIToFPOp>(opType, fixedConvert);
68 return b.
create<math::CopySignOp>(fpFixedConvert, operand);
74 Value operand = op.getOperand();
90 Value operand = op.getOperand();
112 auto floatType = op.getOperand().getType();
120 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
121 Value isNegativeFloat =
122 rewriter.
create<arith::UIToFPOp>(loc, floatType, isNegative);
123 Value isNegativeTimesNegTwo =
124 rewriter.
create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
125 Value sign = rewriter.
create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
128 Value positiveX = rewriter.
create<arith::MulFOp>(loc, sign, op.getOperand());
131 Value negDoubledX = rewriter.
create<arith::MulFOp>(loc, negTwo, positiveX);
132 Value exp2x = rewriter.
create<math::ExpOp>(loc, negDoubledX);
133 Value dividend = rewriter.
create<arith::SubFOp>(loc, one, exp2x);
134 Value divisor = rewriter.
create<arith::AddFOp>(loc, one, exp2x);
135 Value positiveRes = rewriter.
create<arith::DivFOp>(loc, dividend, divisor);
146 Value operand = op.getOperand();
150 Value div = b.
create<arith::DivFOp>(type, sin, cos);
159 Value operand = op.getOperand();
163 Value fma = b.
create<math::FmaOp>(operand, operand, one);
165 Value add = b.
create<arith::AddFOp>(operand, sqrt);
175 Value operand = op.getOperand();
179 Value fma = b.
create<math::FmaOp>(operand, operand, negOne);
181 Value add = b.
create<arith::AddFOp>(operand, sqrt);
191 Value operand = op.getOperand();
208 Value operandA = op.getOperand(0);
209 Value operandB = op.getOperand(1);
210 Value operandC = op.getOperand(2);
211 Type type = op.getType();
212 Value mult = b.
create<arith::MulFOp>(type, operandA, operandB);
213 Value add = b.
create<arith::AddFOp>(type, mult, operandC);
226 Value operand = op.getOperand();
235 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
237 b.
create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
238 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
250 Value operand = op.getOperand();
258 Value gtCheck = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
260 Value incrValue = b.
create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
262 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
274 Value base = op.getOperand(0);
275 Value power = op.getOperand(1);
278 auto convertFPowItoPowf = [&]() -> LogicalResult {
279 Value castPowerToFp =
280 rewriter.
create<arith::SIToFPOp>(op.getLoc(), baseType, power);
281 Value res = rewriter.
create<math::PowFOp>(op.getLoc(), baseType, base,
289 return convertFPowItoPowf();
293 return convertFPowItoPowf();
295 int64_t powerInt = value.getSExtValue();
296 bool isNegative = powerInt < 0;
297 int64_t absPower =
std::abs(powerInt);
301 while (absPower > 0) {
303 res = b.
create<arith::MulFOp>(baseType, base, res);
305 base = b.
create<arith::MulFOp>(baseType, base, base);
311 .getFloatSemantics();
320 APFloat::getInf(sem,
false), rewriter);
323 APFloat::getInf(sem,
true), rewriter);
325 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
326 Value negZeroEqCheck =
327 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
328 res = b.
create<arith::DivFOp>(baseType, one, res);
330 b.
create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
331 res = b.
create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
342 Value operandA = op.getOperand(0);
343 Value operandB = op.getOperand(1);
349 Value opASquared = b.
create<arith::MulFOp>(opType, operandA, operandA);
350 Value opBHalf = b.
create<arith::DivFOp>(opType, operandB, two);
352 Value logA = b.
create<math::LogOp>(opType, opASquared);
353 Value mult = b.
create<arith::MulFOp>(opType, opBHalf, logA);
354 Value expResult = b.
create<math::ExpOp>(opType, mult);
355 Value negExpResult = b.
create<arith::MulFOp>(opType, expResult, negOne);
356 Value remainder = b.
create<arith::RemFOp>(opType, operandB, two);
358 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
360 b.
create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
361 Value oddAndNeg = b.
create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
368 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
369 Value res = b.
create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
371 res = b.
create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
383 Value operand = op.getOperand();
386 Value mult = b.
create<arith::MulFOp>(opType, operand, ln2);
387 Value exp = b.
create<math::ExpOp>(op->getLoc(), mult);
396 Value operand = op.getOperand();
400 if (!opEType.
isF32()) {
405 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
406 i32Ty = shapedTy.clone(i32Ty);
413 Value incrValue = b.
create<math::CopySignOp>(half, operand);
414 Value add = b.
create<arith::AddFOp>(opType, operand, incrValue);
437 Value operandBitcast = b.
create<arith::BitcastOp>(i32Ty, operand);
439 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
440 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
441 Value isSpecialValOrLargeVal =
442 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
444 Value result = b.
create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
454 auto operand = op.getOperand();
455 auto operandTy = operand.getType();
459 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
463 uint64_t allbits = -1;
465 allbits = allbits >> (64 - bitwidth);
470 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
473 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
476 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
477 Value add = rewriter.
create<arith::AddIOp>(loc, count, bits);
478 Value shift = rewriter.
create<arith::ShLIOp>(loc, x, bits);
480 x = rewriter.
create<arith::SelectOp>(loc, pred, shift, x);
481 count = rewriter.
create<arith::SelectOp>(loc, pred, add, count);
485 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
489 Value sel = rewriter.
create<arith::SelectOp>(loc, pred, bwval, count);
499 auto operand = op.getOperand();
500 Type operandTy = operand.getType();
501 Type resultTy = op.getType();
505 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
509 Type fTy = operandTy;
511 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
512 iTy = shapedTy.clone(iTy);
517 unsigned mantissaWidth =
518 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
519 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
536 Value operandBitcast = b.
create<arith::BitcastOp>(iTy, operand);
542 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
543 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
545 b.
create<arith::ShRUIOp>(roundBitcast, c23), expMask);
546 Value roundBiasedExp = b.
create<arith::SubIOp>(roundExp, c127);
550 Value clampedShift = b.
create<arith::MaxSIOp>(shift, c0);
551 clampedShift = b.
create<arith::MinSIOp>(clampedShift, c31);
552 return b.
create<arith::ShRUIOp>(x, clampedShift);
555 auto maskMantissa = [&](
Value mantissa,
557 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
558 return b.
create<arith::AndIOp>(mantissa, shiftedMantissaMask);
575 Value roundBiasedExpEq0 =
576 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
577 Value roundBiasedExpMinus1 = b.
create<arith::SubIOp>(roundBiasedExp, c1);
578 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
579 Value roundIsNotEvenOrSpecialVal = b.
create<arith::CmpIOp>(
580 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
581 roundIsNotEvenOrSpecialVal =
582 b.
create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
591 Value operandBiasedExpEqNeg1 = b.
create<arith::CmpIOp>(
592 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
593 Value expectedOperandMaskedMantissa = b.
create<arith::SelectOp>(
594 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
595 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
596 Value operandIsHalfway =
597 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
598 expectedOperandMaskedMantissa);
600 Value operandBiasedExpGeNeg1 = b.
create<arith::CmpIOp>(
601 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
602 Value operandBiasedExpLt23 =
603 b.
create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
605 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
607 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
611 Value sign = b.
create<math::CopySignOp>(c1Float, operand);
616 b.
create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
621 result = b.
create<math::CopySignOp>(result, operand);
630 auto operand = op.getOperand();
631 auto operandTy = operand.getType();
633 if (!isa<FloatType>(eTy))
638 auto sqrtOp = rewriter.
create<math::SqrtOp>(loc, operand);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertFloorOp(math::FloorOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
void populateExpandFloorFPattern(RewritePatternSet &patterns)