MLIR 23.0.0git
MathToXeVM.cpp
Go to the documentation of this file.
1//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/Pass/Pass.h"
16#include "llvm/Support/FormatVariadic.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_CONVERTMATHTOXEVM
20#include "mlir/Conversion/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24
25#define DEBUG_TYPE "math-to-xevm"
26
27/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
28template <typename Op>
29struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
30
32 PatternBenefit benefit = 1)
33 : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
34
35 LogicalResult
36 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 if (!isSPIRVCompatibleFloatOrVec(op.getType()))
39 return failure();
40
41 arith::FastMathFlags fastFlags = op.getFastmath();
42 if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
43 return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
44
45 SmallVector<Type, 1> operandTypes;
46 for (auto operand : adaptor.getOperands()) {
47 Type opTy = operand.getType();
48 // This pass only supports operations on vectors that are already in SPIRV
49 // supported vector sizes: Distributing unsupported vector sizes to SPIRV
50 // supported vector sizes are done in other blocking optimization passes.
52 return rewriter.notifyMatchFailure(
53 op, llvm::formatv("incompatible operand type: '{0}'", opTy));
54 operandTypes.push_back(opTy);
55 }
56
57 auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
58 auto funcOpRes = LLVM::lookupOrCreateFn(
59 rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
60 operandTypes, op.getType());
61 assert(!failed(funcOpRes));
62 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
63
64 auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
65 op, funcOp, adaptor.getOperands());
66 // Preserve fastmath flags in our MLIR op when converting to llvm function
67 // calls, in order to allow further fastmath optimizations: We thus need to
68 // convert arith fastmath attrs into attrs recognized by llvm.
70 mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
71 callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
72 return success();
73 }
74
75 inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
76 if (type.isFloat())
77 return true;
78 if (auto vecType = dyn_cast<VectorType>(type)) {
79 if (!vecType.getElementType().isFloat())
80 return false;
81 // SPIRV distinguishes between vectors and matrices: OpenCL native math
82 // intrsinics are not compatible with matrices.
83 ArrayRef<int64_t> shape = vecType.getShape();
84 if (shape.size() != 1)
85 return false;
86 // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
87 if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
88 shape[0] == 16)
89 return true;
90 }
91 return false;
92 }
93
94 inline std::string
95 getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
96 std::string mangledFuncName =
97 "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
98
99 auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
100 if (type.isF32())
101 mangledFuncName += "f";
102 else if (type.isF16())
103 mangledFuncName += "Dh";
104 else if (type.isF64())
105 mangledFuncName += "d";
106 };
107
108 for (auto type : operandTypes) {
109 if (auto vecType = dyn_cast<VectorType>(type)) {
110 mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
111 appendFloatToMangledFunc(vecType.getElementType());
112 } else
113 appendFloatToMangledFunc(type);
114 }
115
116 return mangledFuncName;
117 }
118
119 const StringRef nativeFunc;
120};
121
123 bool convertArith) {
125 "__spirv_ocl_native_exp");
127 "__spirv_ocl_native_cos");
129 patterns.getContext(), "__spirv_ocl_native_exp2");
131 "__spirv_ocl_native_log");
133 patterns.getContext(), "__spirv_ocl_native_log2");
135 patterns.getContext(), "__spirv_ocl_native_log10");
137 patterns.getContext(), "__spirv_ocl_native_powr");
139 patterns.getContext(), "__spirv_ocl_native_rsqrt");
141 "__spirv_ocl_native_sin");
143 patterns.getContext(), "__spirv_ocl_native_sqrt");
145 "__spirv_ocl_native_tan");
146 if (convertArith)
148 patterns.getContext(), "__spirv_ocl_native_divide");
149}
150
151namespace {
152struct ConvertMathToXeVMPass
153 : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
154 using Base::Base;
155 void runOnOperation() override;
156};
157} // namespace
158
159void ConvertMathToXeVMPass::runOnOperation() {
160 RewritePatternSet patterns(&getContext());
161 populateMathToXeVMConversionPatterns(patterns, convertArith);
162 ConversionTarget target(getContext());
163 target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
164 if (failed(
165 applyPartialConversion(getOperation(), target, std::move(patterns))))
166 signalPassFailure();
167}
return success()
b getContext())
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This provides public APIs that all operations should have.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:47
ArrayRef< NamedAttribute > getAttrs() const
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith)
Populate the given list with patterns that convert from Math to XeVM calls.
Convert math ops marked with fast (afn) to native OpenCL intrinsics.
const StringRef nativeFunc
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit=1)
std::string getMangledNativeFuncName(const ArrayRef< Type > operandTypes) const
bool isSPIRVCompatibleFloatOrVec(Type type) const
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override