MLIR 22.0.0git
OpToFuncCallLowering.h
Go to the documentation of this file.
1//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10
15#include "mlir/IR/Builders.h"
16
17namespace mlir {
18
19namespace {
20/// Detection trait tor the `getFastmath` instance method.
21template <typename T>
22using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
23} // namespace
24
25/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
26/// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
27/// the fastMathFlag of that Op, if present. The function declaration is added
28/// in case it was not added before.
29///
30/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
31/// value is first casted to f32, the function called and then the result casted
32/// back.
33///
34/// Example with NVVM:
35/// %exp_f32 = math.exp %arg_f32 : f32
36///
37/// will be transformed into
38/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
39///
40/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
41/// to the approximate calculation function.
42///
43/// Also example with NVVM:
44/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
45///
46/// will be transformed into
47/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
48///
49/// Final example with NVVM:
50/// %pow_f32 = math.fpowi %arg_f32, %arg_i32
51///
52/// will be transformed into
53/// llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
54template <typename SourceOp>
56public:
57 explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
58 StringRef f32Func, StringRef f64Func,
59 StringRef f32ApproxFunc, StringRef f16Func,
60 StringRef i32Func = "",
61 PatternBenefit benefit = 1)
62 : ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
65
66 LogicalResult
67 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
68 ConversionPatternRewriter &rewriter) const override {
69 using LLVM::LLVMFuncOp;
70
71 static_assert(
72 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
73 "expected single result op");
74
75 bool isResultBool = op->getResultTypes().front().isInteger(1);
76 if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
77 SourceOp>::value) {
78 assert(op->getNumOperands() > 0 &&
79 "expected op to take at least one operand");
80 assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
81 isResultBool) &&
82 "expected op with same operand and result types");
83 }
84
85 if (!op->template getParentOfType<FunctionOpInterface>()) {
86 return rewriter.notifyMatchFailure(
87 op, "expected op to be within a function region");
88 }
89
90 SmallVector<Value, 1> castedOperands;
91 for (Value operand : adaptor.getOperands())
92 castedOperands.push_back(maybeCast(operand, rewriter));
93
94 Type castedOperandType = castedOperands.front().getType();
95
96 // At ABI level, booleans are treated as i32.
97 Type resultType =
98 isResultBool ? rewriter.getIntegerType(32) : castedOperandType;
99 Type funcType = getFunctionType(resultType, castedOperands);
100 StringRef funcName = getFunctionName(castedOperandType, op);
101 if (funcName.empty())
102 return failure();
103
104 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
105 auto callOp =
106 LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
107
108 if (resultType == adaptor.getOperands().front().getType()) {
109 rewriter.replaceOp(op, {callOp.getResult()});
110 return success();
111 }
112
113 // Boolean result are mapping to i32 at the ABI level with zero values being
114 // interpreted as false and non-zero values being interpreted as true. Since
115 // there is no guarantee of a specific value being used to indicate true,
116 // compare for inequality with zero (rather than truncate or shift).
117 if (isResultBool) {
118 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
119 rewriter.getIntegerType(32),
120 rewriter.getI32IntegerAttr(0));
121 Value truncated =
122 LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
123 callOp.getResult(), zero);
124 rewriter.replaceOp(op, {truncated});
125 return success();
126 }
127
128 assert(callOp.getResult().getType().isF32() &&
129 "only f32 types are supposed to be truncated back");
130 Value truncated = LLVM::FPTruncOp::create(
131 rewriter, op->getLoc(), adaptor.getOperands().front().getType(),
132 callOp.getResult());
133 rewriter.replaceOp(op, {truncated});
134 return success();
135 }
136
137 Value maybeCast(Value operand, PatternRewriter &rewriter) const {
138 Type type = operand.getType();
139 if (!isa<Float16Type, BFloat16Type>(type))
140 return operand;
141
142 // If there's an f16 function, no need to cast f16 values.
143 if (!f16Func.empty() && isa<Float16Type>(type))
144 return operand;
145
146 return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
147 Float32Type::get(rewriter.getContext()),
148 operand);
149 }
150
151 Type getFunctionType(Type resultType, ValueRange operands) const {
152 SmallVector<Type> operandTypes(operands.getTypes());
153 return LLVM::LLVMFunctionType::get(resultType, operandTypes);
154 }
155
156 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
157 Operation *op) const {
158 using LLVM::LLVMFuncOp;
159
160 auto funcAttr = StringAttr::get(op->getContext(), funcName);
161 auto funcOp =
163 if (funcOp)
164 return funcOp;
165
166 auto parentFunc = op->getParentOfType<FunctionOpInterface>();
167 assert(parentFunc && "expected there to be a parent function");
168 OpBuilder b(parentFunc);
169
170 // Create a valid global location removing any metadata attached to the
171 // location as debug info metadata inside of a function cannot be used
172 // outside of that function.
173 auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
174 return LLVMFuncOp::create(b, globalloc, funcName, funcType);
175 }
176
177 StringRef getFunctionName(Type type, SourceOp op) const {
178 bool useApprox = false;
179 if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
180 arith::FastMathFlags flag = op.getFastmath();
181 useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
182 !f32ApproxFunc.empty();
183 }
184
185 if (isa<Float16Type>(type))
186 return f16Func;
187 if (isa<Float32Type>(type)) {
188 if (useApprox)
189 return f32ApproxFunc;
190 return f32Func;
191 }
192 if (isa<Float64Type>(type))
193 return f64Func;
194
195 if (type.isInteger(32))
196 return i32Func;
197 return "";
198 }
199
200 const std::string f32Func;
201 const std::string f64Func;
202 const std::string f32ApproxFunc;
203 const std::string f16Func;
204 const std::string i32Func;
205};
206
207} // namespace mlir
208
209#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext * getContext() const
Definition Builders.h:56
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Conversion from types to the LLVM IR dialect.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition Location.h:60
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Include the generated interface declarations.
StringRef getFunctionName(Type type, SourceOp op) const
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1)
Type getFunctionType(Type resultType, ValueRange operands) const
Value maybeCast(Value operand, PatternRewriter &rewriter) const