22 #include "llvm/ADT/APFloat.h"
29 bool losesInfo =
false;
32 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
33 APFloat::rmNearestTiesToEven, &losesInfo);
35 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
36 return b.
create<arith::ConstantOp>(loc,
40 return b.
create<arith::ConstantOp>(loc, attr);
52 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
53 return b.
create<arith::ConstantOp>(loc,
57 return b.
create<arith::ConstantOp>(loc, attr);
63 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
64 i64Ty = shapedTy.clone(i64Ty);
65 Value fixedConvert = b.
create<arith::FPToSIOp>(i64Ty, operand);
66 Value fpFixedConvert = b.
create<arith::SIToFPOp>(opType, fixedConvert);
69 return b.
create<math::CopySignOp>(fpFixedConvert, operand);
75 Value operand = op.getOperand();
91 Value operand = op.getOperand();
113 auto floatType = op.getOperand().getType();
121 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
122 Value isNegativeFloat =
123 rewriter.
create<arith::UIToFPOp>(loc, floatType, isNegative);
124 Value isNegativeTimesNegTwo =
125 rewriter.
create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
126 Value sign = rewriter.
create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
129 Value positiveX = rewriter.
create<arith::MulFOp>(loc, sign, op.getOperand());
132 Value negDoubledX = rewriter.
create<arith::MulFOp>(loc, negTwo, positiveX);
133 Value exp2x = rewriter.
create<math::ExpOp>(loc, negDoubledX);
134 Value dividend = rewriter.
create<arith::SubFOp>(loc, one, exp2x);
135 Value divisor = rewriter.
create<arith::AddFOp>(loc, one, exp2x);
136 Value positiveRes = rewriter.
create<arith::DivFOp>(loc, dividend, divisor);
147 Value operand = op.getOperand();
151 Value div = b.
create<arith::DivFOp>(type, sin, cos);
160 Value operand = op.getOperand();
164 Value fma = b.
create<math::FmaOp>(operand, operand, one);
166 Value add = b.
create<arith::AddFOp>(operand, sqrt);
176 Value operand = op.getOperand();
180 Value fma = b.
create<math::FmaOp>(operand, operand, negOne);
182 Value add = b.
create<arith::AddFOp>(operand, sqrt);
192 Value operand = op.getOperand();
209 Value operandA = op.getOperand(0);
210 Value operandB = op.getOperand(1);
211 Value operandC = op.getOperand(2);
212 Type type = op.getType();
213 Value mult = b.
create<arith::MulFOp>(type, operandA, operandB);
214 Value add = b.
create<arith::AddFOp>(type, mult, operandC);
226 Value operand = op.getOperand();
234 Value gtCheck = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
236 Value incrValue = b.
create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
238 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
250 Value base = op.getOperand(0);
251 Value power = op.getOperand(1);
254 auto convertFPowItoPowf = [&]() -> LogicalResult {
255 Value castPowerToFp =
256 rewriter.
create<arith::SIToFPOp>(op.getLoc(), baseType, power);
257 Value res = rewriter.
create<math::PowFOp>(op.getLoc(), baseType, base,
265 return convertFPowItoPowf();
269 return convertFPowItoPowf();
271 int64_t powerInt = value.getSExtValue();
272 bool isNegative = powerInt < 0;
273 int64_t absPower =
std::abs(powerInt);
277 while (absPower > 0) {
279 res = b.
create<arith::MulFOp>(baseType, base, res);
281 base = b.
create<arith::MulFOp>(baseType, base, base);
287 .getFloatSemantics();
296 APFloat::getInf(sem,
false), rewriter);
299 APFloat::getInf(sem,
true), rewriter);
301 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
302 Value negZeroEqCheck =
303 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
304 res = b.
create<arith::DivFOp>(baseType, one, res);
306 b.
create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
307 res = b.
create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
320 Value operandA = op.getOperand(0);
321 Value operandB = op.getOperand(1);
322 auto typeA = operandA.
getType();
323 auto typeB = operandB.
getType();
329 return b.
create<arith::MulFOp>(x, y);
332 if (valueB.isZero()) {
338 if (valueB.isExactlyValue(1.0)) {
343 if (valueB.isExactlyValue(-1.0)) {
346 Value div = b.
create<arith::DivFOp>(one, operandA);
350 if (valueB.isExactlyValue(0.5)) {
356 if (valueB.isExactlyValue(-0.5)) {
362 if (valueB.isExactlyValue(2.0)) {
364 rewriter.
replaceOp(op, mulf(operandA, operandA));
367 if (valueB.isExactlyValue(-2.0)) {
371 Value div = b.
create<arith::DivFOp>(one, mulf(operandA, operandA));
375 if (valueB.isExactlyValue(3.0)) {
376 rewriter.
replaceOp(op, mulf(mulf(operandA, operandA), operandA));
382 Value mult = b.
create<arith::MulFOp>(operandB, logA);
395 Value operand = op.getOperand();
398 Value mult = b.
create<arith::MulFOp>(opType, operand, ln2);
399 Value exp = b.
create<math::ExpOp>(op->getLoc(), mult);
408 Value operand = op.getOperand();
412 if (!opEType.
isF32()) {
417 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
418 i32Ty = shapedTy.clone(i32Ty);
425 Value incrValue = b.
create<math::CopySignOp>(half, operand);
426 Value add = b.
create<arith::AddFOp>(opType, operand, incrValue);
449 Value operandBitcast = b.
create<arith::BitcastOp>(i32Ty, operand);
451 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
452 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
453 Value isSpecialValOrLargeVal =
454 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
456 Value result = b.
create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
466 auto operand = op.getOperand();
467 auto operandTy = operand.getType();
471 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
475 uint64_t allbits = -1;
477 allbits = allbits >> (64 - bitwidth);
482 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
485 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
488 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
489 Value add = rewriter.
create<arith::AddIOp>(loc, count, bits);
490 Value shift = rewriter.
create<arith::ShLIOp>(loc, x, bits);
492 x = rewriter.
create<arith::SelectOp>(loc, pred, shift, x);
493 count = rewriter.
create<arith::SelectOp>(loc, pred, add, count);
497 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
501 Value sel = rewriter.
create<arith::SelectOp>(loc, pred, bwval, count);
511 auto operand = op.getOperand();
512 Type operandTy = operand.getType();
513 Type resultTy = op.getType();
517 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
521 Type fTy = operandTy;
523 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
524 iTy = shapedTy.clone(iTy);
529 unsigned mantissaWidth =
530 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
531 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
548 Value operandBitcast = b.
create<arith::BitcastOp>(iTy, operand);
554 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
555 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
557 b.
create<arith::ShRUIOp>(roundBitcast, c23), expMask);
558 Value roundBiasedExp = b.
create<arith::SubIOp>(roundExp, c127);
562 Value clampedShift = b.
create<arith::MaxSIOp>(shift, c0);
563 clampedShift = b.
create<arith::MinSIOp>(clampedShift, c31);
564 return b.
create<arith::ShRUIOp>(x, clampedShift);
567 auto maskMantissa = [&](
Value mantissa,
569 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
570 return b.
create<arith::AndIOp>(mantissa, shiftedMantissaMask);
587 Value roundBiasedExpEq0 =
588 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
589 Value roundBiasedExpMinus1 = b.
create<arith::SubIOp>(roundBiasedExp, c1);
590 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
591 Value roundIsNotEvenOrSpecialVal = b.
create<arith::CmpIOp>(
592 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
593 roundIsNotEvenOrSpecialVal =
594 b.
create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
603 Value operandBiasedExpEqNeg1 = b.
create<arith::CmpIOp>(
604 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
605 Value expectedOperandMaskedMantissa = b.
create<arith::SelectOp>(
606 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
607 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
608 Value operandIsHalfway =
609 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
610 expectedOperandMaskedMantissa);
612 Value operandBiasedExpGeNeg1 = b.
create<arith::CmpIOp>(
613 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
614 Value operandBiasedExpLt23 =
615 b.
create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
617 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
619 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
623 Value sign = b.
create<math::CopySignOp>(c1Float, operand);
628 b.
create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
633 result = b.
create<math::CopySignOp>(result, operand);
642 auto operand = op.getOperand();
643 auto operandTy = operand.getType();
645 if (!isa<FloatType>(eTy))
650 auto sqrtOp = rewriter.
create<math::SqrtOp>(loc, operand);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...