8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
22 using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
54 template <
typename SourceOp>
69 using LLVM::LLVMFuncOp;
73 "expected single result op");
75 bool isResultBool = op->getResultTypes().front().isInteger(1);
78 assert(op->getNumOperands() > 0 &&
79 "expected op to take at least one operand");
80 assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
82 "expected op with same operand and result types");
85 if (!op->template getParentOfType<FunctionOpInterface>()) {
87 op,
"expected op to be within a function region");
91 for (
Value operand : adaptor.getOperands())
92 castedOperands.push_back(
maybeCast(operand, rewriter));
94 Type castedOperandType = castedOperands.front().getType();
101 if (funcName.empty())
106 rewriter.
create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
108 if (resultType == adaptor.getOperands().front().getType()) {
109 rewriter.
replaceOp(op, {callOp.getResult()});
122 op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
127 assert(callOp.getResult().getType().isF32() &&
128 "only f32 types are supposed to be truncated back");
129 Value truncated = rewriter.
create<LLVM::FPTruncOp>(
130 op->getLoc(), adaptor.getOperands().front().getType(),
138 if (!isa<Float16Type, BFloat16Type>(type))
142 if (!
f16Func.empty() && isa<Float16Type>(type))
145 return rewriter.
create<LLVM::FPExtOp>(
156 using LLVM::LLVMFuncOp;
160 SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
165 assert(parentFunc &&
"expected there to be a parent function");
167 return b.
create<LLVMFuncOp>(op->
getLoc(), funcName, funcType);
171 bool useApprox =
false;
172 if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
173 arith::FastMathFlags flag = op.getFastmath();
174 useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
178 if (isa<Float16Type>(type))
180 if (isa<Float32Type>(type)) {
185 if (isa<Float64Type>(type))
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides return value APIs for ops that are known to have a single result.
This class provides verification for ops that are known to have the same operand and result type.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
const std::string f64Func
const std::string f32ApproxFunc
StringRef getFunctionName(Type type, SourceOp op) const
const std::string f32Func
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const
const std::string f16Func
const std::string i32Func
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1)
Type getFunctionType(Type resultType, ValueRange operands) const
Value maybeCast(Value operand, PatternRewriter &rewriter) const