MLIR  20.0.0git
OpToFuncCallLowering.h
Go to the documentation of this file.
1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
15 #include "mlir/IR/Builders.h"
16 
17 namespace mlir {
18 
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20 /// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
21 /// Op. The function declaration is added in case it was not added before.
22 ///
23 /// If the input values are of f16 type, the value is first casted to f32, the
24 /// function called and then the result casted back.
25 ///
26 /// Example with NVVM:
27 /// %exp_f32 = math.exp %arg_f32 : f32
28 ///
29 /// will be transformed into
30 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
31 ///
32 /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
33 /// to the approximate calculation function.
34 ///
35 /// Also example with NVVM:
36 /// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
37 ///
38 /// will be transformed into
39 /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
40 template <typename SourceOp>
41 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
42 public:
43  explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
44  StringRef f64Func, StringRef f32ApproxFunc)
45  : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
46  f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
47 
48  LogicalResult
49  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  using LLVM::LLVMFuncOp;
52 
53  static_assert(
54  std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
55  "expected single result op");
56 
57  static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
58  SourceOp>::value,
59  "expected op with same operand and result types");
60 
61  SmallVector<Value, 1> castedOperands;
62  for (Value operand : adaptor.getOperands())
63  castedOperands.push_back(maybeCast(operand, rewriter));
64 
65  Type resultType = castedOperands.front().getType();
66  Type funcType = getFunctionType(resultType, castedOperands);
67  StringRef funcName =
68  getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
69  op.getFastmath());
70  if (funcName.empty())
71  return failure();
72 
73  LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
74  auto callOp =
75  rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
76 
77  if (resultType == adaptor.getOperands().front().getType()) {
78  rewriter.replaceOp(op, {callOp.getResult()});
79  return success();
80  }
81 
82  Value truncated = rewriter.create<LLVM::FPTruncOp>(
83  op->getLoc(), adaptor.getOperands().front().getType(),
84  callOp.getResult());
85  rewriter.replaceOp(op, {truncated});
86  return success();
87  }
88 
89 private:
90  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
91  Type type = operand.getType();
92  if (!isa<Float16Type>(type))
93  return operand;
94 
95  return rewriter.create<LLVM::FPExtOp>(
96  operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
97  }
98 
99  Type getFunctionType(Type resultType, ValueRange operands) const {
100  SmallVector<Type> operandTypes(operands.getTypes());
101  return LLVM::LLVMFunctionType::get(resultType, operandTypes);
102  }
103 
104  StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
105  if (isa<Float32Type>(type)) {
106  if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
107  !f32ApproxFunc.empty())
108  return f32ApproxFunc;
109  else
110  return f32Func;
111  }
112  if (isa<Float64Type>(type))
113  return f64Func;
114  return "";
115  }
116 
117  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
118  Operation *op) const {
119  using LLVM::LLVMFuncOp;
120 
121  auto funcAttr = StringAttr::get(op->getContext(), funcName);
122  Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
123  if (funcOp)
124  return cast<LLVMFuncOp>(*funcOp);
125 
126  mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
127  return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
128  }
129 
130  const std::string f32Func;
131  const std::string f64Func;
132  const std::string f32ApproxFunc;
133 };
134 
135 } // namespace mlir
136 
137 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class helps build Operations.
Definition: Builders.h:212
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class provides return value APIs for ops that are known to have a single result.
Definition: OpDefinition.h:665
This class provides verification for ops that are known to have the same operand and result type.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc depending on the...
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc)