MLIR  22.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/Verifier.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/Support/DebugLog.h"
29 #include "llvm/Support/ScopedPrinter.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
33 #include <random>
34 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
35 
36 using namespace mlir;
37 
38 #define DEBUG_TYPE "greedy-rewriter"
39 
40 namespace {
41 
42 //===----------------------------------------------------------------------===//
43 // Debugging Infrastructure
44 //===----------------------------------------------------------------------===//
45 
46 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
47 /// A helper struct that performs various "expensive checks" to detect broken
48 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
49 /// broken if:
50 /// * IR does not verify after pattern application / folding.
51 /// * Pattern returns "failure" but the IR has changed.
52 /// * Pattern returns "success" but the IR has not changed.
53 ///
54 /// This struct stores finger prints of ops to determine whether the IR has
55 /// changed or not.
56 struct ExpensiveChecks : public RewriterBase::ForwardingListener {
57  ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
58  : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
59 
60  /// Compute finger prints of the given op and its nested ops.
61  void computeFingerPrints(Operation *topLevel) {
62  this->topLevel = topLevel;
63  this->topLevelFingerPrint.emplace(topLevel);
64  topLevel->walk([&](Operation *op) {
65  fingerprints.try_emplace(op, op, /*includeNested=*/false);
66  });
67  }
68 
69  /// Clear all finger prints.
70  void clear() {
71  topLevel = nullptr;
72  topLevelFingerPrint.reset();
73  fingerprints.clear();
74  }
75 
76  void notifyRewriteSuccess() {
77  if (!topLevel)
78  return;
79 
80  // Make sure that the IR still verifies.
81  if (failed(verify(topLevel)))
82  llvm::report_fatal_error("IR failed to verify after pattern application");
83 
84  // Pattern application success => IR must have changed.
85  OperationFingerPrint afterFingerPrint(topLevel);
86  if (*topLevelFingerPrint == afterFingerPrint) {
87  // Note: Run "mlir-opt -debug" to see which pattern is broken.
88  llvm::report_fatal_error(
89  "pattern returned success but IR did not change");
90  }
91  for (const auto &it : fingerprints) {
92  // Skip top-level op, its finger print is never invalidated.
93  if (it.first == topLevel)
94  continue;
95  // Note: Finger print computation may crash when an op was erased
96  // without notifying the rewriter. (Run with ASAN to see where the op was
97  // erased; the op was probably erased directly, bypassing the rewriter
98  // API.) Finger print computation does may not crash if a new op was
99  // created at the same memory location. (But then the finger print should
100  // have changed.)
101  if (it.second !=
102  OperationFingerPrint(it.first, /*includeNested=*/false)) {
103  // Note: Run "mlir-opt -debug" to see which pattern is broken.
104  llvm::report_fatal_error("operation finger print changed");
105  }
106  }
107  }
108 
109  void notifyRewriteFailure() {
110  if (!topLevel)
111  return;
112 
113  // Pattern application failure => IR must not have changed.
114  OperationFingerPrint afterFingerPrint(topLevel);
115  if (*topLevelFingerPrint != afterFingerPrint) {
116  // Note: Run "mlir-opt -debug" to see which pattern is broken.
117  llvm::report_fatal_error("pattern returned failure but IR did change");
118  }
119  }
120 
121  void notifyFoldingSuccess() {
122  if (!topLevel)
123  return;
124 
125  // Make sure that the IR still verifies.
126  if (failed(verify(topLevel)))
127  llvm::report_fatal_error("IR failed to verify after folding");
128  }
129 
130 protected:
131  /// Invalidate the finger print of the given op, i.e., remove it from the map.
132  void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
133 
134  void notifyBlockErased(Block *block) override {
136 
137  // The block structure (number of blocks, types of block arguments, etc.)
138  // is part of the fingerprint of the parent op.
139  // TODO: The parent op fingerprint should also be invalidated when modifying
140  // the block arguments of a block, but we do not have a
141  // `notifyBlockModified` callback yet.
142  invalidateFingerPrint(block->getParentOp());
143  }
144 
145  void notifyOperationInserted(Operation *op,
146  OpBuilder::InsertPoint previous) override {
148  invalidateFingerPrint(op->getParentOp());
149  }
150 
151  void notifyOperationModified(Operation *op) override {
153  invalidateFingerPrint(op);
154  }
155 
156  void notifyOperationErased(Operation *op) override {
158  op->walk([this](Operation *op) { invalidateFingerPrint(op); });
159  }
160 
161  /// Operation finger prints to detect invalid pattern API usage. IR is checked
162  /// against these finger prints after pattern application to detect cases
163  /// where IR was modified directly, bypassing the rewriter API.
165 
166  /// Top-level operation of the current greedy rewrite.
167  Operation *topLevel = nullptr;
168 
169  /// Finger print of the top-level operation.
170  std::optional<OperationFingerPrint> topLevelFingerPrint;
171 };
172 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
173 
174 #ifndef NDEBUG
175 static Operation *getDumpRootOp(Operation *op) {
176  // Dump the parent op so that materialized constants are visible. If the op
177  // is a top-level op, dump it directly.
178  if (Operation *parentOp = op->getParentOp())
179  return parentOp;
180  return op;
181 }
182 static void logSuccessfulFolding(Operation *op) {
183  LDBG() << "// *** IR Dump After Successful Folding ***\n"
184  << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs());
185 }
186 #endif // NDEBUG
187 
188 //===----------------------------------------------------------------------===//
189 // Worklist
190 //===----------------------------------------------------------------------===//
191 
192 /// A LIFO worklist of operations with efficient removal and set semantics.
193 ///
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
198 class Worklist {
199 public:
200  Worklist();
201 
202  /// Clear the worklist.
203  void clear();
204 
205  /// Return whether the worklist is empty.
206  bool empty() const;
207 
208  /// Push an operation to the end of the worklist, unless the operation is
209  /// already on the worklist.
210  void push(Operation *op);
211 
212  /// Pop the an operation from the end of the worklist. Only allowed on
213  /// non-empty worklists.
214  Operation *pop();
215 
216  /// Remove an operation from the worklist.
217  void remove(Operation *op);
218 
219  /// Reverse the worklist.
220  void reverse();
221 
222 protected:
223  /// The worklist of operations.
224  std::vector<Operation *> list;
225 
226  /// A mapping of operations to positions in `list`.
228 };
229 
230 Worklist::Worklist() { list.reserve(64); }
231 
232 void Worklist::clear() {
233  list.clear();
234  map.clear();
235 }
236 
237 bool Worklist::empty() const {
238  // Skip all nullptr.
239  return !llvm::any_of(list,
240  [](Operation *op) { return static_cast<bool>(op); });
241 }
242 
243 void Worklist::push(Operation *op) {
244  assert(op && "cannot push nullptr to worklist");
245  // Check to see if the worklist already contains this op.
246  if (!map.insert({op, list.size()}).second)
247  return;
248  list.push_back(op);
249 }
250 
251 Operation *Worklist::pop() {
252  assert(!empty() && "cannot pop from empty worklist");
253  // Skip and remove all trailing nullptr.
254  while (!list.back())
255  list.pop_back();
256  Operation *op = list.back();
257  list.pop_back();
258  map.erase(op);
259  // Cleanup: Remove all trailing nullptr.
260  while (!list.empty() && !list.back())
261  list.pop_back();
262  return op;
263 }
264 
265 void Worklist::remove(Operation *op) {
266  assert(op && "cannot remove nullptr from worklist");
267  auto it = map.find(op);
268  if (it != map.end()) {
269  assert(list[it->second] == op && "malformed worklist data structure");
270  list[it->second] = nullptr;
271  map.erase(it);
272  }
273 }
274 
275 void Worklist::reverse() {
276  std::reverse(list.begin(), list.end());
277  for (size_t i = 0, e = list.size(); i != e; ++i)
278  map[list[i]] = i;
279 }
280 
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist : public Worklist {
287 public:
288  RandomizedWorklist() : Worklist() {
289  generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290  }
291 
292  /// Pop a random non-empty op from the worklist.
293  Operation *pop() {
294  Operation *op = nullptr;
295  do {
296  assert(!list.empty() && "cannot pop from empty worklist");
297  int64_t pos = generator() % list.size();
298  op = list[pos];
299  list.erase(list.begin() + pos);
300  for (int64_t i = pos, e = list.size(); i < e; ++i)
301  map[list[i]] = i;
302  map.erase(op);
303  } while (!op);
304  return op;
305  }
306 
307 private:
308  std::minstd_rand0 generator;
309 };
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311 
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
315 
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
318 ///
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323 protected:
324  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
326  const GreedyRewriteConfig &config);
327 
328  /// Add the given operation to the worklist.
329  void addSingleOpToWorklist(Operation *op);
330 
331  /// Add the given operation and its ancestors to the worklist.
332  void addToWorklist(Operation *op);
333 
334  /// Notify the driver that the specified operation may have been modified
335  /// in-place. The operation is added to the worklist.
336  void notifyOperationModified(Operation *op) override;
337 
338  /// Notify the driver that the specified operation was inserted. Update the
339  /// worklist as needed: The operation is enqueued depending on scope and
340  /// strict mode.
341  void notifyOperationInserted(Operation *op,
342  OpBuilder::InsertPoint previous) override;
343 
344  /// Notify the driver that the specified operation was removed. Update the
345  /// worklist as needed: The operation and its children are removed from the
346  /// worklist.
347  void notifyOperationErased(Operation *op) override;
348 
349  /// Notify the driver that the specified operation was replaced. Update the
350  /// worklist as needed: New users are added enqueued.
351  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352 
353  /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354  /// reached. Return `true` if any IR was changed.
355  bool processWorklist();
356 
357  /// The pattern rewriter that is used for making IR modifications and is
358  /// passed to rewrite patterns.
359  PatternRewriter rewriter;
360 
361  /// The worklist for this transformation keeps track of the operations that
362  /// need to be (re)visited.
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364  RandomizedWorklist worklist;
365 #else
366  Worklist worklist;
367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368 
369  /// Configuration information for how to simplify.
371 
372  /// The list of ops we are restricting our rewrites to. These include the
373  /// supplied set of ops as well as new ops created while rewriting those ops
374  /// depending on `strictMode`. This set is not maintained when
375  /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377 
378 private:
379  /// Look over the provided operands for any defining operations that should
380  /// be re-added to the worklist. This function should be called when an
381  /// operation is modified or removed, as it may trigger further
382  /// simplifications.
383  void addOperandsToWorklist(Operation *op);
384 
385  /// Notify the driver that the given block was inserted.
386  void notifyBlockInserted(Block *block, Region *previous,
387  Region::iterator previousIt) override;
388 
389  /// Notify the driver that the given block is about to be removed.
390  void notifyBlockErased(Block *block) override;
391 
392  /// For debugging only: Notify the driver of a pattern match failure.
393  void
394  notifyMatchFailure(Location loc,
395  function_ref<void(Diagnostic &)> reasonCallback) override;
396 
397 #ifndef NDEBUG
398  /// A raw output stream used to prefix the debug log.
399 
400  llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
401  llvm::dbgs()};
402  /// A logger used to emit information during the application process.
403  llvm::ScopedPrinter logger{os};
404 #endif
405 
406  /// The low-level pattern applicator.
407  PatternApplicator matcher;
408 
409 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
410  ExpensiveChecks expensiveChecks;
411 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412 };
413 } // namespace
414 
415 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
418  : rewriter(ctx), config(config), matcher(patterns)
419 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
420  // clang-format off
421  , expensiveChecks(
422  /*driver=*/this,
423  /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
424  : nullptr)
425 // clang-format on
426 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
427 {
428  // Apply a simple cost model based solely on pattern benefit.
429  matcher.applyDefaultCostModel();
430 
431  // Set up listener.
432 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
433  // Send IR notifications to the debug handler. This handler will then forward
434  // all notifications to this GreedyPatternRewriteDriver.
435  rewriter.setListener(&expensiveChecks);
436 #else
437  rewriter.setListener(this);
438 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
439 }
440 
441 bool GreedyPatternRewriteDriver::processWorklist() {
442 #ifndef NDEBUG
443  const char *logLineComment =
444  "//===-------------------------------------------===//\n";
445 
446  /// A utility function to log a process result for the given reason.
447  auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
448  logger.unindent();
449  logger.startLine() << "} -> " << result;
450  if (!msg.isTriviallyEmpty())
451  logger.getOStream() << " : " << msg;
452  logger.getOStream() << "\n";
453  };
454  auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
455  logResult(result, msg);
456  logger.startLine() << logLineComment;
457  };
458 #endif
459 
460  bool changed = false;
461  int64_t numRewrites = 0;
462  while (!worklist.empty() &&
463  (numRewrites < config.getMaxNumRewrites() ||
464  config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
465  auto *op = worklist.pop();
466 
467  LLVM_DEBUG({
468  logger.getOStream() << "\n";
469  logger.startLine() << logLineComment;
470  logger.startLine() << "Processing operation : '" << op->getName() << "'("
471  << op << ") {\n";
472  logger.indent();
473 
474  // If the operation has no regions, just print it here.
475  if (op->getNumRegions() == 0) {
476  op->print(
477  logger.startLine(),
478  OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
479  logger.getOStream() << "\n\n";
480  }
481  });
482 
483  // If the operation is trivially dead - remove it.
484  if (isOpTriviallyDead(op)) {
485  rewriter.eraseOp(op);
486  changed = true;
487 
488  LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
489  continue;
490  }
491 
492  // Try to fold this op. Do not fold constant ops. That would lead to an
493  // infinite folding loop, as every constant op would be folded to an
494  // Attribute and then immediately be rematerialized as a constant op, which
495  // is then put on the worklist.
496  if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
497  SmallVector<OpFoldResult> foldResults;
498  if (succeeded(op->fold(foldResults))) {
499  LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
500 #ifndef NDEBUG
501  Operation *dumpRootOp = getDumpRootOp(op);
502 #endif // NDEBUG
503  if (foldResults.empty()) {
504  // Op was modified in-place.
505  notifyOperationModified(op);
506  changed = true;
507  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
508 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
509  expensiveChecks.notifyFoldingSuccess();
510 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
511  continue;
512  }
513 
514  // Op results can be replaced with `foldResults`.
515  assert(foldResults.size() == op->getNumResults() &&
516  "folder produced incorrect number of results");
517  OpBuilder::InsertionGuard g(rewriter);
518  rewriter.setInsertionPoint(op);
519  SmallVector<Value> replacements;
520  bool materializationSucceeded = true;
521  for (auto [ofr, resultType] :
522  llvm::zip_equal(foldResults, op->getResultTypes())) {
523  if (auto value = dyn_cast<Value>(ofr)) {
524  assert(value.getType() == resultType &&
525  "folder produced value of incorrect type");
526  replacements.push_back(value);
527  continue;
528  }
529  // Materialize Attributes as SSA values.
530  Operation *constOp = op->getDialect()->materializeConstant(
531  rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
532 
533  if (!constOp) {
534  // If materialization fails, cleanup any operations generated for
535  // the previous results.
536  llvm::SmallDenseSet<Operation *> replacementOps;
537  for (Value replacement : replacements) {
538  assert(replacement.use_empty() &&
539  "folder reused existing op for one result but constant "
540  "materialization failed for another result");
541  replacementOps.insert(replacement.getDefiningOp());
542  }
543  for (Operation *op : replacementOps) {
544  rewriter.eraseOp(op);
545  }
546 
547  materializationSucceeded = false;
548  break;
549  }
550 
551  assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
552  "materializeConstant produced op that is not a ConstantLike");
553  assert(constOp->getResultTypes()[0] == resultType &&
554  "materializeConstant produced incorrect result type");
555  replacements.push_back(constOp->getResult(0));
556  }
557 
558  if (materializationSucceeded) {
559  rewriter.replaceOp(op, replacements);
560  changed = true;
561  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
562 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
563  expensiveChecks.notifyFoldingSuccess();
564 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
565  continue;
566  }
567  }
568  }
569 
570  // Try to match one of the patterns. The rewriter is automatically
571  // notified of any necessary changes, so there is nothing else to do
572  // here.
573  auto canApplyCallback = [&](const Pattern &pattern) {
574  LLVM_DEBUG({
575  logger.getOStream() << "\n";
576  logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
577  << op->getName() << " -> (";
578  llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
579  logger.getOStream() << ")' {\n";
580  logger.indent();
581  });
582  if (RewriterBase::Listener *listener = config.getListener())
583  listener->notifyPatternBegin(pattern, op);
584  return true;
585  };
586  function_ref<bool(const Pattern &)> canApply = canApplyCallback;
587  auto onFailureCallback = [&](const Pattern &pattern) {
588  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
589  if (RewriterBase::Listener *listener = config.getListener())
590  listener->notifyPatternEnd(pattern, failure());
591  };
592  function_ref<void(const Pattern &)> onFailure = onFailureCallback;
593  auto onSuccessCallback = [&](const Pattern &pattern) {
594  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
595  if (RewriterBase::Listener *listener = config.getListener())
596  listener->notifyPatternEnd(pattern, success());
597  return success();
598  };
599  function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
600 
601 #ifdef NDEBUG
602  // Optimization: PatternApplicator callbacks are not needed when running in
603  // optimized mode and without a listener.
604  if (!config.getListener()) {
605  canApply = nullptr;
606  onFailure = nullptr;
607  onSuccess = nullptr;
608  }
609 #endif // NDEBUG
610 
611 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
612  if (config.getScope()) {
613  expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
614  }
615  auto clearFingerprints =
616  llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
617 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
618 
619  LogicalResult matchResult =
620  matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
621 
622  if (succeeded(matchResult)) {
623  LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched"));
624 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
625  expensiveChecks.notifyRewriteSuccess();
626 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627  changed = true;
628  ++numRewrites;
629  } else {
630  LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match"));
631 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
632  expensiveChecks.notifyRewriteFailure();
633 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
634  }
635  }
636 
637  return changed;
638 }
639 
640 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
641  assert(op && "expected valid op");
642  // Gather potential ancestors while looking for a "scope" parent region.
643  SmallVector<Operation *, 8> ancestors;
644  Region *region = nullptr;
645  do {
646  ancestors.push_back(op);
647  region = op->getParentRegion();
648  if (config.getScope() == region) {
649  // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
650  for (Operation *op : ancestors)
651  addSingleOpToWorklist(op);
652  return;
653  }
654  if (region == nullptr)
655  return;
656  } while ((op = region->getParentOp()));
657 }
658 
659 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
660  if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
661  strictModeFilteredOps.contains(op))
662  worklist.push(op);
663 }
664 
665 void GreedyPatternRewriteDriver::notifyBlockInserted(
666  Block *block, Region *previous, Region::iterator previousIt) {
667  if (RewriterBase::Listener *listener = config.getListener())
668  listener->notifyBlockInserted(block, previous, previousIt);
669 }
670 
671 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
672  if (RewriterBase::Listener *listener = config.getListener())
673  listener->notifyBlockErased(block);
674 }
675 
676 void GreedyPatternRewriteDriver::notifyOperationInserted(
677  Operation *op, OpBuilder::InsertPoint previous) {
678  LLVM_DEBUG({
679  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
680  << ")\n";
681  });
682  if (RewriterBase::Listener *listener = config.getListener())
683  listener->notifyOperationInserted(op, previous);
684  if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
685  strictModeFilteredOps.insert(op);
686  addToWorklist(op);
687 }
688 
689 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
690  LLVM_DEBUG({
691  logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
692  << ")\n";
693  });
694  if (RewriterBase::Listener *listener = config.getListener())
695  listener->notifyOperationModified(op);
696  addToWorklist(op);
697 }
698 
699 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
700  for (Value operand : op->getOperands()) {
701  // If this operand currently has at most 2 users, add its defining op to the
702  // worklist. Indeed, after the op is deleted, then the operand will have at
703  // most 1 user left. If it has 0 users left, it can be deleted too,
704  // and if it has 1 user left, there may be further canonicalization
705  // opportunities.
706  if (!operand)
707  continue;
708 
709  auto *defOp = operand.getDefiningOp();
710  if (!defOp)
711  continue;
712 
713  Operation *otherUser = nullptr;
714  bool hasMoreThanTwoUses = false;
715  for (auto user : operand.getUsers()) {
716  if (user == op || user == otherUser)
717  continue;
718  if (!otherUser) {
719  otherUser = user;
720  continue;
721  }
722  hasMoreThanTwoUses = true;
723  break;
724  }
725  if (hasMoreThanTwoUses)
726  continue;
727 
728  addToWorklist(defOp);
729  }
730 }
731 
732 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
733  LLVM_DEBUG({
734  logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
735  << ")\n";
736  });
737 
738 #ifndef NDEBUG
739  // Only ops that are within the configured scope are added to the worklist of
740  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
741  // the part of the IR that is taken into account for the "expensive checks".
742  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
743  // region, as that would break the worklist handling and the expensive checks.
744  if (Region *scope = config.getScope(); scope->getParentOp() == op)
745  llvm_unreachable(
746  "scope region must not be erased during greedy pattern rewrite");
747 #endif // NDEBUG
748 
749  if (RewriterBase::Listener *listener = config.getListener())
750  listener->notifyOperationErased(op);
751 
752  addOperandsToWorklist(op);
753  worklist.remove(op);
754 
755  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
756  strictModeFilteredOps.erase(op);
757 }
758 
759 void GreedyPatternRewriteDriver::notifyOperationReplaced(
760  Operation *op, ValueRange replacement) {
761  LLVM_DEBUG({
762  logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
763  << ")\n";
764  });
765  if (RewriterBase::Listener *listener = config.getListener())
766  listener->notifyOperationReplaced(op, replacement);
767 }
768 
769 void GreedyPatternRewriteDriver::notifyMatchFailure(
770  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
771  LLVM_DEBUG({
772  Diagnostic diag(loc, DiagnosticSeverity::Remark);
773  reasonCallback(diag);
774  logger.startLine() << "** Match Failure : " << diag.str() << "\n";
775  });
776  if (RewriterBase::Listener *listener = config.getListener())
777  listener->notifyMatchFailure(loc, reasonCallback);
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // RegionPatternRewriteDriver
782 //===----------------------------------------------------------------------===//
783 
784 namespace {
785 /// This driver simplfies all ops in a region.
786 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
787 public:
788  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
791  Region &regions);
792 
793  /// Simplify ops inside `region` and simplify the region itself. Return
794  /// success if the transformation converged.
795  LogicalResult simplify(bool *changed) &&;
796 
797 private:
798  /// The region that is simplified.
799  Region &region;
800 };
801 } // namespace
802 
803 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
805  const GreedyRewriteConfig &config, Region &region)
806  : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
807  // Populate strict mode ops.
808  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
809  region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
810  }
811 }
812 
813 namespace {
814 class GreedyPatternRewriteIteration
815  : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
816 public:
817  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
818  GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
819  : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
820  iteration(iteration) {}
821  static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
822  void print(raw_ostream &os) const override {
823  os << "GreedyPatternRewriteIteration(" << iteration << ")";
824  }
825 
826 private:
827  int64_t iteration = 0;
828 };
829 } // namespace
830 
831 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
832  bool continueRewrites = false;
833  int64_t iteration = 0;
834  MLIRContext *ctx = rewriter.getContext();
835  do {
836  // Check if the iteration limit was reached.
837  if (++iteration > config.getMaxIterations() &&
838  config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
839  break;
840 
841  // New iteration: start with an empty worklist.
842  worklist.clear();
843 
844  // `OperationFolder` CSE's constant ops (and may move them into parents
845  // regions to enable more aggressive CSE'ing).
846  OperationFolder folder(ctx, this);
847  auto insertKnownConstant = [&](Operation *op) {
848  // Check for existing constants when populating the worklist. This avoids
849  // accidentally reversing the constant order during processing.
850  Attribute constValue;
851  if (matchPattern(op, m_Constant(&constValue)))
852  if (!folder.insertKnownConstant(op, constValue))
853  return true;
854  return false;
855  };
856 
857  if (!config.getUseTopDownTraversal()) {
858  // Add operations to the worklist in postorder.
859  region.walk([&](Operation *op) {
860  if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
861  addToWorklist(op);
862  });
863  } else {
864  // Add all nested operations to the worklist in preorder.
865  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
866  if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
867  addToWorklist(op);
868  return WalkResult::advance();
869  }
870  return WalkResult::skip();
871  });
872 
873  // Reverse the list so our pop-back loop processes them in-order.
874  worklist.reverse();
875  }
876 
877  ctx->executeAction<GreedyPatternRewriteIteration>(
878  [&] {
879  continueRewrites = false;
880 
881  // Erase unreachable blocks
882  // Operations like:
883  // %add = arith.addi %add, %add : i64
884  // are legal in unreachable code. Unfortunately many patterns would be
885  // unsafe to apply on such IR and can lead to crashes or infinite
886  // loops.
887  continueRewrites |=
888  succeeded(eraseUnreachableBlocks(rewriter, region));
889 
890  continueRewrites |= processWorklist();
891 
892  // After applying patterns, make sure that the CFG of each of the
893  // regions is kept up to date.
894  if (config.getRegionSimplificationLevel() !=
895  GreedySimplifyRegionLevel::Disabled) {
896  continueRewrites |= succeeded(simplifyRegions(
897  rewriter, region,
898  /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
899  GreedySimplifyRegionLevel::Aggressive));
900  }
901  },
902  {&region}, iteration);
903  } while (continueRewrites);
904 
905  if (changed)
906  *changed = iteration > 1;
907 
908  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
909  return success(!continueRewrites);
910 }
911 
912 LogicalResult
916  // The top-level operation must be known to be isolated from above to
917  // prevent performing canonicalizations on operations defined at or above
918  // the region containing 'op'.
919  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
920  "patterns can only be applied to operations IsolatedFromAbove");
921 
922  // Set scope if not specified.
923  if (!config.getScope())
924  config.setScope(&region);
925 
926 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
927  if (failed(verify(config.getScope()->getParentOp())))
928  llvm::report_fatal_error(
929  "greedy pattern rewriter input IR failed to verify");
930 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
931 
932  // Start the pattern driver.
933  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
934  region);
935  LogicalResult converged = std::move(driver).simplify(changed);
936  if (failed(converged))
937  LDBG() << "The pattern rewrite did not converge after scanning "
938  << config.getMaxIterations() << " times";
939  return converged;
940 }
941 
942 //===----------------------------------------------------------------------===//
943 // MultiOpPatternRewriteDriver
944 //===----------------------------------------------------------------------===//
945 
946 namespace {
947 /// This driver simplfies a list of ops.
948 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
949 public:
950  explicit MultiOpPatternRewriteDriver(
953  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
954 
955  /// Simplify `ops`. Return `success` if the transformation converged.
956  LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
957 
958 private:
959  void notifyOperationErased(Operation *op) override {
960  GreedyPatternRewriteDriver::notifyOperationErased(op);
961  if (survivingOps)
962  survivingOps->erase(op);
963  }
964 
965  /// An optional set of ops that survived the rewrite. This set is populated
966  /// at the beginning of `simplifyLocally` with the inititally provided list
967  /// of ops.
968  llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
969 };
970 } // namespace
971 
972 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
975  llvm::SmallDenseSet<Operation *, 4> *survivingOps)
976  : GreedyPatternRewriteDriver(ctx, patterns, config),
977  survivingOps(survivingOps) {
978  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
979  strictModeFilteredOps.insert_range(ops);
980 
981  if (survivingOps) {
982  survivingOps->clear();
983  survivingOps->insert_range(ops);
984  }
985 }
986 
987 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
988  bool *changed) && {
989  // Populate the initial worklist.
990  for (Operation *op : ops)
991  addSingleOpToWorklist(op);
992 
993  // Process ops on the worklist.
994  bool result = processWorklist();
995  if (changed)
996  *changed = result;
997 
998  return success(worklist.empty());
999 }
1000 
1001 /// Find the region that is the closest common ancestor of all given ops.
1002 ///
1003 /// Note: This function returns `nullptr` if there is a top-level op among the
1004 /// given list of ops.
1006  assert(!ops.empty() && "expected at least one op");
1007  // Fast path in case there is only one op.
1008  if (ops.size() == 1)
1009  return ops.front()->getParentRegion();
1010 
1011  Region *region = ops.front()->getParentRegion();
1012  ops = ops.drop_front();
1013  int sz = ops.size();
1014  llvm::BitVector remainingOps(sz, true);
1015  while (region) {
1016  int pos = -1;
1017  // Iterate over all remaining ops.
1018  while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1019  // Is this op contained in `region`?
1020  if (region->findAncestorOpInRegion(*ops[pos]))
1021  remainingOps.reset(pos);
1022  }
1023  if (remainingOps.none())
1024  break;
1025  region = region->getParentRegion();
1026  }
1027  return region;
1028 }
1029 
1032  GreedyRewriteConfig config, bool *changed, bool *allErased) {
1033  if (ops.empty()) {
1034  if (changed)
1035  *changed = false;
1036  if (allErased)
1037  *allErased = true;
1038  return success();
1039  }
1040 
1041  // Determine scope of rewrite.
1042  if (!config.getScope()) {
1043  // Compute scope if none was provided. The scope will remain `nullptr` if
1044  // there is a top-level op among `ops`.
1045  config.setScope(findCommonAncestor(ops));
1046  } else {
1047  // If a scope was provided, make sure that all ops are in scope.
1048 #ifndef NDEBUG
1049  bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1050  return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
1051  });
1052  assert(allOpsInScope && "ops must be within the specified scope");
1053 #endif // NDEBUG
1054  }
1055 
1056 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1057  if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
1058  llvm::report_fatal_error(
1059  "greedy pattern rewriter input IR failed to verify");
1060 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1061 
1062  // Start the pattern driver.
1063  llvm::SmallDenseSet<Operation *, 4> surviving;
1064  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1065  config, ops,
1066  allErased ? &surviving : nullptr);
1067  LogicalResult converged = std::move(driver).simplify(ops, changed);
1068  if (allErased)
1069  *allErased = surviving.empty();
1070  if (failed(converged))
1071  LDBG() << "The pattern rewrite did not converge after "
1072  << config.getMaxNumRewrites() << " rewrites";
1073  return converged;
1074 }
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
#define DEBUG_TYPE
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:274
This class represents a saved insertion point.
Definition: Builders.h:325
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition: Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
BlockListType::iterator iterator
Definition: Region.h:52
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
CRTP Implementation of an action.
Definition: Action.h:76
virtual void print(raw_ostream &os) const
Definition: Action.h:50
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:427
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:440
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:453
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:436