23 #include "../GPUCommon/GPUOpsLowering.h"
24 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
25 #include "../GPUCommon/OpToFuncCallLowering.h"
29 #define GEN_PASS_DEF_CONVERTMATHTOROCDL
30 #include "mlir/Conversion/Passes.h.inc"
35 #define DEBUG_TYPE "math-to-rocdl"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38 template <
typename OpTy>
41 StringRef f64Func, StringRef f16Func,
42 StringRef f32ApproxFunc =
"") {
45 f32ApproxFunc, f16Func);
64 populateOpPatterns<math::AcosOp>(converter,
patterns,
"__ocml_acos_f32",
65 "__ocml_acos_f64",
"__ocml_acos_f16");
66 populateOpPatterns<math::AcoshOp>(converter,
patterns,
"__ocml_acosh_f32",
67 "__ocml_acosh_f64",
"__ocml_acosh_f16");
68 populateOpPatterns<math::AsinOp>(converter,
patterns,
"__ocml_asin_f32",
69 "__ocml_asin_f64",
"__ocml_asin_f16");
70 populateOpPatterns<math::AsinhOp>(converter,
patterns,
"__ocml_asinh_f32",
71 "__ocml_asinh_f64",
"__ocml_asinh_f16");
72 populateOpPatterns<math::AtanOp>(converter,
patterns,
"__ocml_atan_f32",
73 "__ocml_atan_f64",
"__ocml_atan_f16");
74 populateOpPatterns<math::AtanhOp>(converter,
patterns,
"__ocml_atanh_f32",
75 "__ocml_atanh_f64",
"__ocml_atanh_f16");
76 populateOpPatterns<math::Atan2Op>(converter,
patterns,
"__ocml_atan2_f32",
77 "__ocml_atan2_f64",
"__ocml_atan2_f16");
78 populateOpPatterns<math::CbrtOp>(converter,
patterns,
"__ocml_cbrt_f32",
79 "__ocml_cbrt_f64",
"__ocml_cbrt_f16");
80 populateOpPatterns<math::CeilOp>(converter,
patterns,
"__ocml_ceil_f32",
81 "__ocml_ceil_f64",
"__ocml_ceil_f16");
82 populateOpPatterns<math::CosOp>(converter,
patterns,
"__ocml_cos_f32",
83 "__ocml_cos_f64",
"__ocml_cos_f16");
84 populateOpPatterns<math::CoshOp>(converter,
patterns,
"__ocml_cosh_f32",
85 "__ocml_cosh_f64",
"__ocml_cosh_f16");
86 populateOpPatterns<math::SinhOp>(converter,
patterns,
"__ocml_sinh_f32",
87 "__ocml_sinh_f64",
"__ocml_sinh_f16");
88 populateOpPatterns<math::ExpOp>(converter,
patterns,
"",
"__ocml_exp_f64",
90 populateOpPatterns<math::Exp2Op>(converter,
patterns,
"__ocml_exp2_f32",
91 "__ocml_exp2_f64",
"__ocml_exp2_f16");
92 populateOpPatterns<math::ExpM1Op>(converter,
patterns,
"__ocml_expm1_f32",
93 "__ocml_expm1_f64",
"__ocml_expm1_f16");
94 populateOpPatterns<math::FloorOp>(converter,
patterns,
"__ocml_floor_f32",
95 "__ocml_floor_f64",
"__ocml_floor_f16");
96 populateOpPatterns<math::LogOp>(converter,
patterns,
"",
"__ocml_log_f64",
98 populateOpPatterns<math::Log10Op>(converter,
patterns,
"__ocml_log10_f32",
99 "__ocml_log10_f64",
"__ocml_log10_f16");
100 populateOpPatterns<math::Log1pOp>(converter,
patterns,
"__ocml_log1p_f32",
101 "__ocml_log1p_f64",
"__ocml_log1p_f16");
102 populateOpPatterns<math::Log2Op>(converter,
patterns,
"__ocml_log2_f32",
103 "__ocml_log2_f64",
"__ocml_log2_f16");
104 populateOpPatterns<math::PowFOp>(converter,
patterns,
"__ocml_pow_f32",
105 "__ocml_pow_f64",
"__ocml_pow_f16");
106 populateOpPatterns<math::RsqrtOp>(converter,
patterns,
"__ocml_rsqrt_f32",
107 "__ocml_rsqrt_f64",
"__ocml_rsqrt_f16");
108 populateOpPatterns<math::SinOp>(converter,
patterns,
"__ocml_sin_f32",
109 "__ocml_sin_f64",
"__ocml_sin_f16");
110 populateOpPatterns<math::TanhOp>(converter,
patterns,
"__ocml_tanh_f32",
111 "__ocml_tanh_f64",
"__ocml_tanh_f16");
112 populateOpPatterns<math::TanOp>(converter,
patterns,
"__ocml_tan_f32",
113 "__ocml_tan_f64",
"__ocml_tan_f16");
114 populateOpPatterns<math::ErfOp>(converter,
patterns,
"__ocml_erf_f32",
115 "__ocml_erf_f64",
"__ocml_erf_f16");
116 populateOpPatterns<math::FPowIOp>(converter,
patterns,
"__ocml_pown_f32",
117 "__ocml_pown_f64",
"__ocml_pown_f16");
120 populateOpPatterns<arith::RemFOp>(converter,
patterns,
"__ocml_fmod_f32",
121 "__ocml_fmod_f64",
"__ocml_fmod_f16");
125 struct ConvertMathToROCDLPass
126 :
public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
127 ConvertMathToROCDLPass() =
default;
128 void runOnOperation()
override;
132 void ConvertMathToROCDLPass::runOnOperation() {
133 auto m = getOperation();
141 target.addLegalDialect<BuiltinDialect, func::FuncDialect,
142 vector::VectorDialect, LLVM::LLVMDialect>();
143 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
144 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
145 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f16Func, StringRef f32ApproxFunc="")
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
The main mechanism for performing data layout queries.
Conversion from types to the LLVM IR dialect.
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from Math to ROCDL calls.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.