MLIR 23.0.0git
MathToROCDL.cpp
Go to the documentation of this file.
1//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
22#include "mlir/Pass/Pass.h"
24#include "llvm/Support/DebugLog.h"
25
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTMATHTOROCDL
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36#define DEBUG_TYPE "math-to-rocdl"
37
38template <typename OpTy>
39static void populateOpPatterns(const LLVMTypeConverter &converter,
40 RewritePatternSet &patterns, StringRef f32Func,
41 StringRef f64Func, StringRef f16Func,
42 StringRef f32ApproxFunc = "") {
43 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
44 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45 f32ApproxFunc, f16Func);
46}
47
49 : public ConvertOpToLLVMPattern<math::ClampFOp> {
51
52 LogicalResult
53 matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
54 ConversionPatternRewriter &rewriter) const override {
55 // Only f16 and f32 types are supported by fmed3
56 Type opTy = op.getType();
57 Type resultType = getTypeConverter()->convertType(opTy);
58
59 if (auto vectorType = dyn_cast<VectorType>(opTy))
60 opTy = vectorType.getElementType();
61
62 if (!isa<Float16Type, Float32Type>(opTy))
63 return rewriter.notifyMatchFailure(
64 op, "fmed3 only supports f16 and f32 types");
65
66 // Handle multi-dimensional vectors (converted to LLVM arrays)
67 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
69 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
70 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
71 math::ClampFOp::Adaptor adaptor(operands);
72 return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
73 adaptor.getValue(), adaptor.getMin(),
74 adaptor.getMax());
75 },
76 rewriter);
77
78 // Handle 1D vectors and scalars directly
79 rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
80 op.getMin(), op.getMax());
81 return success();
82 }
83};
84
86 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
87 std::optional<amdgpu::Chipset> chipset) {
88 // Handled by mathToLLVM: math::AbsIOp
89 // Handled by mathToLLVM: math::AbsFOp
90 // Handled by mathToLLVM: math::CopySignOp
91 // Handled by mathToLLVM: math::CountLeadingZerosOp
92 // Handled by mathToLLVM: math::CountTrailingZerosOp
93 // Handled by mathToLLVM: math::CgPopOp
94 // Handled by mathToLLVM: math::ExpOp (32-bit only)
95 // Handled by mathToLLVM: math::FmaOp
96 // Handled by mathToLLVM: math::LogOp (32-bit only)
97 // FIXME: math::IPowIOp
98 // Handled by mathToLLVM: math::RoundEvenOp
99 // Handled by mathToLLVM: math::RoundOp
100 // Handled by mathToLLVM: math::SqrtOp
101 // Handled by mathToLLVM: math::TruncOp
102 populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
103 "__ocml_acos_f64", "__ocml_acos_f16");
104 populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
105 "__ocml_acosh_f64", "__ocml_acosh_f16");
106 populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
107 "__ocml_asin_f64", "__ocml_asin_f16");
108 populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
109 "__ocml_asinh_f64", "__ocml_asinh_f16");
110 populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
111 "__ocml_atan_f64", "__ocml_atan_f16");
112 populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
113 "__ocml_atanh_f64", "__ocml_atanh_f16");
114 populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
115 "__ocml_atan2_f64", "__ocml_atan2_f16");
116 populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
117 "__ocml_cbrt_f64", "__ocml_cbrt_f16");
118 populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
119 "__ocml_ceil_f64", "__ocml_ceil_f16");
120 populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
121 "__ocml_cos_f64", "__ocml_cos_f16");
122 populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
123 "__ocml_cosh_f64", "__ocml_cosh_f16");
124 populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
125 "__ocml_sinh_f64", "__ocml_sinh_f16");
126 populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
127 "__ocml_exp_f16");
128 populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
129 "__ocml_exp2_f64", "__ocml_exp2_f16");
130 populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
131 "__ocml_expm1_f64", "__ocml_expm1_f16");
132 populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
133 "__ocml_floor_f64", "__ocml_floor_f16");
134 populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
135 "__ocml_log_f16");
136 populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
137 "__ocml_log10_f64", "__ocml_log10_f16");
138 populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
139 "__ocml_log1p_f64", "__ocml_log1p_f16");
140 populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
141 "__ocml_log2_f64", "__ocml_log2_f16");
142 populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
143 "__ocml_pow_f64", "__ocml_pow_f16");
144 populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
145 "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
146 populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
147 "__ocml_sin_f64", "__ocml_sin_f16");
148 populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
149 "__ocml_tanh_f64", "__ocml_tanh_f16");
150 populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
151 "__ocml_tan_f64", "__ocml_tan_f16");
152 populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
153 "__ocml_erf_f64", "__ocml_erf_f16");
154 populateOpPatterns<math::ErfcOp>(converter, patterns, "__ocml_erfc_f32",
155 "__ocml_erfc_f64", "__ocml_erfc_f16");
156 populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
157 "__ocml_pown_f64", "__ocml_pown_f16");
158 // Single arith pattern that needs a ROCDL call, probably not
159 // worth creating a separate pass for it.
160 populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
161 "__ocml_fmod_f64", "__ocml_fmod_f16");
162
163 if (chipset.has_value() && chipset->majorVersion >= 9) {
164 patterns.add<ClampFOpConversion>(converter);
165 } else {
166 LDBG() << "Chipset dependent patterns were not added";
167 }
168}
169
171 : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
172 using impl::ConvertMathToROCDLBase<
173 ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
174
175 void runOnOperation() override;
176};
177
179 auto m = getOperation();
180 MLIRContext *ctx = m.getContext();
181
182 RewritePatternSet patterns(&getContext());
184 LLVMTypeConverter converter(ctx, options);
185
186 FailureOr<amdgpu::Chipset> maybeChipset;
187 if (!chipset.empty()) {
188 maybeChipset = amdgpu::Chipset::parse(chipset);
189 if (failed(maybeChipset))
190 return signalPassFailure();
191 }
193 converter, patterns,
194 succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
195
197 target
198 .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
199 LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
200 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
201 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
202 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
203 LLVM::SqrtOp>();
204 if (failed(applyPartialConversion(m, target, std::move(patterns))))
205 signalPassFailure();
206}
return success()
b getContext())
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f16Func, StringRef f32ApproxFunc="")
static llvm::ManagedStatic< PassManagerOptions > options
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename math::ClampFOp::Adaptor OpAdaptor
Definition Pattern.h:229
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:29
The main mechanism for performing data layout queries.
Conversion from types to the LLVM IR dialect.
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Include the generated interface declarations.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
LogicalResult matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14