22 #include "llvm/ADT/APFloat.h"
29 bool losesInfo =
false;
32 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
33 APFloat::rmNearestTiesToEven, &losesInfo);
35 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
36 return b.
create<arith::ConstantOp>(loc,
40 return b.
create<arith::ConstantOp>(loc, attr);
52 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
53 return b.
create<arith::ConstantOp>(loc,
57 return b.
create<arith::ConstantOp>(loc, attr);
63 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
64 i64Ty = shapedTy.clone(i64Ty);
65 Value fixedConvert = b.
create<arith::FPToSIOp>(i64Ty, operand);
66 Value fpFixedConvert = b.
create<arith::SIToFPOp>(opType, fixedConvert);
69 return b.
create<math::CopySignOp>(fpFixedConvert, operand);
75 Value operand = op.getOperand();
91 Value operand = op.getOperand();
113 auto floatType = op.getOperand().getType();
121 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
122 Value isNegativeFloat =
123 rewriter.
create<arith::UIToFPOp>(loc, floatType, isNegative);
124 Value isNegativeTimesNegTwo =
125 rewriter.
create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
126 Value sign = rewriter.
create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
129 Value positiveX = rewriter.
create<arith::MulFOp>(loc, sign, op.getOperand());
132 Value negDoubledX = rewriter.
create<arith::MulFOp>(loc, negTwo, positiveX);
133 Value exp2x = rewriter.
create<math::ExpOp>(loc, negDoubledX);
134 Value dividend = rewriter.
create<arith::SubFOp>(loc, one, exp2x);
135 Value divisor = rewriter.
create<arith::AddFOp>(loc, one, exp2x);
136 Value positiveRes = rewriter.
create<arith::DivFOp>(loc, dividend, divisor);
147 Value operand = op.getOperand();
151 Value div = b.
create<arith::DivFOp>(type, sin, cos);
160 Value operand = op.getOperand();
164 Value fma = b.
create<math::FmaOp>(operand, operand, one);
166 Value add = b.
create<arith::AddFOp>(operand, sqrt);
176 Value operand = op.getOperand();
180 Value fma = b.
create<math::FmaOp>(operand, operand, negOne);
182 Value add = b.
create<arith::AddFOp>(operand, sqrt);
192 Value operand = op.getOperand();
209 Value operandA = op.getOperand(0);
210 Value operandB = op.getOperand(1);
211 Value operandC = op.getOperand(2);
212 Type type = op.getType();
213 Value mult = b.
create<arith::MulFOp>(type, operandA, operandB);
214 Value add = b.
create<arith::AddFOp>(type, mult, operandC);
226 auto shapedType = dyn_cast<ShapedType>(op.getType());
227 if (shapedType && !shapedType.hasStaticShape())
231 Value operand = op.getOperand();
239 Value gtCheck = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
241 Value incrValue = b.
create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
243 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
255 Value base = op.getOperand(0);
256 Value power = op.getOperand(1);
259 auto convertFPowItoPowf = [&]() -> LogicalResult {
260 Value castPowerToFp =
261 rewriter.
create<arith::SIToFPOp>(op.getLoc(), baseType, power);
262 Value res = rewriter.
create<math::PowFOp>(op.getLoc(), baseType, base,
270 return convertFPowItoPowf();
274 return convertFPowItoPowf();
276 int64_t powerInt = value.getSExtValue();
277 bool isNegative = powerInt < 0;
278 int64_t absPower =
std::abs(powerInt);
282 while (absPower > 0) {
284 res = b.
create<arith::MulFOp>(baseType, base, res);
286 base = b.
create<arith::MulFOp>(baseType, base, base);
292 .getFloatSemantics();
301 APFloat::getInf(sem,
false), rewriter);
304 APFloat::getInf(sem,
true), rewriter);
306 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
307 Value negZeroEqCheck =
308 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
309 res = b.
create<arith::DivFOp>(baseType, one, res);
311 b.
create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
312 res = b.
create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
325 Value operandA = op.getOperand(0);
326 Value operandB = op.getOperand(1);
327 auto typeA = operandA.
getType();
328 auto typeB = operandB.
getType();
334 return b.
create<arith::MulFOp>(x, y);
337 if (valueB.isZero()) {
343 if (valueB.isExactlyValue(1.0)) {
348 if (valueB.isExactlyValue(-1.0)) {
351 Value div = b.
create<arith::DivFOp>(one, operandA);
355 if (valueB.isExactlyValue(0.5)) {
361 if (valueB.isExactlyValue(-0.5)) {
367 if (valueB.isExactlyValue(2.0)) {
369 rewriter.
replaceOp(op, mulf(operandA, operandA));
372 if (valueB.isExactlyValue(-2.0)) {
376 Value div = b.
create<arith::DivFOp>(one, mulf(operandA, operandA));
380 if (valueB.isExactlyValue(3.0)) {
381 rewriter.
replaceOp(op, mulf(mulf(operandA, operandA), operandA));
387 Value mult = b.
create<arith::MulFOp>(operandB, logA);
400 Value operand = op.getOperand();
403 Value mult = b.
create<arith::MulFOp>(opType, operand, ln2);
404 Value exp = b.
create<math::ExpOp>(op->getLoc(), mult);
413 Value operand = op.getOperand();
417 if (!opEType.
isF32()) {
422 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
423 i32Ty = shapedTy.clone(i32Ty);
430 Value incrValue = b.
create<math::CopySignOp>(half, operand);
431 Value add = b.
create<arith::AddFOp>(opType, operand, incrValue);
454 Value operandBitcast = b.
create<arith::BitcastOp>(i32Ty, operand);
456 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
457 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
458 Value isSpecialValOrLargeVal =
459 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
461 Value result = b.
create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
471 auto operand = op.getOperand();
472 auto operandTy = operand.getType();
476 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
480 uint64_t allbits = -1;
482 allbits = allbits >> (64 - bitwidth);
487 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
490 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
493 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
494 Value add = rewriter.
create<arith::AddIOp>(loc, count, bits);
495 Value shift = rewriter.
create<arith::ShLIOp>(loc, x, bits);
497 x = rewriter.
create<arith::SelectOp>(loc, pred, shift, x);
498 count = rewriter.
create<arith::SelectOp>(loc, pred, add, count);
502 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
506 Value sel = rewriter.
create<arith::SelectOp>(loc, pred, bwval, count);
516 auto operand = op.getOperand();
517 Type operandTy = operand.getType();
518 Type resultTy = op.getType();
522 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
526 Type fTy = operandTy;
528 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
529 iTy = shapedTy.clone(iTy);
534 unsigned mantissaWidth =
535 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
536 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
553 Value operandBitcast = b.
create<arith::BitcastOp>(iTy, operand);
559 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
560 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
562 b.
create<arith::ShRUIOp>(roundBitcast, c23), expMask);
563 Value roundBiasedExp = b.
create<arith::SubIOp>(roundExp, c127);
567 Value clampedShift = b.
create<arith::MaxSIOp>(shift, c0);
568 clampedShift = b.
create<arith::MinSIOp>(clampedShift, c31);
569 return b.
create<arith::ShRUIOp>(x, clampedShift);
572 auto maskMantissa = [&](
Value mantissa,
574 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
575 return b.
create<arith::AndIOp>(mantissa, shiftedMantissaMask);
592 Value roundBiasedExpEq0 =
593 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
594 Value roundBiasedExpMinus1 = b.
create<arith::SubIOp>(roundBiasedExp, c1);
595 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
596 Value roundIsNotEvenOrSpecialVal = b.
create<arith::CmpIOp>(
597 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
598 roundIsNotEvenOrSpecialVal =
599 b.
create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
608 Value operandBiasedExpEqNeg1 = b.
create<arith::CmpIOp>(
609 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
610 Value expectedOperandMaskedMantissa = b.
create<arith::SelectOp>(
611 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
612 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
613 Value operandIsHalfway =
614 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
615 expectedOperandMaskedMantissa);
617 Value operandBiasedExpGeNeg1 = b.
create<arith::CmpIOp>(
618 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
619 Value operandBiasedExpLt23 =
620 b.
create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
622 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
624 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
628 Value sign = b.
create<math::CopySignOp>(c1Float, operand);
633 b.
create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
638 result = b.
create<math::CopySignOp>(result, operand);
647 auto operand = op.getOperand();
648 auto operandTy = operand.getType();
650 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
651 if (shapedOperandType && !shapedOperandType.hasStaticShape())
655 if (!isa<FloatType>(eTy))
660 auto sqrtOp = rewriter.
create<math::SqrtOp>(loc, operand);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...