29 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
30 return b.
create<arith::ConstantOp>(loc,
34 return b.
create<arith::ConstantOp>(loc, attr);
41 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return b.
create<arith::ConstantOp>(loc,
46 return b.
create<arith::ConstantOp>(loc, attr);
52 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
53 i64Ty = shapedTy.clone(i64Ty);
54 Value fixedConvert = b.
create<arith::FPToSIOp>(i64Ty, operand);
55 Value fpFixedConvert = b.
create<arith::SIToFPOp>(opType, fixedConvert);
58 return b.
create<math::CopySignOp>(fpFixedConvert, operand);
72 Value negDoubledX = rewriter.
create<arith::NegFOp>(loc, doubledX);
73 Value exp2x = rewriter.
create<math::ExpOp>(loc, negDoubledX);
74 Value dividend = rewriter.
create<arith::SubFOp>(loc, one, exp2x);
75 Value divisor = rewriter.
create<arith::AddFOp>(loc, one, exp2x);
76 Value positiveRes = rewriter.
create<arith::DivFOp>(loc, dividend, divisor);
79 exp2x = rewriter.
create<math::ExpOp>(loc, doubledX);
80 dividend = rewriter.
create<arith::SubFOp>(loc, exp2x, one);
81 divisor = rewriter.
create<arith::AddFOp>(loc, exp2x, one);
82 Value negativeRes = rewriter.
create<arith::DivFOp>(loc, dividend, divisor);
86 Value cmpRes = rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
100 Value div = b.
create<arith::DivFOp>(type, sin, cos);
110 Type type = op.getType();
111 Value mult = b.
create<arith::MulFOp>(type, operandA, operandB);
112 Value add = b.
create<arith::AddFOp>(type, mult, operandC);
134 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
136 b.
create<arith::SelectOp>(op->
getLoc(), negCheck, negOne, zero);
137 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
157 Value gtCheck = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
161 Value ret = b.
create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
174 Value opASquared = b.
create<arith::MulFOp>(opType, operandA, operandA);
175 Value opBHalf = b.
create<arith::DivFOp>(opType, operandB, two);
177 Value logA = b.
create<math::LogOp>(opType, opASquared);
178 Value mult = b.
create<arith::MulFOp>(opType, opBHalf, logA);
179 Value expResult = b.
create<math::ExpOp>(opType, mult);
180 Value negExpResult = b.
create<arith::MulFOp>(opType, expResult, negOne);
181 Value remainder = b.
create<arith::RemFOp>(opType, operandB, two);
183 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
185 b.
create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
204 Value mult = b.
create<arith::MulFOp>(opType, operand, ln2);
218 if (!opEType.
isF32()) {
223 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
224 i32Ty = shapedTy.clone(i32Ty);
231 Value incrValue = b.
create<math::CopySignOp>(half, operand);
232 Value add = b.
create<arith::AddFOp>(opType, operand, incrValue);
255 Value operandBitcast = b.
create<arith::BitcastOp>(i32Ty, operand);
257 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
258 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
259 Value isSpecialValOrLargeVal =
260 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
262 Value result = b.
create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
273 auto operandTy = operand.
getType();
277 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
281 uint64_t allbits = -1;
283 allbits = allbits >> (64 - bitwidth);
288 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
291 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
294 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
295 Value add = rewriter.
create<arith::AddIOp>(loc, count, bits);
296 Value shift = rewriter.
create<arith::ShLIOp>(loc, x, bits);
298 x = rewriter.
create<arith::SelectOp>(loc, pred, shift, x);
299 count = rewriter.
create<arith::SelectOp>(loc, pred, add, count);
303 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
307 Value sel = rewriter.
create<arith::SelectOp>(loc, pred, bwval, count);
318 Type operandTy = operand.getType();
319 Type resultTy = op.getType();
323 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
327 Type fTy = operandTy;
329 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
330 iTy = shapedTy.clone(iTy);
335 unsigned mantissaWidth =
336 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
337 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
354 Value operandBitcast = b.
create<arith::BitcastOp>(iTy, operand);
356 Value roundBitcast = b.
create<arith::BitcastOp>(iTy, round);
360 b.
create<arith::ShRUIOp>(operandBitcast, c23), expMask);
361 Value operandBiasedExp = b.
create<arith::SubIOp>(operandExp, c127);
363 b.
create<arith::ShRUIOp>(roundBitcast, c23), expMask);
364 Value roundBiasedExp = b.
create<arith::SubIOp>(roundExp, c127);
368 Value clampedShift = b.
create<arith::MaxSIOp>(shift, c0);
369 clampedShift = b.
create<arith::MinSIOp>(clampedShift, c31);
370 return b.
create<arith::ShRUIOp>(x, clampedShift);
373 auto maskMantissa = [&](
Value mantissa,
375 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
376 return b.
create<arith::AndIOp>(mantissa, shiftedMantissaMask);
393 Value roundBiasedExpEq0 =
394 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
395 Value roundBiasedExpMinus1 = b.
create<arith::SubIOp>(roundBiasedExp, c1);
396 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
397 Value roundIsNotEvenOrSpecialVal = b.
create<arith::CmpIOp>(
398 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
399 roundIsNotEvenOrSpecialVal =
400 b.
create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
409 Value operandBiasedExpEqNeg1 = b.
create<arith::CmpIOp>(
410 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
411 Value expectedOperandMaskedMantissa = b.
create<arith::SelectOp>(
412 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
413 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
414 Value operandIsHalfway =
415 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
416 expectedOperandMaskedMantissa);
418 Value operandBiasedExpGeNeg1 = b.
create<arith::CmpIOp>(
419 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
420 Value operandBiasedExpLt23 =
421 b.
create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
423 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
425 b.
create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
429 Value sign = b.
create<math::CopySignOp>(c1Float, operand);
430 Value roundShifted = b.
create<arith::SubFOp>(round, sign);
434 b.
create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
435 Value result = b.
create<arith::SelectOp>(needsShift, roundShifted, round);
439 result = b.
create<math::CopySignOp>(result, operand);
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static Value createFloatConst(Location loc, Type type, double value, OpBuilder &b)
Create a float constant.
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertFloorOp(math::FloorOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create a float constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 2) exp^{2x}-1 / exp^{2x}+1 ,...
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandFloorFPattern(RewritePatternSet &patterns)
This class represents an efficient way to signal success or failure.