|
MLIR 22.0.0git
|
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLoweringOptions. More...
Public Member Functions | |
| TransposeOpLowering (LoweringOptions loweringOptions, MLIRContext *context, int benefit) | |
| LogicalResult | matchAndRewrite (vector::TransposeOp op, PatternRewriter &rewriter) const override |
| Public Member Functions inherited from mlir::OpRewritePattern< vector::TransposeOp > | |
| OpRewritePattern (MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={}) | |
| Patterns must specify the root operation name they match against, and can also specify the benefit of the pattern matching and a list of generated ops. | |
| Public Member Functions inherited from mlir::detail::OpOrInterfaceRewritePatternBase< vector::TransposeOp > | |
| LogicalResult | matchAndRewrite (Operation *op, PatternRewriter &rewriter) const final |
| Wrapper around the RewritePattern method that passes the derived op type. | |
| Public Member Functions inherited from mlir::RewritePattern | |
| virtual | ~RewritePattern ()=default |
| Public Member Functions inherited from mlir::Pattern | |
| ArrayRef< OperationName > | getGeneratedOps () const |
| Return a list of operations that may be generated when rewriting an operation instance with this pattern. | |
| std::optional< OperationName > | getRootKind () const |
| Return the root node that this pattern matches. | |
| std::optional< TypeID > | getRootInterfaceID () const |
| Return the interface ID used to match the root operation of this pattern. | |
| std::optional< TypeID > | getRootTraitID () const |
| Return the trait ID used to match the root operation of this pattern. | |
| PatternBenefit | getBenefit () const |
| Return the benefit (the inverse of "cost") of matching this pattern. | |
| bool | hasBoundedRewriteRecursion () const |
| Returns true if this pattern is known to result in recursive application, i.e. | |
| MLIRContext * | getContext () const |
| Return the MLIRContext used to create this pattern. | |
| StringRef | getDebugName () const |
| Return a readable name for this pattern. | |
| void | setDebugName (StringRef name) |
| Set the human readable debug name used for this pattern. | |
| ArrayRef< StringRef > | getDebugLabels () const |
| Return the set of debug labels attached to this pattern. | |
| void | addDebugLabels (ArrayRef< StringRef > labels) |
| Add the provided debug labels to this pattern. | |
| void | addDebugLabels (StringRef label) |
| Public Member Functions inherited from mlir::ConvertOpToLLVMPattern< memref::TransposeOp > | |
| ConvertOpToLLVMPattern (const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1) | |
| LogicalResult | matchAndRewrite (Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final |
| Wrappers around the RewritePattern methods that pass the derived op type. | |
| Public Member Functions inherited from mlir::ConvertToLLVMPattern | |
| ConvertToLLVMPattern (StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1) | |
Additional Inherited Members | |
| Public Types inherited from mlir::OpRewritePattern< vector::TransposeOp > | |
| using | Base |
| Type alias to allow derived classes to inherit constructors with using Base::Base;. | |
| Public Types inherited from mlir::ConvertOpToLLVMPattern< memref::TransposeOp > | |
| using | OpAdaptor |
| using | OneToNOpAdaptor |
| Static Public Member Functions inherited from mlir::RewritePattern | |
| template<typename T, typename... Args> | |
| static std::unique_ptr< T > | create (Args &&...args) |
| This method provides a convenient interface for creating and initializing derived rewrite patterns of the given type T. | |
| Protected Member Functions inherited from mlir::RewritePattern | |
| Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchInterfaceOpTypeTag tag, TypeID interfaceID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Protected Member Functions inherited from mlir::Pattern | |
| Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern with a certain benefit that matches the operation with the given root name. | |
| Pattern (MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation type. | |
| Pattern (MatchInterfaceOpTypeTag tag, TypeID interfaceID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation that implements the interface defined by the provided interfaceID. | |
| Pattern (MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation that implements the trait defined by the provided traitID. | |
| void | setHasBoundedRewriteRecursion (bool hasBoundedRecursionArg=true) |
| Set the flag detailing if this pattern has bounded rewrite recursion or not. | |
| Protected Member Functions inherited from mlir::ConvertToLLVMPattern | |
| LLVM::LLVMDialect & | getDialect () const |
| Returns the LLVM dialect. | |
| const LLVMTypeConverter * | getTypeConverter () const |
| Type | getIndexType () const |
| Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type converter. | |
| Type | getIntPtrType (unsigned addressSpace=0) const |
| Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM pointer type. | |
| Type | getVoidType () const |
| Gets the MLIR type wrapping the LLVM void type. | |
| Type | getVoidPtrType () const |
| Get the MLIR type wrapping the LLVM i8* type. | |
| Type | getPtrType (unsigned addressSpace=0) const |
| Get the MLIR type wrapping the LLVM ptr type. | |
| Value | getStridedElementPtr (ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const |
| Convenience wrapper for the corresponding helper utility. | |
| bool | isConvertibleAndHasIdentityMaps (MemRefType type) const |
| Returns if the given memref type is convertible to LLVM and has an identity layout map. | |
| Type | getElementPtrType (MemRefType type) const |
| Returns the type of a pointer to an element of the memref. | |
| void | getMemRefDescriptorSizes (Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const |
| Computes sizes, strides and buffer size of memRefType with identity layout. | |
| Value | getSizeInBytes (Location loc, Type type, ConversionPatternRewriter &rewriter) const |
| Computes the size of type in bytes. | |
| Value | getNumElements (Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const |
| Computes total number of elements for the given MemRef and dynamicSizes. | |
| MemRefDescriptor | createMemRefDescriptor (Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const |
| Creates and populates a canonical memref descriptor struct. | |
| Value | copyUnrankedDescriptor (OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const |
| Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to stack-allocated memory (otherwise) and returns the new descriptor. | |
| LogicalResult | copyUnrankedDescriptors (OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const |
| Copies the memory descriptor for any operands that were unranked descriptors originally to heap-allocated memory (if toDynamic is true) or to stack-allocated memory (otherwise). | |
| Static Protected Member Functions inherited from mlir::ConvertToLLVMPattern | |
| static Value | createIndexAttrConstant (OpBuilder &builder, Location loc, Type resultType, int64_t value) |
| Create a constant Op producing a value of resultType from an index-typed integer attribute. | |
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLoweringOptions.
The lowering supports 2-D transpose cases and n-D cases that have been decomposed into 2-D transposition slices. For example, a 3-D transpose:
%0 = vector.transpose arg0, [2, 0, 1] : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
could be sliced into 2-D transposes by tiling two of its dimensions to one of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
%0 = vector.transpose arg0, [2, 0, 1] : vector<1x4x8xf32> to vector<8x1x4xf32>
This lowering will analyze the n-D vector.transpose and determine if it's a supported 2-D transposition slice where any of the AVX2 patterns can be applied.
Definition at line 207 of file AVXTranspose.cpp.
|
inline |
Definition at line 211 of file AVXTranspose.cpp.
References mlir::OpRewritePattern< vector::TransposeOp >::OpRewritePattern().
|
inlineoverride |
Definition at line 216 of file AVXTranspose.cpp.
References getElementType(), mlir::Builder::getZeroAttr(), mlir::vector::isTranspose2DSlice(), mlir::RewriterBase::notifyMatchFailure(), mlir::RewriterBase::replaceOp(), success(), mlir::x86vector::avx2::transpose4x8xf32(), and mlir::x86vector::avx2::transpose8x8xf32().