MLIR  20.0.0git
WalkPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements mlir::walkAndApplyPatterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/Verifier.h"
19 #include "mlir/IR/Visitors.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/ErrorHandling.h"
23 
24 #define DEBUG_TYPE "walk-rewriter"
25 
26 namespace mlir {
27 
28 namespace {
29 struct WalkAndApplyPatternsAction final
30  : tracing::ActionImpl<WalkAndApplyPatternsAction> {
31  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
32  using ActionImpl::ActionImpl;
33  static constexpr StringLiteral tag = "walk-and-apply-patterns";
34  void print(raw_ostream &os) const override { os << tag; }
35 };
36 
37 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
38 // Forwarding listener to guard against unsupported erasures of non-descendant
39 // ops/blocks. Because we use walk-based pattern application, erasing the
40 // op/block from the *next* iteration (e.g., a user of the visited op) is not
41 // valid. Note that this is only used with expensive pattern API checks.
42 struct ErasedOpsListener final : RewriterBase::ForwardingListener {
44 
45  void notifyOperationErased(Operation *op) override {
46  checkErasure(op);
47  ForwardingListener::notifyOperationErased(op);
48  }
49 
50  void notifyBlockErased(Block *block) override {
51  checkErasure(block->getParentOp());
52  ForwardingListener::notifyBlockErased(block);
53  }
54 
55  void checkErasure(Operation *op) const {
56  Operation *ancestorOp = op;
57  while (ancestorOp && ancestorOp != visitedOp)
58  ancestorOp = ancestorOp->getParentOp();
59 
60  if (ancestorOp != visitedOp)
61  llvm::report_fatal_error(
62  "unsupported erasure in WalkPatternRewriter; "
63  "erasure is only supported for matched ops and their descendants");
64  }
65 
66  Operation *visitedOp = nullptr;
67 };
68 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
69 } // namespace
70 
73  RewriterBase::Listener *listener) {
74 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
75  if (failed(verify(op)))
76  llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
77 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
78 
79  MLIRContext *ctx = op->getContext();
80  PatternRewriter rewriter(ctx);
81 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
82  ErasedOpsListener erasedListener(listener);
83  rewriter.setListener(&erasedListener);
84 #else
85  rewriter.setListener(listener);
86 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
87 
88  PatternApplicator applicator(patterns);
89  applicator.applyDefaultCostModel();
90 
91  ctx->executeAction<WalkAndApplyPatternsAction>(
92  [&] {
93  for (Region &region : op->getRegions()) {
94  region.walk([&](Operation *visitedOp) {
95  LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
96  llvm::dbgs(), OpPrintingFlags().skipRegions());
97  llvm::dbgs() << "\n";);
98 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99  erasedListener.visitedOp = visitedOp;
100 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
101  if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
102  LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
103  }
104  });
105  }
106  },
107  {op});
108 
109 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
110  if (failed(verify(op)))
111  llvm::report_fatal_error(
112  "walk pattern rewriter result IR failed to verify");
113 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
114 }
115 
116 } // namespace mlir
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:264
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:325
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:464