MLIR 22.0.0git
MathToXeVM.cpp
Go to the documentation of this file.
1//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/Pass/Pass.h"
16#include "llvm/Support/FormatVariadic.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_CONVERTMATHTOXEVM
20#include "mlir/Conversion/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24
25#define DEBUG_TYPE "math-to-xevm"
26
27/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
28template <typename Op>
29struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
30
32 PatternBenefit benefit = 1)
33 : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
34
35 LogicalResult
36 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 if (!isSPIRVCompatibleFloatOrVec(op.getType()))
39 return failure();
40
41 arith::FastMathFlags fastFlags = op.getFastmath();
42 if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
43 return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
44
45 SmallVector<Type, 1> operandTypes;
46 for (auto operand : adaptor.getOperands()) {
47 Type opTy = operand.getType();
48 // This pass only supports operations on vectors that are already in SPIRV
49 // supported vector sizes: Distributing unsupported vector sizes to SPIRV
50 // supported vector sizes are done in other blocking optimization passes.
52 return rewriter.notifyMatchFailure(
53 op, llvm::formatv("incompatible operand type: '{0}'", opTy));
54 operandTypes.push_back(opTy);
55 }
56
57 auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
58 auto funcOpRes = LLVM::lookupOrCreateFn(
59 rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
60 operandTypes, op.getType());
61 assert(!failed(funcOpRes));
62 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
63
64 auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
65 op, funcOp, adaptor.getOperands());
66 // Preserve fastmath flags in our MLIR op when converting to llvm function
67 // calls, in order to allow further fastmath optimizations: We thus need to
68 // convert arith fastmath attrs into attrs recognized by llvm.
70 mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
71 callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
72 return success();
73 }
74
75 inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
76 if (type.isFloat())
77 return true;
78 if (auto vecType = dyn_cast<VectorType>(type)) {
79 if (!vecType.getElementType().isFloat())
80 return false;
81 // SPIRV distinguishes between vectors and matrices: OpenCL native math
82 // intrsinics are not compatible with matrices.
83 ArrayRef<int64_t> shape = vecType.getShape();
84 if (shape.size() != 1)
85 return false;
86 // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
87 if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
88 shape[0] == 16)
89 return true;
90 }
91 return false;
92 }
93
94 inline std::string
95 getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
96 std::string mangledFuncName =
97 "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
98
99 auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
100 if (type.isF32())
101 mangledFuncName += "f";
102 else if (type.isF16())
103 mangledFuncName += "Dh";
104 else if (type.isF64())
105 mangledFuncName += "d";
106 };
107
108 for (auto type : operandTypes) {
109 if (auto vecType = dyn_cast<VectorType>(type)) {
110 mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
111 appendFloatToMangledFunc(vecType.getElementType());
112 } else
113 appendFloatToMangledFunc(type);
114 }
115
116 return mangledFuncName;
117 }
118
119 const StringRef nativeFunc;
120};
121
123 bool convertArith) {
125 "__spirv_ocl_native_exp");
127 "__spirv_ocl_native_cos");
129 patterns.getContext(), "__spirv_ocl_native_exp2");
131 "__spirv_ocl_native_log");
133 patterns.getContext(), "__spirv_ocl_native_log2");
135 patterns.getContext(), "__spirv_ocl_native_log10");
137 patterns.getContext(), "__spirv_ocl_native_powr");
139 patterns.getContext(), "__spirv_ocl_native_rsqrt");
141 "__spirv_ocl_native_sin");
143 patterns.getContext(), "__spirv_ocl_native_sqrt");
145 "__spirv_ocl_native_tan");
146 if (convertArith)
148 patterns.getContext(), "__spirv_ocl_native_divide");
149}
150
151namespace {
152struct ConvertMathToXeVMPass
153 : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
154 using Base::Base;
155 void runOnOperation() override;
156};
157} // namespace
158
159void ConvertMathToXeVMPass::runOnOperation() {
160 RewritePatternSet patterns(&getContext());
162 ConversionTarget target(getContext());
163 target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
164 if (failed(
165 applyPartialConversion(getOperation(), target, std::move(patterns))))
166 signalPassFailure();
167}
return success()
b getContext())
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This provides public APIs that all operations should have.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:45
ArrayRef< NamedAttribute > getAttrs() const
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith)
Populate the given list with patterns that convert from Math to XeVM calls.
const FrozenRewritePatternSet & patterns
Convert math ops marked with fast (afn) to native OpenCL intrinsics.
const StringRef nativeFunc
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit=1)
std::string getMangledNativeFuncName(const ArrayRef< Type > operandTypes) const
bool isSPIRVCompatibleFloatOrVec(Type type) const
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override