22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/ErrorHandling.h"
26 #define DEBUG_TYPE "walk-rewriter"
35 reachableBlocks.insert(entryBlock);
38 while (!worklist.empty()) {
39 Block *block = worklist.pop_back_val();
42 if (reachableBlocks.contains(successor))
44 worklist.push_back(successor);
45 reachableBlocks.insert(successor);
51 struct WalkAndApplyPatternsAction final
52 : tracing::ActionImpl<WalkAndApplyPatternsAction> {
54 using ActionImpl::ActionImpl;
55 static constexpr StringLiteral tag =
"walk-and-apply-patterns";
56 void print(raw_ostream &os)
const override { os << tag; }
59 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
64 struct ErasedOpsListener final : RewriterBase::ForwardingListener {
67 void notifyOperationErased(Operation *op)
override {
69 ForwardingListener::notifyOperationErased(op);
72 void notifyBlockErased(Block *block)
override {
73 checkErasure(block->getParentOp());
74 ForwardingListener::notifyBlockErased(block);
77 void checkErasure(Operation *op)
const {
78 Operation *ancestorOp = op;
79 while (ancestorOp && ancestorOp != visitedOp)
80 ancestorOp = ancestorOp->getParentOp();
82 if (ancestorOp != visitedOp)
83 llvm::report_fatal_error(
84 "unsupported erasure in WalkPatternRewriter; "
85 "erasure is only supported for matched ops and their descendants");
88 Operation *visitedOp =
nullptr;
96 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
98 llvm::report_fatal_error(
"walk pattern rewriter input IR failed to verify");
103 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
104 ErasedOpsListener erasedListener(listener);
116 struct RegionReachableOpIterator {
117 RegionReachableOpIterator(
Region *region) : region(region) {
118 regionIt = region->
begin();
119 if (regionIt != region->
end())
120 blockIt = regionIt->begin();
121 if (!llvm::hasSingleElement(*region))
126 assert(regionIt != region->
end());
127 hasVisitedRegions =
false;
128 if (blockIt == regionIt->end()) {
130 while (regionIt != region->
end() &&
131 !reachableBlocks.contains(&*regionIt))
133 if (regionIt != region->
end())
134 blockIt = regionIt->begin();
138 if (blockIt != regionIt->end()) {
139 LDBG() <<
"Incrementing block iterator, next op: "
152 bool hasVisitedRegions =
false;
158 LDBG() <<
"Starting walk-based pattern rewrite driver";
164 assert(worklist.empty());
169 worklist.push_back({®ion});
170 while (!worklist.empty()) {
171 RegionReachableOpIterator &it = worklist.back();
172 if (it.regionIt == it.region->end()) {
177 if (it.blockIt == it.regionIt->end()) {
185 if (!it.hasVisitedRegions) {
186 it.hasVisitedRegions =
true;
188 if (nestedRegion.empty())
190 worklist.push_back({&nestedRegion});
196 if (&it != &worklist.back())
203 LDBG() <<
"Visiting op: "
205 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
206 erasedListener.visitedOp = op;
209 LDBG() <<
"\tOp matched and rewritten";
215 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
217 llvm::report_fatal_error(
218 "walk pattern rewriter result IR failed to verify");
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
OpListType::iterator iterator
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
SuccessorRange getSuccessors()
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
Include the generated interface declarations.
static void findReachableBlocks(Region ®ion, DenseSet< Block * > &reachableBlocks)
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ForwardingListener(OpBuilder::Listener *listener)