MLIR  21.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Verifier.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32 #include <random>
33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34 
35 using namespace mlir;
36 
37 #define DEBUG_TYPE "greedy-rewriter"
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // Debugging Infrastructure
43 //===----------------------------------------------------------------------===//
44 
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46 /// A helper struct that performs various "expensive checks" to detect broken
47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48 /// broken if:
49 /// * IR does not verify after pattern application / folding.
50 /// * Pattern returns "failure" but the IR has changed.
51 /// * Pattern returns "success" but the IR has not changed.
52 ///
53 /// This struct stores finger prints of ops to determine whether the IR has
54 /// changed or not.
55 struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56  ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57  : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58 
59  /// Compute finger prints of the given op and its nested ops.
60  void computeFingerPrints(Operation *topLevel) {
61  this->topLevel = topLevel;
62  this->topLevelFingerPrint.emplace(topLevel);
63  topLevel->walk([&](Operation *op) {
64  fingerprints.try_emplace(op, op, /*includeNested=*/false);
65  });
66  }
67 
68  /// Clear all finger prints.
69  void clear() {
70  topLevel = nullptr;
71  topLevelFingerPrint.reset();
72  fingerprints.clear();
73  }
74 
75  void notifyRewriteSuccess() {
76  if (!topLevel)
77  return;
78 
79  // Make sure that the IR still verifies.
80  if (failed(verify(topLevel)))
81  llvm::report_fatal_error("IR failed to verify after pattern application");
82 
83  // Pattern application success => IR must have changed.
84  OperationFingerPrint afterFingerPrint(topLevel);
85  if (*topLevelFingerPrint == afterFingerPrint) {
86  // Note: Run "mlir-opt -debug" to see which pattern is broken.
87  llvm::report_fatal_error(
88  "pattern returned success but IR did not change");
89  }
90  for (const auto &it : fingerprints) {
91  // Skip top-level op, its finger print is never invalidated.
92  if (it.first == topLevel)
93  continue;
94  // Note: Finger print computation may crash when an op was erased
95  // without notifying the rewriter. (Run with ASAN to see where the op was
96  // erased; the op was probably erased directly, bypassing the rewriter
97  // API.) Finger print computation does may not crash if a new op was
98  // created at the same memory location. (But then the finger print should
99  // have changed.)
100  if (it.second !=
101  OperationFingerPrint(it.first, /*includeNested=*/false)) {
102  // Note: Run "mlir-opt -debug" to see which pattern is broken.
103  llvm::report_fatal_error("operation finger print changed");
104  }
105  }
106  }
107 
108  void notifyRewriteFailure() {
109  if (!topLevel)
110  return;
111 
112  // Pattern application failure => IR must not have changed.
113  OperationFingerPrint afterFingerPrint(topLevel);
114  if (*topLevelFingerPrint != afterFingerPrint) {
115  // Note: Run "mlir-opt -debug" to see which pattern is broken.
116  llvm::report_fatal_error("pattern returned failure but IR did change");
117  }
118  }
119 
120  void notifyFoldingSuccess() {
121  if (!topLevel)
122  return;
123 
124  // Make sure that the IR still verifies.
125  if (failed(verify(topLevel)))
126  llvm::report_fatal_error("IR failed to verify after folding");
127  }
128 
129 protected:
130  /// Invalidate the finger print of the given op, i.e., remove it from the map.
131  void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
132 
133  void notifyBlockErased(Block *block) override {
135 
136  // The block structure (number of blocks, types of block arguments, etc.)
137  // is part of the fingerprint of the parent op.
138  // TODO: The parent op fingerprint should also be invalidated when modifying
139  // the block arguments of a block, but we do not have a
140  // `notifyBlockModified` callback yet.
141  invalidateFingerPrint(block->getParentOp());
142  }
143 
144  void notifyOperationInserted(Operation *op,
145  OpBuilder::InsertPoint previous) override {
147  invalidateFingerPrint(op->getParentOp());
148  }
149 
150  void notifyOperationModified(Operation *op) override {
152  invalidateFingerPrint(op);
153  }
154 
155  void notifyOperationErased(Operation *op) override {
157  op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158  }
159 
160  /// Operation finger prints to detect invalid pattern API usage. IR is checked
161  /// against these finger prints after pattern application to detect cases
162  /// where IR was modified directly, bypassing the rewriter API.
164 
165  /// Top-level operation of the current greedy rewrite.
166  Operation *topLevel = nullptr;
167 
168  /// Finger print of the top-level operation.
169  std::optional<OperationFingerPrint> topLevelFingerPrint;
170 };
171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172 
173 #ifndef NDEBUG
174 static Operation *getDumpRootOp(Operation *op) {
175  // Dump the parent op so that materialized constants are visible. If the op
176  // is a top-level op, dump it directly.
177  if (Operation *parentOp = op->getParentOp())
178  return parentOp;
179  return op;
180 }
181 static void logSuccessfulFolding(Operation *op) {
182  llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183  op->dump();
184  llvm::dbgs() << "\n\n";
185 }
186 #endif // NDEBUG
187 
188 //===----------------------------------------------------------------------===//
189 // Worklist
190 //===----------------------------------------------------------------------===//
191 
192 /// A LIFO worklist of operations with efficient removal and set semantics.
193 ///
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
198 class Worklist {
199 public:
200  Worklist();
201 
202  /// Clear the worklist.
203  void clear();
204 
205  /// Return whether the worklist is empty.
206  bool empty() const;
207 
208  /// Push an operation to the end of the worklist, unless the operation is
209  /// already on the worklist.
210  void push(Operation *op);
211 
212  /// Pop the an operation from the end of the worklist. Only allowed on
213  /// non-empty worklists.
214  Operation *pop();
215 
216  /// Remove an operation from the worklist.
217  void remove(Operation *op);
218 
219  /// Reverse the worklist.
220  void reverse();
221 
222 protected:
223  /// The worklist of operations.
224  std::vector<Operation *> list;
225 
226  /// A mapping of operations to positions in `list`.
228 };
229 
230 Worklist::Worklist() { list.reserve(64); }
231 
232 void Worklist::clear() {
233  list.clear();
234  map.clear();
235 }
236 
237 bool Worklist::empty() const {
238  // Skip all nullptr.
239  return !llvm::any_of(list,
240  [](Operation *op) { return static_cast<bool>(op); });
241 }
242 
243 void Worklist::push(Operation *op) {
244  assert(op && "cannot push nullptr to worklist");
245  // Check to see if the worklist already contains this op.
246  if (!map.insert({op, list.size()}).second)
247  return;
248  list.push_back(op);
249 }
250 
251 Operation *Worklist::pop() {
252  assert(!empty() && "cannot pop from empty worklist");
253  // Skip and remove all trailing nullptr.
254  while (!list.back())
255  list.pop_back();
256  Operation *op = list.back();
257  list.pop_back();
258  map.erase(op);
259  // Cleanup: Remove all trailing nullptr.
260  while (!list.empty() && !list.back())
261  list.pop_back();
262  return op;
263 }
264 
265 void Worklist::remove(Operation *op) {
266  assert(op && "cannot remove nullptr from worklist");
267  auto it = map.find(op);
268  if (it != map.end()) {
269  assert(list[it->second] == op && "malformed worklist data structure");
270  list[it->second] = nullptr;
271  map.erase(it);
272  }
273 }
274 
275 void Worklist::reverse() {
276  std::reverse(list.begin(), list.end());
277  for (size_t i = 0, e = list.size(); i != e; ++i)
278  map[list[i]] = i;
279 }
280 
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist : public Worklist {
287 public:
288  RandomizedWorklist() : Worklist() {
289  generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290  }
291 
292  /// Pop a random non-empty op from the worklist.
293  Operation *pop() {
294  Operation *op = nullptr;
295  do {
296  assert(!list.empty() && "cannot pop from empty worklist");
297  int64_t pos = generator() % list.size();
298  op = list[pos];
299  list.erase(list.begin() + pos);
300  for (int64_t i = pos, e = list.size(); i < e; ++i)
301  map[list[i]] = i;
302  map.erase(op);
303  } while (!op);
304  return op;
305  }
306 
307 private:
308  std::minstd_rand0 generator;
309 };
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311 
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
315 
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
318 ///
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323 protected:
324  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
326  const GreedyRewriteConfig &config);
327 
328  /// Add the given operation to the worklist.
329  void addSingleOpToWorklist(Operation *op);
330 
331  /// Add the given operation and its ancestors to the worklist.
332  void addToWorklist(Operation *op);
333 
334  /// Notify the driver that the specified operation may have been modified
335  /// in-place. The operation is added to the worklist.
336  void notifyOperationModified(Operation *op) override;
337 
338  /// Notify the driver that the specified operation was inserted. Update the
339  /// worklist as needed: The operation is enqueued depending on scope and
340  /// strict mode.
341  void notifyOperationInserted(Operation *op,
342  OpBuilder::InsertPoint previous) override;
343 
344  /// Notify the driver that the specified operation was removed. Update the
345  /// worklist as needed: The operation and its children are removed from the
346  /// worklist.
347  void notifyOperationErased(Operation *op) override;
348 
349  /// Notify the driver that the specified operation was replaced. Update the
350  /// worklist as needed: New users are added enqueued.
351  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352 
353  /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354  /// reached. Return `true` if any IR was changed.
355  bool processWorklist();
356 
357  /// The pattern rewriter that is used for making IR modifications and is
358  /// passed to rewrite patterns.
359  PatternRewriter rewriter;
360 
361  /// The worklist for this transformation keeps track of the operations that
362  /// need to be (re)visited.
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364  RandomizedWorklist worklist;
365 #else
366  Worklist worklist;
367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368 
369  /// Configuration information for how to simplify.
371 
372  /// The list of ops we are restricting our rewrites to. These include the
373  /// supplied set of ops as well as new ops created while rewriting those ops
374  /// depending on `strictMode`. This set is not maintained when
375  /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377 
378 private:
379  /// Look over the provided operands for any defining operations that should
380  /// be re-added to the worklist. This function should be called when an
381  /// operation is modified or removed, as it may trigger further
382  /// simplifications.
383  void addOperandsToWorklist(Operation *op);
384 
385  /// Notify the driver that the given block was inserted.
386  void notifyBlockInserted(Block *block, Region *previous,
387  Region::iterator previousIt) override;
388 
389  /// Notify the driver that the given block is about to be removed.
390  void notifyBlockErased(Block *block) override;
391 
392  /// For debugging only: Notify the driver of a pattern match failure.
393  void
394  notifyMatchFailure(Location loc,
395  function_ref<void(Diagnostic &)> reasonCallback) override;
396 
397 #ifndef NDEBUG
398  /// A logger used to emit information during the application process.
399  llvm::ScopedPrinter logger{llvm::dbgs()};
400 #endif
401 
402  /// The low-level pattern applicator.
403  PatternApplicator matcher;
404 
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406  ExpensiveChecks expensiveChecks;
407 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
408 };
409 } // namespace
410 
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
414  : rewriter(ctx), config(config), matcher(patterns)
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
416  // clang-format off
417  , expensiveChecks(
418  /*driver=*/this,
419  /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
420  : nullptr)
421 // clang-format on
422 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
423 {
424  // Apply a simple cost model based solely on pattern benefit.
425  matcher.applyDefaultCostModel();
426 
427  // Set up listener.
428 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
429  // Send IR notifications to the debug handler. This handler will then forward
430  // all notifications to this GreedyPatternRewriteDriver.
431  rewriter.setListener(&expensiveChecks);
432 #else
433  rewriter.setListener(this);
434 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
435 }
436 
437 bool GreedyPatternRewriteDriver::processWorklist() {
438 #ifndef NDEBUG
439  const char *logLineComment =
440  "//===-------------------------------------------===//\n";
441 
442  /// A utility function to log a process result for the given reason.
443  auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
444  logger.unindent();
445  logger.startLine() << "} -> " << result;
446  if (!msg.isTriviallyEmpty())
447  logger.getOStream() << " : " << msg;
448  logger.getOStream() << "\n";
449  };
450  auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
451  logResult(result, msg);
452  logger.startLine() << logLineComment;
453  };
454 #endif
455 
456  bool changed = false;
457  int64_t numRewrites = 0;
458  while (!worklist.empty() &&
459  (numRewrites < config.getMaxNumRewrites() ||
460  config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
461  auto *op = worklist.pop();
462 
463  LLVM_DEBUG({
464  logger.getOStream() << "\n";
465  logger.startLine() << logLineComment;
466  logger.startLine() << "Processing operation : '" << op->getName() << "'("
467  << op << ") {\n";
468  logger.indent();
469 
470  // If the operation has no regions, just print it here.
471  if (op->getNumRegions() == 0) {
472  op->print(
473  logger.startLine(),
474  OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
475  logger.getOStream() << "\n\n";
476  }
477  });
478 
479  // If the operation is trivially dead - remove it.
480  if (isOpTriviallyDead(op)) {
481  rewriter.eraseOp(op);
482  changed = true;
483 
484  LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
485  continue;
486  }
487 
488  // Try to fold this op. Do not fold constant ops. That would lead to an
489  // infinite folding loop, as every constant op would be folded to an
490  // Attribute and then immediately be rematerialized as a constant op, which
491  // is then put on the worklist.
492  if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
493  SmallVector<OpFoldResult> foldResults;
494  if (succeeded(op->fold(foldResults))) {
495  LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
496 #ifndef NDEBUG
497  Operation *dumpRootOp = getDumpRootOp(op);
498 #endif // NDEBUG
499  if (foldResults.empty()) {
500  // Op was modified in-place.
501  notifyOperationModified(op);
502  changed = true;
503  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
504 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
505  expensiveChecks.notifyFoldingSuccess();
506 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
507  continue;
508  }
509 
510  // Op results can be replaced with `foldResults`.
511  assert(foldResults.size() == op->getNumResults() &&
512  "folder produced incorrect number of results");
513  OpBuilder::InsertionGuard g(rewriter);
514  rewriter.setInsertionPoint(op);
515  SmallVector<Value> replacements;
516  bool materializationSucceeded = true;
517  for (auto [ofr, resultType] :
518  llvm::zip_equal(foldResults, op->getResultTypes())) {
519  if (auto value = dyn_cast<Value>(ofr)) {
520  assert(value.getType() == resultType &&
521  "folder produced value of incorrect type");
522  replacements.push_back(value);
523  continue;
524  }
525  // Materialize Attributes as SSA values.
526  Operation *constOp = op->getDialect()->materializeConstant(
527  rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
528 
529  if (!constOp) {
530  // If materialization fails, cleanup any operations generated for
531  // the previous results.
532  llvm::SmallDenseSet<Operation *> replacementOps;
533  for (Value replacement : replacements) {
534  assert(replacement.use_empty() &&
535  "folder reused existing op for one result but constant "
536  "materialization failed for another result");
537  replacementOps.insert(replacement.getDefiningOp());
538  }
539  for (Operation *op : replacementOps) {
540  rewriter.eraseOp(op);
541  }
542 
543  materializationSucceeded = false;
544  break;
545  }
546 
547  assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
548  "materializeConstant produced op that is not a ConstantLike");
549  assert(constOp->getResultTypes()[0] == resultType &&
550  "materializeConstant produced incorrect result type");
551  replacements.push_back(constOp->getResult(0));
552  }
553 
554  if (materializationSucceeded) {
555  rewriter.replaceOp(op, replacements);
556  changed = true;
557  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
558 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
559  expensiveChecks.notifyFoldingSuccess();
560 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
561  continue;
562  }
563  }
564  }
565 
566  // Try to match one of the patterns. The rewriter is automatically
567  // notified of any necessary changes, so there is nothing else to do
568  // here.
569  auto canApplyCallback = [&](const Pattern &pattern) {
570  LLVM_DEBUG({
571  logger.getOStream() << "\n";
572  logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
573  << op->getName() << " -> (";
574  llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
575  logger.getOStream() << ")' {\n";
576  logger.indent();
577  });
578  if (RewriterBase::Listener *listener = config.getListener())
579  listener->notifyPatternBegin(pattern, op);
580  return true;
581  };
582  function_ref<bool(const Pattern &)> canApply = canApplyCallback;
583  auto onFailureCallback = [&](const Pattern &pattern) {
584  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
585  if (RewriterBase::Listener *listener = config.getListener())
586  listener->notifyPatternEnd(pattern, failure());
587  };
588  function_ref<void(const Pattern &)> onFailure = onFailureCallback;
589  auto onSuccessCallback = [&](const Pattern &pattern) {
590  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
591  if (RewriterBase::Listener *listener = config.getListener())
592  listener->notifyPatternEnd(pattern, success());
593  return success();
594  };
595  function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
596 
597 #ifdef NDEBUG
598  // Optimization: PatternApplicator callbacks are not needed when running in
599  // optimized mode and without a listener.
600  if (!config.getListener()) {
601  canApply = nullptr;
602  onFailure = nullptr;
603  onSuccess = nullptr;
604  }
605 #endif // NDEBUG
606 
607 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
608  if (config.getScope()) {
609  expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
610  }
611  auto clearFingerprints =
612  llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
613 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
614 
615  LogicalResult matchResult =
616  matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
617 
618  if (succeeded(matchResult)) {
619  LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched"));
620 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
621  expensiveChecks.notifyRewriteSuccess();
622 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
623  changed = true;
624  ++numRewrites;
625  } else {
626  LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match"));
627 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
628  expensiveChecks.notifyRewriteFailure();
629 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
630  }
631  }
632 
633  return changed;
634 }
635 
636 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
637  assert(op && "expected valid op");
638  // Gather potential ancestors while looking for a "scope" parent region.
639  SmallVector<Operation *, 8> ancestors;
640  Region *region = nullptr;
641  do {
642  ancestors.push_back(op);
643  region = op->getParentRegion();
644  if (config.getScope() == region) {
645  // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
646  for (Operation *op : ancestors)
647  addSingleOpToWorklist(op);
648  return;
649  }
650  if (region == nullptr)
651  return;
652  } while ((op = region->getParentOp()));
653 }
654 
655 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
656  if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
657  strictModeFilteredOps.contains(op))
658  worklist.push(op);
659 }
660 
661 void GreedyPatternRewriteDriver::notifyBlockInserted(
662  Block *block, Region *previous, Region::iterator previousIt) {
663  if (RewriterBase::Listener *listener = config.getListener())
664  listener->notifyBlockInserted(block, previous, previousIt);
665 }
666 
667 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
668  if (RewriterBase::Listener *listener = config.getListener())
669  listener->notifyBlockErased(block);
670 }
671 
672 void GreedyPatternRewriteDriver::notifyOperationInserted(
673  Operation *op, OpBuilder::InsertPoint previous) {
674  LLVM_DEBUG({
675  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
676  << ")\n";
677  });
678  if (RewriterBase::Listener *listener = config.getListener())
679  listener->notifyOperationInserted(op, previous);
680  if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
681  strictModeFilteredOps.insert(op);
682  addToWorklist(op);
683 }
684 
685 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
686  LLVM_DEBUG({
687  logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
688  << ")\n";
689  });
690  if (RewriterBase::Listener *listener = config.getListener())
691  listener->notifyOperationModified(op);
692  addToWorklist(op);
693 }
694 
695 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
696  for (Value operand : op->getOperands()) {
697  // If this operand currently has at most 2 users, add its defining op to the
698  // worklist. Indeed, after the op is deleted, then the operand will have at
699  // most 1 user left. If it has 0 users left, it can be deleted too,
700  // and if it has 1 user left, there may be further canonicalization
701  // opportunities.
702  if (!operand)
703  continue;
704 
705  auto *defOp = operand.getDefiningOp();
706  if (!defOp)
707  continue;
708 
709  Operation *otherUser = nullptr;
710  bool hasMoreThanTwoUses = false;
711  for (auto user : operand.getUsers()) {
712  if (user == op || user == otherUser)
713  continue;
714  if (!otherUser) {
715  otherUser = user;
716  continue;
717  }
718  hasMoreThanTwoUses = true;
719  break;
720  }
721  if (hasMoreThanTwoUses)
722  continue;
723 
724  addToWorklist(defOp);
725  }
726 }
727 
728 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
729  LLVM_DEBUG({
730  logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
731  << ")\n";
732  });
733 
734 #ifndef NDEBUG
735  // Only ops that are within the configured scope are added to the worklist of
736  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
737  // the part of the IR that is taken into account for the "expensive checks".
738  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
739  // region, as that would break the worklist handling and the expensive checks.
740  if (Region *scope = config.getScope(); scope->getParentOp() == op)
741  llvm_unreachable(
742  "scope region must not be erased during greedy pattern rewrite");
743 #endif // NDEBUG
744 
745  if (RewriterBase::Listener *listener = config.getListener())
746  listener->notifyOperationErased(op);
747 
748  addOperandsToWorklist(op);
749  worklist.remove(op);
750 
751  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
752  strictModeFilteredOps.erase(op);
753 }
754 
755 void GreedyPatternRewriteDriver::notifyOperationReplaced(
756  Operation *op, ValueRange replacement) {
757  LLVM_DEBUG({
758  logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
759  << ")\n";
760  });
761  if (RewriterBase::Listener *listener = config.getListener())
762  listener->notifyOperationReplaced(op, replacement);
763 }
764 
765 void GreedyPatternRewriteDriver::notifyMatchFailure(
766  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
767  LLVM_DEBUG({
768  Diagnostic diag(loc, DiagnosticSeverity::Remark);
769  reasonCallback(diag);
770  logger.startLine() << "** Match Failure : " << diag.str() << "\n";
771  });
772  if (RewriterBase::Listener *listener = config.getListener())
773  listener->notifyMatchFailure(loc, reasonCallback);
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // RegionPatternRewriteDriver
778 //===----------------------------------------------------------------------===//
779 
780 namespace {
781 /// This driver simplfies all ops in a region.
782 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
783 public:
784  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
787  Region &regions);
788 
789  /// Simplify ops inside `region` and simplify the region itself. Return
790  /// success if the transformation converged.
791  LogicalResult simplify(bool *changed) &&;
792 
793 private:
794  /// The region that is simplified.
795  Region &region;
796 };
797 } // namespace
798 
799 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
801  const GreedyRewriteConfig &config, Region &region)
802  : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
803  // Populate strict mode ops.
804  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
805  region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
806  }
807 }
808 
809 namespace {
810 class GreedyPatternRewriteIteration
811  : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
812 public:
813  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
814  GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
815  : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
816  iteration(iteration) {}
817  static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
818  void print(raw_ostream &os) const override {
819  os << "GreedyPatternRewriteIteration(" << iteration << ")";
820  }
821 
822 private:
823  int64_t iteration = 0;
824 };
825 } // namespace
826 
827 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
828  bool continueRewrites = false;
829  int64_t iteration = 0;
830  MLIRContext *ctx = rewriter.getContext();
831  do {
832  // Check if the iteration limit was reached.
833  if (++iteration > config.getMaxIterations() &&
834  config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
835  break;
836 
837  // New iteration: start with an empty worklist.
838  worklist.clear();
839 
840  // `OperationFolder` CSE's constant ops (and may move them into parents
841  // regions to enable more aggressive CSE'ing).
842  OperationFolder folder(ctx, this);
843  auto insertKnownConstant = [&](Operation *op) {
844  // Check for existing constants when populating the worklist. This avoids
845  // accidentally reversing the constant order during processing.
846  Attribute constValue;
847  if (matchPattern(op, m_Constant(&constValue)))
848  if (!folder.insertKnownConstant(op, constValue))
849  return true;
850  return false;
851  };
852 
853  if (!config.getUseTopDownTraversal()) {
854  // Add operations to the worklist in postorder.
855  region.walk([&](Operation *op) {
856  if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
857  addToWorklist(op);
858  });
859  } else {
860  // Add all nested operations to the worklist in preorder.
861  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
862  if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
863  addToWorklist(op);
864  return WalkResult::advance();
865  }
866  return WalkResult::skip();
867  });
868 
869  // Reverse the list so our pop-back loop processes them in-order.
870  worklist.reverse();
871  }
872 
873  ctx->executeAction<GreedyPatternRewriteIteration>(
874  [&] {
875  continueRewrites = processWorklist();
876 
877  // After applying patterns, make sure that the CFG of each of the
878  // regions is kept up to date.
879  if (config.getRegionSimplificationLevel() !=
880  GreedySimplifyRegionLevel::Disabled) {
881  continueRewrites |= succeeded(simplifyRegions(
882  rewriter, region,
883  /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
884  GreedySimplifyRegionLevel::Aggressive));
885  }
886  },
887  {&region}, iteration);
888  } while (continueRewrites);
889 
890  if (changed)
891  *changed = iteration > 1;
892 
893  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
894  return success(!continueRewrites);
895 }
896 
897 LogicalResult
901  // The top-level operation must be known to be isolated from above to
902  // prevent performing canonicalizations on operations defined at or above
903  // the region containing 'op'.
904  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
905  "patterns can only be applied to operations IsolatedFromAbove");
906 
907  // Set scope if not specified.
908  if (!config.getScope())
909  config.setScope(&region);
910 
911 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
912  if (failed(verify(config.getScope()->getParentOp())))
913  llvm::report_fatal_error(
914  "greedy pattern rewriter input IR failed to verify");
915 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
916 
917  // Start the pattern driver.
918  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
919  region);
920  LogicalResult converged = std::move(driver).simplify(changed);
921  LLVM_DEBUG(if (failed(converged)) {
922  llvm::dbgs() << "The pattern rewrite did not converge after scanning "
923  << config.getMaxIterations() << " times\n";
924  });
925  return converged;
926 }
927 
928 //===----------------------------------------------------------------------===//
929 // MultiOpPatternRewriteDriver
930 //===----------------------------------------------------------------------===//
931 
932 namespace {
933 /// This driver simplfies a list of ops.
934 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
935 public:
936  explicit MultiOpPatternRewriteDriver(
939  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
940 
941  /// Simplify `ops`. Return `success` if the transformation converged.
942  LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
943 
944 private:
945  void notifyOperationErased(Operation *op) override {
946  GreedyPatternRewriteDriver::notifyOperationErased(op);
947  if (survivingOps)
948  survivingOps->erase(op);
949  }
950 
951  /// An optional set of ops that survived the rewrite. This set is populated
952  /// at the beginning of `simplifyLocally` with the inititally provided list
953  /// of ops.
954  llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
955 };
956 } // namespace
957 
958 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
961  llvm::SmallDenseSet<Operation *, 4> *survivingOps)
962  : GreedyPatternRewriteDriver(ctx, patterns, config),
963  survivingOps(survivingOps) {
964  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
965  strictModeFilteredOps.insert_range(ops);
966 
967  if (survivingOps) {
968  survivingOps->clear();
969  survivingOps->insert_range(ops);
970  }
971 }
972 
973 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
974  bool *changed) && {
975  // Populate the initial worklist.
976  for (Operation *op : ops)
977  addSingleOpToWorklist(op);
978 
979  // Process ops on the worklist.
980  bool result = processWorklist();
981  if (changed)
982  *changed = result;
983 
984  return success(worklist.empty());
985 }
986 
987 /// Find the region that is the closest common ancestor of all given ops.
988 ///
989 /// Note: This function returns `nullptr` if there is a top-level op among the
990 /// given list of ops.
992  assert(!ops.empty() && "expected at least one op");
993  // Fast path in case there is only one op.
994  if (ops.size() == 1)
995  return ops.front()->getParentRegion();
996 
997  Region *region = ops.front()->getParentRegion();
998  ops = ops.drop_front();
999  int sz = ops.size();
1000  llvm::BitVector remainingOps(sz, true);
1001  while (region) {
1002  int pos = -1;
1003  // Iterate over all remaining ops.
1004  while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1005  // Is this op contained in `region`?
1006  if (region->findAncestorOpInRegion(*ops[pos]))
1007  remainingOps.reset(pos);
1008  }
1009  if (remainingOps.none())
1010  break;
1011  region = region->getParentRegion();
1012  }
1013  return region;
1014 }
1015 
1018  GreedyRewriteConfig config, bool *changed, bool *allErased) {
1019  if (ops.empty()) {
1020  if (changed)
1021  *changed = false;
1022  if (allErased)
1023  *allErased = true;
1024  return success();
1025  }
1026 
1027  // Determine scope of rewrite.
1028  if (!config.getScope()) {
1029  // Compute scope if none was provided. The scope will remain `nullptr` if
1030  // there is a top-level op among `ops`.
1031  config.setScope(findCommonAncestor(ops));
1032  } else {
1033  // If a scope was provided, make sure that all ops are in scope.
1034 #ifndef NDEBUG
1035  bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1036  return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
1037  });
1038  assert(allOpsInScope && "ops must be within the specified scope");
1039 #endif // NDEBUG
1040  }
1041 
1042 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1043  if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
1044  llvm::report_fatal_error(
1045  "greedy pattern rewriter input IR failed to verify");
1046 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1047 
1048  // Start the pattern driver.
1049  llvm::SmallDenseSet<Operation *, 4> surviving;
1050  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1051  config, ops,
1052  allErased ? &surviving : nullptr);
1053  LogicalResult converged = std::move(driver).simplify(ops, changed);
1054  if (allErased)
1055  *allErased = surviving.empty();
1056  LLVM_DEBUG(if (failed(converged)) {
1057  llvm::dbgs() << "The pattern rewrite did not converge after "
1058  << config.getMaxNumRewrites() << " rewrites";
1059  });
1060  return converged;
1061 }
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:264
This class represents a saved insertion point.
Definition: Builders.h:325
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition: Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
BlockListType::iterator iterator
Definition: Region.h:52
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
CRTP Implementation of an action.
Definition: Action.h:76
virtual void print(raw_ostream &os) const
Definition: Action.h:50
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:425
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:431
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:444
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:457
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:440