MLIR  20.0.0git
MathToROCDL.cpp
Go to the documentation of this file.
1 //===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
22 
23 #include "../GPUCommon/GPUOpsLowering.h"
24 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
25 #include "../GPUCommon/OpToFuncCallLowering.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTMATHTOROCDL
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "math-to-rocdl"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37 
38 template <typename OpTy>
39 static void populateOpPatterns(const LLVMTypeConverter &converter,
40  RewritePatternSet &patterns, StringRef f32Func,
41  StringRef f64Func, StringRef f16Func,
42  StringRef f32ApproxFunc = "") {
43  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
44  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45  f32ApproxFunc, f16Func);
46 }
47 
49  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
50  // Handled by mathToLLVM: math::AbsIOp
51  // Handled by mathToLLVM: math::AbsFOp
52  // Handled by mathToLLVM: math::CopySignOp
53  // Handled by mathToLLVM: math::CountLeadingZerosOp
54  // Handled by mathToLLVM: math::CountTrailingZerosOp
55  // Handled by mathToLLVM: math::CgPopOp
56  // Handled by mathToLLVM: math::ExpOp (32-bit only)
57  // Handled by mathToLLVM: math::FmaOp
58  // Handled by mathToLLVM: math::LogOp (32-bit only)
59  // FIXME: math::IPowIOp
60  // FIXME: math::FPowIOp
61  // Handled by mathToLLVM: math::RoundEvenOp
62  // Handled by mathToLLVM: math::RoundOp
63  // Handled by mathToLLVM: math::SqrtOp
64  // Handled by mathToLLVM: math::TruncOp
65  populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
66  "__ocml_acos_f64", "__ocml_acos_f16");
67  populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
68  "__ocml_acosh_f64", "__ocml_acosh_f16");
69  populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
70  "__ocml_asin_f64", "__ocml_asin_f16");
71  populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
72  "__ocml_asinh_f64", "__ocml_asinh_f16");
73  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
74  "__ocml_atan_f64", "__ocml_atan_f16");
75  populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
76  "__ocml_atanh_f64", "__ocml_atanh_f16");
77  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
78  "__ocml_atan2_f64", "__ocml_atan2_f16");
79  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
80  "__ocml_cbrt_f64", "__ocml_cbrt_f16");
81  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
82  "__ocml_ceil_f64", "__ocml_ceil_f16");
83  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
84  "__ocml_cos_f64", "__ocml_cos_f16");
85  populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
86  "__ocml_cosh_f64", "__ocml_cosh_f16");
87  populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
88  "__ocml_sinh_f64", "__ocml_sinh_f16");
89  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
90  "__ocml_exp_f16");
91  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
92  "__ocml_exp2_f64", "__ocml_exp2_f16");
93  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
94  "__ocml_expm1_f64", "__ocml_expm1_f16");
95  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
96  "__ocml_floor_f64", "__ocml_floor_f16");
97  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
98  "__ocml_log_f16");
99  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
100  "__ocml_log10_f64", "__ocml_log10_f16");
101  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
102  "__ocml_log1p_f64", "__ocml_log1p_f16");
103  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
104  "__ocml_log2_f64", "__ocml_log2_f16");
105  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
106  "__ocml_pow_f64", "__ocml_pow_f16");
107  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
108  "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
109  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
110  "__ocml_sin_f64", "__ocml_sin_f16");
111  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
112  "__ocml_tanh_f64", "__ocml_tanh_f16");
113  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
114  "__ocml_tan_f64", "__ocml_tan_f16");
115  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
116  "__ocml_erf_f64", "__ocml_erf_f16");
117  // Single arith pattern that needs a ROCDL call, probably not
118  // worth creating a separate pass for it.
119  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
120  "__ocml_fmod_f64", "__ocml_fmod_f16");
121 }
122 
123 namespace {
124 struct ConvertMathToROCDLPass
125  : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
126  ConvertMathToROCDLPass() = default;
127  void runOnOperation() override;
128 };
129 } // namespace
130 
131 void ConvertMathToROCDLPass::runOnOperation() {
132  auto m = getOperation();
133  MLIRContext *ctx = m.getContext();
134 
135  RewritePatternSet patterns(&getContext());
137  LLVMTypeConverter converter(ctx, options);
138  populateMathToROCDLConversionPatterns(converter, patterns);
139  ConversionTarget target(getContext());
140  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
141  vector::VectorDialect, LLVM::LLVMDialect>();
142  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
143  LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
144  LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
145  LLVM::SqrtOp>();
146  if (failed(applyPartialConversion(m, target, std::move(patterns))))
147  signalPassFailure();
148 }
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f16Func, StringRef f32ApproxFunc="")
Definition: MathToROCDL.cpp:39
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
The main mechanism for performing data layout queries.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from Math to ROCDL calls.
Definition: MathToROCDL.cpp:48
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.