MLIR 23.0.0git
MathToNVVM.cpp
Go to the documentation of this file.
1//===-- MathToNVVM.cpp - conversion from Math to CUDA libdevice calls ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "mlir/Pass/Pass.h"
18
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTMATHTONVVM
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29#define DEBUG_TYPE "math-to-nvvm"
30
31template <typename OpTy>
32static void populateOpPatterns(const LLVMTypeConverter &converter,
33 RewritePatternSet &patterns,
34 PatternBenefit benefit, StringRef f32Func,
35 StringRef f64Func, StringRef f32ApproxFunc = "",
36 StringRef f16Func = "") {
37 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
38 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
39 f32ApproxFunc, f16Func,
40 /*i32Func=*/"", benefit);
41}
42
43template <typename OpTy>
44static void populateIntOpPatterns(const LLVMTypeConverter &converter,
45 RewritePatternSet &patterns,
46 PatternBenefit benefit, StringRef i32Func) {
47 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
48 patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
49 benefit);
50}
51
52template <typename OpTy>
53static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
54 RewritePatternSet &patterns,
55 PatternBenefit benefit,
56 StringRef f32Func, StringRef f64Func) {
57 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
58 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
59 /*i32Func=*/"", benefit);
60}
61
62// Custom pattern for sincos since it returns two values
63struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
65
66 LogicalResult
67 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const override {
69 Location loc = op.getLoc();
70 Value input = adaptor.getOperand();
71 Type inputType = input.getType();
72 auto convertedInput = maybeExt(input, rewriter);
73 auto computeType = convertedInput.getType();
74
75 StringRef sincosFunc;
76 if (isa<Float32Type>(computeType)) {
77 const arith::FastMathFlags flag = op.getFastmath();
78 const bool useApprox =
79 mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
80 sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
81 } else if (isa<Float64Type>(computeType)) {
82 sincosFunc = "__nv_sincos";
83 } else {
84 return rewriter.notifyMatchFailure(op,
85 "unsupported operand type for sincos");
86 }
87
88 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
89
90 Value sinPtr, cosPtr;
91 {
92 OpBuilder::InsertionGuard guard(rewriter);
93 auto *scope =
94 op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
95 assert(scope && "Expected op to be inside automatic allocation scope");
96 rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
97 auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
98 rewriter.getI32IntegerAttr(1));
99 sinPtr =
100 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
101 cosPtr =
102 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
103 }
104
105 createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
106 op);
107
108 auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
109 auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
110
111 rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
112 maybeTrunc(cosResult, inputType, rewriter)});
113 return success();
114 }
115
116private:
117 Value maybeExt(Value operand, PatternRewriter &rewriter) const {
118 if (isa<Float16Type, BFloat16Type>(operand.getType()))
119 return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
120 Float32Type::get(rewriter.getContext()),
121 operand);
122 return operand;
123 }
124
125 Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
126 if (operand.getType() != type)
127 return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
128 return operand;
129 }
130
131 void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
132 StringRef funcName, Value input, Value sinPtr,
133 Value cosPtr, Operation *op) const {
134 auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
135 auto ptrType = sinPtr.getType();
136
137 SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
138 auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
139
140 auto funcAttr = StringAttr::get(op->getContext(), funcName);
141 auto funcOp =
143
144 if (!funcOp) {
145 auto parentFunc = op->getParentOfType<FunctionOpInterface>();
146 assert(parentFunc && "expected there to be a parent function");
147 OpBuilder b(parentFunc);
148
149 auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
150 funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
151 }
152
153 SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
154 LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
155 }
156};
157
159 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
160 PatternBenefit benefit) {
161 populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
162 "__nv_fmod");
163 populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
164 "__nv_fmaxf", "__nv_fmax");
165 populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
166 "__nv_fminf", "__nv_fmin");
167
168 populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
169 populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
170 "__nv_fabs");
171 populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
172 "__nv_acos");
173 populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
174 "__nv_acosh");
175 populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
176 "__nv_asin");
177 populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
178 "__nv_asinh");
179 populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
180 "__nv_atan");
181 populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
182 "__nv_atan2");
183 populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
184 "__nv_atanh");
185 populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
186 "__nv_cbrt");
187 populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
188 "__nv_ceil");
189 populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
190 "__nv_copysignf", "__nv_copysign");
191 populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
192 "__nv_cos", "__nv_fast_cosf");
193 populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
194 "__nv_cosh");
195 populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
196 "__nv_erf");
197 populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
198 "__nv_erfc");
199 populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
200 "__nv_exp", "__nv_fast_expf");
201 populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
202 "__nv_exp2");
203 populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
204 "__nv_expm1");
205 populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
206 "__nv_floor");
207 populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
208 "__nv_fma");
209 // Note: libdevice uses a different name for 32-bit finite checking
210 populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
211 "__nv_finitef", "__nv_isfinited");
212 populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
213 "__nv_isinfd");
214 populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
215 "__nv_isnand");
216 populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
217 "__nv_log", "__nv_fast_logf");
218 populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
219 "__nv_log10", "__nv_fast_log10f");
220 populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
221 "__nv_log1p");
222 populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
223 "__nv_log2", "__nv_fast_log2f");
224 populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
225 "__nv_pow", "__nv_fast_powf");
226 populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
227 "__nv_powif", "__nv_powi");
228 populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
229 "__nv_round");
230 populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
231 "__nv_rintf", "__nv_rint");
232 populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
233 "__nv_rsqrt");
234 populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
235 "__nv_sin", "__nv_fast_sinf");
236 populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
237 "__nv_sinh");
238 populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
239 "__nv_sqrt");
240 populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
241 "__nv_tan", "__nv_fast_tanf");
242 populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
243 "__nv_tanh");
244
245 // Custom pattern for sincos since it returns two values
246 patterns.add<SincosOpLowering>(converter, benefit);
247}
248
249namespace {
250struct ConvertMathToNVVMPass final
251 : impl::ConvertMathToNVVMBase<ConvertMathToNVVMPass> {
252 using impl::ConvertMathToNVVMBase<
253 ConvertMathToNVVMPass>::ConvertMathToNVVMBase;
254
255 void runOnOperation() override;
256};
257} // namespace
258
259void ConvertMathToNVVMPass::runOnOperation() {
260 auto m = getOperation();
261 MLIRContext *ctx = m.getContext();
262
263 RewritePatternSet patterns(&getContext());
264 LowerToLLVMOptions options(ctx, DataLayout(m));
265 LLVMTypeConverter converter(ctx, options);
266
267 populateLibDeviceConversionPatterns(converter, patterns, /*benefit=*/1);
268
269 ConversionTarget target(getContext());
270 target
271 .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
272 LLVM::LLVMDialect, NVVM::NVVMDialect>();
273 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
274 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
275 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
276 LLVM::SqrtOp>();
277 if (failed(applyPartialConversion(m, target, std::move(patterns))))
278 signalPassFailure();
279}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
Definition Builders.h:56
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename math::SincosOp::Adaptor OpAdaptor
Definition Pattern.h:229
Conversion from types to the LLVM IR dialect.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition Location.h:60
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
A trait of region holding operations that define a new scope for automatic allocations,...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to NVVM libdevice calls.
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.