12 #include "llvm/Support/Debug.h"
19 template <
typename ReshapeOp>
22 bool foldSingleUseOnly =
false)
24 foldSingleUseOnly(foldSingleUseOnly) {}
29 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
34 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
41 !llvm::hasSingleElement(resultShapes))
47 loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48 if (emptyTensor.
getType() != reshapeOp.getResultType()) {
50 reshapeOp, reshapeOp.getResultType(), emptyTensor);
52 rewriter.
replaceOp(reshapeOp, emptyTensor);
58 bool foldSingleUseOnly =
false;
63 struct FoldEmptyTensorWithExtractSliceOp
67 bool foldSingleUseOnly =
false)
69 foldSingleUseOnly(foldSingleUseOnly) {}
74 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
79 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
85 sliceOp.getType().getElementType(),
86 sliceOp.getType().getEncoding());
93 bool foldSingleUseOnly =
false;
99 bool foldSingleUseOnly) {
100 patterns.
add<FoldEmptyTensorWithExtractSliceOp,
101 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
102 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...