MLIR  22.0.0git
MathToXeVM.cpp
Go to the documentation of this file.
1 //===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/Support/FormatVariadic.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_CONVERTMATHTOXEVM
20 #include "mlir/Conversion/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "math-to-xevm"
26 
27 /// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
28 template <typename Op>
29 struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
30 
31  ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
32  PatternBenefit benefit = 1)
33  : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
34 
35  LogicalResult
36  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
37  ConversionPatternRewriter &rewriter) const override {
38  if (!isSPIRVCompatibleFloatOrVec(op.getType()))
39  return failure();
40 
41  arith::FastMathFlags fastFlags = op.getFastmath();
42  if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
43  return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
44 
45  SmallVector<Type, 1> operandTypes;
46  for (auto operand : adaptor.getOperands()) {
47  Type opTy = operand.getType();
48  // This pass only supports operations on vectors that are already in SPIRV
49  // supported vector sizes: Distributing unsupported vector sizes to SPIRV
50  // supported vector sizes are done in other blocking optimization passes.
51  if (!isSPIRVCompatibleFloatOrVec(opTy))
52  return rewriter.notifyMatchFailure(
53  op, llvm::formatv("incompatible operand type: '{0}'", opTy));
54  operandTypes.push_back(opTy);
55  }
56 
57  auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
58  auto funcOpRes = LLVM::lookupOrCreateFn(
59  rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
60  operandTypes, op.getType());
61  assert(!failed(funcOpRes));
62  LLVM::LLVMFuncOp funcOp = funcOpRes.value();
63 
64  auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
65  op, funcOp, adaptor.getOperands());
66  // Preserve fastmath flags in our MLIR op when converting to llvm function
67  // calls, in order to allow further fastmath optimizations: We thus need to
68  // convert arith fastmath attrs into attrs recognized by llvm.
69  arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
70  mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
71  callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
72  return success();
73  }
74 
75  inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
76  if (type.isFloat())
77  return true;
78  if (auto vecType = dyn_cast<VectorType>(type)) {
79  if (!vecType.getElementType().isFloat())
80  return false;
81  // SPIRV distinguishes between vectors and matrices: OpenCL native math
82  // intrsinics are not compatible with matrices.
83  ArrayRef<int64_t> shape = vecType.getShape();
84  if (shape.size() != 1)
85  return false;
86  // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
87  if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
88  shape[0] == 16)
89  return true;
90  }
91  return false;
92  }
93 
94  inline std::string
95  getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
96  std::string mangledFuncName =
97  "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
98 
99  auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
100  if (type.isF32())
101  mangledFuncName += "f";
102  else if (type.isF16())
103  mangledFuncName += "Dh";
104  else if (type.isF64())
105  mangledFuncName += "d";
106  };
107 
108  for (auto type : operandTypes) {
109  if (auto vecType = dyn_cast<VectorType>(type)) {
110  mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
111  appendFloatToMangledFunc(vecType.getElementType());
112  } else
113  appendFloatToMangledFunc(type);
114  }
115 
116  return mangledFuncName;
117  }
118 
119  const StringRef nativeFunc;
120 };
121 
123  bool convertArith) {
125  "__spirv_ocl_native_exp");
127  "__spirv_ocl_native_cos");
129  patterns.getContext(), "__spirv_ocl_native_exp2");
131  "__spirv_ocl_native_log");
133  patterns.getContext(), "__spirv_ocl_native_log2");
135  patterns.getContext(), "__spirv_ocl_native_log10");
137  patterns.getContext(), "__spirv_ocl_native_powr");
139  patterns.getContext(), "__spirv_ocl_native_rsqrt");
141  "__spirv_ocl_native_sin");
143  patterns.getContext(), "__spirv_ocl_native_sqrt");
145  "__spirv_ocl_native_tan");
146  if (convertArith)
148  patterns.getContext(), "__spirv_ocl_native_divide");
149 }
150 
151 namespace {
152 struct ConvertMathToXeVMPass
153  : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
154  using Base::Base;
155  void runOnOperation() override;
156 };
157 } // namespace
158 
159 void ConvertMathToXeVMPass::runOnOperation() {
162  ConversionTarget target(getContext());
163  target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
164  if (failed(
165  applyPartialConversion(getOperation(), target, std::move(patterns))))
166  signalPassFailure();
167 }
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This provides public APIs that all operations should have.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes)and namename`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith)
Populate the given list with patterns that convert from Math to XeVM calls.
Definition: MathToXeVM.cpp:122
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Convert math ops marked with fast (afn) to native OpenCL intrinsics.
Definition: MathToXeVM.cpp:29
const StringRef nativeFunc
Definition: MathToXeVM.cpp:119
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit=1)
Definition: MathToXeVM.cpp:31
std::string getMangledNativeFuncName(const ArrayRef< Type > operandTypes) const
Definition: MathToXeVM.cpp:95
bool isSPIRVCompatibleFloatOrVec(Type type) const
Definition: MathToXeVM.cpp:75
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Definition: MathToXeVM.cpp:36