8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
42 template <
typename SourceOp>
46 StringRef f32Func, StringRef f64Func,
47 StringRef f32ApproxFunc, StringRef f16Func)
49 f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
54 using LLVM::LLVMFuncOp;
58 "expected single result op");
62 "expected op with same operand and result types");
64 if (!op->template getParentOfType<FunctionOpInterface>()) {
66 op,
"expected op to be within a function region");
70 for (
Value operand : adaptor.getOperands())
71 castedOperands.push_back(maybeCast(operand, rewriter));
73 Type resultType = castedOperands.front().getType();
74 Type funcType = getFunctionType(resultType, castedOperands);
76 getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
81 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
83 rewriter.
create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
85 if (resultType == adaptor.getOperands().front().getType()) {
86 rewriter.
replaceOp(op, {callOp.getResult()});
91 op->getLoc(), adaptor.getOperands().front().getType(),
100 if (!isa<Float16Type, BFloat16Type>(type))
104 if (!f16Func.empty() && isa<Float16Type>(type))
107 return rewriter.
create<LLVM::FPExtOp>(
111 Type getFunctionType(Type resultType, ValueRange operands)
const {
112 SmallVector<Type> operandTypes(operands.getTypes());
116 StringRef getFunctionName(Type type, arith::FastMathFlags flag)
const {
117 if (isa<Float16Type>(type))
119 if (isa<Float32Type>(type)) {
120 if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
121 !f32ApproxFunc.empty())
122 return f32ApproxFunc;
126 if (isa<Float64Type>(type))
131 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
132 Operation *op)
const {
133 using LLVM::LLVMFuncOp;
138 return cast<LLVMFuncOp>(*funcOp);
141 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
144 const std::string f32Func;
145 const std::string f64Func;
146 const std::string f32ApproxFunc;
147 const std::string f16Func;
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides return value APIs for ops that are known to have a single result.
This class provides verification for ops that are known to have the same operand and result type.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func)
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override