8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
40 template <
typename SourceOp>
44 StringRef f64Func, StringRef f32ApproxFunc)
46 f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
51 using LLVM::LLVMFuncOp;
55 "expected single result op");
59 "expected op with same operand and result types");
62 for (
Value operand : adaptor.getOperands())
63 castedOperands.push_back(maybeCast(operand, rewriter));
65 Type resultType = castedOperands.front().getType();
66 Type funcType = getFunctionType(resultType, castedOperands);
68 getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
73 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
75 rewriter.
create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
77 if (resultType == adaptor.getOperands().front().getType()) {
78 rewriter.
replaceOp(op, {callOp.getResult()});
83 op->getLoc(), adaptor.getOperands().front().getType(),
92 if (!isa<Float16Type>(type))
95 return rewriter.
create<LLVM::FPExtOp>(
99 Type getFunctionType(Type resultType, ValueRange operands)
const {
100 SmallVector<Type> operandTypes(operands.getTypes());
104 StringRef getFunctionName(Type type, arith::FastMathFlags flag)
const {
105 if (isa<Float32Type>(type)) {
106 if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
107 !f32ApproxFunc.empty())
108 return f32ApproxFunc;
112 if (isa<Float64Type>(type))
117 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
118 Operation *op)
const {
119 using LLVM::LLVMFuncOp;
124 return cast<LLVMFuncOp>(*funcOp);
127 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
130 const std::string f32Func;
131 const std::string f64Func;
132 const std::string f32ApproxFunc;
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides return value APIs for ops that are known to have a single result.
This class provides verification for ops that are known to have the same operand and result type.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc depending on the...
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc)