MLIR 23.0.0git
MathToAPFloat.cpp
Go to the documentation of this file.
1//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils.h"
10
18#include "mlir/IR/Verifier.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::func;
28
29struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
30 AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
31 PatternBenefit benefit = 1)
32 : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
33
34 LogicalResult matchAndRewrite(math::AbsFOp op,
35 PatternRewriter &rewriter) const override {
36 if (failed(checkPreconditions(rewriter, op)))
37 return failure();
38 // Get APFloat function from runtime library.
39 auto i32Type = IntegerType::get(symTable->getContext(), 32);
40 auto i64Type = IntegerType::get(symTable->getContext(), 64);
41 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
42 rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
43 if (failed(fn))
44 return fn;
45 Location loc = op.getLoc();
46 rewriter.setInsertionPoint(op);
47 // Scalarize and convert to APFloat runtime calls.
49 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
50 [&](Value operand, Value, Type resultType) {
51 auto floatTy = cast<FloatType>(operand.getType());
52 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
53 Value operandBits = arith::ExtUIOp::create(
54 rewriter, loc, i64Type,
55 arith::BitcastOp::create(rewriter, loc, intWType, operand));
56 // Call APFloat function.
57 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
58 SmallVector<Value> params = {semValue, operandBits};
59 Value negatedBits =
60 func::CallOp::create(rewriter, loc, TypeRange(i64Type),
61 SymbolRefAttr::get(*fn), params)
62 ->getResult(0);
63 // Truncate result to the original width.
64 auto truncatedBits =
65 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
66 return arith::BitcastOp::create(rewriter, loc, floatTy,
67 truncatedBits);
68 });
69
70 rewriter.replaceOp(op, repl);
71 return success();
72 }
73
74 SymbolOpInterface symTable;
75};
76
77template <typename OpTy>
80 SymbolOpInterface symTable,
81 PatternBenefit benefit = 1)
82 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
84
85 LogicalResult matchAndRewrite(OpTy op,
86 PatternRewriter &rewriter) const override {
87 if (failed(checkPreconditions(rewriter, op)))
88 return failure();
89 // Get APFloat function from runtime library.
90 auto i1 = IntegerType::get(symTable->getContext(), 1);
91 auto i32Type = IntegerType::get(symTable->getContext(), 32);
92 auto i64Type = IntegerType::get(symTable->getContext(), 64);
93 std::string funcName =
94 (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
95 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
96 rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
97 if (failed(fn))
98 return fn;
99 Location loc = op.getLoc();
100 rewriter.setInsertionPoint(op);
101 // Scalarize and convert to APFloat runtime calls.
103 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
104 [&](Value operand, Value, Type resultType) {
105 auto floatTy = cast<FloatType>(operand.getType());
106 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
107 Value operandBits = arith::ExtUIOp::create(
108 rewriter, loc, i64Type,
109 arith::BitcastOp::create(rewriter, loc, intWType, operand));
110
111 // Call APFloat function.
112 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
113 Value params[] = {semValue, operandBits};
114 return func::CallOp::create(rewriter, loc, TypeRange(i1),
115 SymbolRefAttr::get(*fn), params)
116 .getResult(0);
117 });
118 rewriter.replaceOp(op, repl);
119 return success();
120 }
121
122 SymbolOpInterface symTable;
123 const char *APFloatName;
124};
125
126struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
127 FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
128 PatternBenefit benefit = 1)
129 : OpRewritePattern<math::FmaOp>(context, benefit), symTable(symTable) {};
130
131 LogicalResult matchAndRewrite(math::FmaOp op,
132 PatternRewriter &rewriter) const override {
133 if (failed(checkPreconditions(rewriter, op)))
134 return failure();
135 // Cast operands to 64-bit integers.
136 mlir::Type resType = op.getResult().getType();
137 auto floatTy = dyn_cast<FloatType>(resType);
138 if (!floatTy) {
139 auto vecTy1 = cast<VectorType>(resType);
140 floatTy = llvm::cast<FloatType>(vecTy1.getElementType());
141 }
142 auto i32Type = IntegerType::get(symTable->getContext(), 32);
143 auto i64Type = IntegerType::get(symTable->getContext(), 64);
144 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
145 rewriter, symTable, "_mlir_apfloat_fused_multiply_add",
146 {i32Type, i64Type, i64Type, i64Type});
147 if (failed(fn))
148 return fn;
149 Location loc = op.getLoc();
150 rewriter.setInsertionPoint(op);
151
152 IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
153 IntegerType int64Type = rewriter.getI64Type();
154
155 auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType,
156 &int64Type](Value a, Value b, Value c) {
157 Value operand = arith::ExtUIOp::create(
158 rewriter, loc, int64Type,
159 arith::BitcastOp::create(rewriter, loc, intWType, a));
160 Value multiplicand = arith::ExtUIOp::create(
161 rewriter, loc, int64Type,
162 arith::BitcastOp::create(rewriter, loc, intWType, b));
163 Value addend = arith::ExtUIOp::create(
164 rewriter, loc, int64Type,
165 arith::BitcastOp::create(rewriter, loc, intWType, c));
166 // Call APFloat function.
167 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
168 SmallVector<Value> params = {semValue, operand, multiplicand, addend};
169 auto resultOp =
170 func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
171 SymbolRefAttr::get(*fn), params);
172
173 // Truncate result to the original width.
174 auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
175 resultOp->getResult(0));
176 return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
177 };
178
179 if (auto vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
180 // Sanity check: Operand types must match.
181 assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
182 "expected same vector types");
183 assert(vecTy1 == dyn_cast<VectorType>(op.getC().getType()) &&
184 "expected same vector types");
185 // Prepare scalar operands.
186 ResultRange scalarOperands =
187 vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults();
188 ResultRange scalarMultiplicands =
189 vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults();
190 ResultRange scalarAddends =
191 vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults();
192 // Call the function for each pair of scalar operands.
193 SmallVector<Value> results;
194 for (auto [operand, multiplicand, addend] : llvm::zip_equal(
195 scalarOperands, scalarMultiplicands, scalarAddends)) {
196 results.push_back(scalarFMA(operand, multiplicand, addend));
197 }
198 // Package the results into a vector.
199 auto fromElements = vector::FromElementsOp::create(
200 rewriter, loc,
201 vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
202 results);
203 rewriter.replaceOp(op, fromElements);
204 return success();
205 }
206
207 Value repl = scalarFMA(op.getA(), op.getB(), op.getC());
208 rewriter.replaceOp(op, repl);
209 return success();
210 }
211
212 SymbolOpInterface symTable;
213};
214
215namespace {
216struct MathToAPFloatConversionPass final
217 : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
218 using Base::Base;
219
220 void runOnOperation() override;
221};
222
223void MathToAPFloatConversionPass::runOnOperation() {
224 MLIRContext *context = &getContext();
225 RewritePatternSet patterns(context);
226
227 patterns.add<AbsFOpToAPFloatConversion>(context, getOperation());
228 patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context, "finite",
229 getOperation());
230 patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context, "infinite",
231 getOperation());
232 patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context, "nan",
233 getOperation());
234 patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal",
235 getOperation());
236 patterns.add<FmaOpToAPFloatConversion>(context, getOperation());
237
238 LogicalResult result = success();
239 ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
240 if (diag.getSeverity() == DiagnosticSeverity::Error) {
241 result = failure();
242 }
243 // NB: if you don't return failure, no other diag handlers will fire (see
244 // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
245 return failure();
246 });
247 walkAndApplyPatterns(getOperation(), std::move(patterns));
248 if (failed(result))
249 return signalPassFailure();
250}
251} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static std::string diag(const llvm::Value &value)
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Definition Utils.cpp:302
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
Definition Utils.cpp:70
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
Definition Utils.cpp:21
Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, llvm::function_ref< Value(Value, Value, Type)> fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
Definition Utils.cpp:28
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult matchAndRewrite(math::AbsFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(math::FmaOp op, PatternRewriter &rewriter) const override
FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})