MLIR  20.0.0git
OpToFuncCallLowering.h
Go to the documentation of this file.
1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
15 #include "mlir/IR/Builders.h"
16 
17 namespace mlir {
18 
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20 /// `f32ApproxFunc` or `f16Func` depending on the element type and the
21 /// fastMathFlag of that Op. The function declaration is added in case it was
22 /// not added before.
23 ///
24 /// If the input values are of bf16 type (or f16 type if f16Func is empty), the
25 /// value is first casted to f32, the function called and then the result casted
26 /// back.
27 ///
28 /// Example with NVVM:
29 /// %exp_f32 = math.exp %arg_f32 : f32
30 ///
31 /// will be transformed into
32 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
33 ///
34 /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
35 /// to the approximate calculation function.
36 ///
37 /// Also example with NVVM:
38 /// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
39 ///
40 /// will be transformed into
41 /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
42 template <typename SourceOp>
43 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
44 public:
45  explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
46  StringRef f32Func, StringRef f64Func,
47  StringRef f32ApproxFunc, StringRef f16Func)
48  : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
49  f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
50 
51  LogicalResult
52  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override {
54  using LLVM::LLVMFuncOp;
55 
56  static_assert(
57  std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
58  "expected single result op");
59 
60  static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
61  SourceOp>::value,
62  "expected op with same operand and result types");
63 
64  if (!op->template getParentOfType<FunctionOpInterface>()) {
65  return rewriter.notifyMatchFailure(
66  op, "expected op to be within a function region");
67  }
68 
69  SmallVector<Value, 1> castedOperands;
70  for (Value operand : adaptor.getOperands())
71  castedOperands.push_back(maybeCast(operand, rewriter));
72 
73  Type resultType = castedOperands.front().getType();
74  Type funcType = getFunctionType(resultType, castedOperands);
75  StringRef funcName =
76  getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
77  op.getFastmath());
78  if (funcName.empty())
79  return failure();
80 
81  LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
82  auto callOp =
83  rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
84 
85  if (resultType == adaptor.getOperands().front().getType()) {
86  rewriter.replaceOp(op, {callOp.getResult()});
87  return success();
88  }
89 
90  Value truncated = rewriter.create<LLVM::FPTruncOp>(
91  op->getLoc(), adaptor.getOperands().front().getType(),
92  callOp.getResult());
93  rewriter.replaceOp(op, {truncated});
94  return success();
95  }
96 
97 private:
98  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
99  Type type = operand.getType();
100  if (!isa<Float16Type, BFloat16Type>(type))
101  return operand;
102 
103  // if there's a f16 function, no need to cast f16 values
104  if (!f16Func.empty() && isa<Float16Type>(type))
105  return operand;
106 
107  return rewriter.create<LLVM::FPExtOp>(
108  operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
109  }
110 
111  Type getFunctionType(Type resultType, ValueRange operands) const {
112  SmallVector<Type> operandTypes(operands.getTypes());
113  return LLVM::LLVMFunctionType::get(resultType, operandTypes);
114  }
115 
116  StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
117  if (isa<Float16Type>(type))
118  return f16Func;
119  if (isa<Float32Type>(type)) {
120  if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
121  !f32ApproxFunc.empty())
122  return f32ApproxFunc;
123  else
124  return f32Func;
125  }
126  if (isa<Float64Type>(type))
127  return f64Func;
128  return "";
129  }
130 
131  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
132  Operation *op) const {
133  using LLVM::LLVMFuncOp;
134 
135  auto funcAttr = StringAttr::get(op->getContext(), funcName);
136  Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
137  if (funcOp)
138  return cast<LLVMFuncOp>(*funcOp);
139 
140  mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
141  return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
142  }
143 
144  const std::string f32Func;
145  const std::string f64Func;
146  const std::string f32ApproxFunc;
147  const std::string f16Func;
148 };
149 
150 } // namespace mlir
151 
152 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class provides return value APIs for ops that are known to have a single result.
Definition: OpDefinition.h:665
This class provides verification for ops that are known to have the same operand and result type.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func)
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override