MLIR 22.0.0git
MathToROCDL.cpp
Go to the documentation of this file.
1//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
22#include "mlir/Pass/Pass.h"
24#include "llvm/Support/DebugLog.h"
25
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTMATHTOROCDL
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36#define DEBUG_TYPE "math-to-rocdl"
37
38template <typename OpTy>
39static void populateOpPatterns(const LLVMTypeConverter &converter,
40 RewritePatternSet &patterns, StringRef f32Func,
41 StringRef f64Func, StringRef f16Func,
42 StringRef f32ApproxFunc = "") {
44 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45 f32ApproxFunc, f16Func);
46}
47
49 : public ConvertOpToLLVMPattern<math::ClampFOp> {
51
52 LogicalResult
53 matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
54 ConversionPatternRewriter &rewriter) const override {
55 // Only f16 and f32 types are supported by fmed3
56 Type opTy = op.getType();
57 Type resultType = getTypeConverter()->convertType(opTy);
58
59 if (auto vectorType = dyn_cast<VectorType>(opTy))
60 opTy = vectorType.getElementType();
61
62 if (!isa<Float16Type, Float32Type>(opTy))
63 return rewriter.notifyMatchFailure(
64 op, "fmed3 only supports f16 and f32 types");
65
66 // Handle multi-dimensional vectors (converted to LLVM arrays)
67 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
69 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
70 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
71 math::ClampFOp::Adaptor adaptor(operands);
72 return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
73 adaptor.getValue(), adaptor.getMin(),
74 adaptor.getMax());
75 },
76 rewriter);
77
78 // Handle 1D vectors and scalars directly
79 rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
80 op.getMin(), op.getMax());
81 return success();
82 }
83};
84
87 std::optional<amdgpu::Chipset> chipset) {
88 // Handled by mathToLLVM: math::AbsIOp
89 // Handled by mathToLLVM: math::AbsFOp
90 // Handled by mathToLLVM: math::CopySignOp
91 // Handled by mathToLLVM: math::CountLeadingZerosOp
92 // Handled by mathToLLVM: math::CountTrailingZerosOp
93 // Handled by mathToLLVM: math::CgPopOp
94 // Handled by mathToLLVM: math::ExpOp (32-bit only)
95 // Handled by mathToLLVM: math::FmaOp
96 // Handled by mathToLLVM: math::LogOp (32-bit only)
97 // FIXME: math::IPowIOp
98 // Handled by mathToLLVM: math::RoundEvenOp
99 // Handled by mathToLLVM: math::RoundOp
100 // Handled by mathToLLVM: math::SqrtOp
101 // Handled by mathToLLVM: math::TruncOp
102 populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
103 "__ocml_acos_f64", "__ocml_acos_f16");
104 populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
105 "__ocml_acosh_f64", "__ocml_acosh_f16");
106 populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
107 "__ocml_asin_f64", "__ocml_asin_f16");
108 populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
109 "__ocml_asinh_f64", "__ocml_asinh_f16");
110 populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
111 "__ocml_atan_f64", "__ocml_atan_f16");
112 populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
113 "__ocml_atanh_f64", "__ocml_atanh_f16");
114 populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
115 "__ocml_atan2_f64", "__ocml_atan2_f16");
116 populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
117 "__ocml_cbrt_f64", "__ocml_cbrt_f16");
118 populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
119 "__ocml_ceil_f64", "__ocml_ceil_f16");
120 populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
121 "__ocml_cos_f64", "__ocml_cos_f16");
122 populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
123 "__ocml_cosh_f64", "__ocml_cosh_f16");
124 populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
125 "__ocml_sinh_f64", "__ocml_sinh_f16");
126 populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
127 "__ocml_exp_f16");
128 populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
129 "__ocml_exp2_f64", "__ocml_exp2_f16");
130 populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
131 "__ocml_expm1_f64", "__ocml_expm1_f16");
132 populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
133 "__ocml_floor_f64", "__ocml_floor_f16");
134 populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
135 "__ocml_log_f16");
136 populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
137 "__ocml_log10_f64", "__ocml_log10_f16");
138 populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
139 "__ocml_log1p_f64", "__ocml_log1p_f16");
140 populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
141 "__ocml_log2_f64", "__ocml_log2_f16");
142 populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
143 "__ocml_pow_f64", "__ocml_pow_f16");
144 populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
145 "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
146 populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
147 "__ocml_sin_f64", "__ocml_sin_f16");
148 populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
149 "__ocml_tanh_f64", "__ocml_tanh_f16");
150 populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
151 "__ocml_tan_f64", "__ocml_tan_f16");
152 populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
153 "__ocml_erf_f64", "__ocml_erf_f16");
154 populateOpPatterns<math::ErfcOp>(converter, patterns, "__ocml_erfc_f32",
155 "__ocml_erfc_f64", "__ocml_erfc_f16");
156 populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
157 "__ocml_pown_f64", "__ocml_pown_f16");
158 // Single arith pattern that needs a ROCDL call, probably not
159 // worth creating a separate pass for it.
160 populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
161 "__ocml_fmod_f64", "__ocml_fmod_f16");
162
163 if (chipset.has_value() && chipset->majorVersion >= 9) {
164 patterns.add<ClampFOpConversion>(converter);
165 } else {
166 LDBG() << "Chipset dependent patterns were not added";
167 }
168}
169
171 : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
172 using impl::ConvertMathToROCDLBase<
173 ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
174
175 void runOnOperation() override;
176};
177
179 auto m = getOperation();
180 MLIRContext *ctx = m.getContext();
181
184 LLVMTypeConverter converter(ctx, options);
185
186 FailureOr<amdgpu::Chipset> maybeChipset;
187 if (!chipset.empty()) {
188 maybeChipset = amdgpu::Chipset::parse(chipset);
189 if (failed(maybeChipset))
190 return signalPassFailure();
191 }
193 converter, patterns,
194 succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
195
197 target
198 .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
199 LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
200 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
201 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
202 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
203 LLVM::SqrtOp>();
204 if (failed(applyPartialConversion(m, target, std::move(patterns))))
205 signalPassFailure();
206}
return success()
b getContext())
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f16Func, StringRef f32ApproxFunc="")
static llvm::ManagedStatic< PassManagerOptions > options
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
typename math::ClampFOp::Adaptor OpAdaptor
Definition Pattern.h:211
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
The main mechanism for performing data layout queries.
Conversion from types to the LLVM IR dialect.
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Include the generated interface declarations.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
const FrozenRewritePatternSet & patterns
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
LogicalResult matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14