|
MLIR 22.0.0git
|
Common verifier for reshape-like types. More...
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
Public Member Functions | |
| LogicalResult | matchAndRewrite (CollapseOpTy collapseOp, PatternRewriter &rewriter) const override |
| Public Member Functions inherited from mlir::OpRewritePattern< CollapseOpTy > | |
| OpRewritePattern (MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={}) | |
| Patterns must specify the root operation name they match against, and can also specify the benefit of the pattern matching and a list of generated ops. | |
| Public Member Functions inherited from mlir::detail::OpOrInterfaceRewritePatternBase< CollapseOpTy > | |
| LogicalResult | matchAndRewrite (Operation *op, PatternRewriter &rewriter) const final |
| Wrapper around the RewritePattern method that passes the derived op type. | |
| Public Member Functions inherited from mlir::RewritePattern | |
| virtual | ~RewritePattern ()=default |
| Public Member Functions inherited from mlir::Pattern | |
| ArrayRef< OperationName > | getGeneratedOps () const |
| Return a list of operations that may be generated when rewriting an operation instance with this pattern. | |
| std::optional< OperationName > | getRootKind () const |
| Return the root node that this pattern matches. | |
| std::optional< TypeID > | getRootInterfaceID () const |
| Return the interface ID used to match the root operation of this pattern. | |
| std::optional< TypeID > | getRootTraitID () const |
| Return the trait ID used to match the root operation of this pattern. | |
| PatternBenefit | getBenefit () const |
| Return the benefit (the inverse of "cost") of matching this pattern. | |
| bool | hasBoundedRewriteRecursion () const |
| Returns true if this pattern is known to result in recursive application, i.e. | |
| MLIRContext * | getContext () const |
| Return the MLIRContext used to create this pattern. | |
| StringRef | getDebugName () const |
| Return a readable name for this pattern. | |
| void | setDebugName (StringRef name) |
| Set the human readable debug name used for this pattern. | |
| ArrayRef< StringRef > | getDebugLabels () const |
| Return the set of debug labels attached to this pattern. | |
| void | addDebugLabels (ArrayRef< StringRef > labels) |
| Add the provided debug labels to this pattern. | |
| void | addDebugLabels (StringRef label) |
Additional Inherited Members | |
| Public Types inherited from mlir::OpRewritePattern< CollapseOpTy > | |
| using | Base |
| Type alias to allow derived classes to inherit constructors with using Base::Base;. | |
| Static Public Member Functions inherited from mlir::RewritePattern | |
| template<typename T, typename... Args> | |
| static std::unique_ptr< T > | create (Args &&...args) |
| This method provides a convenient interface for creating and initializing derived rewrite patterns of the given type T. | |
| Protected Member Functions inherited from mlir::RewritePattern | |
| Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchInterfaceOpTypeTag tag, TypeID interfaceID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Protected Member Functions inherited from mlir::Pattern | |
| Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern with a certain benefit that matches the operation with the given root name. | |
| Pattern (MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation type. | |
| Pattern (MatchInterfaceOpTypeTag tag, TypeID interfaceID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation that implements the interface defined by the provided interfaceID. | |
| Pattern (MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation that implements the trait defined by the provided traitID. | |
| void | setHasBoundedRewriteRecursion (bool hasBoundedRecursionArg=true) |
| Set the flag detailing if this pattern has bounded rewrite recursion or not. | |
Common verifier for reshape-like types.
Fills expandedType and collapsedType with the proper src or result type. template <typename Op, typename T> static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion) {
unsigned expandedRank = expandedType.getRank(); unsigned collapsedRank = collapsedType.getRank(); if (expandedRank < collapsedRank) return op.emitOpError("expected the expanded type, ") << expandedType << " to have a higher (or same) rank " << "than the collapsed type, " << collapsedType << '.';
if (collapsedRank != op.getReassociation().size()) return op.emitOpError("expected collapsed rank (") << collapsedRank << ") to equal the number of reassociation maps (" << op.getReassociation().size() << ").";
auto maps = op.getReassociationMaps(); for (auto it : llvm::enumerate(maps)) if (it.value().getNumDims() != expandedRank) return op.emitOpError("expected reassociation map #") << it.index() << " to have size equal to the expanded rank (" << expandedRank << "), but it is " << it.value().getNumDims() << '.';
int invalidIdx = 0; if (!isReassociationValid(maps, &invalidIdx)) return op.emitOpError("expected reassociation map #") << invalidIdx << " to be valid and contiguous.";
return reshapeLikeShapesAreCompatible( [&](const Twine &msg) { return op->emitOpError(msg); }, collapsedType.getShape(), expandedType.getShape(), op.getReassociationIndices(), isExpansion); }
/ Verify that shapes of the reshaped types using following rule: / if a dimension in the collapsed type is static, then the corresponding / dimensions in the expanded shape should be / a) static / b) the product should be same as the collaped shape. LogicalResult reshapeLikeShapesAreCompatible( function_ref<LogicalResult(const Twine &)> emitError, ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
/ Returns true iff the type is a MemRefType and has a non-identity layout. bool hasNonIdentityLayout(Type type);
enum class ReshapeOpKind { kExpand, kCollapse };
/ Pattern to collapse producer/consumer reshape ops that are both collapsing / dimensions or are both expanding dimensions. template <typename ReshapeOpTy, ReshapeOpKind opKind> struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> { using OpRewritePattern<ReshapeOpTy>::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override { auto srcReshapeOp = reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>(); if (!srcReshapeOp) return failure();
ShapedType resultType = reshapeOp.getResultType();
if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) || hasNonIdentityLayout(reshapeOp.getSrc().getType()) || hasNonIdentityLayout(reshapeOp.getResult().getType())) return failure();
std::optional<SmallVector<ReassociationIndices>> reassociationIndices = composeReassociationIndices(srcReshapeOp.getReassociationIndices(), reshapeOp.getReassociationIndices(), rewriter.getContext()); if (!reassociationIndices) return failure();
if constexpr (opKind == ReshapeOpKind::kExpand) { SmallVector<OpFoldResult> outputShape( getMixedValues(reshapeOp.getStaticOutputShape(), reshapeOp.getOutputShape(), rewriter)); rewriter.replaceOpWithNewOp<ReshapeOpTy>( reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices, outputShape); } else { rewriter.replaceOpWithNewOp<ReshapeOpTy>( reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); } return success(); } };
/ Pattern to compose / collapse_shape(expand_shape(src, reassociation_1), reassociation_2). / In that case both srcType and resultType can be expressed as a function / of intermediateType. / In order to demonstrate the approach, let's assume that rank(srcType) > rank(resultType), i.e. the resulting operation should be collapse_shape. In that case, we can iterate over every set of indices in reassociation_2 and try to find ids of sets of indices in reassociation_1` that cover it completely.
Example:
%0 = tensor.expand_shape arg [[0], [1], [2, 3]] : tensor<?x?x?xi64> into tensor<?x?x?x1xi64> %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor<?x?x?x1xi64> into tensor<?x?xi64>
can be canonicalized into
%0 = tensor.collapse_shape arg [[0, 1], [2]] : tensor<?x?x?xi64> into tensor<?x?xi64>
because [0] and [1] from expand_shape reassociation cover completely [0, 1] from collapse_shape. If it is impossible to find such union of indices, then we fail. When rank(srcType) < rank(resultType), then we just swap reassociation_1 reassociation_2 and produce expand_shape.
Definition at line 258 of file ReshapeOpsUtils.h.
|
inlineoverride |
Definition at line 260 of file ReshapeOpsUtils.h.
References mlir::arith::ConstantIndexOp::create(), mlir::getConstantIntValue(), mlir::Builder::getIndexAttr(), indices, mlir::RewriterBase::replaceOpWithNewOp(), result, and success().