22#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
37 auto operand = op.getOperand();
38 auto floatTy = dyn_cast<FloatType>(operand.getType());
41 "only scalar FloatTypes supported");
42 if (floatTy.getIntOrFloatBitWidth() > 64) {
44 "bitwidth > 64 bits is not supported");
47 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
48 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
50 rewriter,
symTable,
"_mlir_apfloat_abs", {i32Type, i64Type});
56 Value operandBits = arith::ExtUIOp::create(
57 rewriter, loc, i64Type,
58 arith::BitcastOp::create(rewriter, loc, intWType, operand));
63 Value negatedBits = func::CallOp::create(rewriter, loc,
TypeRange(i64Type),
64 SymbolRefAttr::get(*fn), params)
69 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
71 op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
78template <
typename OpTy>
89 auto operand = op.getOperand();
90 auto floatTy = dyn_cast<FloatType>(operand.getType());
93 "only scalar FloatTypes supported");
94 if (floatTy.getIntOrFloatBitWidth() > 64) {
96 "bitwidth > 64 bits is not supported");
99 auto i1 = IntegerType::get(
symTable->getContext(), 1);
100 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
101 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
102 std::string funcName =
103 (llvm::Twine(
"_mlir_apfloat_is") +
APFloatName).str();
105 rewriter,
symTable, funcName, {i32Type, i64Type},
nullptr, i1);
111 Value operandBits = arith::ExtUIOp::create(
112 rewriter, loc, i64Type,
113 arith::BitcastOp::create(rewriter, loc, intWType, operand));
119 SymbolRefAttr::get(*fn), params);
135 auto floatTy = cast<FloatType>(op.getResult().getType());
138 "only scalar FloatTypes supported");
139 if (floatTy.getIntOrFloatBitWidth() > 64) {
141 "bitwidth > 64 bits is not supported");
144 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
145 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
147 rewriter,
symTable,
"_mlir_apfloat_fused_multiply_add",
148 {i32Type, i64Type, i64Type, i64Type});
156 Value operand = arith::ExtUIOp::create(
157 rewriter, loc, int64Type,
158 arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
159 Value multiplicand = arith::ExtUIOp::create(
160 rewriter, loc, int64Type,
161 arith::BitcastOp::create(rewriter, loc, intWType, op.getB()));
162 Value addend = arith::ExtUIOp::create(
163 rewriter, loc, int64Type,
164 arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
171 SymbolRefAttr::get(*fn), params);
174 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
175 resultOp->getResult(0));
184struct MathToAPFloatConversionPass final
185 : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
188 void runOnOperation()
override;
191void MathToAPFloatConversionPass::runOnOperation() {
193 RewritePatternSet
patterns(context);
195 patterns.add<AbsFOpToAPFloatConversion>(context, getOperation());
196 patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context,
"finite",
198 patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context,
"infinite",
200 patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context,
"nan",
202 patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context,
"normal",
204 patterns.add<FmaOpToAPFloatConversion>(context, getOperation());
207 ScopedDiagnosticHandler scopedHandler(context, [&
result](Diagnostic &
diag) {
208 if (
diag.getSeverity() == DiagnosticSeverity::Error) {
217 return signalPassFailure();
static std::string diag(const llvm::Value &value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
LogicalResult matchAndRewrite(math::AbsFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(math::FmaOp op, PatternRewriter &rewriter) const override
FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})