1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
15 #include "mlir/IR/Builders.h"
17 namespace mlir {
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
20 /// depending on the element type that Op operates upon. The function
21 /// declaration is added in case it was not added before.
22 ///
23 /// If the input values are of f16 type, the value is first casted to f32, the
24 /// function called and then the result casted back.
25 ///
26 /// Example with NVVM:
27 /// %exp_f32 = std.exp %arg_f32 : f32
28 ///
29 /// will be transformed into
30 /// @__nv_expf(%arg_f32) : (f32) -> f32
31 template <typename SourceOp>
32 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
33 public:
34  explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
35  StringRef f64Func)
36  : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
37  f64Func(f64Func) {}
40  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override {
42  using LLVM::LLVMFuncOp;
44  static_assert(
45  std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
46  "expected single result op");
48  static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
49  SourceOp>::value,
50  "expected op with same operand and result types");
52  SmallVector<Value, 1> castedOperands;
53  for (Value operand : adaptor.getOperands())
54  castedOperands.push_back(maybeCast(operand, rewriter));
56  Type resultType = castedOperands.front().getType();
57  Type funcType = getFunctionType(resultType, castedOperands);
58  StringRef funcName = getFunctionName(
60  if (funcName.empty())
61  return failure();
63  LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
64  auto callOp = rewriter.create<LLVM::CallOp>(
65  op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
67  if (resultType == adaptor.getOperands().front().getType()) {
68  rewriter.replaceOp(op, {callOp.getResult(0)});
69  return success();
70  }
72  Value truncated = rewriter.create<LLVM::FPTruncOp>(
73  op->getLoc(), adaptor.getOperands().front().getType(),
74  callOp.getResult(0));
75  rewriter.replaceOp(op, {truncated});
76  return success();
77  }
79 private:
80  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
81  Type type = operand.getType();
82  if (!type.isa<Float16Type>())
83  return operand;
85  return rewriter.create<LLVM::FPExtOp>(
86  operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
87  }
89  Type getFunctionType(Type resultType, ValueRange operands) const {
90  SmallVector<Type> operandTypes(operands.getTypes());
91  return LLVM::LLVMFunctionType::get(resultType, operandTypes);
92  }
94  StringRef getFunctionName(Type type) const {
95  if (type.isa<Float32Type>())
96  return f32Func;
97  if (type.isa<Float64Type>())
98  return f64Func;
99  return "";
100  }
102  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
103  Operation *op) const {
104  using LLVM::LLVMFuncOp;
106  auto funcAttr = StringAttr::get(op->getContext(), funcName);
107  Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
108  if (funcOp)
109  return cast<LLVMFuncOp>(*funcOp);
111  mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
112  return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
113  }
115  const std::string f32Func;
116  const std::string f64Func;
117 };
119 } // namespace mlir
