MLIR 23.0.0git
MathToXeVM.cpp
Go to the documentation of this file.
1//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/Pass/Pass.h"
17#include "llvm/Support/FormatVariadic.h"
18
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTMATHTOXEVM
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29#define DEBUG_TYPE "math-to-xevm"
30
31/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
32template <typename Op>
33struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
34
36 PatternBenefit benefit = 1)
37 : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
38
39 LogicalResult
40 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
41 ConversionPatternRewriter &rewriter) const override {
42 if (!isSPIRVCompatibleFloatOrVec(op.getType()))
43 return failure();
44
45 arith::FastMathFlags fastFlags = op.getFastmath();
46 if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
47 return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
48
49 SmallVector<Type, 1> operandTypes;
50 for (auto operand : adaptor.getOperands()) {
51 Type opTy = operand.getType();
52 // This pass only supports operations on vectors that are already in SPIRV
53 // supported vector sizes: Distributing unsupported vector sizes to SPIRV
54 // supported vector sizes are done in other blocking optimization passes.
56 return rewriter.notifyMatchFailure(
57 op, llvm::formatv("incompatible operand type: '{0}'", opTy));
58 operandTypes.push_back(opTy);
59 }
60
61 auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
62 auto funcOpRes = LLVM::lookupOrCreateFn(
63 rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
64 operandTypes, op.getType());
65 assert(!failed(funcOpRes));
66 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
67
68 auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
69 op, funcOp, adaptor.getOperands());
70 // Preserve fastmath flags in our MLIR op when converting to llvm function
71 // calls, in order to allow further fastmath optimizations: We thus need to
72 // convert arith fastmath attrs into attrs recognized by llvm.
74 mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
75 callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
76 return success();
77 }
78
79 inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
80 if (type.isFloat())
81 return true;
82 if (auto vecType = dyn_cast<VectorType>(type)) {
83 if (!vecType.getElementType().isFloat())
84 return false;
85 // SPIRV distinguishes between vectors and matrices: OpenCL native math
86 // intrsinics are not compatible with matrices.
87 ArrayRef<int64_t> shape = vecType.getShape();
88 if (shape.size() != 1)
89 return false;
90 // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
91 if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
92 shape[0] == 16)
93 return true;
94 }
95 return false;
96 }
97
98 inline std::string
99 getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
100 std::string mangledFuncName =
101 "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
102
103 auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
104 if (type.isF32())
105 mangledFuncName += "f";
106 else if (type.isF16())
107 mangledFuncName += "Dh";
108 else if (type.isF64())
109 mangledFuncName += "d";
110 };
111
112 for (auto type : operandTypes) {
113 if (auto vecType = dyn_cast<VectorType>(type)) {
114 mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
115 appendFloatToMangledFunc(vecType.getElementType());
116 } else
117 appendFloatToMangledFunc(type);
118 }
119
120 return mangledFuncName;
121 }
122
123 const StringRef nativeFunc;
124};
125
126template <typename OpTy>
128 RewritePatternSet &patterns,
129 PatternBenefit benefit,
130 StringRef opName) {
131 std::string prefix = "__spirv_ocl_";
132 std::string mangledName = "_Z" +
133 std::to_string(prefix.size() + opName.size()) +
134 prefix + opName.str();
135
136 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
138 converter, mangledName + "f", mangledName + "d",
139 /*f32ApproxFunc=*/"", /*f16Func=*/"",
140 /*i32Func=*/"", benefit, LLVM::cconv::CConv::SPIR_FUNC);
141}
142
144 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
145 PatternBenefit benefit) {
146 populateOCLExtSetOpPatterns<math::AcosOp>(converter, patterns, benefit,
147 "acos");
148 populateOCLExtSetOpPatterns<math::AcoshOp>(converter, patterns, benefit,
149 "acosh");
150 populateOCLExtSetOpPatterns<math::AsinOp>(converter, patterns, benefit,
151 "asin");
152 populateOCLExtSetOpPatterns<math::AsinhOp>(converter, patterns, benefit,
153 "asinh");
154 populateOCLExtSetOpPatterns<math::AtanOp>(converter, patterns, benefit,
155 "atan");
156 populateOCLExtSetOpPatterns<math::Atan2Op>(converter, patterns, benefit,
157 "atan2");
158 populateOCLExtSetOpPatterns<math::AtanhOp>(converter, patterns, benefit,
159 "atanh");
160 populateOCLExtSetOpPatterns<math::CbrtOp>(converter, patterns, benefit,
161 "cbrt");
162 populateOCLExtSetOpPatterns<math::CopySignOp>(converter, patterns, benefit,
163 "copysign");
164 populateOCLExtSetOpPatterns<math::CosOp>(converter, patterns, benefit, "cos");
165 populateOCLExtSetOpPatterns<math::CoshOp>(converter, patterns, benefit,
166 "cosh");
167 populateOCLExtSetOpPatterns<math::ErfOp>(converter, patterns, benefit, "erf");
168 populateOCLExtSetOpPatterns<math::ErfcOp>(converter, patterns, benefit,
169 "erfc");
170 populateOCLExtSetOpPatterns<math::ExpOp>(converter, patterns, benefit, "exp");
171 populateOCLExtSetOpPatterns<math::Exp2Op>(converter, patterns, benefit,
172 "exp2");
173 populateOCLExtSetOpPatterns<math::ExpM1Op>(converter, patterns, benefit,
174 "expm1");
175 populateOCLExtSetOpPatterns<math::LogOp>(converter, patterns, benefit, "log");
176 populateOCLExtSetOpPatterns<math::Log10Op>(converter, patterns, benefit,
177 "log10");
178 populateOCLExtSetOpPatterns<math::Log1pOp>(converter, patterns, benefit,
179 "log1p");
180 populateOCLExtSetOpPatterns<math::Log2Op>(converter, patterns, benefit,
181 "log2");
182 populateOCLExtSetOpPatterns<math::PowFOp>(converter, patterns, benefit,
183 "pow");
184 populateOCLExtSetOpPatterns<math::RsqrtOp>(converter, patterns, benefit,
185 "rsqrt");
186 populateOCLExtSetOpPatterns<math::SinOp>(converter, patterns, benefit, "sin");
187 populateOCLExtSetOpPatterns<math::SinhOp>(converter, patterns, benefit,
188 "sinh");
189 populateOCLExtSetOpPatterns<math::SqrtOp>(converter, patterns, benefit,
190 "sqrt");
191 populateOCLExtSetOpPatterns<math::TanOp>(converter, patterns, benefit, "tan");
192 populateOCLExtSetOpPatterns<math::TanhOp>(converter, patterns, benefit,
193 "tanh");
194}
195
197 bool convertArith,
198 PatternBenefit benefit) {
200 patterns.getContext(), "__spirv_ocl_native_exp", benefit);
202 patterns.getContext(), "__spirv_ocl_native_cos", benefit);
204 patterns.getContext(), "__spirv_ocl_native_exp2", benefit);
206 patterns.getContext(), "__spirv_ocl_native_log", benefit);
208 patterns.getContext(), "__spirv_ocl_native_log2", benefit);
210 patterns.getContext(), "__spirv_ocl_native_log10", benefit);
212 patterns.getContext(), "__spirv_ocl_native_powr", benefit);
214 patterns.getContext(), "__spirv_ocl_native_rsqrt", benefit);
216 patterns.getContext(), "__spirv_ocl_native_sin", benefit);
218 patterns.getContext(), "__spirv_ocl_native_sqrt", benefit);
220 patterns.getContext(), "__spirv_ocl_native_tan", benefit);
221 if (convertArith)
223 patterns.getContext(), "__spirv_ocl_native_divide", benefit);
224}
225
226namespace {
227struct ConvertMathToXeVMPass
228 : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
229 using Base::Base;
230 void runOnOperation() override;
231};
232} // namespace
233
234void ConvertMathToXeVMPass::runOnOperation() {
235 Operation *op = getOperation();
236 MLIRContext *ctx = op->getContext();
237
238 const auto &dl = getAnalysis<DataLayoutAnalysis>();
239
240 RewritePatternSet patterns(&getContext());
241 LowerToLLVMOptions options(ctx, dl.getAtOrAbove(op));
242 LLVMTypeConverter converter(ctx, options);
243 ConversionTarget target(getContext());
244
245 // Native OCL patterns should take precedence for `fast` ops even when
246 // convertToOCL is set.
247 populateMathToXeVMConversionPatterns(patterns, convertArith,
248 convertToOCL + 1);
249 if (convertToOCL) {
251 target
252 .addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::LogOp,
253 LLVM::Log10Op, LLVM::Log2Op, LLVM::SinOp, LLVM::SqrtOp>();
254 }
255 target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
256 if (failed(
257 applyPartialConversion(getOperation(), target, std::move(patterns))))
258 signalPassFailure();
259}
return success()
b getContext())
static void populateOCLExtSetOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef opName)
static llvm::ManagedStatic< PassManagerOptions > options
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This provides public APIs that all operations should have.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:47
ArrayRef< NamedAttribute > getAttrs() const
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateMathToScalarOCLExtSetConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to OCL LLVM-SPV builtin calls.
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to XeVM calls.
Convert math ops marked with fast (afn) to native OpenCL intrinsics.
const StringRef nativeFunc
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit=1)
std::string getMangledNativeFuncName(const ArrayRef< Type > operandTypes) const
bool isSPIRVCompatibleFloatOrVec(Type type) const
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.