12 #include "llvm/Support/Debug.h"
19 template <
typename ReshapeOp>
22 bool foldSingleUseOnly =
false)
24 foldSingleUseOnly(foldSingleUseOnly) {}
26 LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
29 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
34 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
41 !llvm::hasSingleElement(resultShapes))
47 loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48 if (emptyTensor.
getType() != reshapeOp.getResultType()) {
50 reshapeOp, reshapeOp.getResultType(), emptyTensor);
52 rewriter.
replaceOp(reshapeOp, emptyTensor);
58 bool foldSingleUseOnly =
false;
63 struct FoldEmptyTensorWithExtractSliceOp
67 bool foldSingleUseOnly =
false)
69 foldSingleUseOnly(foldSingleUseOnly) {}
71 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
74 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
79 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
85 sliceOp.getType().getElementType(),
86 sliceOp.getType().getEncoding());
93 bool foldSingleUseOnly =
false;
101 LogicalResult matchAndRewrite(PackOp packOp,
104 auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
110 if (packOp.getPaddingValue())
114 rewriter.
replaceOp(packOp, packOp.getDest());
125 LogicalResult matchAndRewrite(UnPackOp unPackOp,
128 auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
133 rewriter.
replaceOp(unPackOp, unPackOp.getDest());
143 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
145 auto concatOperands = concatOp.getInputs();
146 if (concatOperands.empty()) {
149 auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
153 auto isDefinedByEmptyOp = [](
Value v) ->
bool {
154 return v.getDefiningOp<tensor::EmptyOp>();
156 if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
158 concatOp,
"not all operands are defined by an empty op");
161 if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
163 "failed to get result shape");
166 concatOp, resultShape[0], concatOp.getResultType().getElementType());
174 bool foldSingleUseOnly) {
175 patterns.
add<FoldEmptyTensorWithExtractSliceOp,
176 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
177 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
179 patterns.
add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
180 FoldEmptyTensorWithUnPackOp>(patterns.
getContext(),
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...