MLIR  19.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Verifier.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32 #include <random>
33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34 
35 using namespace mlir;
36 
37 #define DEBUG_TYPE "greedy-rewriter"
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // Debugging Infrastructure
43 //===----------------------------------------------------------------------===//
44 
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46 /// A helper struct that performs various "expensive checks" to detect broken
47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48 /// broken if:
49 /// * IR does not verify after pattern application / folding.
50 /// * Pattern returns "failure" but the IR has changed.
51 /// * Pattern returns "success" but the IR has not changed.
52 ///
53 /// This struct stores finger prints of ops to determine whether the IR has
54 /// changed or not.
55 struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56  ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57  : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58 
59  /// Compute finger prints of the given op and its nested ops.
60  void computeFingerPrints(Operation *topLevel) {
61  this->topLevel = topLevel;
62  this->topLevelFingerPrint.emplace(topLevel);
63  topLevel->walk([&](Operation *op) {
64  fingerprints.try_emplace(op, op, /*includeNested=*/false);
65  });
66  }
67 
68  /// Clear all finger prints.
69  void clear() {
70  topLevel = nullptr;
71  topLevelFingerPrint.reset();
72  fingerprints.clear();
73  }
74 
75  void notifyRewriteSuccess() {
76  if (!topLevel)
77  return;
78 
79  // Make sure that the IR still verifies.
80  if (failed(verify(topLevel)))
81  llvm::report_fatal_error("IR failed to verify after pattern application");
82 
83  // Pattern application success => IR must have changed.
84  OperationFingerPrint afterFingerPrint(topLevel);
85  if (*topLevelFingerPrint == afterFingerPrint) {
86  // Note: Run "mlir-opt -debug" to see which pattern is broken.
87  llvm::report_fatal_error(
88  "pattern returned success but IR did not change");
89  }
90  for (const auto &it : fingerprints) {
91  // Skip top-level op, its finger print is never invalidated.
92  if (it.first == topLevel)
93  continue;
94  // Note: Finger print computation may crash when an op was erased
95  // without notifying the rewriter. (Run with ASAN to see where the op was
96  // erased; the op was probably erased directly, bypassing the rewriter
97  // API.) Finger print computation does may not crash if a new op was
98  // created at the same memory location. (But then the finger print should
99  // have changed.)
100  if (it.second !=
101  OperationFingerPrint(it.first, /*includeNested=*/false)) {
102  // Note: Run "mlir-opt -debug" to see which pattern is broken.
103  llvm::report_fatal_error("operation finger print changed");
104  }
105  }
106  }
107 
108  void notifyRewriteFailure() {
109  if (!topLevel)
110  return;
111 
112  // Pattern application failure => IR must not have changed.
113  OperationFingerPrint afterFingerPrint(topLevel);
114  if (*topLevelFingerPrint != afterFingerPrint) {
115  // Note: Run "mlir-opt -debug" to see which pattern is broken.
116  llvm::report_fatal_error("pattern returned failure but IR did change");
117  }
118  }
119 
120  void notifyFoldingSuccess() {
121  if (!topLevel)
122  return;
123 
124  // Make sure that the IR still verifies.
125  if (failed(verify(topLevel)))
126  llvm::report_fatal_error("IR failed to verify after folding");
127  }
128 
129 protected:
130  /// Invalidate the finger print of the given op, i.e., remove it from the map.
131  void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
132 
133  void notifyBlockErased(Block *block) override {
135 
136  // The block structure (number of blocks, types of block arguments, etc.)
137  // is part of the fingerprint of the parent op.
138  // TODO: The parent op fingerprint should also be invalidated when modifying
139  // the block arguments of a block, but we do not have a
140  // `notifyBlockModified` callback yet.
141  invalidateFingerPrint(block->getParentOp());
142  }
143 
144  void notifyOperationInserted(Operation *op,
145  OpBuilder::InsertPoint previous) override {
147  invalidateFingerPrint(op->getParentOp());
148  }
149 
150  void notifyOperationModified(Operation *op) override {
152  invalidateFingerPrint(op);
153  }
154 
155  void notifyOperationErased(Operation *op) override {
157  op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158  }
159 
160  /// Operation finger prints to detect invalid pattern API usage. IR is checked
161  /// against these finger prints after pattern application to detect cases
162  /// where IR was modified directly, bypassing the rewriter API.
164 
165  /// Top-level operation of the current greedy rewrite.
166  Operation *topLevel = nullptr;
167 
168  /// Finger print of the top-level operation.
169  std::optional<OperationFingerPrint> topLevelFingerPrint;
170 };
171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172 
173 #ifndef NDEBUG
174 static Operation *getDumpRootOp(Operation *op) {
175  // Dump the parent op so that materialized constants are visible. If the op
176  // is a top-level op, dump it directly.
177  if (Operation *parentOp = op->getParentOp())
178  return parentOp;
179  return op;
180 }
181 static void logSuccessfulFolding(Operation *op) {
182  llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183  op->dump();
184  llvm::dbgs() << "\n\n";
185 }
186 #endif // NDEBUG
187 
188 //===----------------------------------------------------------------------===//
189 // Worklist
190 //===----------------------------------------------------------------------===//
191 
192 /// A LIFO worklist of operations with efficient removal and set semantics.
193 ///
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
198 class Worklist {
199 public:
200  Worklist();
201 
202  /// Clear the worklist.
203  void clear();
204 
205  /// Return whether the worklist is empty.
206  bool empty() const;
207 
208  /// Push an operation to the end of the worklist, unless the operation is
209  /// already on the worklist.
210  void push(Operation *op);
211 
212  /// Pop the an operation from the end of the worklist. Only allowed on
213  /// non-empty worklists.
214  Operation *pop();
215 
216  /// Remove an operation from the worklist.
217  void remove(Operation *op);
218 
219  /// Reverse the worklist.
220  void reverse();
221 
222 protected:
223  /// The worklist of operations.
224  std::vector<Operation *> list;
225 
226  /// A mapping of operations to positions in `list`.
228 };
229 
230 Worklist::Worklist() { list.reserve(64); }
231 
232 void Worklist::clear() {
233  list.clear();
234  map.clear();
235 }
236 
237 bool Worklist::empty() const {
238  // Skip all nullptr.
239  return !llvm::any_of(list,
240  [](Operation *op) { return static_cast<bool>(op); });
241 }
242 
243 void Worklist::push(Operation *op) {
244  assert(op && "cannot push nullptr to worklist");
245  // Check to see if the worklist already contains this op.
246  if (!map.insert({op, list.size()}).second)
247  return;
248  list.push_back(op);
249 }
250 
251 Operation *Worklist::pop() {
252  assert(!empty() && "cannot pop from empty worklist");
253  // Skip and remove all trailing nullptr.
254  while (!list.back())
255  list.pop_back();
256  Operation *op = list.back();
257  list.pop_back();
258  map.erase(op);
259  // Cleanup: Remove all trailing nullptr.
260  while (!list.empty() && !list.back())
261  list.pop_back();
262  return op;
263 }
264 
265 void Worklist::remove(Operation *op) {
266  assert(op && "cannot remove nullptr from worklist");
267  auto it = map.find(op);
268  if (it != map.end()) {
269  assert(list[it->second] == op && "malformed worklist data structure");
270  list[it->second] = nullptr;
271  map.erase(it);
272  }
273 }
274 
275 void Worklist::reverse() {
276  std::reverse(list.begin(), list.end());
277  for (size_t i = 0, e = list.size(); i != e; ++i)
278  map[list[i]] = i;
279 }
280 
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist : public Worklist {
287 public:
288  RandomizedWorklist() : Worklist() {
289  generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290  }
291 
292  /// Pop a random non-empty op from the worklist.
293  Operation *pop() {
294  Operation *op = nullptr;
295  do {
296  assert(!list.empty() && "cannot pop from empty worklist");
297  int64_t pos = generator() % list.size();
298  op = list[pos];
299  list.erase(list.begin() + pos);
300  for (int64_t i = pos, e = list.size(); i < e; ++i)
301  map[list[i]] = i;
302  map.erase(op);
303  } while (!op);
304  return op;
305  }
306 
307 private:
308  std::minstd_rand0 generator;
309 };
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311 
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
315 
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
318 ///
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323 protected:
324  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
325  const FrozenRewritePatternSet &patterns,
326  const GreedyRewriteConfig &config);
327 
328  /// Add the given operation to the worklist.
329  void addSingleOpToWorklist(Operation *op);
330 
331  /// Add the given operation and its ancestors to the worklist.
332  void addToWorklist(Operation *op);
333 
334  /// Notify the driver that the specified operation may have been modified
335  /// in-place. The operation is added to the worklist.
336  void notifyOperationModified(Operation *op) override;
337 
338  /// Notify the driver that the specified operation was inserted. Update the
339  /// worklist as needed: The operation is enqueued depending on scope and
340  /// strict mode.
341  void notifyOperationInserted(Operation *op,
342  OpBuilder::InsertPoint previous) override;
343 
344  /// Notify the driver that the specified operation was removed. Update the
345  /// worklist as needed: The operation and its children are removed from the
346  /// worklist.
347  void notifyOperationErased(Operation *op) override;
348 
349  /// Notify the driver that the specified operation was replaced. Update the
350  /// worklist as needed: New users are added enqueued.
351  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352 
353  /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354  /// reached. Return `true` if any IR was changed.
355  bool processWorklist();
356 
357  /// The pattern rewriter that is used for making IR modifications and is
358  /// passed to rewrite patterns.
359  PatternRewriter rewriter;
360 
361  /// The worklist for this transformation keeps track of the operations that
362  /// need to be (re)visited.
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364  RandomizedWorklist worklist;
365 #else
366  Worklist worklist;
367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368 
369  /// Configuration information for how to simplify.
370  const GreedyRewriteConfig config;
371 
372  /// The list of ops we are restricting our rewrites to. These include the
373  /// supplied set of ops as well as new ops created while rewriting those ops
374  /// depending on `strictMode`. This set is not maintained when
375  /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377 
378 private:
379  /// Look over the provided operands for any defining operations that should
380  /// be re-added to the worklist. This function should be called when an
381  /// operation is modified or removed, as it may trigger further
382  /// simplifications.
383  void addOperandsToWorklist(Operation *op);
384 
385  /// Notify the driver that the given block was inserted.
386  void notifyBlockInserted(Block *block, Region *previous,
387  Region::iterator previousIt) override;
388 
389  /// Notify the driver that the given block is about to be removed.
390  void notifyBlockErased(Block *block) override;
391 
392  /// For debugging only: Notify the driver of a pattern match failure.
393  void
394  notifyMatchFailure(Location loc,
395  function_ref<void(Diagnostic &)> reasonCallback) override;
396 
397 #ifndef NDEBUG
398  /// A logger used to emit information during the application process.
399  llvm::ScopedPrinter logger{llvm::dbgs()};
400 #endif
401 
402  /// The low-level pattern applicator.
403  PatternApplicator matcher;
404 
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406  ExpensiveChecks expensiveChecks;
407 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
408 };
409 } // namespace
410 
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
412  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
413  const GreedyRewriteConfig &config)
414  : rewriter(ctx), config(config), matcher(patterns)
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
416  // clang-format off
417  , expensiveChecks(
418  /*driver=*/this,
419  /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
420 // clang-format on
421 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
422 {
423  // Apply a simple cost model based solely on pattern benefit.
424  matcher.applyDefaultCostModel();
425 
426  // Set up listener.
427 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
428  // Send IR notifications to the debug handler. This handler will then forward
429  // all notifications to this GreedyPatternRewriteDriver.
430  rewriter.setListener(&expensiveChecks);
431 #else
432  rewriter.setListener(this);
433 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
434 }
435 
436 bool GreedyPatternRewriteDriver::processWorklist() {
437 #ifndef NDEBUG
438  const char *logLineComment =
439  "//===-------------------------------------------===//\n";
440 
441  /// A utility function to log a process result for the given reason.
442  auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
443  logger.unindent();
444  logger.startLine() << "} -> " << result;
445  if (!msg.isTriviallyEmpty())
446  logger.getOStream() << " : " << msg;
447  logger.getOStream() << "\n";
448  };
449  auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
450  logResult(result, msg);
451  logger.startLine() << logLineComment;
452  };
453 #endif
454 
455  bool changed = false;
456  int64_t numRewrites = 0;
457  while (!worklist.empty() &&
458  (numRewrites < config.maxNumRewrites ||
459  config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
460  auto *op = worklist.pop();
461 
462  LLVM_DEBUG({
463  logger.getOStream() << "\n";
464  logger.startLine() << logLineComment;
465  logger.startLine() << "Processing operation : '" << op->getName() << "'("
466  << op << ") {\n";
467  logger.indent();
468 
469  // If the operation has no regions, just print it here.
470  if (op->getNumRegions() == 0) {
471  op->print(
472  logger.startLine(),
473  OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
474  logger.getOStream() << "\n\n";
475  }
476  });
477 
478  // If the operation is trivially dead - remove it.
479  if (isOpTriviallyDead(op)) {
480  rewriter.eraseOp(op);
481  changed = true;
482 
483  LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
484  continue;
485  }
486 
487  // Try to fold this op. Do not fold constant ops. That would lead to an
488  // infinite folding loop, as every constant op would be folded to an
489  // Attribute and then immediately be rematerialized as a constant op, which
490  // is then put on the worklist.
491  if (!op->hasTrait<OpTrait::ConstantLike>()) {
492  SmallVector<OpFoldResult> foldResults;
493  if (succeeded(op->fold(foldResults))) {
494  LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
495 #ifndef NDEBUG
496  Operation *dumpRootOp = getDumpRootOp(op);
497 #endif // NDEBUG
498  if (foldResults.empty()) {
499  // Op was modified in-place.
500  notifyOperationModified(op);
501  changed = true;
502  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
503 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504  expensiveChecks.notifyFoldingSuccess();
505 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
506  continue;
507  }
508 
509  // Op results can be replaced with `foldResults`.
510  assert(foldResults.size() == op->getNumResults() &&
511  "folder produced incorrect number of results");
512  OpBuilder::InsertionGuard g(rewriter);
513  rewriter.setInsertionPoint(op);
514  SmallVector<Value> replacements;
515  bool materializationSucceeded = true;
516  for (auto [ofr, resultType] :
517  llvm::zip_equal(foldResults, op->getResultTypes())) {
518  if (auto value = ofr.dyn_cast<Value>()) {
519  assert(value.getType() == resultType &&
520  "folder produced value of incorrect type");
521  replacements.push_back(value);
522  continue;
523  }
524  // Materialize Attributes as SSA values.
525  Operation *constOp = op->getDialect()->materializeConstant(
526  rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
527 
528  if (!constOp) {
529  // If materialization fails, cleanup any operations generated for
530  // the previous results.
531  llvm::SmallDenseSet<Operation *> replacementOps;
532  for (Value replacement : replacements) {
533  assert(replacement.use_empty() &&
534  "folder reused existing op for one result but constant "
535  "materialization failed for another result");
536  replacementOps.insert(replacement.getDefiningOp());
537  }
538  for (Operation *op : replacementOps) {
539  rewriter.eraseOp(op);
540  }
541 
542  materializationSucceeded = false;
543  break;
544  }
545 
546  assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
547  "materializeConstant produced op that is not a ConstantLike");
548  assert(constOp->getResultTypes()[0] == resultType &&
549  "materializeConstant produced incorrect result type");
550  replacements.push_back(constOp->getResult(0));
551  }
552 
553  if (materializationSucceeded) {
554  rewriter.replaceOp(op, replacements);
555  changed = true;
556  LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
557 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558  expensiveChecks.notifyFoldingSuccess();
559 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
560  continue;
561  }
562  }
563  }
564 
565  // Try to match one of the patterns. The rewriter is automatically
566  // notified of any necessary changes, so there is nothing else to do
567  // here.
568  auto canApplyCallback = [&](const Pattern &pattern) {
569  LLVM_DEBUG({
570  logger.getOStream() << "\n";
571  logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
572  << op->getName() << " -> (";
573  llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
574  logger.getOStream() << ")' {\n";
575  logger.indent();
576  });
577  if (config.listener)
578  config.listener->notifyPatternBegin(pattern, op);
579  return true;
580  };
581  function_ref<bool(const Pattern &)> canApply = canApplyCallback;
582  auto onFailureCallback = [&](const Pattern &pattern) {
583  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
584  if (config.listener)
585  config.listener->notifyPatternEnd(pattern, failure());
586  };
587  function_ref<void(const Pattern &)> onFailure = onFailureCallback;
588  auto onSuccessCallback = [&](const Pattern &pattern) {
589  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
590  if (config.listener)
591  config.listener->notifyPatternEnd(pattern, success());
592  return success();
593  };
594  function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
595 
596 #ifdef NDEBUG
597  // Optimization: PatternApplicator callbacks are not needed when running in
598  // optimized mode and without a listener.
599  if (!config.listener) {
600  canApply = nullptr;
601  onFailure = nullptr;
602  onSuccess = nullptr;
603  }
604 #endif // NDEBUG
605 
606 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
607  if (config.scope) {
608  expensiveChecks.computeFingerPrints(config.scope->getParentOp());
609  }
610  auto clearFingerprints =
611  llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
612 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
613 
614  LogicalResult matchResult =
615  matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
616 
617  if (succeeded(matchResult)) {
618  LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
619 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620  expensiveChecks.notifyRewriteSuccess();
621 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
622  changed = true;
623  ++numRewrites;
624  } else {
625  LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
626 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627  expensiveChecks.notifyRewriteFailure();
628 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
629  }
630  }
631 
632  return changed;
633 }
634 
635 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
636  assert(op && "expected valid op");
637  // Gather potential ancestors while looking for a "scope" parent region.
638  SmallVector<Operation *, 8> ancestors;
639  Region *region = nullptr;
640  do {
641  ancestors.push_back(op);
642  region = op->getParentRegion();
643  if (config.scope == region) {
644  // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
645  for (Operation *op : ancestors)
646  addSingleOpToWorklist(op);
647  return;
648  }
649  if (region == nullptr)
650  return;
651  } while ((op = region->getParentOp()));
652 }
653 
654 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
655  if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
656  strictModeFilteredOps.contains(op))
657  worklist.push(op);
658 }
659 
660 void GreedyPatternRewriteDriver::notifyBlockInserted(
661  Block *block, Region *previous, Region::iterator previousIt) {
662  if (config.listener)
663  config.listener->notifyBlockInserted(block, previous, previousIt);
664 }
665 
666 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
667  if (config.listener)
668  config.listener->notifyBlockErased(block);
669 }
670 
671 void GreedyPatternRewriteDriver::notifyOperationInserted(
672  Operation *op, OpBuilder::InsertPoint previous) {
673  LLVM_DEBUG({
674  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
675  << ")\n";
676  });
677  if (config.listener)
678  config.listener->notifyOperationInserted(op, previous);
679  if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
680  strictModeFilteredOps.insert(op);
681  addToWorklist(op);
682 }
683 
684 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
685  LLVM_DEBUG({
686  logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
687  << ")\n";
688  });
689  if (config.listener)
690  config.listener->notifyOperationModified(op);
691  addToWorklist(op);
692 }
693 
694 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
695  for (Value operand : op->getOperands()) {
696  // If this operand currently has at most 2 users, add its defining op to the
697  // worklist. Indeed, after the op is deleted, then the operand will have at
698  // most 1 user left. If it has 0 users left, it can be deleted too,
699  // and if it has 1 user left, there may be further canonicalization
700  // opportunities.
701  if (!operand)
702  continue;
703 
704  auto *defOp = operand.getDefiningOp();
705  if (!defOp)
706  continue;
707 
708  Operation *otherUser = nullptr;
709  bool hasMoreThanTwoUses = false;
710  for (auto user : operand.getUsers()) {
711  if (user == op || user == otherUser)
712  continue;
713  if (!otherUser) {
714  otherUser = user;
715  continue;
716  }
717  hasMoreThanTwoUses = true;
718  break;
719  }
720  if (hasMoreThanTwoUses)
721  continue;
722 
723  addToWorklist(defOp);
724  }
725 }
726 
727 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
728  LLVM_DEBUG({
729  logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
730  << ")\n";
731  });
732 
733 #ifndef NDEBUG
734  // Only ops that are within the configured scope are added to the worklist of
735  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
736  // the part of the IR that is taken into account for the "expensive checks".
737  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
738  // region, as that would break the worklist handling and the expensive checks.
739  if (config.scope && config.scope->getParentOp() == op)
740  llvm_unreachable(
741  "scope region must not be erased during greedy pattern rewrite");
742 #endif // NDEBUG
743 
744  if (config.listener)
745  config.listener->notifyOperationErased(op);
746 
747  addOperandsToWorklist(op);
748  worklist.remove(op);
749 
750  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
751  strictModeFilteredOps.erase(op);
752 }
753 
754 void GreedyPatternRewriteDriver::notifyOperationReplaced(
755  Operation *op, ValueRange replacement) {
756  LLVM_DEBUG({
757  logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
758  << ")\n";
759  });
760  if (config.listener)
761  config.listener->notifyOperationReplaced(op, replacement);
762 }
763 
764 void GreedyPatternRewriteDriver::notifyMatchFailure(
765  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
766  LLVM_DEBUG({
767  Diagnostic diag(loc, DiagnosticSeverity::Remark);
768  reasonCallback(diag);
769  logger.startLine() << "** Match Failure : " << diag.str() << "\n";
770  });
771  if (config.listener)
772  config.listener->notifyMatchFailure(loc, reasonCallback);
773 }
774 
775 //===----------------------------------------------------------------------===//
776 // RegionPatternRewriteDriver
777 //===----------------------------------------------------------------------===//
778 
779 namespace {
780 /// This driver simplfies all ops in a region.
781 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
782 public:
783  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
784  const FrozenRewritePatternSet &patterns,
785  const GreedyRewriteConfig &config,
786  Region &regions);
787 
788  /// Simplify ops inside `region` and simplify the region itself. Return
789  /// success if the transformation converged.
790  LogicalResult simplify(bool *changed) &&;
791 
792 private:
793  /// The region that is simplified.
794  Region &region;
795 };
796 } // namespace
797 
798 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
799  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
800  const GreedyRewriteConfig &config, Region &region)
801  : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
802  // Populate strict mode ops.
803  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
804  region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
805  }
806 }
807 
808 namespace {
809 class GreedyPatternRewriteIteration
810  : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
811 public:
812  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
813  GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
814  : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
815  iteration(iteration) {}
816  static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
817  void print(raw_ostream &os) const override {
818  os << "GreedyPatternRewriteIteration(" << iteration << ")";
819  }
820 
821 private:
822  int64_t iteration = 0;
823 };
824 } // namespace
825 
826 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
827  bool continueRewrites = false;
828  int64_t iteration = 0;
829  MLIRContext *ctx = rewriter.getContext();
830  do {
831  // Check if the iteration limit was reached.
832  if (++iteration > config.maxIterations &&
833  config.maxIterations != GreedyRewriteConfig::kNoLimit)
834  break;
835 
836  // New iteration: start with an empty worklist.
837  worklist.clear();
838 
839  // `OperationFolder` CSE's constant ops (and may move them into parents
840  // regions to enable more aggressive CSE'ing).
841  OperationFolder folder(ctx, this);
842  auto insertKnownConstant = [&](Operation *op) {
843  // Check for existing constants when populating the worklist. This avoids
844  // accidentally reversing the constant order during processing.
845  Attribute constValue;
846  if (matchPattern(op, m_Constant(&constValue)))
847  if (!folder.insertKnownConstant(op, constValue))
848  return true;
849  return false;
850  };
851 
852  if (!config.useTopDownTraversal) {
853  // Add operations to the worklist in postorder.
854  region.walk([&](Operation *op) {
855  if (!insertKnownConstant(op))
856  addToWorklist(op);
857  });
858  } else {
859  // Add all nested operations to the worklist in preorder.
860  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
861  if (!insertKnownConstant(op)) {
862  addToWorklist(op);
863  return WalkResult::advance();
864  }
865  return WalkResult::skip();
866  });
867 
868  // Reverse the list so our pop-back loop processes them in-order.
869  worklist.reverse();
870  }
871 
872  ctx->executeAction<GreedyPatternRewriteIteration>(
873  [&] {
874  continueRewrites = processWorklist();
875 
876  // After applying patterns, make sure that the CFG of each of the
877  // regions is kept up to date.
878  if (config.enableRegionSimplification !=
879  GreedySimplifyRegionLevel::Disabled) {
880  continueRewrites |= succeeded(simplifyRegions(
881  rewriter, region,
882  /*mergeBlocks=*/config.enableRegionSimplification ==
883  GreedySimplifyRegionLevel::Aggressive));
884  }
885  },
886  {&region}, iteration);
887  } while (continueRewrites);
888 
889  if (changed)
890  *changed = iteration > 1;
891 
892  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
893  return success(!continueRewrites);
894 }
895 
896 LogicalResult
898  const FrozenRewritePatternSet &patterns,
899  GreedyRewriteConfig config, bool *changed) {
900  // The top-level operation must be known to be isolated from above to
901  // prevent performing canonicalizations on operations defined at or above
902  // the region containing 'op'.
903  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
904  "patterns can only be applied to operations IsolatedFromAbove");
905 
906  // Set scope if not specified.
907  if (!config.scope)
908  config.scope = &region;
909 
910 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
911  if (failed(verify(config.scope->getParentOp())))
912  llvm::report_fatal_error(
913  "greedy pattern rewriter input IR failed to verify");
914 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
915 
916  // Start the pattern driver.
917  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
918  region);
919  LogicalResult converged = std::move(driver).simplify(changed);
920  LLVM_DEBUG(if (failed(converged)) {
921  llvm::dbgs() << "The pattern rewrite did not converge after scanning "
922  << config.maxIterations << " times\n";
923  });
924  return converged;
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // MultiOpPatternRewriteDriver
929 //===----------------------------------------------------------------------===//
930 
931 namespace {
932 /// This driver simplfies a list of ops.
933 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
934 public:
935  explicit MultiOpPatternRewriteDriver(
936  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
937  const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
938  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
939 
940  /// Simplify `ops`. Return `success` if the transformation converged.
941  LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
942 
943 private:
944  void notifyOperationErased(Operation *op) override {
945  GreedyPatternRewriteDriver::notifyOperationErased(op);
946  if (survivingOps)
947  survivingOps->erase(op);
948  }
949 
950  /// An optional set of ops that survived the rewrite. This set is populated
951  /// at the beginning of `simplifyLocally` with the inititally provided list
952  /// of ops.
953  llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
954 };
955 } // namespace
956 
957 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
958  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
959  const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
960  llvm::SmallDenseSet<Operation *, 4> *survivingOps)
961  : GreedyPatternRewriteDriver(ctx, patterns, config),
962  survivingOps(survivingOps) {
964  strictModeFilteredOps.insert(ops.begin(), ops.end());
965 
966  if (survivingOps) {
967  survivingOps->clear();
968  survivingOps->insert(ops.begin(), ops.end());
969  }
970 }
971 
972 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
973  bool *changed) && {
974  // Populate the initial worklist.
975  for (Operation *op : ops)
976  addSingleOpToWorklist(op);
977 
978  // Process ops on the worklist.
979  bool result = processWorklist();
980  if (changed)
981  *changed = result;
982 
983  return success(worklist.empty());
984 }
985 
986 /// Find the region that is the closest common ancestor of all given ops.
987 ///
988 /// Note: This function returns `nullptr` if there is a top-level op among the
989 /// given list of ops.
991  assert(!ops.empty() && "expected at least one op");
992  // Fast path in case there is only one op.
993  if (ops.size() == 1)
994  return ops.front()->getParentRegion();
995 
996  Region *region = ops.front()->getParentRegion();
997  ops = ops.drop_front();
998  int sz = ops.size();
999  llvm::BitVector remainingOps(sz, true);
1000  while (region) {
1001  int pos = -1;
1002  // Iterate over all remaining ops.
1003  while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1004  // Is this op contained in `region`?
1005  if (region->findAncestorOpInRegion(*ops[pos]))
1006  remainingOps.reset(pos);
1007  }
1008  if (remainingOps.none())
1009  break;
1010  region = region->getParentRegion();
1011  }
1012  return region;
1013 }
1014 
1016  ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
1017  GreedyRewriteConfig config, bool *changed, bool *allErased) {
1018  if (ops.empty()) {
1019  if (changed)
1020  *changed = false;
1021  if (allErased)
1022  *allErased = true;
1023  return success();
1024  }
1025 
1026  // Determine scope of rewrite.
1027  if (!config.scope) {
1028  // Compute scope if none was provided. The scope will remain `nullptr` if
1029  // there is a top-level op among `ops`.
1030  config.scope = findCommonAncestor(ops);
1031  } else {
1032  // If a scope was provided, make sure that all ops are in scope.
1033 #ifndef NDEBUG
1034  bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1035  return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
1036  });
1037  assert(allOpsInScope && "ops must be within the specified scope");
1038 #endif // NDEBUG
1039  }
1040 
1041 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1042  if (config.scope && failed(verify(config.scope->getParentOp())))
1043  llvm::report_fatal_error(
1044  "greedy pattern rewriter input IR failed to verify");
1045 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1046 
1047  // Start the pattern driver.
1048  llvm::SmallDenseSet<Operation *, 4> surviving;
1049  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1050  config, ops,
1051  allErased ? &surviving : nullptr);
1052  LogicalResult converged = std::move(driver).simplify(ops, changed);
1053  if (allErased)
1054  *allErased = surviving.empty();
1055  LLVM_DEBUG(if (failed(converged)) {
1056  llvm::dbgs() << "The pattern rewrite did not converge after "
1057  << config.maxNumRewrites << " rewrites";
1058  });
1059  return converged;
1060 }
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
Region * scope
Only ops within the scope are added to the worklist.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
GreedySimplifyRegionLevel enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:259
This class represents a saved insertion point.
Definition: Builders.h:329
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:632
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition: Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
BlockListType::iterator iterator
Definition: Region.h:52
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
CRTP Implementation of an action.
Definition: Action.h:76
virtual void print(raw_ostream &os) const
Definition: Action.h:50
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:300
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:466
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:477
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:490
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:473
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:454
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:411
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:435
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:419
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:408
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:446
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:439