MLIR  19.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Verifier.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32 #include <random>
33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34 
35 using namespace mlir;
36 
37 #define DEBUG_TYPE "greedy-rewriter"
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // Debugging Infrastructure
43 //===----------------------------------------------------------------------===//
44 
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46 /// A helper struct that performs various "expensive checks" to detect broken
47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48 /// broken if:
49 /// * IR does not verify after pattern application / folding.
50 /// * Pattern returns "failure" but the IR has changed.
51 /// * Pattern returns "success" but the IR has not changed.
52 ///
53 /// This struct stores finger prints of ops to determine whether the IR has
54 /// changed or not.
55 struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56  ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57  : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58 
59  /// Compute finger prints of the given op and its nested ops.
60  void computeFingerPrints(Operation *topLevel) {
61  this->topLevel = topLevel;
62  this->topLevelFingerPrint.emplace(topLevel);
63  topLevel->walk([&](Operation *op) {
64  fingerprints.try_emplace(op, op, /*includeNested=*/false);
65  });
66  }
67 
68  /// Clear all finger prints.
69  void clear() {
70  topLevel = nullptr;
71  topLevelFingerPrint.reset();
72  fingerprints.clear();
73  }
74 
75  void notifyRewriteSuccess() {
76  if (!topLevel)
77  return;
78 
79  // Make sure that the IR still verifies.
80  if (failed(verify(topLevel)))
81  llvm::report_fatal_error("IR failed to verify after pattern application");
82 
83  // Pattern application success => IR must have changed.
84  OperationFingerPrint afterFingerPrint(topLevel);
85  if (*topLevelFingerPrint == afterFingerPrint) {
86  // Note: Run "mlir-opt -debug" to see which pattern is broken.
87  llvm::report_fatal_error(
88  "pattern returned success but IR did not change");
89  }
90  for (const auto &it : fingerprints) {
91  // Skip top-level op, its finger print is never invalidated.
92  if (it.first == topLevel)
93  continue;
94  // Note: Finger print computation may crash when an op was erased
95  // without notifying the rewriter. (Run with ASAN to see where the op was
96  // erased; the op was probably erased directly, bypassing the rewriter
97  // API.) Finger print computation does may not crash if a new op was
98  // created at the same memory location. (But then the finger print should
99  // have changed.)
100  if (it.second !=
101  OperationFingerPrint(it.first, /*includeNested=*/false)) {
102  // Note: Run "mlir-opt -debug" to see which pattern is broken.
103  llvm::report_fatal_error("operation finger print changed");
104  }
105  }
106  }
107 
108  void notifyRewriteFailure() {
109  if (!topLevel)
110  return;
111 
112  // Pattern application failure => IR must not have changed.
113  OperationFingerPrint afterFingerPrint(topLevel);
114  if (*topLevelFingerPrint != afterFingerPrint) {
115  // Note: Run "mlir-opt -debug" to see which pattern is broken.
116  llvm::report_fatal_error("pattern returned failure but IR did change");
117  }
118  }
119 
120  void notifyFoldingSuccess() {
121  if (!topLevel)
122  return;
123 
124  // Make sure that the IR still verifies.
125  if (failed(verify(topLevel)))
126  llvm::report_fatal_error("IR failed to verify after folding");
127  }
128 
129 protected:
130  /// Invalidate the finger print of the given op, i.e., remove it from the map.
131  void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
132 
133  void notifyBlockErased(Block *block) override {
135 
136  // The block structure (number of blocks, types of block arguments, etc.)
137  // is part of the fingerprint of the parent op.
138  // TODO: The parent op fingerprint should also be invalidated when modifying
139  // the block arguments of a block, but we do not have a
140  // `notifyBlockModified` callback yet.
141  invalidateFingerPrint(block->getParentOp());
142  }
143 
144  void notifyOperationInserted(Operation *op,
145  OpBuilder::InsertPoint previous) override {
147  invalidateFingerPrint(op->getParentOp());
148  }
149 
150  void notifyOperationModified(Operation *op) override {
152  invalidateFingerPrint(op);
153  }
154 
155  void notifyOperationErased(Operation *op) override {
157  op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158  }
159 
160  /// Operation finger prints to detect invalid pattern API usage. IR is checked
161  /// against these finger prints after pattern application to detect cases
162  /// where IR was modified directly, bypassing the rewriter API.
164 
165  /// Top-level operation of the current greedy rewrite.
166  Operation *topLevel = nullptr;
167 
168  /// Finger print of the top-level operation.
169  std::optional<OperationFingerPrint> topLevelFingerPrint;
170 };
171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172 
173 #ifndef NDEBUG
174 static Operation *getDumpRootOp(Operation *op) {
175  // Dump the parent op so that materialized constants are visible. If the op
176  // is a top-level op, dump it directly.
177  if (Operation *parentOp = op->getParentOp())
178  return parentOp;
179  return op;
180 }
181 static void logSuccessfulFolding(Operation *op) {
182  llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183  op->dump();
184  llvm::dbgs() << "\n\n";
185 }
186 #endif // NDEBUG
187 
188 //===----------------------------------------------------------------------===//
189 // Worklist
190 //===----------------------------------------------------------------------===//
191 
192 /// A LIFO worklist of operations with efficient removal and set semantics.
193 ///
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
198 class Worklist {
199 public:
200  Worklist();
201 
202  /// Clear the worklist.
203  void clear();
204 
205  /// Return whether the worklist is empty.
206  bool empty() const;
207 
208  /// Push an operation to the end of the worklist, unless the operation is
209  /// already on the worklist.
210  void push(Operation *op);
211 
212  /// Pop the an operation from the end of the worklist. Only allowed on
213  /// non-empty worklists.
214  Operation *pop();
215 
216  /// Remove an operation from the worklist.
217  void remove(Operation *op);
218 
219  /// Reverse the worklist.
220  void reverse();
221 
222 protected:
223  /// The worklist of operations.
224  std::vector<Operation *> list;
225 
226  /// A mapping of operations to positions in `list`.
228 };
229 
230 Worklist::Worklist() { list.reserve(64); }
231 
232 void Worklist::clear() {
233  list.clear();
234  map.clear();
235 }
236 
237 bool Worklist::empty() const {
238  // Skip all nullptr.
239  return !llvm::any_of(list,
240  [](Operation *op) { return static_cast<bool>(op); });
241 }
242 
243 void Worklist::push(Operation *op) {
244  assert(op && "cannot push nullptr to worklist");
245  // Check to see if the worklist already contains this op.
246  if (!map.insert({op, list.size()}).second)
247  return;
248  list.push_back(op);
249 }
250 
251 Operation *Worklist::pop() {
252  assert(!empty() && "cannot pop from empty worklist");
253  // Skip and remove all trailing nullptr.
254  while (!list.back())
255  list.pop_back();
256  Operation *op = list.back();
257  list.pop_back();
258  map.erase(op);
259  // Cleanup: Remove all trailing nullptr.
260  while (!list.empty() && !list.back())
261  list.pop_back();
262  return op;
263 }
264 
265 void Worklist::remove(Operation *op) {
266  assert(op && "cannot remove nullptr from worklist");
267  auto it = map.find(op);
268  if (it != map.end()) {
269  assert(list[it->second] == op && "malformed worklist data structure");
270  list[it->second] = nullptr;
271  map.erase(it);
272  }
273 }
274 
275 void Worklist::reverse() {
276  std::reverse(list.begin(), list.end());
277  for (size_t i = 0, e = list.size(); i != e; ++i)
278  map[list[i]] = i;
279 }
280 
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist : public Worklist {
287 public:
288  RandomizedWorklist() : Worklist() {
289  generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290  }
291 
292  /// Pop a random non-empty op from the worklist.
293  Operation *pop() {
294  Operation *op = nullptr;
295  do {
296  assert(!list.empty() && "cannot pop from empty worklist");
297  int64_t pos = generator() % list.size();
298  op = list[pos];
299  list.erase(list.begin() + pos);
300  for (int64_t i = pos, e = list.size(); i < e; ++i)
301  map[list[i]] = i;
302  map.erase(op);
303  } while (!op);
304  return op;
305  }
306 
307 private:
308  std::minstd_rand0 generator;
309 };
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311 
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
315 
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
318 ///
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver : public PatternRewriter,
323  public RewriterBase::Listener {
324 protected:
325  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
326  const FrozenRewritePatternSet &patterns,
327  const GreedyRewriteConfig &config);
328 
329  /// Add the given operation to the worklist.
330  void addSingleOpToWorklist(Operation *op);
331 
332  /// Add the given operation and its ancestors to the worklist.
333  void addToWorklist(Operation *op);
334 
335  /// Notify the driver that the specified operation may have been modified
336  /// in-place. The operation is added to the worklist.
337  void notifyOperationModified(Operation *op) override;
338 
339  /// Notify the driver that the specified operation was inserted. Update the
340  /// worklist as needed: The operation is enqueued depending on scope and
341  /// strict mode.
342  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
343 
344  /// Notify the driver that the specified operation was removed. Update the
345  /// worklist as needed: The operation and its children are removed from the
346  /// worklist.
347  void notifyOperationErased(Operation *op) override;
348 
349  /// Notify the driver that the specified operation was replaced. Update the
350  /// worklist as needed: New users are added enqueued.
351  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352 
353  /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354  /// reached. Return `true` if any IR was changed.
355  bool processWorklist();
356 
357  /// The worklist for this transformation keeps track of the operations that
358  /// need to be (re)visited.
359 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
360  RandomizedWorklist worklist;
361 #else
362  Worklist worklist;
363 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 
365  /// Configuration information for how to simplify.
366  const GreedyRewriteConfig config;
367 
368  /// The list of ops we are restricting our rewrites to. These include the
369  /// supplied set of ops as well as new ops created while rewriting those ops
370  /// depending on `strictMode`. This set is not maintained when
371  /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
372  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
373 
374 private:
375  /// Look over the provided operands for any defining operations that should
376  /// be re-added to the worklist. This function should be called when an
377  /// operation is modified or removed, as it may trigger further
378  /// simplifications.
379  void addOperandsToWorklist(Operation *op);
380 
381  /// Notify the driver that the given block was inserted.
382  void notifyBlockInserted(Block *block, Region *previous,
383  Region::iterator previousIt) override;
384 
385  /// Notify the driver that the given block is about to be removed.
386  void notifyBlockErased(Block *block) override;
387 
388  /// For debugging only: Notify the driver of a pattern match failure.
389  void
390  notifyMatchFailure(Location loc,
391  function_ref<void(Diagnostic &)> reasonCallback) override;
392 
393 #ifndef NDEBUG
394  /// A logger used to emit information during the application process.
395  llvm::ScopedPrinter logger{llvm::dbgs()};
396 #endif
397 
398  /// The low-level pattern applicator.
399  PatternApplicator matcher;
400 
401 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
402  ExpensiveChecks expensiveChecks;
403 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
404 };
405 } // namespace
406 
407 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
408  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
409  const GreedyRewriteConfig &config)
410  : PatternRewriter(ctx), config(config), matcher(patterns)
411 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412  // clang-format off
413  , expensiveChecks(
414  /*driver=*/this,
415  /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
416 // clang-format on
417 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
418 {
419  // Apply a simple cost model based solely on pattern benefit.
420  matcher.applyDefaultCostModel();
421 
422  // Set up listener.
423 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
424  // Send IR notifications to the debug handler. This handler will then forward
425  // all notifications to this GreedyPatternRewriteDriver.
426  setListener(&expensiveChecks);
427 #else
428  setListener(this);
429 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
430 }
431 
432 bool GreedyPatternRewriteDriver::processWorklist() {
433 #ifndef NDEBUG
434  const char *logLineComment =
435  "//===-------------------------------------------===//\n";
436 
437  /// A utility function to log a process result for the given reason.
438  auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
439  logger.unindent();
440  logger.startLine() << "} -> " << result;
441  if (!msg.isTriviallyEmpty())
442  logger.getOStream() << " : " << msg;
443  logger.getOStream() << "\n";
444  };
445  auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
446  logResult(result, msg);
447  logger.startLine() << logLineComment;
448  };
449 #endif
450 
451  bool changed = false;
452  int64_t numRewrites = 0;
453  while (!worklist.empty() &&
454  (numRewrites < config.maxNumRewrites ||
455  config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
456  auto *op = worklist.pop();
457 
458  LLVM_DEBUG({
459  logger.getOStream() << "\n";
460  logger.startLine() << logLineComment;
461  logger.startLine() << "Processing operation : '" << op->getName() << "'("
462  << op << ") {\n";
463  logger.indent();
464 
465  // If the operation has no regions, just print it here.
466  if (op->getNumRegions() == 0) {
467  op->print(
468  logger.startLine(),
469  OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
470  logger.getOStream() << "\n\n";
471  }
472  });
473 
474  // If the operation is trivially dead - remove it.
475  if (isOpTriviallyDead(op)) {
476  eraseOp(op);
477  changed = true;
478 
479  LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
480  continue;
481  }
482 
483  // Try to fold this op. Do not fold constant ops. That would lead to an
484  // infinite folding loop, as every constant op would be folded to an
485  // Attribute and then immediately be rematerialized as a constant op, which
486  // is then put on the worklist.
487  if (!op->hasTrait<OpTrait::ConstantLike>()) {
488  SmallVector<OpFoldResult> foldResults;
489  if (succeeded(op->fold(foldResults))) {
490  LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
491 #ifndef NDEBUG
492  Operation *dumpRootOp = getDumpRootOp(op);
493 #endif // NDEBUG
494  if (foldResults.empty()) {
495  // Op was modified in-place.
496  notifyOperationModified(op);
497  changed = true;
498  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
499 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
500  expensiveChecks.notifyFoldingSuccess();
501 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
502  continue;
503  }
504 
505  // Op results can be replaced with `foldResults`.
506  assert(foldResults.size() == op->getNumResults() &&
507  "folder produced incorrect number of results");
508  OpBuilder::InsertionGuard g(*this);
509  setInsertionPoint(op);
510  SmallVector<Value> replacements;
511  bool materializationSucceeded = true;
512  for (auto [ofr, resultType] :
513  llvm::zip_equal(foldResults, op->getResultTypes())) {
514  if (auto value = ofr.dyn_cast<Value>()) {
515  assert(value.getType() == resultType &&
516  "folder produced value of incorrect type");
517  replacements.push_back(value);
518  continue;
519  }
520  // Materialize Attributes as SSA values.
521  Operation *constOp = op->getDialect()->materializeConstant(
522  *this, ofr.get<Attribute>(), resultType, op->getLoc());
523 
524  if (!constOp) {
525  // If materialization fails, cleanup any operations generated for
526  // the previous results.
527  llvm::SmallDenseSet<Operation *> replacementOps;
528  for (Value replacement : replacements) {
529  assert(replacement.use_empty() &&
530  "folder reused existing op for one result but constant "
531  "materialization failed for another result");
532  replacementOps.insert(replacement.getDefiningOp());
533  }
534  for (Operation *op : replacementOps) {
535  eraseOp(op);
536  }
537 
538  materializationSucceeded = false;
539  break;
540  }
541 
542  assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
543  "materializeConstant produced op that is not a ConstantLike");
544  assert(constOp->getResultTypes()[0] == resultType &&
545  "materializeConstant produced incorrect result type");
546  replacements.push_back(constOp->getResult(0));
547  }
548 
549  if (materializationSucceeded) {
550  replaceOp(op, replacements);
551  changed = true;
552  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
553 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
554  expensiveChecks.notifyFoldingSuccess();
555 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
556  continue;
557  }
558  }
559  }
560 
561  // Try to match one of the patterns. The rewriter is automatically
562  // notified of any necessary changes, so there is nothing else to do
563  // here.
564  auto canApplyCallback = [&](const Pattern &pattern) {
565  LLVM_DEBUG({
566  logger.getOStream() << "\n";
567  logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
568  << op->getName() << " -> (";
569  llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
570  logger.getOStream() << ")' {\n";
571  logger.indent();
572  });
573  if (config.listener)
574  config.listener->notifyPatternBegin(pattern, op);
575  return true;
576  };
577  function_ref<bool(const Pattern &)> canApply = canApplyCallback;
578  auto onFailureCallback = [&](const Pattern &pattern) {
579  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
580  if (config.listener)
581  config.listener->notifyPatternEnd(pattern, failure());
582  };
583  function_ref<void(const Pattern &)> onFailure = onFailureCallback;
584  auto onSuccessCallback = [&](const Pattern &pattern) {
585  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
586  if (config.listener)
587  config.listener->notifyPatternEnd(pattern, success());
588  return success();
589  };
590  function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
591 
592 #ifdef NDEBUG
593  // Optimization: PatternApplicator callbacks are not needed when running in
594  // optimized mode and without a listener.
595  if (!config.listener) {
596  canApply = nullptr;
597  onFailure = nullptr;
598  onSuccess = nullptr;
599  }
600 #endif // NDEBUG
601 
602 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
603  if (config.scope) {
604  expensiveChecks.computeFingerPrints(config.scope->getParentOp());
605  }
606  auto clearFingerprints =
607  llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
608 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
609 
610  LogicalResult matchResult =
611  matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
612 
613  if (succeeded(matchResult)) {
614  LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
615 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
616  expensiveChecks.notifyRewriteSuccess();
617 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
618  changed = true;
619  ++numRewrites;
620  } else {
621  LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
622 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
623  expensiveChecks.notifyRewriteFailure();
624 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
625  }
626  }
627 
628  return changed;
629 }
630 
631 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
632  assert(op && "expected valid op");
633  // Gather potential ancestors while looking for a "scope" parent region.
634  SmallVector<Operation *, 8> ancestors;
635  Region *region = nullptr;
636  do {
637  ancestors.push_back(op);
638  region = op->getParentRegion();
639  if (config.scope == region) {
640  // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
641  for (Operation *op : ancestors)
642  addSingleOpToWorklist(op);
643  return;
644  }
645  if (region == nullptr)
646  return;
647  } while ((op = region->getParentOp()));
648 }
649 
650 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
651  if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
652  strictModeFilteredOps.contains(op))
653  worklist.push(op);
654 }
655 
656 void GreedyPatternRewriteDriver::notifyBlockInserted(
657  Block *block, Region *previous, Region::iterator previousIt) {
658  if (config.listener)
659  config.listener->notifyBlockInserted(block, previous, previousIt);
660 }
661 
662 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
663  if (config.listener)
664  config.listener->notifyBlockErased(block);
665 }
666 
667 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
668  InsertPoint previous) {
669  LLVM_DEBUG({
670  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
671  << ")\n";
672  });
673  if (config.listener)
674  config.listener->notifyOperationInserted(op, previous);
675  if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
676  strictModeFilteredOps.insert(op);
677  addToWorklist(op);
678 }
679 
680 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
681  LLVM_DEBUG({
682  logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
683  << ")\n";
684  });
685  if (config.listener)
686  config.listener->notifyOperationModified(op);
687  addToWorklist(op);
688 }
689 
690 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
691  for (Value operand : op->getOperands()) {
692  // If this operand currently has at most 2 users, add its defining op to the
693  // worklist. Indeed, after the op is deleted, then the operand will have at
694  // most 1 user left. If it has 0 users left, it can be deleted too,
695  // and if it has 1 user left, there may be further canonicalization
696  // opportunities.
697  if (!operand)
698  continue;
699 
700  auto *defOp = operand.getDefiningOp();
701  if (!defOp)
702  continue;
703 
704  Operation *otherUser = nullptr;
705  bool hasMoreThanTwoUses = false;
706  for (auto user : operand.getUsers()) {
707  if (user == op || user == otherUser)
708  continue;
709  if (!otherUser) {
710  otherUser = user;
711  continue;
712  }
713  hasMoreThanTwoUses = true;
714  break;
715  }
716  if (hasMoreThanTwoUses)
717  continue;
718 
719  addToWorklist(defOp);
720  }
721 }
722 
723 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
724  LLVM_DEBUG({
725  logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
726  << ")\n";
727  });
728 
729 #ifndef NDEBUG
730  // Only ops that are within the configured scope are added to the worklist of
731  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
732  // the part of the IR that is taken into account for the "expensive checks".
733  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
734  // region, as that would break the worklist handling and the expensive checks.
735  if (config.scope && config.scope->getParentOp() == op)
736  llvm_unreachable(
737  "scope region must not be erased during greedy pattern rewrite");
738 #endif // NDEBUG
739 
740  if (config.listener)
741  config.listener->notifyOperationErased(op);
742 
743  addOperandsToWorklist(op);
744  worklist.remove(op);
745 
746  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
747  strictModeFilteredOps.erase(op);
748 }
749 
750 void GreedyPatternRewriteDriver::notifyOperationReplaced(
751  Operation *op, ValueRange replacement) {
752  LLVM_DEBUG({
753  logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
754  << ")\n";
755  });
756  if (config.listener)
757  config.listener->notifyOperationReplaced(op, replacement);
758 }
759 
760 void GreedyPatternRewriteDriver::notifyMatchFailure(
761  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
762  LLVM_DEBUG({
763  Diagnostic diag(loc, DiagnosticSeverity::Remark);
764  reasonCallback(diag);
765  logger.startLine() << "** Match Failure : " << diag.str() << "\n";
766  });
767  if (config.listener)
768  config.listener->notifyMatchFailure(loc, reasonCallback);
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // RegionPatternRewriteDriver
773 //===----------------------------------------------------------------------===//
774 
775 namespace {
776 /// This driver simplfies all ops in a region.
777 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
778 public:
779  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
780  const FrozenRewritePatternSet &patterns,
781  const GreedyRewriteConfig &config,
782  Region &regions);
783 
784  /// Simplify ops inside `region` and simplify the region itself. Return
785  /// success if the transformation converged.
786  LogicalResult simplify(bool *changed) &&;
787 
788 private:
789  /// The region that is simplified.
790  Region &region;
791 };
792 } // namespace
793 
794 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
795  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
796  const GreedyRewriteConfig &config, Region &region)
797  : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
798  // Populate strict mode ops.
799  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
800  region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
801  }
802 }
803 
804 namespace {
805 class GreedyPatternRewriteIteration
806  : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
807 public:
808  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
809  GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
810  : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
811  iteration(iteration) {}
812  static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
813  void print(raw_ostream &os) const override {
814  os << "GreedyPatternRewriteIteration(" << iteration << ")";
815  }
816 
817 private:
818  int64_t iteration = 0;
819 };
820 } // namespace
821 
822 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
823  bool continueRewrites = false;
824  int64_t iteration = 0;
825  MLIRContext *ctx = getContext();
826  do {
827  // Check if the iteration limit was reached.
828  if (++iteration > config.maxIterations &&
829  config.maxIterations != GreedyRewriteConfig::kNoLimit)
830  break;
831 
832  // New iteration: start with an empty worklist.
833  worklist.clear();
834 
835  // `OperationFolder` CSE's constant ops (and may move them into parents
836  // regions to enable more aggressive CSE'ing).
837  OperationFolder folder(getContext(), this);
838  auto insertKnownConstant = [&](Operation *op) {
839  // Check for existing constants when populating the worklist. This avoids
840  // accidentally reversing the constant order during processing.
841  Attribute constValue;
842  if (matchPattern(op, m_Constant(&constValue)))
843  if (!folder.insertKnownConstant(op, constValue))
844  return true;
845  return false;
846  };
847 
848  if (!config.useTopDownTraversal) {
849  // Add operations to the worklist in postorder.
850  region.walk([&](Operation *op) {
851  if (!insertKnownConstant(op))
852  addToWorklist(op);
853  });
854  } else {
855  // Add all nested operations to the worklist in preorder.
856  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
857  if (!insertKnownConstant(op)) {
858  addToWorklist(op);
859  return WalkResult::advance();
860  }
861  return WalkResult::skip();
862  });
863 
864  // Reverse the list so our pop-back loop processes them in-order.
865  worklist.reverse();
866  }
867 
868  ctx->executeAction<GreedyPatternRewriteIteration>(
869  [&] {
870  continueRewrites = processWorklist();
871 
872  // After applying patterns, make sure that the CFG of each of the
873  // regions is kept up to date.
874  if (config.enableRegionSimplification)
875  continueRewrites |= succeeded(simplifyRegions(*this, region));
876  },
877  {&region}, iteration);
878  } while (continueRewrites);
879 
880  if (changed)
881  *changed = iteration > 1;
882 
883  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
884  return success(!continueRewrites);
885 }
886 
889  const FrozenRewritePatternSet &patterns,
890  GreedyRewriteConfig config, bool *changed) {
891  // The top-level operation must be known to be isolated from above to
892  // prevent performing canonicalizations on operations defined at or above
893  // the region containing 'op'.
894  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
895  "patterns can only be applied to operations IsolatedFromAbove");
896 
897  // Set scope if not specified.
898  if (!config.scope)
899  config.scope = &region;
900 
901 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
902  if (failed(verify(config.scope->getParentOp())))
903  llvm::report_fatal_error(
904  "greedy pattern rewriter input IR failed to verify");
905 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
906 
907  // Start the pattern driver.
908  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
909  region);
910  LogicalResult converged = std::move(driver).simplify(changed);
911  LLVM_DEBUG(if (failed(converged)) {
912  llvm::dbgs() << "The pattern rewrite did not converge after scanning "
913  << config.maxIterations << " times\n";
914  });
915  return converged;
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // MultiOpPatternRewriteDriver
920 //===----------------------------------------------------------------------===//
921 
922 namespace {
923 /// This driver simplfies a list of ops.
924 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
925 public:
926  explicit MultiOpPatternRewriteDriver(
927  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
928  const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
929  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
930 
931  /// Simplify `ops`. Return `success` if the transformation converged.
932  LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
933 
934 private:
935  void notifyOperationErased(Operation *op) override {
936  GreedyPatternRewriteDriver::notifyOperationErased(op);
937  if (survivingOps)
938  survivingOps->erase(op);
939  }
940 
941  /// An optional set of ops that survived the rewrite. This set is populated
942  /// at the beginning of `simplifyLocally` with the inititally provided list
943  /// of ops.
944  llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
945 };
946 } // namespace
947 
948 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
949  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
950  const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
951  llvm::SmallDenseSet<Operation *, 4> *survivingOps)
952  : GreedyPatternRewriteDriver(ctx, patterns, config),
953  survivingOps(survivingOps) {
955  strictModeFilteredOps.insert(ops.begin(), ops.end());
956 
957  if (survivingOps) {
958  survivingOps->clear();
959  survivingOps->insert(ops.begin(), ops.end());
960  }
961 }
962 
963 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
964  bool *changed) && {
965  // Populate the initial worklist.
966  for (Operation *op : ops)
967  addSingleOpToWorklist(op);
968 
969  // Process ops on the worklist.
970  bool result = processWorklist();
971  if (changed)
972  *changed = result;
973 
974  return success(worklist.empty());
975 }
976 
977 /// Find the region that is the closest common ancestor of all given ops.
978 ///
979 /// Note: This function returns `nullptr` if there is a top-level op among the
980 /// given list of ops.
982  assert(!ops.empty() && "expected at least one op");
983  // Fast path in case there is only one op.
984  if (ops.size() == 1)
985  return ops.front()->getParentRegion();
986 
987  Region *region = ops.front()->getParentRegion();
988  ops = ops.drop_front();
989  int sz = ops.size();
990  llvm::BitVector remainingOps(sz, true);
991  while (region) {
992  int pos = -1;
993  // Iterate over all remaining ops.
994  while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
995  // Is this op contained in `region`?
996  if (region->findAncestorOpInRegion(*ops[pos]))
997  remainingOps.reset(pos);
998  }
999  if (remainingOps.none())
1000  break;
1001  region = region->getParentRegion();
1002  }
1003  return region;
1004 }
1005 
1007  ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
1008  GreedyRewriteConfig config, bool *changed, bool *allErased) {
1009  if (ops.empty()) {
1010  if (changed)
1011  *changed = false;
1012  if (allErased)
1013  *allErased = true;
1014  return success();
1015  }
1016 
1017  // Determine scope of rewrite.
1018  if (!config.scope) {
1019  // Compute scope if none was provided. The scope will remain `nullptr` if
1020  // there is a top-level op among `ops`.
1021  config.scope = findCommonAncestor(ops);
1022  } else {
1023  // If a scope was provided, make sure that all ops are in scope.
1024 #ifndef NDEBUG
1025  bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1026  return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
1027  });
1028  assert(allOpsInScope && "ops must be within the specified scope");
1029 #endif // NDEBUG
1030  }
1031 
1032 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1033  if (config.scope && failed(verify(config.scope->getParentOp())))
1034  llvm::report_fatal_error(
1035  "greedy pattern rewriter input IR failed to verify");
1036 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1037 
1038  // Start the pattern driver.
1039  llvm::SmallDenseSet<Operation *, 4> surviving;
1040  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1041  config, ops,
1042  allErased ? &surviving : nullptr);
1043  LogicalResult converged = std::move(driver).simplify(ops, changed);
1044  if (allErased)
1045  *allErased = surviving.empty();
1046  LLVM_DEBUG(if (failed(converged)) {
1047  llvm::dbgs() << "The pattern rewrite did not converge after "
1048  << config.maxNumRewrites << " rewrites";
1049  });
1050  return converged;
1051 }
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
Region * scope
Only ops within the scope are added to the worklist.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
bool enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:259
This class represents a saved insertion point.
Definition: Builders.h:329
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:632
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition: Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
BlockListType::iterator iterator
Definition: Region.h:52
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
CRTP Implementation of an action.
Definition: Action.h:77
virtual void print(raw_ostream &os) const
Definition: Action.h:51
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Run a set of structural simplifications over the given regions.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ AnyOp
No restrictions wrt. which ops are processed.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:300
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:466
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:477
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:490
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:473
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:454
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:411
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:435
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:419
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:408
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:446
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:439