28 bool losesInfo =
false;
31 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
32 APFloat::rmNearestTiesToEven, &losesInfo);
34 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
35 return b.
create<arith::ConstantOp>(loc,
39 return b.
create<arith::ConstantOp>(loc, attr);
51 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
52 return b.
create<arith::ConstantOp>(loc,
56 return b.
create<arith::ConstantOp>(loc, attr);
62 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
63 i64Ty = shapedTy.clone(i64Ty);
64 Value fixedConvert = b.
create<arith::FPToSIOp>(i64Ty, operand);
65 Value fpFixedConvert = b.
create<arith::SIToFPOp>(opType, fixedConvert);
68 return b.
create<math::CopySignOp>(fpFixedConvert, operand);
74 Value operand = op.getOperand();
90 Value operand = op.getOperand();
112 auto floatType = op.getOperand().getType();
120 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
121 Value isNegativeFloat =
122 rewriter.
create<arith::UIToFPOp>(loc, floatType, isNegative);
123 Value isNegativeTimesNegTwo =
124 rewriter.
create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
125 Value sign = rewriter.
create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
128 Value positiveX = rewriter.
create<arith::MulFOp>(loc, sign, op.getOperand());
131 Value negDoubledX = rewriter.
create<arith::MulFOp>(loc, negTwo, positiveX);
132 Value exp2x = rewriter.
create<math::ExpOp>(loc, negDoubledX);
133 Value dividend = rewriter.
create<arith::SubFOp>(loc, one, exp2x);
134 Value divisor = rewriter.
create<arith::AddFOp>(loc, one, exp2x);
135 Value positiveRes = rewriter.
create<arith::DivFOp>(loc, dividend, divisor);
146 Value operand = op.getOperand();
150 Value div = b.
create<arith::DivFOp>(type, sin, cos);
159 Value operand = op.getOperand();
163 Value fma = b.
create<math::FmaOp>(operand, operand, one);
165 Value add = b.
create<arith::AddFOp>(operand, sqrt);
175 Value operand = op.getOperand();
179 Value fma = b.
create<math::FmaOp>(operand, operand, negOne);
181 Value add = b.
create<arith::AddFOp>(operand, sqrt);
191 Value operand = op.getOperand();
208 Value operandA = op.getOperand(0);
209 Value operandB = op.getOperand(1);
210 Value operandC = op.getOperand(2);
211 Type type = op.getType();
212 Value mult = b.
create<arith::MulFOp>(type, operandA, operandB);
213 Value add = b.
create<arith::AddFOp>(type, mult, operandC);
226 Value operand = op.getOperand();
235 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
237 b.
create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
238 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
250 Value operand = op.getOperand();
258 Value gtCheck = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
260 Value incrValue = b.
create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
262 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
274 Value base = op.getOperand(0);
275 Value power = op.getOperand(1);
278 auto convertFPowItoPowf = [&]() -> LogicalResult {
279 Value castPowerToFp =
280 rewriter.
create<arith::SIToFPOp>(op.getLoc(), baseType, power);
281 Value res = rewriter.
create<math::PowFOp>(op.getLoc(), baseType, base,
289 return convertFPowItoPowf();
293 return convertFPowItoPowf();
295 int64_t powerInt = value.getSExtValue();
296 bool isNegative = powerInt < 0;
297 int64_t absPower =
std::abs(powerInt);
301 while (absPower > 0) {
303 res = b.
create<arith::MulFOp>(baseType, base, res);
305 base = b.
create<arith::MulFOp>(baseType, base, base);
311 .getFloatSemantics();
320 APFloat::getInf(sem,
false), rewriter);
323 APFloat::getInf(sem,
true), rewriter);
325 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
326 Value negZeroEqCheck =
327 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
328 res = b.
create<arith::DivFOp>(baseType, one, res);
330 b.
create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
331 res = b.
create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
342 Value operandA = op.getOperand(0);
343 Value operandB = op.getOperand(1);
348 Value opASquared = b.
create<arith::MulFOp>(opType, operandA, operandA);
349 Value opBHalf = b.
create<arith::DivFOp>(opType, operandB, two);
351 Value logA = b.
create<math::LogOp>(opType, opASquared);
352 Value mult = b.
create<arith::MulFOp>(opType, opBHalf, logA);
353 Value expResult = b.
create<math::ExpOp>(opType, mult);
354 Value negExpResult = b.
create<arith::MulFOp>(opType, expResult, negOne);
355 Value remainder = b.
create<arith::RemFOp>(opType, operandB, two);
357 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
359 b.
create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
360 Value oddAndNeg = b.
create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
362 Value res = b.
create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
375 Value operand = op.getOperand();
378 Value mult = b.
create<arith::MulFOp>(opType, operand, ln2);
379 Value exp = b.
create<math::ExpOp>(op->getLoc(), mult);
388 Value operand = op.getOperand();
392 if (!opEType.
isF32()) {
397 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
398 i32Ty = shapedTy.clone(i32Ty);
405 Value incrValue = b.
create<math::CopySignOp>(half, operand);
406 Value add = b.
create<arith::AddFOp>(opType, operand, incrValue);
429 Value operandBitcast = b.
create<arith::BitcastOp>(i32Ty, operand);
431 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
432 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
433 Value isSpecialValOrLargeVal =
434 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
436 Value result = b.
create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
446 auto operand = op.getOperand();
447 auto operandTy = operand.getType();
451 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
455 uint64_t allbits = -1;
457 allbits = allbits >> (64 - bitwidth);
462 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
465 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
468 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
469 Value add = rewriter.
create<arith::AddIOp>(loc, count, bits);
470 Value shift = rewriter.
create<arith::ShLIOp>(loc, x, bits);
472 x = rewriter.
create<arith::SelectOp>(loc, pred, shift, x);
473 count = rewriter.
create<arith::SelectOp>(loc, pred, add, count);
477 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
481 Value sel = rewriter.
create<arith::SelectOp>(loc, pred, bwval, count);
491 auto operand = op.getOperand();
492 Type operandTy = operand.getType();
493 Type resultTy = op.getType();
497 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
501 Type fTy = operandTy;
503 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
504 iTy = shapedTy.clone(iTy);
509 unsigned mantissaWidth =
510 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
511 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
528 Value operandBitcast = b.
create<arith::BitcastOp>(iTy, operand);
534 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
535 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
537 b.
create<arith::ShRUIOp>(roundBitcast, c23), expMask);
538 Value roundBiasedExp = b.
create<arith::SubIOp>(roundExp, c127);
542 Value clampedShift = b.
create<arith::MaxSIOp>(shift, c0);
543 clampedShift = b.
create<arith::MinSIOp>(clampedShift, c31);
544 return b.
create<arith::ShRUIOp>(x, clampedShift);
547 auto maskMantissa = [&](
Value mantissa,
549 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
550 return b.
create<arith::AndIOp>(mantissa, shiftedMantissaMask);
567 Value roundBiasedExpEq0 =
568 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
569 Value roundBiasedExpMinus1 = b.
create<arith::SubIOp>(roundBiasedExp, c1);
570 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
571 Value roundIsNotEvenOrSpecialVal = b.
create<arith::CmpIOp>(
572 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
573 roundIsNotEvenOrSpecialVal =
574 b.
create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
583 Value operandBiasedExpEqNeg1 = b.
create<arith::CmpIOp>(
584 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
585 Value expectedOperandMaskedMantissa = b.
create<arith::SelectOp>(
586 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
587 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
588 Value operandIsHalfway =
589 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
590 expectedOperandMaskedMantissa);
592 Value operandBiasedExpGeNeg1 = b.
create<arith::CmpIOp>(
593 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
594 Value operandBiasedExpLt23 =
595 b.
create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
597 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
599 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
603 Value sign = b.
create<math::CopySignOp>(c1Float, operand);
608 b.
create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
613 result = b.
create<math::CopySignOp>(result, operand);
622 auto operand = op.getOperand();
623 auto operandTy = operand.getType();
625 if (!isa<FloatType>(eTy))
630 auto sqrtOp = rewriter.
create<math::SqrtOp>(loc, operand);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertFloorOp(math::FloorOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
void populateExpandFloorFPattern(RewritePatternSet &patterns)