MLIR  14.0.0git
Go to the documentation of this file.
1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
15 #include "mlir/IR/Builders.h"
17 namespace mlir {
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
20 /// depending on the element type that Op operates upon. The function
21 /// declaration is added in case it was not added before.
22 ///
23 /// If the input values are of f16 type, the value is first casted to f32, the
24 /// function called and then the result casted back.
25 ///
26 /// Example with NVVM:
27 /// %exp_f32 = std.exp %arg_f32 : f32
28 ///
29 /// will be transformed into
30 /// @__nv_expf(%arg_f32) : (f32) -> f32
31 template <typename SourceOp>
32 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
33 public:
34  explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
35  StringRef f64Func)
36  : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
37  f64Func(f64Func) {}
40  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override {
42  using LLVM::LLVMFuncOp;
44  static_assert(
45  std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
46  "expected single result op");
48  static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
49  SourceOp>::value,
50  "expected op with same operand and result types");
52  SmallVector<Value, 1> castedOperands;
53  for (Value operand : adaptor.getOperands())
54  castedOperands.push_back(maybeCast(operand, rewriter));
56  Type resultType = castedOperands.front().getType();
57  Type funcType = getFunctionType(resultType, castedOperands);
58  StringRef funcName = getFunctionName(
60  if (funcName.empty())
61  return failure();
63  LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
64  auto callOp = rewriter.create<LLVM::CallOp>(
65  op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
67  if (resultType == adaptor.getOperands().front().getType()) {
68  rewriter.replaceOp(op, {callOp.getResult(0)});
69  return success();
70  }
72  Value truncated = rewriter.create<LLVM::FPTruncOp>(
73  op->getLoc(), adaptor.getOperands().front().getType(),
74  callOp.getResult(0));
75  rewriter.replaceOp(op, {truncated});
76  return success();
77  }
79 private:
80  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
81  Type type = operand.getType();
82  if (!type.isa<Float16Type>())
83  return operand;
85  return rewriter.create<LLVM::FPExtOp>(
86  operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
87  }
89  Type getFunctionType(Type resultType, ValueRange operands) const {
90  SmallVector<Type> operandTypes(operands.getTypes());
91  return LLVM::LLVMFunctionType::get(resultType, operandTypes);
92  }
94  StringRef getFunctionName(Type type) const {
95  if (type.isa<Float32Type>())
96  return f32Func;
97  if (type.isa<Float64Type>())
98  return f64Func;
99  return "";
100  }
102  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
103  Operation *op) const {
104  using LLVM::LLVMFuncOp;
106  auto funcAttr = StringAttr::get(op->getContext(), funcName);
107  Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
108  if (funcOp)
109  return cast<LLVMFuncOp>(*funcOp);
111  mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
112  return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
113  }
115  const std::string f32Func;
116  const std::string f64Func;
117 };
119 } // namespace mlir
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
MLIRContext * getContext() const
Definition: Builders.h:54
OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This class provides return value APIs for ops that are known to have a single result.
Definition: OpDefinition.h:607
LLVM dialect function type.
Definition: LLVMTypes.h:123
static LLVMFunctionType get(Type result, ArrayRef< Type > arguments, bool isVarArg=false)
Gets or creates an instance of LLVM dialect function in the same context as the result type...
Definition: LLVMTypes.cpp:101
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:120
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides verification for ops that are known to have the same operand and result type...
Type getReturnType()
Returns the result type of the function.
Definition: LLVMTypes.cpp:122
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Type getType() const
Return the type of this value.
Definition: Value.h:117
type_range getTypes() const
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: TypeConverter.h:30
This class implements a pattern rewriter for use with ConversionPatterns.
bool isa() const
Definition: Types.h:234
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func depending on the element type tha...
U cast() const
Definition: Types.h:250