18template <
typename ReshapeOp>
20 FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
21 bool foldSingleUseOnly =
false)
22 : OpRewritePattern<ReshapeOp>(ctx, benefit),
23 foldSingleUseOnly(foldSingleUseOnly) {}
25 LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
26 PatternRewriter &rewriter)
const override {
28 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
33 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
37 Location loc = reshapeOp.getLoc();
40 !llvm::hasSingleElement(resultShapes))
46 EmptyOp::create(rewriter, loc, resultShapes[0],
47 reshapeOp.getResultType().getElementType());
48 if (emptyTensor.
getType() != reshapeOp.getResultType()) {
50 reshapeOp, reshapeOp.getResultType(), emptyTensor);
52 rewriter.
replaceOp(reshapeOp, emptyTensor);
58 bool foldSingleUseOnly =
false;
63struct FoldEmptyTensorWithExtractSliceOp
65 FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66 PatternBenefit benefit = 1,
67 bool foldSingleUseOnly =
false)
68 : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69 foldSingleUseOnly(foldSingleUseOnly) {}
71 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72 PatternRewriter &rewriter)
const override {
74 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
79 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
84 auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85 sliceOp.getType().getElementType(),
86 sliceOp.getType().getEncoding());
93 bool foldSingleUseOnly =
false;
98 using OpRewritePattern<ConcatOp>::OpRewritePattern;
100 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
101 PatternRewriter &rewriter)
const override {
102 auto concatOperands = concatOp.getInputs();
103 if (concatOperands.empty()) {
106 auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
110 auto isDefinedByEmptyOp = [](Value v) ->
bool {
111 return v.getDefiningOp<tensor::EmptyOp>();
113 if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
115 concatOp,
"not all operands are defined by an empty op");
117 SmallVector<SmallVector<OpFoldResult>> resultShape;
118 if (
failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
120 "failed to get result shape");
123 concatOp, resultShape[0], concatOp.getResultType().
getElementType());
131 bool foldSingleUseOnly) {
132 patterns.add<FoldEmptyTensorWithExtractSliceOp,
133 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
134 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
135 patterns.getContext(), 1, foldSingleUseOnly);
static Type getElementType(Type type)
Determine the element type of type.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...