16 #include "llvm/Support/FormatVariadic.h"
19 #define GEN_PASS_DEF_CONVERTMATHTOXEVM
20 #include "mlir/Conversion/Passes.h.inc"
25 #define DEBUG_TYPE "math-to-xevm"
28 template <
typename Op>
38 if (!isSPIRVCompatibleFloatOrVec(op.getType()))
41 arith::FastMathFlags fastFlags = op.getFastmath();
42 if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
46 for (
auto operand : adaptor.getOperands()) {
47 Type opTy = operand.getType();
51 if (!isSPIRVCompatibleFloatOrVec(opTy))
53 op, llvm::formatv(
"incompatible operand type: '{0}'", opTy));
54 operandTypes.push_back(opTy);
57 auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
59 rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
60 operandTypes, op.getType());
61 assert(!
failed(funcOpRes));
62 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
65 op, funcOp, adaptor.getOperands());
69 arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
78 if (
auto vecType = dyn_cast<VectorType>(type)) {
79 if (!vecType.getElementType().isFloat())
84 if (shape.size() != 1)
87 if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
96 std::string mangledFuncName =
97 "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
99 auto appendFloatToMangledFunc = [&mangledFuncName](
Type type) {
101 mangledFuncName +=
"f";
102 else if (type.isF16())
103 mangledFuncName +=
"Dh";
104 else if (type.isF64())
105 mangledFuncName +=
"d";
108 for (
auto type : operandTypes) {
109 if (
auto vecType = dyn_cast<VectorType>(type)) {
110 mangledFuncName +=
"Dv" + std::to_string(vecType.getShape()[0]) +
"_";
111 appendFloatToMangledFunc(vecType.getElementType());
113 appendFloatToMangledFunc(type);
116 return mangledFuncName;
125 "__spirv_ocl_native_exp");
127 "__spirv_ocl_native_cos");
129 patterns.getContext(),
"__spirv_ocl_native_exp2");
131 "__spirv_ocl_native_log");
133 patterns.getContext(),
"__spirv_ocl_native_log2");
135 patterns.getContext(),
"__spirv_ocl_native_log10");
137 patterns.getContext(),
"__spirv_ocl_native_powr");
139 patterns.getContext(),
"__spirv_ocl_native_rsqrt");
141 "__spirv_ocl_native_sin");
143 patterns.getContext(),
"__spirv_ocl_native_sqrt");
145 "__spirv_ocl_native_tan");
148 patterns.getContext(),
"__spirv_ocl_native_divide");
152 struct ConvertMathToXeVMPass
153 :
public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
155 void runOnOperation()
override;
159 void ConvertMathToXeVMPass::runOnOperation() {
163 target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This provides public APIs that all operations should have.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes)and namename`.
Include the generated interface declarations.
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith)
Populate the given list with patterns that convert from Math to XeVM calls.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Convert math ops marked with fast (afn) to native OpenCL intrinsics.
const StringRef nativeFunc
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit=1)
std::string getMangledNativeFuncName(const ArrayRef< Type > operandTypes) const
bool isSPIRVCompatibleFloatOrVec(Type type) const
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override