MLIR 22.0.0git
ArithToAPFloat.cpp
Go to the documentation of this file.
1//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/Verifier.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
21#include "mlir/Conversion/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25using namespace mlir::func;
26
27static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
28 StringRef name, FunctionType funcT, bool setPrivate,
29 SymbolTableCollection *symbolTables = nullptr) {
31 assert(!symTable->getRegion(0).empty() && "expected non-empty region");
32 b.setInsertionPointToStart(&symTable->getRegion(0).front());
33 FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
34 if (setPrivate)
35 funcOp.setPrivate();
36 if (symbolTables) {
37 SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
38 symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
39 }
40 return funcOp;
41}
42
43/// Helper function to look up or create the symbol for a runtime library
44/// function for a binary arithmetic operation.
45///
46/// Parameter 1: APFloat semantics
47/// Parameter 2: Left-hand side operand
48/// Parameter 3: Right-hand side operand
49///
50/// This function will return a failure if the function is found but has an
51/// unexpected signature.
52///
53static FailureOr<FuncOp>
54lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
55 SymbolTableCollection *symbolTables = nullptr) {
56 auto i32Type = IntegerType::get(symTable->getContext(), 32);
57 auto i64Type = IntegerType::get(symTable->getContext(), 64);
58
59 std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
60 FunctionType funcT =
61 FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
62 FailureOr<FuncOp> func =
63 lookupFnDecl(symTable, funcName, funcT, symbolTables);
64 // Failed due to type mismatch.
65 if (failed(func))
66 return func;
67 // Successfully matched existing decl.
68 if (*func)
69 return *func;
70
71 return createFnDecl(b, symTable, funcName, funcT,
72 /*setPrivate=*/true, symbolTables);
73}
74
75/// Rewrite a binary arithmetic operation to an APFloat function call.
76template <typename OpTy>
79 const char *APFloatName,
80 SymbolOpInterface symTable,
81 PatternBenefit benefit = 1)
82 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
84
85 LogicalResult matchAndRewrite(OpTy op,
86 PatternRewriter &rewriter) const override {
87 // Get APFloat function from runtime library.
88 FailureOr<FuncOp> fn =
90 if (failed(fn))
91 return fn;
92
93 rewriter.setInsertionPoint(op);
94 // Cast operands to 64-bit integers.
95 Location loc = op.getLoc();
96 auto floatTy = cast<FloatType>(op.getType());
97 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
98 auto int64Type = rewriter.getI64Type();
99 Value lhsBits = arith::ExtUIOp::create(
100 rewriter, loc, int64Type,
101 arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
102 Value rhsBits = arith::ExtUIOp::create(
103 rewriter, loc, int64Type,
104 arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
105
106 // Call APFloat function.
107 int32_t sem =
108 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
109 Value semValue = arith::ConstantOp::create(
110 rewriter, loc, rewriter.getI32Type(),
111 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
112 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
113 auto resultOp =
114 func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
115 SymbolRefAttr::get(*fn), params);
116
117 // Truncate result to the original width.
118 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
119 resultOp->getResult(0));
120 rewriter.replaceOp(
121 op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
122 return success();
123 }
124
125 SymbolOpInterface symTable;
126 const char *APFloatName;
127};
128
129namespace {
130struct ArithToAPFloatConversionPass final
131 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
132 using Base::Base;
133
134 void runOnOperation() override;
135};
136
137void ArithToAPFloatConversionPass::runOnOperation() {
138 MLIRContext *context = &getContext();
139 RewritePatternSet patterns(context);
140 patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
141 getOperation());
142 patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
143 context, "subtract", getOperation());
144 patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
145 context, "multiply", getOperation());
146 patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
147 context, "divide", getOperation());
148 patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
149 context, "remainder", getOperation());
150 LogicalResult result = success();
151 ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
152 if (diag.getSeverity() == DiagnosticSeverity::Error) {
153 result = failure();
154 }
155 // NB: if you don't return failure, no other diag handlers will fire (see
156 // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
157 return failure();
158 });
159 walkAndApplyPatterns(getOperation(), std::move(patterns));
160 if (failed(result))
161 return signalPassFailure();
162}
163} // namespace
return success()
static FailureOr< FuncOp > lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables=nullptr)
Helper function to look up or create the symbol for a runtime library function for a binary arithmeti...
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, FunctionType funcT, bool setPrivate, SymbolTableCollection *symbolTables=nullptr)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static std::string diag(const llvm::Value &value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
FailureOr< FuncOp > lookupFnDecl(SymbolOpInterface symTable, StringRef name, FunctionType funcT, SymbolTableCollection *symbolTables=nullptr)
Look up a FuncOp with signature resultTypes(paramTypes) and name / name`.
Definition Utils.cpp:259
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})