22#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
39 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
40 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
42 rewriter,
symTable,
"_mlir_apfloat_abs", {i32Type, i64Type});
51 auto floatTy = cast<FloatType>(operand.getType());
52 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
53 Value operandBits = arith::ExtUIOp::create(
54 rewriter, loc, i64Type,
55 arith::BitcastOp::create(rewriter, loc, intWType, operand));
57 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
58 SmallVector<Value> params = {semValue, operandBits};
60 func::CallOp::create(rewriter, loc,
TypeRange(i64Type),
61 SymbolRefAttr::get(*fn), params)
65 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
66 return arith::BitcastOp::create(rewriter, loc, floatTy,
70 rewriter.replaceOp(op, repl);
77template <
typename OpTy>
90 auto i1 = IntegerType::get(
symTable->getContext(), 1);
91 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
92 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
93 std::string funcName =
94 (llvm::Twine(
"_mlir_apfloat_is") +
APFloatName).str();
96 rewriter,
symTable, funcName, {i32Type, i64Type},
nullptr, i1);
103 rewriter, loc, op.getOperand(),
Value(), op.getType(),
105 auto floatTy = cast<FloatType>(operand.getType());
106 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
107 Value operandBits = arith::ExtUIOp::create(
108 rewriter, loc, i64Type,
109 arith::BitcastOp::create(rewriter, loc, intWType, operand));
112 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
113 Value params[] = {semValue, operandBits};
114 return func::CallOp::create(rewriter, loc,
TypeRange(i1),
115 SymbolRefAttr::get(*fn), params)
118 rewriter.replaceOp(op, repl);
136 mlir::Type resType = op.getResult().getType();
137 auto floatTy = dyn_cast<FloatType>(resType);
139 auto vecTy1 = cast<VectorType>(resType);
140 floatTy = llvm::cast<FloatType>(vecTy1.getElementType());
142 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
143 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
145 rewriter,
symTable,
"_mlir_apfloat_fused_multiply_add",
146 {i32Type, i64Type, i64Type, i64Type});
152 IntegerType intWType = rewriter.
getIntegerType(floatTy.getWidth());
153 IntegerType int64Type = rewriter.
getI64Type();
155 auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType,
157 Value operand = arith::ExtUIOp::create(
158 rewriter, loc, int64Type,
159 arith::BitcastOp::create(rewriter, loc, intWType, a));
160 Value multiplicand = arith::ExtUIOp::create(
161 rewriter, loc, int64Type,
162 arith::BitcastOp::create(rewriter, loc, intWType,
b));
163 Value addend = arith::ExtUIOp::create(
164 rewriter, loc, int64Type,
165 arith::BitcastOp::create(rewriter, loc, intWType, c));
171 SymbolRefAttr::get(*fn), params);
174 auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
175 resultOp->getResult(0));
176 return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
179 if (
auto vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
181 assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
182 "expected same vector types");
183 assert(vecTy1 == dyn_cast<VectorType>(op.getC().getType()) &&
184 "expected same vector types");
187 vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults();
189 vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults();
191 vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults();
194 for (
auto [operand, multiplicand, addend] : llvm::zip_equal(
195 scalarOperands, scalarMultiplicands, scalarAddends)) {
196 results.push_back(scalarFMA(operand, multiplicand, addend));
199 auto fromElements = vector::FromElementsOp::create(
201 vecTy1.cloneWith(std::nullopt, results.front().getType()),
207 Value repl = scalarFMA(op.getA(), op.getB(), op.getC());
216struct MathToAPFloatConversionPass final
220 void runOnOperation()
override;
223void MathToAPFloatConversionPass::runOnOperation() {
225 RewritePatternSet
patterns(context);
227 patterns.add<AbsFOpToAPFloatConversion>(context, getOperation());
228 patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context,
"finite",
230 patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context,
"infinite",
232 patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context,
"nan",
234 patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context,
"normal",
236 patterns.add<FmaOpToAPFloatConversion>(context, getOperation());
239 ScopedDiagnosticHandler scopedHandler(context, [&
result](Diagnostic &
diag) {
240 if (
diag.getSeverity() == DiagnosticSeverity::Error) {
249 return signalPassFailure();
static std::string diag(const llvm::Value &value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Include the generated interface declarations.
LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, llvm::function_ref< Value(Value, Value, Type)> fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult matchAndRewrite(math::AbsFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(math::FmaOp op, PatternRewriter &rewriter) const override
FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})