MLIR  20.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Verifier.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32 #include <random>
33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34 
35 using namespace mlir;
36 
37 #define DEBUG_TYPE "greedy-rewriter"
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // Debugging Infrastructure
43 //===----------------------------------------------------------------------===//
44 
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46 /// A helper struct that performs various "expensive checks" to detect broken
47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48 /// broken if:
49 /// * IR does not verify after pattern application / folding.
50 /// * Pattern returns "failure" but the IR has changed.
51 /// * Pattern returns "success" but the IR has not changed.
52 ///
53 /// This struct stores finger prints of ops to determine whether the IR has
54 /// changed or not.
55 struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56  ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57  : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58 
59  /// Compute finger prints of the given op and its nested ops.
60  void computeFingerPrints(Operation *topLevel) {
61  this->topLevel = topLevel;
62  this->topLevelFingerPrint.emplace(topLevel);
63  topLevel->walk([&](Operation *op) {
64  fingerprints.try_emplace(op, op, /*includeNested=*/false);
65  });
66  }
67 
68  /// Clear all finger prints.
69  void clear() {
70  topLevel = nullptr;
71  topLevelFingerPrint.reset();
72  fingerprints.clear();
73  }
74 
75  void notifyRewriteSuccess() {
76  if (!topLevel)
77  return;
78 
79  // Make sure that the IR still verifies.
80  if (failed(verify(topLevel)))
81  llvm::report_fatal_error("IR failed to verify after pattern application");
82 
83  // Pattern application success => IR must have changed.
84  OperationFingerPrint afterFingerPrint(topLevel);
85  if (*topLevelFingerPrint == afterFingerPrint) {
86  // Note: Run "mlir-opt -debug" to see which pattern is broken.
87  llvm::report_fatal_error(
88  "pattern returned success but IR did not change");
89  }
90  for (const auto &it : fingerprints) {
91  // Skip top-level op, its finger print is never invalidated.
92  if (it.first == topLevel)
93  continue;
94  // Note: Finger print computation may crash when an op was erased
95  // without notifying the rewriter. (Run with ASAN to see where the op was
96  // erased; the op was probably erased directly, bypassing the rewriter
97  // API.) Finger print computation does may not crash if a new op was
98  // created at the same memory location. (But then the finger print should
99  // have changed.)
100  if (it.second !=
101  OperationFingerPrint(it.first, /*includeNested=*/false)) {
102  // Note: Run "mlir-opt -debug" to see which pattern is broken.
103  llvm::report_fatal_error("operation finger print changed");
104  }
105  }
106  }
107 
108  void notifyRewriteFailure() {
109  if (!topLevel)
110  return;
111 
112  // Pattern application failure => IR must not have changed.
113  OperationFingerPrint afterFingerPrint(topLevel);
114  if (*topLevelFingerPrint != afterFingerPrint) {
115  // Note: Run "mlir-opt -debug" to see which pattern is broken.
116  llvm::report_fatal_error("pattern returned failure but IR did change");
117  }
118  }
119 
120  void notifyFoldingSuccess() {
121  if (!topLevel)
122  return;
123 
124  // Make sure that the IR still verifies.
125  if (failed(verify(topLevel)))
126  llvm::report_fatal_error("IR failed to verify after folding");
127  }
128 
129 protected:
130  /// Invalidate the finger print of the given op, i.e., remove it from the map.
131  void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
132 
133  void notifyBlockErased(Block *block) override {
135 
136  // The block structure (number of blocks, types of block arguments, etc.)
137  // is part of the fingerprint of the parent op.
138  // TODO: The parent op fingerprint should also be invalidated when modifying
139  // the block arguments of a block, but we do not have a
140  // `notifyBlockModified` callback yet.
141  invalidateFingerPrint(block->getParentOp());
142  }
143 
144  void notifyOperationInserted(Operation *op,
145  OpBuilder::InsertPoint previous) override {
147  invalidateFingerPrint(op->getParentOp());
148  }
149 
150  void notifyOperationModified(Operation *op) override {
152  invalidateFingerPrint(op);
153  }
154 
155  void notifyOperationErased(Operation *op) override {
157  op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158  }
159 
160  /// Operation finger prints to detect invalid pattern API usage. IR is checked
161  /// against these finger prints after pattern application to detect cases
162  /// where IR was modified directly, bypassing the rewriter API.
164 
165  /// Top-level operation of the current greedy rewrite.
166  Operation *topLevel = nullptr;
167 
168  /// Finger print of the top-level operation.
169  std::optional<OperationFingerPrint> topLevelFingerPrint;
170 };
171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172 
173 #ifndef NDEBUG
174 static Operation *getDumpRootOp(Operation *op) {
175  // Dump the parent op so that materialized constants are visible. If the op
176  // is a top-level op, dump it directly.
177  if (Operation *parentOp = op->getParentOp())
178  return parentOp;
179  return op;
180 }
181 static void logSuccessfulFolding(Operation *op) {
182  llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183  op->dump();
184  llvm::dbgs() << "\n\n";
185 }
186 #endif // NDEBUG
187 
188 //===----------------------------------------------------------------------===//
189 // Worklist
190 //===----------------------------------------------------------------------===//
191 
192 /// A LIFO worklist of operations with efficient removal and set semantics.
193 ///
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
198 class Worklist {
199 public:
200  Worklist();
201 
202  /// Clear the worklist.
203  void clear();
204 
205  /// Return whether the worklist is empty.
206  bool empty() const;
207 
208  /// Push an operation to the end of the worklist, unless the operation is
209  /// already on the worklist.
210  void push(Operation *op);
211 
212  /// Pop the an operation from the end of the worklist. Only allowed on
213  /// non-empty worklists.
214  Operation *pop();
215 
216  /// Remove an operation from the worklist.
217  void remove(Operation *op);
218 
219  /// Reverse the worklist.
220  void reverse();
221 
222 protected:
223  /// The worklist of operations.
224  std::vector<Operation *> list;
225 
226  /// A mapping of operations to positions in `list`.
228 };
229 
230 Worklist::Worklist() { list.reserve(64); }
231 
232 void Worklist::clear() {
233  list.clear();
234  map.clear();
235 }
236 
237 bool Worklist::empty() const {
238  // Skip all nullptr.
239  return !llvm::any_of(list,
240  [](Operation *op) { return static_cast<bool>(op); });
241 }
242 
243 void Worklist::push(Operation *op) {
244  assert(op && "cannot push nullptr to worklist");
245  // Check to see if the worklist already contains this op.
246  if (!map.insert({op, list.size()}).second)
247  return;
248  list.push_back(op);
249 }
250 
251 Operation *Worklist::pop() {
252  assert(!empty() && "cannot pop from empty worklist");
253  // Skip and remove all trailing nullptr.
254  while (!list.back())
255  list.pop_back();
256  Operation *op = list.back();
257  list.pop_back();
258  map.erase(op);
259  // Cleanup: Remove all trailing nullptr.
260  while (!list.empty() && !list.back())
261  list.pop_back();
262  return op;
263 }
264 
265 void Worklist::remove(Operation *op) {
266  assert(op && "cannot remove nullptr from worklist");
267  auto it = map.find(op);
268  if (it != map.end()) {
269  assert(list[it->second] == op && "malformed worklist data structure");
270  list[it->second] = nullptr;
271  map.erase(it);
272  }
273 }
274 
275 void Worklist::reverse() {
276  std::reverse(list.begin(), list.end());
277  for (size_t i = 0, e = list.size(); i != e; ++i)
278  map[list[i]] = i;
279 }
280 
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist : public Worklist {
287 public:
288  RandomizedWorklist() : Worklist() {
289  generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290  }
291 
292  /// Pop a random non-empty op from the worklist.
293  Operation *pop() {
294  Operation *op = nullptr;
295  do {
296  assert(!list.empty() && "cannot pop from empty worklist");
297  int64_t pos = generator() % list.size();
298  op = list[pos];
299  list.erase(list.begin() + pos);
300  for (int64_t i = pos, e = list.size(); i < e; ++i)
301  map[list[i]] = i;
302  map.erase(op);
303  } while (!op);
304  return op;
305  }
306 
307 private:
308  std::minstd_rand0 generator;
309 };
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311 
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
315 
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
318 ///
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323 protected:
324  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
326  const GreedyRewriteConfig &config);
327 
328  /// Add the given operation to the worklist.
329  void addSingleOpToWorklist(Operation *op);
330 
331  /// Add the given operation and its ancestors to the worklist.
332  void addToWorklist(Operation *op);
333 
334  /// Notify the driver that the specified operation may have been modified
335  /// in-place. The operation is added to the worklist.
336  void notifyOperationModified(Operation *op) override;
337 
338  /// Notify the driver that the specified operation was inserted. Update the
339  /// worklist as needed: The operation is enqueued depending on scope and
340  /// strict mode.
341  void notifyOperationInserted(Operation *op,
342  OpBuilder::InsertPoint previous) override;
343 
344  /// Notify the driver that the specified operation was removed. Update the
345  /// worklist as needed: The operation and its children are removed from the
346  /// worklist.
347  void notifyOperationErased(Operation *op) override;
348 
349  /// Notify the driver that the specified operation was replaced. Update the
350  /// worklist as needed: New users are added enqueued.
351  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352 
353  /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354  /// reached. Return `true` if any IR was changed.
355  bool processWorklist();
356 
357  /// The pattern rewriter that is used for making IR modifications and is
358  /// passed to rewrite patterns.
359  PatternRewriter rewriter;
360 
361  /// The worklist for this transformation keeps track of the operations that
362  /// need to be (re)visited.
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364  RandomizedWorklist worklist;
365 #else
366  Worklist worklist;
367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368 
369  /// Configuration information for how to simplify.
371 
372  /// The list of ops we are restricting our rewrites to. These include the
373  /// supplied set of ops as well as new ops created while rewriting those ops
374  /// depending on `strictMode`. This set is not maintained when
375  /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377 
378 private:
379  /// Look over the provided operands for any defining operations that should
380  /// be re-added to the worklist. This function should be called when an
381  /// operation is modified or removed, as it may trigger further
382  /// simplifications.
383  void addOperandsToWorklist(Operation *op);
384 
385  /// Notify the driver that the given block was inserted.
386  void notifyBlockInserted(Block *block, Region *previous,
387  Region::iterator previousIt) override;
388 
389  /// Notify the driver that the given block is about to be removed.
390  void notifyBlockErased(Block *block) override;
391 
392  /// For debugging only: Notify the driver of a pattern match failure.
393  void
394  notifyMatchFailure(Location loc,
395  function_ref<void(Diagnostic &)> reasonCallback) override;
396 
397 #ifndef NDEBUG
398  /// A logger used to emit information during the application process.
399  llvm::ScopedPrinter logger{llvm::dbgs()};
400 #endif
401 
402  /// The low-level pattern applicator.
403  PatternApplicator matcher;
404 
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406  ExpensiveChecks expensiveChecks;
407 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
408 };
409 } // namespace
410 
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
414  : rewriter(ctx), config(config), matcher(patterns)
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
416  // clang-format off
417  , expensiveChecks(
418  /*driver=*/this,
419  /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
420 // clang-format on
421 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
422 {
423  // Apply a simple cost model based solely on pattern benefit.
424  matcher.applyDefaultCostModel();
425 
426  // Set up listener.
427 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
428  // Send IR notifications to the debug handler. This handler will then forward
429  // all notifications to this GreedyPatternRewriteDriver.
430  rewriter.setListener(&expensiveChecks);
431 #else
432  rewriter.setListener(this);
433 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
434 }
435 
436 bool GreedyPatternRewriteDriver::processWorklist() {
437 #ifndef NDEBUG
438  const char *logLineComment =
439  "//===-------------------------------------------===//\n";
440 
441  /// A utility function to log a process result for the given reason.
442  auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
443  logger.unindent();
444  logger.startLine() << "} -> " << result;
445  if (!msg.isTriviallyEmpty())
446  logger.getOStream() << " : " << msg;
447  logger.getOStream() << "\n";
448  };
449  auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
450  logResult(result, msg);
451  logger.startLine() << logLineComment;
452  };
453 #endif
454 
455  bool changed = false;
456  int64_t numRewrites = 0;
457  while (!worklist.empty() &&
458  (numRewrites < config.maxNumRewrites ||
459  config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
460  auto *op = worklist.pop();
461 
462  LLVM_DEBUG({
463  logger.getOStream() << "\n";
464  logger.startLine() << logLineComment;
465  logger.startLine() << "Processing operation : '" << op->getName() << "'("
466  << op << ") {\n";
467  logger.indent();
468 
469  // If the operation has no regions, just print it here.
470  if (op->getNumRegions() == 0) {
471  op->print(
472  logger.startLine(),
473  OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
474  logger.getOStream() << "\n\n";
475  }
476  });
477 
478  // If the operation is trivially dead - remove it.
479  if (isOpTriviallyDead(op)) {
480  rewriter.eraseOp(op);
481  changed = true;
482 
483  LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
484  continue;
485  }
486 
487  // Try to fold this op. Do not fold constant ops. That would lead to an
488  // infinite folding loop, as every constant op would be folded to an
489  // Attribute and then immediately be rematerialized as a constant op, which
490  // is then put on the worklist.
491  if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
492  SmallVector<OpFoldResult> foldResults;
493  if (succeeded(op->fold(foldResults))) {
494  LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
495 #ifndef NDEBUG
496  Operation *dumpRootOp = getDumpRootOp(op);
497 #endif // NDEBUG
498  if (foldResults.empty()) {
499  // Op was modified in-place.
500  notifyOperationModified(op);
501  changed = true;
502  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
503 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504  expensiveChecks.notifyFoldingSuccess();
505 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
506  continue;
507  }
508 
509  // Op results can be replaced with `foldResults`.
510  assert(foldResults.size() == op->getNumResults() &&
511  "folder produced incorrect number of results");
512  OpBuilder::InsertionGuard g(rewriter);
513  rewriter.setInsertionPoint(op);
514  SmallVector<Value> replacements;
515  bool materializationSucceeded = true;
516  for (auto [ofr, resultType] :
517  llvm::zip_equal(foldResults, op->getResultTypes())) {
518  if (auto value = ofr.dyn_cast<Value>()) {
519  assert(value.getType() == resultType &&
520  "folder produced value of incorrect type");
521  replacements.push_back(value);
522  continue;
523  }
524  // Materialize Attributes as SSA values.
525  Operation *constOp = op->getDialect()->materializeConstant(
526  rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
527 
528  if (!constOp) {
529  // If materialization fails, cleanup any operations generated for
530  // the previous results.
531  llvm::SmallDenseSet<Operation *> replacementOps;
532  for (Value replacement : replacements) {
533  assert(replacement.use_empty() &&
534  "folder reused existing op for one result but constant "
535  "materialization failed for another result");
536  replacementOps.insert(replacement.getDefiningOp());
537  }
538  for (Operation *op : replacementOps) {
539  rewriter.eraseOp(op);
540  }
541 
542  materializationSucceeded = false;
543  break;
544  }
545 
546  assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
547  "materializeConstant produced op that is not a ConstantLike");
548  assert(constOp->getResultTypes()[0] == resultType &&
549  "materializeConstant produced incorrect result type");
550  replacements.push_back(constOp->getResult(0));
551  }
552 
553  if (materializationSucceeded) {
554  rewriter.replaceOp(op, replacements);
555  changed = true;
556  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
557 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558  expensiveChecks.notifyFoldingSuccess();
559 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
560  continue;
561  }
562  }
563  }
564 
565  // Try to match one of the patterns. The rewriter is automatically
566  // notified of any necessary changes, so there is nothing else to do
567  // here.
568  auto canApplyCallback = [&](const Pattern &pattern) {
569  LLVM_DEBUG({
570  logger.getOStream() << "\n";
571  logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
572  << op->getName() << " -> (";
573  llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
574  logger.getOStream() << ")' {\n";
575  logger.indent();
576  });
577  if (config.listener)
578  config.listener->notifyPatternBegin(pattern, op);
579  return true;
580  };
581  function_ref<bool(const Pattern &)> canApply = canApplyCallback;
582  auto onFailureCallback = [&](const Pattern &pattern) {
583  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
584  if (config.listener)
585  config.listener->notifyPatternEnd(pattern, failure());
586  };
587  function_ref<void(const Pattern &)> onFailure = onFailureCallback;
588  auto onSuccessCallback = [&](const Pattern &pattern) {
589  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
590  if (config.listener)
591  config.listener->notifyPatternEnd(pattern, success());
592  return success();
593  };
594  function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
595 
596 #ifdef NDEBUG
597  // Optimization: PatternApplicator callbacks are not needed when running in
598  // optimized mode and without a listener.
599  if (!config.listener) {
600  canApply = nullptr;
601  onFailure = nullptr;
602  onSuccess = nullptr;
603  }
604 #endif // NDEBUG
605 
606 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
607  if (config.scope) {
608  expensiveChecks.computeFingerPrints(config.scope->getParentOp());
609  }
610  auto clearFingerprints =
611  llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
612 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
613 
614  LogicalResult matchResult =
615  matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
616 
617  if (succeeded(matchResult)) {
618  LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
619 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620  expensiveChecks.notifyRewriteSuccess();
621 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
622  changed = true;
623  ++numRewrites;
624  } else {
625  LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
626 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627  expensiveChecks.notifyRewriteFailure();
628 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
629  }
630  }
631 
632  return changed;
633 }
634 
635 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
636  assert(op && "expected valid op");
637  // Gather potential ancestors while looking for a "scope" parent region.
638  SmallVector<Operation *, 8> ancestors;
639  Region *region = nullptr;
640  do {
641  ancestors.push_back(op);
642  region = op->getParentRegion();
643  if (config.scope == region) {
644  // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
645  for (Operation *op : ancestors)
646  addSingleOpToWorklist(op);
647  return;
648  }
649  if (region == nullptr)
650  return;
651  } while ((op = region->getParentOp()));
652 }
653 
654 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
655  if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
656  strictModeFilteredOps.contains(op))
657  worklist.push(op);
658 }
659 
660 void GreedyPatternRewriteDriver::notifyBlockInserted(
661  Block *block, Region *previous, Region::iterator previousIt) {
662  if (config.listener)
663  config.listener->notifyBlockInserted(block, previous, previousIt);
664 }
665 
666 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
667  if (config.listener)
668  config.listener->notifyBlockErased(block);
669 }
670 
671 void GreedyPatternRewriteDriver::notifyOperationInserted(
672  Operation *op, OpBuilder::InsertPoint previous) {
673  LLVM_DEBUG({
674  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
675  << ")\n";
676  });
677  if (config.listener)
678  config.listener->notifyOperationInserted(op, previous);
679  if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
680  strictModeFilteredOps.insert(op);
681  addToWorklist(op);
682 }
683 
684 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
685  LLVM_DEBUG({
686  logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
687  << ")\n";
688  });
689  if (config.listener)
690  config.listener->notifyOperationModified(op);
691  addToWorklist(op);
692 }
693 
694 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
695  for (Value operand : op->getOperands()) {
696  // If this operand currently has at most 2 users, add its defining op to the
697  // worklist. Indeed, after the op is deleted, then the operand will have at
698  // most 1 user left. If it has 0 users left, it can be deleted too,
699  // and if it has 1 user left, there may be further canonicalization
700  // opportunities.
701  if (!operand)
702  continue;
703 
704  auto *defOp = operand.getDefiningOp();
705  if (!defOp)
706  continue;
707 
708  Operation *otherUser = nullptr;
709  bool hasMoreThanTwoUses = false;
710  for (auto user : operand.getUsers()) {
711  if (user == op || user == otherUser)
712  continue;
713  if (!otherUser) {
714  otherUser = user;
715  continue;
716  }
717  hasMoreThanTwoUses = true;
718  break;
719  }
720  if (hasMoreThanTwoUses)
721  continue;
722 
723  addToWorklist(defOp);
724  }
725 }
726 
727 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
728  LLVM_DEBUG({
729  logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
730  << ")\n";
731  });
732 
733 #ifndef NDEBUG
734  // Only ops that are within the configured scope are added to the worklist of
735  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
736  // the part of the IR that is taken into account for the "expensive checks".
737  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
738  // region, as that would break the worklist handling and the expensive checks.
739  if (config.scope && config.scope->getParentOp() == op)
740  llvm_unreachable(
741  "scope region must not be erased during greedy pattern rewrite");
742 #endif // NDEBUG
743 
744  if (config.listener)
745  config.listener->notifyOperationErased(op);
746 
747  addOperandsToWorklist(op);
748  worklist.remove(op);
749 
750  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
751  strictModeFilteredOps.erase(op);
752 }
753 
754 void GreedyPatternRewriteDriver::notifyOperationReplaced(
755  Operation *op, ValueRange replacement) {
756  LLVM_DEBUG({
757  logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
758  << ")\n";
759  });
760  if (config.listener)
761  config.listener->notifyOperationReplaced(op, replacement);
762 }
763 
764 void GreedyPatternRewriteDriver::notifyMatchFailure(
765  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
766  LLVM_DEBUG({
767  Diagnostic diag(loc, DiagnosticSeverity::Remark);
768  reasonCallback(diag);
769  logger.startLine() << "** Match Failure : " << diag.str() << "\n";
770  });
771  if (config.listener)
772  config.listener->notifyMatchFailure(loc, reasonCallback);
773 }
774 
775 //===----------------------------------------------------------------------===//
776 // RegionPatternRewriteDriver
777 //===----------------------------------------------------------------------===//
778 
779 namespace {
780 /// This driver simplfies all ops in a region.
781 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
782 public:
783  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
786  Region &regions);
787 
788  /// Simplify ops inside `region` and simplify the region itself. Return
789  /// success if the transformation converged.
790  LogicalResult simplify(bool *changed) &&;
791 
792 private:
793  /// The region that is simplified.
794  Region &region;
795 };
796 } // namespace
797 
798 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
800  const GreedyRewriteConfig &config, Region &region)
801  : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
802  // Populate strict mode ops.
803  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
804  region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
805  }
806 }
807 
808 namespace {
809 class GreedyPatternRewriteIteration
810  : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
811 public:
812  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
813  GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
814  : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
815  iteration(iteration) {}
816  static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
817  void print(raw_ostream &os) const override {
818  os << "GreedyPatternRewriteIteration(" << iteration << ")";
819  }
820 
821 private:
822  int64_t iteration = 0;
823 };
824 } // namespace
825 
826 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
827  bool continueRewrites = false;
828  int64_t iteration = 0;
829  MLIRContext *ctx = rewriter.getContext();
830  do {
831  // Check if the iteration limit was reached.
832  if (++iteration > config.maxIterations &&
833  config.maxIterations != GreedyRewriteConfig::kNoLimit)
834  break;
835 
836  // New iteration: start with an empty worklist.
837  worklist.clear();
838 
839  // `OperationFolder` CSE's constant ops (and may move them into parents
840  // regions to enable more aggressive CSE'ing).
841  OperationFolder folder(ctx, this);
842  auto insertKnownConstant = [&](Operation *op) {
843  // Check for existing constants when populating the worklist. This avoids
844  // accidentally reversing the constant order during processing.
845  Attribute constValue;
846  if (matchPattern(op, m_Constant(&constValue)))
847  if (!folder.insertKnownConstant(op, constValue))
848  return true;
849  return false;
850  };
851 
852  if (!config.useTopDownTraversal) {
853  // Add operations to the worklist in postorder.
854  region.walk([&](Operation *op) {
855  if (!config.cseConstants || !insertKnownConstant(op))
856  addToWorklist(op);
857  });
858  } else {
859  // Add all nested operations to the worklist in preorder.
860  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
861  if (!config.cseConstants || !insertKnownConstant(op)) {
862  addToWorklist(op);
863  return WalkResult::advance();
864  }
865  return WalkResult::skip();
866  });
867 
868  // Reverse the list so our pop-back loop processes them in-order.
869  worklist.reverse();
870  }
871 
872  ctx->executeAction<GreedyPatternRewriteIteration>(
873  [&] {
874  continueRewrites = processWorklist();
875 
876  // After applying patterns, make sure that the CFG of each of the
877  // regions is kept up to date.
878  if (config.enableRegionSimplification !=
879  GreedySimplifyRegionLevel::Disabled) {
880  continueRewrites |= succeeded(simplifyRegions(
881  rewriter, region,
882  /*mergeBlocks=*/config.enableRegionSimplification ==
883  GreedySimplifyRegionLevel::Aggressive));
884  }
885  },
886  {&region}, iteration);
887  } while (continueRewrites);
888 
889  if (changed)
890  *changed = iteration > 1;
891 
892  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
893  return success(!continueRewrites);
894 }
895 
896 LogicalResult
900  // The top-level operation must be known to be isolated from above to
901  // prevent performing canonicalizations on operations defined at or above
902  // the region containing 'op'.
903  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
904  "patterns can only be applied to operations IsolatedFromAbove");
905 
906  // Set scope if not specified.
907  if (!config.scope)
908  config.scope = &region;
909 
910 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
911  if (failed(verify(config.scope->getParentOp())))
912  llvm::report_fatal_error(
913  "greedy pattern rewriter input IR failed to verify");
914 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
915 
916  // Start the pattern driver.
917  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
918  region);
919  LogicalResult converged = std::move(driver).simplify(changed);
920  LLVM_DEBUG(if (failed(converged)) {
921  llvm::dbgs() << "The pattern rewrite did not converge after scanning "
922  << config.maxIterations << " times\n";
923  });
924  return converged;
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // MultiOpPatternRewriteDriver
929 //===----------------------------------------------------------------------===//
930 
931 namespace {
932 /// This driver simplfies a list of ops.
933 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
934 public:
935  explicit MultiOpPatternRewriteDriver(
938  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
939 
940  /// Simplify `ops`. Return `success` if the transformation converged.
941  LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
942 
943 private:
944  void notifyOperationErased(Operation *op) override {
945  GreedyPatternRewriteDriver::notifyOperationErased(op);
946  if (survivingOps)
947  survivingOps->erase(op);
948  }
949 
950  /// An optional set of ops that survived the rewrite. This set is populated
951  /// at the beginning of `simplifyLocally` with the inititally provided list
952  /// of ops.
953  llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
954 };
955 } // namespace
956 
957 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
960  llvm::SmallDenseSet<Operation *, 4> *survivingOps)
961  : GreedyPatternRewriteDriver(ctx, patterns, config),
962  survivingOps(survivingOps) {
963  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
964  strictModeFilteredOps.insert(ops.begin(), ops.end());
965 
966  if (survivingOps) {
967  survivingOps->clear();
968  survivingOps->insert(ops.begin(), ops.end());
969  }
970 }
971 
972 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
973  bool *changed) && {
974  // Populate the initial worklist.
975  for (Operation *op : ops)
976  addSingleOpToWorklist(op);
977 
978  // Process ops on the worklist.
979  bool result = processWorklist();
980  if (changed)
981  *changed = result;
982 
983  return success(worklist.empty());
984 }
985 
986 /// Find the region that is the closest common ancestor of all given ops.
987 ///
988 /// Note: This function returns `nullptr` if there is a top-level op among the
989 /// given list of ops.
991  assert(!ops.empty() && "expected at least one op");
992  // Fast path in case there is only one op.
993  if (ops.size() == 1)
994  return ops.front()->getParentRegion();
995 
996  Region *region = ops.front()->getParentRegion();
997  ops = ops.drop_front();
998  int sz = ops.size();
999  llvm::BitVector remainingOps(sz, true);
1000  while (region) {
1001  int pos = -1;
1002  // Iterate over all remaining ops.
1003  while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1004  // Is this op contained in `region`?
1005  if (region->findAncestorOpInRegion(*ops[pos]))
1006  remainingOps.reset(pos);
1007  }
1008  if (remainingOps.none())
1009  break;
1010  region = region->getParentRegion();
1011  }
1012  return region;
1013 }
1014 
1017  GreedyRewriteConfig config, bool *changed, bool *allErased) {
1018  if (ops.empty()) {
1019  if (changed)
1020  *changed = false;
1021  if (allErased)
1022  *allErased = true;
1023  return success();
1024  }
1025 
1026  // Determine scope of rewrite.
1027  if (!config.scope) {
1028  // Compute scope if none was provided. The scope will remain `nullptr` if
1029  // there is a top-level op among `ops`.
1030  config.scope = findCommonAncestor(ops);
1031  } else {
1032  // If a scope was provided, make sure that all ops are in scope.
1033 #ifndef NDEBUG
1034  bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1035  return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
1036  });
1037  assert(allOpsInScope && "ops must be within the specified scope");
1038 #endif // NDEBUG
1039  }
1040 
1041 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1042  if (config.scope && failed(verify(config.scope->getParentOp())))
1043  llvm::report_fatal_error(
1044  "greedy pattern rewriter input IR failed to verify");
1045 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1046 
1047  // Start the pattern driver.
1048  llvm::SmallDenseSet<Operation *, 4> surviving;
1049  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1050  config, ops,
1051  allErased ? &surviving : nullptr);
1052  LogicalResult converged = std::move(driver).simplify(ops, changed);
1053  if (allErased)
1054  *allErased = surviving.empty();
1055  LLVM_DEBUG(if (failed(converged)) {
1056  llvm::dbgs() << "The pattern rewrite did not converge after "
1057  << config.maxNumRewrites << " rewrites";
1058  });
1059  return converged;
1060 }
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:264
This class represents a saved insertion point.
Definition: Builders.h:336
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition: Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
BlockListType::iterator iterator
Definition: Region.h:52
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
CRTP Implementation of an action.
Definition: Action.h:76
virtual void print(raw_ostream &os) const
Definition: Action.h:50
RewritePatternSet & patterns
Definition: Patterns.h:74
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:469
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:482
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:495
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:478