MLIR  22.0.0git
WalkPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements mlir::walkAndApplyPatterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/ErrorHandling.h"
25 
26 #define DEBUG_TYPE "walk-rewriter"
27 
28 namespace mlir {
29 
30 // Find all reachable blocks in the region and add them to the visitedBlocks
31 // set.
32 static void findReachableBlocks(Region &region,
33  DenseSet<Block *> &reachableBlocks) {
34  Block *entryBlock = &region.front();
35  reachableBlocks.insert(entryBlock);
36  // Traverse the CFG and add all reachable blocks to the blockList.
37  SmallVector<Block *> worklist({entryBlock});
38  while (!worklist.empty()) {
39  Block *block = worklist.pop_back_val();
40  Operation *terminator = &block->back();
41  for (Block *successor : terminator->getSuccessors()) {
42  if (reachableBlocks.contains(successor))
43  continue;
44  worklist.push_back(successor);
45  reachableBlocks.insert(successor);
46  }
47  }
48 }
49 
50 namespace {
51 struct WalkAndApplyPatternsAction final
52  : tracing::ActionImpl<WalkAndApplyPatternsAction> {
53  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
54  using ActionImpl::ActionImpl;
55  static constexpr StringLiteral tag = "walk-and-apply-patterns";
56  void print(raw_ostream &os) const override { os << tag; }
57 };
58 
59 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60 // Forwarding listener to guard against unsupported erasures of non-descendant
61 // ops/blocks. Because we use walk-based pattern application, erasing the
62 // op/block from the *next* iteration (e.g., a user of the visited op) is not
63 // valid. Note that this is only used with expensive pattern API checks.
64 struct ErasedOpsListener final : RewriterBase::ForwardingListener {
66 
67  void notifyOperationErased(Operation *op) override {
68  checkErasure(op);
69  ForwardingListener::notifyOperationErased(op);
70  }
71 
72  void notifyBlockErased(Block *block) override {
73  checkErasure(block->getParentOp());
74  ForwardingListener::notifyBlockErased(block);
75  }
76 
77  void checkErasure(Operation *op) const {
78  Operation *ancestorOp = op;
79  while (ancestorOp && ancestorOp != visitedOp)
80  ancestorOp = ancestorOp->getParentOp();
81 
82  if (ancestorOp != visitedOp)
83  llvm::report_fatal_error(
84  "unsupported erasure in WalkPatternRewriter; "
85  "erasure is only supported for matched ops and their descendants");
86  }
87 
88  Operation *visitedOp = nullptr;
89 };
90 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
91 } // namespace
92 
95  RewriterBase::Listener *listener) {
96 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
97  if (failed(verify(op)))
98  llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
99 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
100 
101  MLIRContext *ctx = op->getContext();
102  PatternRewriter rewriter(ctx);
103 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
104  ErasedOpsListener erasedListener(listener);
105  rewriter.setListener(&erasedListener);
106 #else
107  rewriter.setListener(listener);
108 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
109 
110  PatternApplicator applicator(patterns);
111  applicator.applyDefaultCostModel();
112 
113  // Iterator on all reachable operations in the region.
114  // Also keep track if we visited the nested regions of the current op
115  // already to drive the post-order traversal.
116  struct RegionReachableOpIterator {
117  RegionReachableOpIterator(Region *region) : region(region) {
118  regionIt = region->begin();
119  if (regionIt != region->end())
120  blockIt = regionIt->begin();
121  if (!llvm::hasSingleElement(*region))
122  findReachableBlocks(*region, reachableBlocks);
123  }
124  // Advance the iterator to the next reachable operation.
125  void advance() {
126  assert(regionIt != region->end());
127  hasVisitedRegions = false;
128  if (blockIt == regionIt->end()) {
129  ++regionIt;
130  while (regionIt != region->end() &&
131  !reachableBlocks.contains(&*regionIt))
132  ++regionIt;
133  if (regionIt != region->end())
134  blockIt = regionIt->begin();
135  return;
136  }
137  ++blockIt;
138  if (blockIt != regionIt->end()) {
139  LDBG() << "Incrementing block iterator, next op: "
140  << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
141  }
142  }
143  // The region we're iterating over.
144  Region *region;
145  // The Block currently being iterated over.
146  Region::iterator regionIt;
147  // The Operation currently being iterated over.
148  Block::iterator blockIt;
149  // The set of blocks that are reachable in the current region.
150  DenseSet<Block *> reachableBlocks;
151  // Whether we've visited the nested regions of the current op already.
152  bool hasVisitedRegions = false;
153  };
154 
155  // Worklist of regions to visit to drive the post-order traversal.
157 
158  LDBG() << "Starting walk-based pattern rewrite driver";
159  ctx->executeAction<WalkAndApplyPatternsAction>(
160  [&] {
161  // Perform a post-order traversal of the regions, visiting each
162  // reachable operation.
163  for (Region &region : op->getRegions()) {
164  assert(worklist.empty());
165  if (region.empty())
166  continue;
167 
168  // Prime the worklist with the entry block of this region.
169  worklist.push_back({&region});
170  while (!worklist.empty()) {
171  RegionReachableOpIterator &it = worklist.back();
172  if (it.regionIt == it.region->end()) {
173  // We're done with this region.
174  worklist.pop_back();
175  continue;
176  }
177  if (it.blockIt == it.regionIt->end()) {
178  // We're done with this block.
179  it.advance();
180  continue;
181  }
182  Operation *op = &*it.blockIt;
183  // If we haven't visited the nested regions of this op yet,
184  // enqueue them.
185  if (!it.hasVisitedRegions) {
186  it.hasVisitedRegions = true;
187  for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
188  if (nestedRegion.empty())
189  continue;
190  worklist.push_back({&nestedRegion});
191  }
192  }
193  // If we're not at the back of the worklist, we've enqueued some
194  // nested region for processing. We'll come back to this op later
195  // (post-order)
196  if (&it != &worklist.back())
197  continue;
198 
199  // Preemptively increment the iterator, in case the current op
200  // would be erased.
201  it.advance();
202 
203  LDBG() << "Visiting op: "
204  << OpWithFlags(op, OpPrintingFlags().skipRegions());
205 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
206  erasedListener.visitedOp = op;
207 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
208  if (succeeded(applicator.matchAndRewrite(op, rewriter)))
209  LDBG() << "\tOp matched and rewritten";
210  }
211  }
212  },
213  {op});
214 
215 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
216  if (failed(verify(op)))
217  llvm::report_fatal_error(
218  "walk pattern rewriter result IR failed to verify");
219 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
220 }
221 
222 } // namespace mlir
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation & back()
Definition: Block.h:152
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:274
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
SuccessorRange getSuccessors()
Definition: Operation.h:703
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
static void findReachableBlocks(Region &region, DenseSet< Block * > &reachableBlocks)
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:422