MLIR  21.0.0git
OpToFuncCallLowering.h
Go to the documentation of this file.
1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
15 #include "mlir/IR/Builders.h"
16 
17 namespace mlir {
18 
19 namespace {
20 /// Detection trait tor the `getFastmath` instance method.
21 template <typename T>
22 using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
23 } // namespace
24 
25 /// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
26 /// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
27 /// the fastMathFlag of that Op, if present. The function declaration is added
28 /// in case it was not added before.
29 ///
30 /// If the input values are of bf16 type (or f16 type if f16Func is empty), the
31 /// value is first casted to f32, the function called and then the result casted
32 /// back.
33 ///
34 /// Example with NVVM:
35 /// %exp_f32 = math.exp %arg_f32 : f32
36 ///
37 /// will be transformed into
38 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
39 ///
40 /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
41 /// to the approximate calculation function.
42 ///
43 /// Also example with NVVM:
44 /// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
45 ///
46 /// will be transformed into
47 /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
48 ///
49 /// Final example with NVVM:
50 /// %pow_f32 = math.fpowi %arg_f32, %arg_i32
51 ///
52 /// will be transformed into
53 /// llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
54 template <typename SourceOp>
55 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
56 public:
57  explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
58  StringRef f32Func, StringRef f64Func,
59  StringRef f32ApproxFunc, StringRef f16Func,
60  StringRef i32Func = "",
61  PatternBenefit benefit = 1)
62  : ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
64  i32Func(i32Func) {}
65 
66  LogicalResult
67  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override {
69  using LLVM::LLVMFuncOp;
70 
71  static_assert(
72  std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
73  "expected single result op");
74 
75  bool isResultBool = op->getResultTypes().front().isInteger(1);
76  if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
77  SourceOp>::value) {
78  assert(op->getNumOperands() > 0 &&
79  "expected op to take at least one operand");
80  assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
81  isResultBool) &&
82  "expected op with same operand and result types");
83  }
84 
85  if (!op->template getParentOfType<FunctionOpInterface>()) {
86  return rewriter.notifyMatchFailure(
87  op, "expected op to be within a function region");
88  }
89 
90  SmallVector<Value, 1> castedOperands;
91  for (Value operand : adaptor.getOperands())
92  castedOperands.push_back(maybeCast(operand, rewriter));
93 
94  Type castedOperandType = castedOperands.front().getType();
95 
96  // At ABI level, booleans are treated as i32.
97  Type resultType =
98  isResultBool ? rewriter.getIntegerType(32) : castedOperandType;
99  Type funcType = getFunctionType(resultType, castedOperands);
100  StringRef funcName = getFunctionName(castedOperandType, op);
101  if (funcName.empty())
102  return failure();
103 
104  LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
105  auto callOp =
106  rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
107 
108  if (resultType == adaptor.getOperands().front().getType()) {
109  rewriter.replaceOp(op, {callOp.getResult()});
110  return success();
111  }
112 
113  // Boolean result are mapping to i32 at the ABI level with zero values being
114  // interpreted as false and non-zero values being interpreted as true. Since
115  // there is no guarantee of a specific value being used to indicate true,
116  // compare for inequality with zero (rather than truncate or shift).
117  if (isResultBool) {
118  Value zero = rewriter.create<LLVM::ConstantOp>(
119  op->getLoc(), rewriter.getIntegerType(32),
120  rewriter.getI32IntegerAttr(0));
121  Value truncated = rewriter.create<LLVM::ICmpOp>(
122  op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
123  rewriter.replaceOp(op, {truncated});
124  return success();
125  }
126 
127  assert(callOp.getResult().getType().isF32() &&
128  "only f32 types are supposed to be truncated back");
129  Value truncated = rewriter.create<LLVM::FPTruncOp>(
130  op->getLoc(), adaptor.getOperands().front().getType(),
131  callOp.getResult());
132  rewriter.replaceOp(op, {truncated});
133  return success();
134  }
135 
136  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
137  Type type = operand.getType();
138  if (!isa<Float16Type, BFloat16Type>(type))
139  return operand;
140 
141  // If there's an f16 function, no need to cast f16 values.
142  if (!f16Func.empty() && isa<Float16Type>(type))
143  return operand;
144 
145  return rewriter.create<LLVM::FPExtOp>(
146  operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
147  }
148 
149  Type getFunctionType(Type resultType, ValueRange operands) const {
150  SmallVector<Type> operandTypes(operands.getTypes());
151  return LLVM::LLVMFunctionType::get(resultType, operandTypes);
152  }
153 
154  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
155  Operation *op) const {
156  using LLVM::LLVMFuncOp;
157 
158  auto funcAttr = StringAttr::get(op->getContext(), funcName);
159  auto funcOp =
160  SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
161  if (funcOp)
162  return funcOp;
163 
164  auto parentFunc = op->getParentOfType<FunctionOpInterface>();
165  assert(parentFunc && "expected there to be a parent function");
166  OpBuilder b(parentFunc);
167  return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
168  }
169 
170  StringRef getFunctionName(Type type, SourceOp op) const {
171  bool useApprox = false;
172  if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
173  arith::FastMathFlags flag = op.getFastmath();
174  useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
175  !f32ApproxFunc.empty();
176  }
177 
178  if (isa<Float16Type>(type))
179  return f16Func;
180  if (isa<Float32Type>(type)) {
181  if (useApprox)
182  return f32ApproxFunc;
183  return f32Func;
184  }
185  if (isa<Float64Type>(type))
186  return f64Func;
187 
188  if (type.isInteger(32))
189  return i32Func;
190  return "";
191  }
192 
193  const std::string f32Func;
194  const std::string f64Func;
195  const std::string f32ApproxFunc;
196  const std::string f16Func;
197  const std::string i32Func;
198 };
199 
200 } // namespace mlir
201 
202 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class provides return value APIs for ops that are known to have a single result.
Definition: OpDefinition.h:669
This class provides verification for ops that are known to have the same operand and result type.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
StringRef getFunctionName(Type type, SourceOp op) const
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1)
Type getFunctionType(Type resultType, ValueRange operands) const
Value maybeCast(Value operand, PatternRewriter &rewriter) const