8#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
22using has_get_fastmath_t =
decltype(std::declval<T>().getFastmath());
54template <
typename SourceOp>
68 ConversionPatternRewriter &rewriter)
const override {
69 using LLVM::LLVMFuncOp;
72 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
73 "expected single result op");
75 bool isResultBool = op->getResultTypes().front().isInteger(1);
76 if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
78 assert(op->getNumOperands() > 0 &&
79 "expected op to take at least one operand");
80 assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
82 "expected op with same operand and result types");
85 if (!op->template getParentOfType<FunctionOpInterface>()) {
86 return rewriter.notifyMatchFailure(
87 op,
"expected op to be within a function region");
91 for (
Value operand : adaptor.getOperands())
92 castedOperands.push_back(
maybeCast(operand, rewriter));
94 Type castedOperandType = castedOperands.front().getType();
98 isResultBool ? rewriter.getIntegerType(32) : castedOperandType;
101 if (funcName.empty())
106 LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
108 if (resultType == adaptor.getOperands().front().getType()) {
109 rewriter.replaceOp(op, {callOp.getResult()});
118 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
119 rewriter.getIntegerType(32),
120 rewriter.getI32IntegerAttr(0));
122 LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
123 callOp.getResult(), zero);
124 rewriter.replaceOp(op, {truncated});
128 assert(callOp.getResult().getType().isF32() &&
129 "only f32 types are supposed to be truncated back");
130 Value truncated = LLVM::FPTruncOp::create(
131 rewriter, op->getLoc(), adaptor.getOperands().front().getType(),
133 rewriter.replaceOp(op, {truncated});
139 if (!isa<Float16Type, BFloat16Type>(type))
143 if (!
f16Func.empty() && isa<Float16Type>(type))
146 return LLVM::FPExtOp::create(rewriter, operand.
getLoc(),
153 return LLVM::LLVMFunctionType::get(resultType, operandTypes);
158 using LLVM::LLVMFuncOp;
160 auto funcAttr = StringAttr::get(op->
getContext(), funcName);
167 assert(parentFunc &&
"expected there to be a parent function");
174 return LLVMFuncOp::create(
b, globalloc, funcName, funcType);
178 bool useApprox =
false;
179 if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
180 arith::FastMathFlags flag = op.getFastmath();
181 useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
185 if (isa<Float16Type>(type))
187 if (isa<Float32Type>(type)) {
192 if (isa<Float64Type>(type))
MLIRContext * getContext() const
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
An instance of this location represents a tuple of file, line number, and column number.
Conversion from types to the LLVM IR dialect.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
const std::string f64Func
const std::string f32ApproxFunc
StringRef getFunctionName(Type type, SourceOp op) const
const std::string f32Func
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const
const std::string f16Func
const std::string i32Func
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1)
Type getFunctionType(Type resultType, ValueRange operands) const
Value maybeCast(Value operand, PatternRewriter &rewriter) const