MLIR  20.0.0git
RemoveShapeConstraints.cpp
Go to the documentation of this file.
1 //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS
18 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
23 namespace {
24 /// Removal patterns.
25 class RemoveCstrBroadcastableOp
26  : public OpRewritePattern<shape::CstrBroadcastableOp> {
27 public:
29 
30  LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
31  PatternRewriter &rewriter) const override {
32  rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
33  return success();
34  }
35 };
36 
37 class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
38 public:
40 
41  LogicalResult matchAndRewrite(shape::CstrEqOp op,
42  PatternRewriter &rewriter) const override {
43  rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
44  return success();
45  }
46 };
47 
48 /// Removal pass.
49 class RemoveShapeConstraintsPass
50  : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
51 
52  void runOnOperation() override {
53  MLIRContext &ctx = getContext();
54 
55  RewritePatternSet patterns(&ctx);
57 
58  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
59  }
60 };
61 
62 } // namespace
63 
65  patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
66  patterns.getContext());
67 }
68 
69 std::unique_ptr<OperationPass<func::FuncOp>>
71  return std::make_unique<RemoveShapeConstraintsPass>();
72 }
static MLIRContext * getContext(OpFoldResult val)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns)
std::unique_ptr< OperationPass< func::FuncOp > > createRemoveShapeConstraintsPass()
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362