|
MLIR 22.0.0git
|
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i32Type depending on the element type and the fastMathFlag of that Op, if present. More...
#include "Conversion/GPUCommon/OpToFuncCallLowering.h"
Public Member Functions | |
| OpToFuncCallLowering (const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1) | |
| LogicalResult | matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override |
| Methods that operate on the SourceOp type. | |
| Value | maybeCast (Value operand, PatternRewriter &rewriter) const |
| Type | getFunctionType (Type resultType, ValueRange operands) const |
| LLVM::LLVMFuncOp | appendOrGetFuncOp (StringRef funcName, Type funcType, Operation *op) const |
| StringRef | getFunctionName (Type type, SourceOp op) const |
| Public Member Functions inherited from mlir::ConvertOpToLLVMPattern< SourceOp > | |
| ConvertOpToLLVMPattern (const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1) | |
| LogicalResult | matchAndRewrite (Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final |
| Wrappers around the RewritePattern methods that pass the derived op type. | |
| LogicalResult | matchAndRewrite (Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final |
| virtual LogicalResult | matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const |
| Public Member Functions inherited from mlir::ConvertToLLVMPattern | |
| ConvertToLLVMPattern (StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1) | |
Public Attributes | |
| const std::string | f32Func |
| const std::string | f64Func |
| const std::string | f32ApproxFunc |
| const std::string | f16Func |
| const std::string | i32Func |
Additional Inherited Members | |
| Public Types inherited from mlir::ConvertOpToLLVMPattern< SourceOp > | |
| using | OpAdaptor = typename SourceOp::Adaptor |
| using | OneToNOpAdaptor |
| Protected Member Functions inherited from mlir::ConvertToLLVMPattern | |
| LLVM::LLVMDialect & | getDialect () const |
| Returns the LLVM dialect. | |
| const LLVMTypeConverter * | getTypeConverter () const |
| Type | getIndexType () const |
| Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type converter. | |
| Type | getIntPtrType (unsigned addressSpace=0) const |
| Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM pointer type. | |
| Type | getVoidType () const |
| Gets the MLIR type wrapping the LLVM void type. | |
| Type | getVoidPtrType () const |
| Get the MLIR type wrapping the LLVM i8* type. | |
| Type | getPtrType (unsigned addressSpace=0) const |
| Get the MLIR type wrapping the LLVM ptr type. | |
| Value | getStridedElementPtr (ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const |
| Convenience wrapper for the corresponding helper utility. | |
| bool | isConvertibleAndHasIdentityMaps (MemRefType type) const |
| Returns if the given memref type is convertible to LLVM and has an identity layout map. | |
| Type | getElementPtrType (MemRefType type) const |
| Returns the type of a pointer to an element of the memref. | |
| void | getMemRefDescriptorSizes (Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const |
| Computes sizes, strides and buffer size of memRefType with identity layout. | |
| Value | getSizeInBytes (Location loc, Type type, ConversionPatternRewriter &rewriter) const |
| Computes the size of type in bytes. | |
| Value | getNumElements (Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const |
| Computes total number of elements for the given MemRef and dynamicSizes. | |
| MemRefDescriptor | createMemRefDescriptor (Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const |
| Creates and populates a canonical memref descriptor struct. | |
| Value | copyUnrankedDescriptor (OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const |
| Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to stack-allocated memory (otherwise) and returns the new descriptor. | |
| LogicalResult | copyUnrankedDescriptors (OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const |
| Copies the memory descriptor for any operands that were unranked descriptors originally to heap-allocated memory (if toDynamic is true) or to stack-allocated memory (otherwise). | |
| Static Protected Member Functions inherited from mlir::ConvertToLLVMPattern | |
| static Value | createIndexAttrConstant (OpBuilder &builder, Location loc, Type resultType, int64_t value) |
| Create a constant Op producing a value of resultType from an index-typed integer attribute. | |
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i32Type depending on the element type and the fastMathFlag of that Op, if present.
The function declaration is added in case it was not added before.
If the input values are of bf16 type (or f16 type if f16Func is empty), the value is first casted to f32, the function called and then the result casted back.
Example with NVVM: exp_f32 = math.exp arg_f32 : f32
will be transformed into llvm.call @__nv_expf(arg_f32) : (f32) -> f32
If the fastMathFlag attribute of SourceOp is afn or fast, this Op lowers to the approximate calculation function.
Also example with NVVM: exp_f32 = math.exp arg_f32 fastmath<afn> : f32
will be transformed into llvm.call @__nv_fast_expf(arg_f32) : (f32) -> f32
Final example with NVVM: pow_f32 = math.fpowi arg_f32, arg_i32
will be transformed into llvm.call @__nv_powif(arg_f32, arg_i32) : (f32, i32) -> f32
Definition at line 55 of file OpToFuncCallLowering.h.
|
inlineexplicit |
Definition at line 57 of file OpToFuncCallLowering.h.
References mlir::ConvertOpToLLVMPattern< SourceOp >::ConvertOpToLLVMPattern(), f16Func, f32ApproxFunc, f32Func, f64Func, and i32Func.
|
inline |
Definition at line 156 of file OpToFuncCallLowering.h.
References b, mlir::LocationAttr::findInstanceOfOrUnknown(), mlir::Operation::getContext(), mlir::Operation::getLoc(), mlir::Operation::getParentOfType(), and mlir::SymbolTable::lookupNearestSymbolFrom().
Referenced by matchAndRewrite().
|
inline |
Definition at line 177 of file OpToFuncCallLowering.h.
References f16Func, f32ApproxFunc, f32Func, f64Func, i32Func, and mlir::Type::isInteger().
Referenced by matchAndRewrite().
|
inline |
Definition at line 151 of file OpToFuncCallLowering.h.
References mlir::ValueRange::getTypes().
Referenced by matchAndRewrite().
|
inlineoverridevirtual |
Methods that operate on the SourceOp type.
One of these must be overridden by the derived pattern class.
Reimplemented from mlir::ConvertOpToLLVMPattern< SourceOp >.
Definition at line 67 of file OpToFuncCallLowering.h.
References appendOrGetFuncOp(), getFunctionName(), getFunctionType(), maybeCast(), and success().
|
inline |
Definition at line 137 of file OpToFuncCallLowering.h.
References f16Func, mlir::Builder::getContext(), mlir::Value::getLoc(), and mlir::Value::getType().
Referenced by matchAndRewrite().
| const std::string mlir::OpToFuncCallLowering< SourceOp >::f16Func |
Definition at line 203 of file OpToFuncCallLowering.h.
Referenced by getFunctionName(), maybeCast(), and OpToFuncCallLowering().
| const std::string mlir::OpToFuncCallLowering< SourceOp >::f32ApproxFunc |
Definition at line 202 of file OpToFuncCallLowering.h.
Referenced by getFunctionName(), and OpToFuncCallLowering().
| const std::string mlir::OpToFuncCallLowering< SourceOp >::f32Func |
Definition at line 200 of file OpToFuncCallLowering.h.
Referenced by getFunctionName(), and OpToFuncCallLowering().
| const std::string mlir::OpToFuncCallLowering< SourceOp >::f64Func |
Definition at line 201 of file OpToFuncCallLowering.h.
Referenced by getFunctionName(), and OpToFuncCallLowering().
| const std::string mlir::OpToFuncCallLowering< SourceOp >::i32Func |
Definition at line 204 of file OpToFuncCallLowering.h.
Referenced by getFunctionName(), and OpToFuncCallLowering().