MLIR 22.0.0git
WalkPatternRewriteDriver.cpp
Go to the documentation of this file.
1//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements mlir::walkAndApplyPatterns.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "mlir/IR/MLIRContext.h"
16#include "mlir/IR/Operation.h"
19#include "mlir/IR/Verifier.h"
20#include "mlir/IR/Visitors.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25
26#define DEBUG_TYPE "walk-rewriter"
27
28namespace mlir {
29
30// Find all reachable blocks in the region and add them to the visitedBlocks
31// set.
32static void findReachableBlocks(Region &region,
33 DenseSet<Block *> &reachableBlocks) {
34 Block *entryBlock = &region.front();
35 reachableBlocks.insert(entryBlock);
36 // Traverse the CFG and add all reachable blocks to the blockList.
37 SmallVector<Block *> worklist({entryBlock});
38 while (!worklist.empty()) {
39 Block *block = worklist.pop_back_val();
40 Operation *terminator = &block->back();
41 for (Block *successor : terminator->getSuccessors()) {
42 if (reachableBlocks.contains(successor))
43 continue;
44 worklist.push_back(successor);
45 reachableBlocks.insert(successor);
46 }
47 }
48}
49
50namespace {
51struct WalkAndApplyPatternsAction final
52 : tracing::ActionImpl<WalkAndApplyPatternsAction> {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
54 using ActionImpl::ActionImpl;
55 static constexpr StringLiteral tag = "walk-and-apply-patterns";
56 void print(raw_ostream &os) const override { os << tag; }
57};
58
59#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60// Forwarding listener to guard against unsupported erasures of non-descendant
61// ops/blocks. Because we use walk-based pattern application, erasing the
62// op/block from the *next* iteration (e.g., a user of the visited op) is not
63// valid. Note that this is only used with expensive pattern API checks.
64struct ErasedOpsListener final : RewriterBase::ForwardingListener {
65 using RewriterBase::ForwardingListener::ForwardingListener;
66
67 void notifyOperationErased(Operation *op) override {
68 checkErasure(op);
69 ForwardingListener::notifyOperationErased(op);
70 }
71
72 void notifyBlockErased(Block *block) override {
73 checkErasure(block->getParentOp());
74 ForwardingListener::notifyBlockErased(block);
75 }
76
77 void checkErasure(Operation *op) const {
78 Operation *ancestorOp = op;
79 while (ancestorOp && ancestorOp != visitedOp)
80 ancestorOp = ancestorOp->getParentOp();
81
82 if (ancestorOp != visitedOp)
83 llvm::report_fatal_error(
84 "unsupported erasure in WalkPatternRewriter; "
85 "erasure is only supported for matched ops and their descendants");
86 }
87
88 Operation *visitedOp = nullptr;
89};
90#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
91} // namespace
92
95 RewriterBase::Listener *listener) {
96#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
97 if (failed(verify(op)))
98 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
99#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
100
101 MLIRContext *ctx = op->getContext();
102 PatternRewriter rewriter(ctx);
103#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
104 ErasedOpsListener erasedListener(listener);
105 rewriter.setListener(&erasedListener);
106#else
107 rewriter.setListener(listener);
108#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
109
110 PatternApplicator applicator(patterns);
111 applicator.applyDefaultCostModel();
112
113 // Iterator on all reachable operations in the region.
114 // Also keep track if we visited the nested regions of the current op
115 // already to drive the post-order traversal.
116 struct RegionReachableOpIterator {
117 RegionReachableOpIterator(Region *region) : region(region) {
118 regionIt = region->begin();
119 if (regionIt != region->end())
120 blockIt = regionIt->begin();
121 if (!llvm::hasSingleElement(*region))
122 findReachableBlocks(*region, reachableBlocks);
123 }
124 // Advance the iterator to the next reachable operation.
125 void advance() {
126 assert(regionIt != region->end());
127 hasVisitedRegions = false;
128 if (blockIt == regionIt->end()) {
129 ++regionIt;
130 while (regionIt != region->end() &&
131 !reachableBlocks.contains(&*regionIt))
132 ++regionIt;
133 if (regionIt != region->end())
134 blockIt = regionIt->begin();
135 return;
136 }
137 ++blockIt;
138 if (blockIt != regionIt->end()) {
139 LDBG() << "Incrementing block iterator, next op: "
140 << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
141 }
142 }
143 // The region we're iterating over.
144 Region *region;
145 // The Block currently being iterated over.
146 Region::iterator regionIt;
147 // The Operation currently being iterated over.
148 Block::iterator blockIt;
149 // The set of blocks that are reachable in the current region.
150 DenseSet<Block *> reachableBlocks;
151 // Whether we've visited the nested regions of the current op already.
152 bool hasVisitedRegions = false;
153 };
154
155 // Worklist of regions to visit to drive the post-order traversal.
157
158 LDBG() << "Starting walk-based pattern rewrite driver";
159 ctx->executeAction<WalkAndApplyPatternsAction>(
160 [&] {
161 // Perform a post-order traversal of the regions, visiting each
162 // reachable operation.
163 for (Region &region : op->getRegions()) {
164 assert(worklist.empty());
165 if (region.empty())
166 continue;
167
168 // Prime the worklist with the entry block of this region.
169 worklist.push_back({&region});
170 while (!worklist.empty()) {
171 RegionReachableOpIterator &it = worklist.back();
172 if (it.regionIt == it.region->end()) {
173 // We're done with this region.
174 worklist.pop_back();
175 continue;
176 }
177 if (it.blockIt == it.regionIt->end()) {
178 // We're done with this block.
179 it.advance();
180 continue;
181 }
182 Operation *op = &*it.blockIt;
183 // If we haven't visited the nested regions of this op yet,
184 // enqueue them.
185 if (!it.hasVisitedRegions) {
186 it.hasVisitedRegions = true;
187 for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
188 if (nestedRegion.empty())
189 continue;
190 worklist.push_back({&nestedRegion});
191 }
192 }
193 // If we're not at the back of the worklist, we've enqueued some
194 // nested region for processing. We'll come back to this op later
195 // (post-order)
196 if (&it != &worklist.back())
197 continue;
198
199 // Preemptively increment the iterator, in case the current op
200 // would be erased.
201 it.advance();
202
203 LDBG() << "Visiting op: "
204 << OpWithFlags(op, OpPrintingFlags().skipRegions());
205#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
206 erasedListener.visitedOp = op;
207#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
208 if (succeeded(applicator.matchAndRewrite(op, rewriter)))
209 LDBG() << "\tOp matched and rewritten";
210 }
211 }
212 },
213 {op});
214
215#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
216 if (failed(verify(op)))
217 llvm::report_fatal_error(
218 "walk pattern rewriter result IR failed to verify");
219#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
220}
221
222} // namespace mlir
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
Operation & back()
Definition Block.h:152
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:316
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
SuccessorRange getSuccessors()
Definition Operation.h:703
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator end()
Definition Region.h:56
iterator begin()
Definition Region.h:55
BlockListType::iterator iterator
Definition Region.h:52
Include the generated interface declarations.
static void findReachableBlocks(Region &region, DenseSet< Block * > &reachableBlocks)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
A listener that forwards all notifications to another listener.