17#include "llvm/Support/FormatVariadic.h"
23#define GEN_PASS_DEF_CONVERTMATHTOXEVM
24#include "mlir/Conversion/Passes.h.inc"
29#define DEBUG_TYPE "math-to-xevm"
41 ConversionPatternRewriter &rewriter)
const override {
45 arith::FastMathFlags fastFlags = op.getFastmath();
46 if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
47 return rewriter.notifyMatchFailure(op,
"not a fastmath `afn` operation");
50 for (
auto operand : adaptor.getOperands()) {
51 Type opTy = operand.getType();
56 return rewriter.notifyMatchFailure(
57 op, llvm::formatv(
"incompatible operand type: '{0}'", opTy));
58 operandTypes.push_back(opTy);
61 auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
64 operandTypes, op.getType());
65 assert(!failed(funcOpRes));
66 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
68 auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
69 op, funcOp, adaptor.getOperands());
82 if (
auto vecType = dyn_cast<VectorType>(type)) {
83 if (!vecType.getElementType().isFloat())
88 if (
shape.size() != 1)
100 std::string mangledFuncName =
103 auto appendFloatToMangledFunc = [&mangledFuncName](
Type type) {
105 mangledFuncName +=
"f";
106 else if (type.isF16())
107 mangledFuncName +=
"Dh";
108 else if (type.isF64())
109 mangledFuncName +=
"d";
112 for (
auto type : operandTypes) {
113 if (
auto vecType = dyn_cast<VectorType>(type)) {
114 mangledFuncName +=
"Dv" + std::to_string(vecType.getShape()[0]) +
"_";
115 appendFloatToMangledFunc(vecType.getElementType());
117 appendFloatToMangledFunc(type);
120 return mangledFuncName;
126template <
typename OpTy>
131 std::string prefix =
"__spirv_ocl_";
132 std::string mangledName =
"_Z" +
133 std::to_string(prefix.size() + opName.size()) +
134 prefix + opName.str();
138 converter, mangledName +
"f", mangledName +
"d",
140 "", benefit, LLVM::cconv::CConv::SPIR_FUNC);
200 patterns.
getContext(),
"__spirv_ocl_native_exp", benefit);
202 patterns.
getContext(),
"__spirv_ocl_native_cos", benefit);
204 patterns.
getContext(),
"__spirv_ocl_native_exp2", benefit);
206 patterns.
getContext(),
"__spirv_ocl_native_log", benefit);
208 patterns.
getContext(),
"__spirv_ocl_native_log2", benefit);
210 patterns.
getContext(),
"__spirv_ocl_native_log10", benefit);
212 patterns.
getContext(),
"__spirv_ocl_native_powr", benefit);
214 patterns.
getContext(),
"__spirv_ocl_native_rsqrt", benefit);
216 patterns.
getContext(),
"__spirv_ocl_native_sin", benefit);
218 patterns.
getContext(),
"__spirv_ocl_native_sqrt", benefit);
220 patterns.
getContext(),
"__spirv_ocl_native_tan", benefit);
223 patterns.
getContext(),
"__spirv_ocl_native_divide", benefit);
227struct ConvertMathToXeVMPass
228 :
public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
230 void runOnOperation()
override;
234void ConvertMathToXeVMPass::runOnOperation() {
235 Operation *op = getOperation();
238 const auto &dl = getAnalysis<DataLayoutAnalysis>();
241 LowerToLLVMOptions
options(ctx, dl.getAtOrAbove(op));
242 LLVMTypeConverter converter(ctx,
options);
252 .addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::LogOp,
253 LLVM::Log10Op, LLVM::Log2Op, LLVM::SinOp, LLVM::SqrtOp>();
255 target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
257 applyPartialConversion(getOperation(),
target, std::move(patterns))))
static void populateOCLExtSetOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef opName)
static llvm::ManagedStatic< PassManagerOptions > options
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This provides public APIs that all operations should have.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
ArrayRef< NamedAttribute > getAttrs() const
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
void populateMathToScalarOCLExtSetConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to OCL LLVM-SPV builtin calls.
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to XeVM calls.
Convert math ops marked with fast (afn) to native OpenCL intrinsics.
const StringRef nativeFunc
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit=1)
std::string getMangledNativeFuncName(const ArrayRef< Type > operandTypes) const
bool isSPIRVCompatibleFloatOrVec(Type type) const
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.