24#include "llvm/Support/DebugLog.h"
30#define GEN_PASS_DEF_CONVERTMATHTOROCDL
31#include "mlir/Conversion/Passes.h.inc"
36#define DEBUG_TYPE "math-to-rocdl"
38template <
typename OpTy>
41 StringRef f64Func, StringRef f16Func,
42 StringRef f32ApproxFunc =
"") {
45 f32ApproxFunc, f16Func);
54 ConversionPatternRewriter &rewriter)
const override {
56 Type opTy = op.getType();
59 if (
auto vectorType = dyn_cast<VectorType>(opTy))
60 opTy = vectorType.getElementType();
62 if (!isa<Float16Type, Float32Type>(opTy))
63 return rewriter.notifyMatchFailure(
64 op,
"fmed3 only supports f16 and f32 types");
67 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
71 math::ClampFOp::Adaptor adaptor(operands);
72 return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
73 adaptor.getValue(), adaptor.getMin(),
79 rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
80 op.getMin(), op.getMax());
87 std::optional<amdgpu::Chipset> chipset) {
103 "__ocml_acos_f64",
"__ocml_acos_f16");
105 "__ocml_acosh_f64",
"__ocml_acosh_f16");
107 "__ocml_asin_f64",
"__ocml_asin_f16");
109 "__ocml_asinh_f64",
"__ocml_asinh_f16");
111 "__ocml_atan_f64",
"__ocml_atan_f16");
113 "__ocml_atanh_f64",
"__ocml_atanh_f16");
115 "__ocml_atan2_f64",
"__ocml_atan2_f16");
117 "__ocml_cbrt_f64",
"__ocml_cbrt_f16");
119 "__ocml_ceil_f64",
"__ocml_ceil_f16");
121 "__ocml_cos_f64",
"__ocml_cos_f16");
123 "__ocml_cosh_f64",
"__ocml_cosh_f16");
125 "__ocml_sinh_f64",
"__ocml_sinh_f16");
129 "__ocml_exp2_f64",
"__ocml_exp2_f16");
131 "__ocml_expm1_f64",
"__ocml_expm1_f16");
133 "__ocml_floor_f64",
"__ocml_floor_f16");
137 "__ocml_log10_f64",
"__ocml_log10_f16");
139 "__ocml_log1p_f64",
"__ocml_log1p_f16");
141 "__ocml_log2_f64",
"__ocml_log2_f16");
143 "__ocml_pow_f64",
"__ocml_pow_f16");
145 "__ocml_rsqrt_f64",
"__ocml_rsqrt_f16");
147 "__ocml_sin_f64",
"__ocml_sin_f16");
149 "__ocml_tanh_f64",
"__ocml_tanh_f16");
151 "__ocml_tan_f64",
"__ocml_tan_f16");
153 "__ocml_erf_f64",
"__ocml_erf_f16");
155 "__ocml_erfc_f64",
"__ocml_erfc_f16");
157 "__ocml_pown_f64",
"__ocml_pown_f16");
161 "__ocml_fmod_f64",
"__ocml_fmod_f16");
163 if (chipset.has_value() && chipset->majorVersion >= 9) {
166 LDBG() <<
"Chipset dependent patterns were not added";
171 : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
172 using impl::ConvertMathToROCDLBase<
179 auto m = getOperation();
186 FailureOr<amdgpu::Chipset> maybeChipset;
187 if (!chipset.empty()) {
189 if (failed(maybeChipset))
190 return signalPassFailure();
194 succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
198 .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
199 LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
200 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
201 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
202 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
204 if (failed(applyPartialConversion(m,
target, std::move(
patterns))))
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f16Func, StringRef f32ApproxFunc="")
static llvm::ManagedStatic< PassManagerOptions > options
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename math::ClampFOp::Adaptor OpAdaptor
const LLVMTypeConverter * getTypeConverter() const
The main mechanism for performing data layout queries.
Conversion from types to the LLVM IR dialect.
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Include the generated interface declarations.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
const FrozenRewritePatternSet & patterns
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.