MLIR 22.0.0git
MathToAPFloat.cpp
Go to the documentation of this file.
1//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils.h"
10
18#include "mlir/IR/Verifier.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::func;
28
29struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
30 AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
31 PatternBenefit benefit = 1)
32 : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
33
34 LogicalResult matchAndRewrite(math::AbsFOp op,
35 PatternRewriter &rewriter) const override {
36 // Cast operands to 64-bit integers.
37 auto operand = op.getOperand();
38 auto floatTy = dyn_cast<FloatType>(operand.getType());
39 if (!floatTy)
40 return rewriter.notifyMatchFailure(op,
41 "only scalar FloatTypes supported");
42 if (floatTy.getIntOrFloatBitWidth() > 64) {
43 return rewriter.notifyMatchFailure(op,
44 "bitwidth > 64 bits is not supported");
45 }
46 // Get APFloat function from runtime library.
47 auto i32Type = IntegerType::get(symTable->getContext(), 32);
48 auto i64Type = IntegerType::get(symTable->getContext(), 64);
49 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
50 rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
51 if (failed(fn))
52 return fn;
53 Location loc = op.getLoc();
54 rewriter.setInsertionPoint(op);
55 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
56 Value operandBits = arith::ExtUIOp::create(
57 rewriter, loc, i64Type,
58 arith::BitcastOp::create(rewriter, loc, intWType, operand));
59
60 // Call APFloat function.
61 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
62 SmallVector<Value> params = {semValue, operandBits};
63 Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
64 SymbolRefAttr::get(*fn), params)
65 ->getResult(0);
66
67 // Truncate result to the original width.
68 Value truncatedBits =
69 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
70 rewriter.replaceOp(
71 op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
72 return success();
73 }
74
75 SymbolOpInterface symTable;
76};
77
78template <typename OpTy>
81 SymbolOpInterface symTable,
82 PatternBenefit benefit = 1)
83 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
85
86 LogicalResult matchAndRewrite(OpTy op,
87 PatternRewriter &rewriter) const override {
88 // Cast operands to 64-bit integers.
89 auto operand = op.getOperand();
90 auto floatTy = dyn_cast<FloatType>(operand.getType());
91 if (!floatTy)
92 return rewriter.notifyMatchFailure(op,
93 "only scalar FloatTypes supported");
94 if (floatTy.getIntOrFloatBitWidth() > 64) {
95 return rewriter.notifyMatchFailure(op,
96 "bitwidth > 64 bits is not supported");
97 }
98 // Get APFloat function from runtime library.
99 auto i1 = IntegerType::get(symTable->getContext(), 1);
100 auto i32Type = IntegerType::get(symTable->getContext(), 32);
101 auto i64Type = IntegerType::get(symTable->getContext(), 64);
102 std::string funcName =
103 (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
104 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
105 rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
106 if (failed(fn))
107 return fn;
108 Location loc = op.getLoc();
109 rewriter.setInsertionPoint(op);
110 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
111 Value operandBits = arith::ExtUIOp::create(
112 rewriter, loc, i64Type,
113 arith::BitcastOp::create(rewriter, loc, intWType, operand));
114
115 // Call APFloat function.
116 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
117 SmallVector<Value> params = {semValue, operandBits};
118 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1),
119 SymbolRefAttr::get(*fn), params);
120 return success();
121 }
122
123 SymbolOpInterface symTable;
124 const char *APFloatName;
125};
126
127struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
128 FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
129 PatternBenefit benefit = 1)
130 : OpRewritePattern<math::FmaOp>(context, benefit), symTable(symTable) {};
131
132 LogicalResult matchAndRewrite(math::FmaOp op,
133 PatternRewriter &rewriter) const override {
134 // Cast operands to 64-bit integers.
135 auto floatTy = cast<FloatType>(op.getResult().getType());
136 if (!floatTy)
137 return rewriter.notifyMatchFailure(op,
138 "only scalar FloatTypes supported");
139 if (floatTy.getIntOrFloatBitWidth() > 64) {
140 return rewriter.notifyMatchFailure(op,
141 "bitwidth > 64 bits is not supported");
142 }
143
144 auto i32Type = IntegerType::get(symTable->getContext(), 32);
145 auto i64Type = IntegerType::get(symTable->getContext(), 64);
146 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
147 rewriter, symTable, "_mlir_apfloat_fused_multiply_add",
148 {i32Type, i64Type, i64Type, i64Type});
149 if (failed(fn))
150 return fn;
151 Location loc = op.getLoc();
152 rewriter.setInsertionPoint(op);
153
154 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
155 auto int64Type = rewriter.getI64Type();
156 Value operand = arith::ExtUIOp::create(
157 rewriter, loc, int64Type,
158 arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
159 Value multiplicand = arith::ExtUIOp::create(
160 rewriter, loc, int64Type,
161 arith::BitcastOp::create(rewriter, loc, intWType, op.getB()));
162 Value addend = arith::ExtUIOp::create(
163 rewriter, loc, int64Type,
164 arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
165
166 // Call APFloat function.
167 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
168 SmallVector<Value> params = {semValue, operand, multiplicand, addend};
169 auto resultOp =
170 func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
171 SymbolRefAttr::get(*fn), params);
172
173 // Truncate result to the original width.
174 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
175 resultOp->getResult(0));
176 rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
177 return success();
178 }
179
180 SymbolOpInterface symTable;
181};
182
183namespace {
184struct MathToAPFloatConversionPass final
185 : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
186 using Base::Base;
187
188 void runOnOperation() override;
189};
190
191void MathToAPFloatConversionPass::runOnOperation() {
192 MLIRContext *context = &getContext();
193 RewritePatternSet patterns(context);
194
195 patterns.add<AbsFOpToAPFloatConversion>(context, getOperation());
196 patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context, "finite",
197 getOperation());
198 patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context, "infinite",
199 getOperation());
200 patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context, "nan",
201 getOperation());
202 patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal",
203 getOperation());
204 patterns.add<FmaOpToAPFloatConversion>(context, getOperation());
205
206 LogicalResult result = success();
207 ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
208 if (diag.getSeverity() == DiagnosticSeverity::Error) {
209 result = failure();
210 }
211 // NB: if you don't return failure, no other diag handlers will fire (see
212 // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
213 return failure();
214 });
215 walkAndApplyPatterns(getOperation(), std::move(patterns));
216 if (failed(result))
217 return signalPassFailure();
218}
219} // namespace
return success()
b getContext())
static std::string diag(const llvm::Value &value)
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Definition Utils.cpp:302
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
Definition Utils.cpp:17
LogicalResult matchAndRewrite(math::AbsFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(math::FmaOp op, PatternRewriter &rewriter) const override
FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})