8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
22 using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
54 template <
typename SourceOp>
69 using LLVM::LLVMFuncOp;
73 "expected single result op");
75 bool isResultBool = op->getResultTypes().front().isInteger(1);
78 assert(op->getNumOperands() > 0 &&
79 "expected op to take at least one operand");
80 assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
82 "expected op with same operand and result types");
85 if (!op->template getParentOfType<FunctionOpInterface>()) {
87 op,
"expected op to be within a function region");
91 for (
Value operand : adaptor.getOperands())
92 castedOperands.push_back(
maybeCast(operand, rewriter));
94 Type castedOperandType = castedOperands.front().getType();
101 if (funcName.empty())
106 LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
108 if (resultType == adaptor.getOperands().front().getType()) {
109 rewriter.
replaceOp(op, {callOp.getResult()});
118 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
122 LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
123 callOp.getResult(), zero);
128 assert(callOp.getResult().getType().isF32() &&
129 "only f32 types are supposed to be truncated back");
130 Value truncated = LLVM::FPTruncOp::create(
131 rewriter, op->getLoc(), adaptor.getOperands().front().
getType(),
139 if (!isa<Float16Type, BFloat16Type>(type))
143 if (!
f16Func.empty() && isa<Float16Type>(type))
146 return LLVM::FPExtOp::create(rewriter, operand.
getLoc(),
158 using LLVM::LLVMFuncOp;
162 SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
167 assert(parentFunc &&
"expected there to be a parent function");
174 return LLVMFuncOp::create(b, globalloc, funcName, funcType);
178 bool useApprox =
false;
179 if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
180 arith::FastMathFlags flag = op.getFastmath();
181 useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
185 if (isa<Float16Type>(type))
187 if (isa<Float32Type>(type)) {
192 if (isa<Float64Type>(type))
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
An instance of this location represents a tuple of file, line number, and column number.
Conversion from types to the LLVM IR dialect.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
This class helps build Operations.
This class provides return value APIs for ops that are known to have a single result.
This class provides verification for ops that are known to have the same operand and result type.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
const std::string f64Func
const std::string f32ApproxFunc
StringRef getFunctionName(Type type, SourceOp op) const
const std::string f32Func
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const
const std::string f16Func
const std::string i32Func
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1)
Type getFunctionType(Type resultType, ValueRange operands) const
Value maybeCast(Value operand, PatternRewriter &rewriter) const