MLIR  18.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
22 #include "llvm/ADT/BitVector.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/ScopedPrinter.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
31 #include <random>
32 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
33 
34 using namespace mlir;
35 
36 #define DEBUG_TYPE "greedy-rewriter"
37 
38 namespace {
39 
40 //===----------------------------------------------------------------------===//
41 // Debugging Infrastructure
42 //===----------------------------------------------------------------------===//
43 
44 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
45 /// A helper struct that stores finger prints of ops in order to detect broken
46 /// RewritePatterns. A rewrite pattern is broken if it modifies IR without
47 /// using the rewriter API or if it returns an inconsistent return value.
48 struct DebugFingerPrints : public RewriterBase::ForwardingListener {
49  DebugFingerPrints(RewriterBase::Listener *driver)
51 
52  /// Compute finger prints of the given op and its nested ops.
53  void computeFingerPrints(Operation *topLevel) {
54  this->topLevel = topLevel;
55  this->topLevelFingerPrint.emplace(topLevel);
56  topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
57  }
58 
59  /// Clear all finger prints.
60  void clear() {
61  topLevel = nullptr;
62  topLevelFingerPrint.reset();
63  fingerprints.clear();
64  }
65 
66  void notifyRewriteSuccess() {
67  // Pattern application success => IR must have changed.
68  OperationFingerPrint afterFingerPrint(topLevel);
69  if (*topLevelFingerPrint == afterFingerPrint) {
70  // Note: Run "mlir-opt -debug" to see which pattern is broken.
71  llvm::report_fatal_error(
72  "pattern returned success but IR did not change");
73  }
74  for (const auto &it : fingerprints) {
75  // Skip top-level op, its finger print is never invalidated.
76  if (it.first == topLevel)
77  continue;
78  // Note: Finger print computation may crash when an op was erased
79  // without notifying the rewriter. (Run with ASAN to see where the op was
80  // erased; the op was probably erased directly, bypassing the rewriter
81  // API.) Finger print computation does may not crash if a new op was
82  // created at the same memory location. (But then the finger print should
83  // have changed.)
84  if (it.second != OperationFingerPrint(it.first)) {
85  // Note: Run "mlir-opt -debug" to see which pattern is broken.
86  llvm::report_fatal_error("operation finger print changed");
87  }
88  }
89  }
90 
91  void notifyRewriteFailure() {
92  // Pattern application failure => IR must not have changed.
93  OperationFingerPrint afterFingerPrint(topLevel);
94  if (*topLevelFingerPrint != afterFingerPrint) {
95  // Note: Run "mlir-opt -debug" to see which pattern is broken.
96  llvm::report_fatal_error("pattern returned failure but IR did change");
97  }
98  }
99 
100 protected:
101  /// Invalidate the finger print of the given op, i.e., remove it from the map.
102  void invalidateFingerPrint(Operation *op) {
103  // Invalidate all finger prints until the top level.
104  while (op && op != topLevel) {
105  fingerprints.erase(op);
106  op = op->getParentOp();
107  }
108  }
109 
110  void notifyOperationInserted(Operation *op) override {
112  invalidateFingerPrint(op->getParentOp());
113  }
114 
115  void notifyOperationModified(Operation *op) override {
117  invalidateFingerPrint(op);
118  }
119 
120  void notifyOperationRemoved(Operation *op) override {
122  op->walk([this](Operation *op) { invalidateFingerPrint(op); });
123  }
124 
125  /// Operation finger prints to detect invalid pattern API usage. IR is checked
126  /// against these finger prints after pattern application to detect cases
127  /// where IR was modified directly, bypassing the rewriter API.
129 
130  /// Top-level operation of the current greedy rewrite.
131  Operation *topLevel = nullptr;
132 
133  /// Finger print of the top-level operation.
134  std::optional<OperationFingerPrint> topLevelFingerPrint;
135 };
136 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
137 
138 //===----------------------------------------------------------------------===//
139 // Worklist
140 //===----------------------------------------------------------------------===//
141 
142 /// A LIFO worklist of operations with efficient removal and set semantics.
143 ///
144 /// This class maintains a vector of operations and a mapping of operations to
145 /// positions in the vector, so that operations can be removed efficiently at
146 /// random. When an operation is removed, it is replaced with nullptr. Such
147 /// nullptr are skipped when pop'ing elements.
148 class Worklist {
149 public:
150  Worklist();
151 
152  /// Clear the worklist.
153  void clear();
154 
155  /// Return whether the worklist is empty.
156  bool empty() const;
157 
158  /// Push an operation to the end of the worklist, unless the operation is
159  /// already on the worklist.
160  void push(Operation *op);
161 
162  /// Pop the an operation from the end of the worklist. Only allowed on
163  /// non-empty worklists.
164  Operation *pop();
165 
166  /// Remove an operation from the worklist.
167  void remove(Operation *op);
168 
169  /// Reverse the worklist.
170  void reverse();
171 
172 protected:
173  /// The worklist of operations.
174  std::vector<Operation *> list;
175 
176  /// A mapping of operations to positions in `list`.
178 };
179 
180 Worklist::Worklist() { list.reserve(64); }
181 
182 void Worklist::clear() {
183  list.clear();
184  map.clear();
185 }
186 
187 bool Worklist::empty() const {
188  // Skip all nullptr.
189  return !llvm::any_of(list,
190  [](Operation *op) { return static_cast<bool>(op); });
191 }
192 
193 void Worklist::push(Operation *op) {
194  assert(op && "cannot push nullptr to worklist");
195  // Check to see if the worklist already contains this op.
196  if (map.count(op))
197  return;
198  map[op] = list.size();
199  list.push_back(op);
200 }
201 
202 Operation *Worklist::pop() {
203  assert(!empty() && "cannot pop from empty worklist");
204  // Skip and remove all trailing nullptr.
205  while (!list.back())
206  list.pop_back();
207  Operation *op = list.back();
208  list.pop_back();
209  map.erase(op);
210  // Cleanup: Remove all trailing nullptr.
211  while (!list.empty() && !list.back())
212  list.pop_back();
213  return op;
214 }
215 
216 void Worklist::remove(Operation *op) {
217  assert(op && "cannot remove nullptr from worklist");
218  auto it = map.find(op);
219  if (it != map.end()) {
220  assert(list[it->second] == op && "malformed worklist data structure");
221  list[it->second] = nullptr;
222  map.erase(it);
223  }
224 }
225 
226 void Worklist::reverse() {
227  std::reverse(list.begin(), list.end());
228  for (size_t i = 0, e = list.size(); i != e; ++i)
229  map[list[i]] = i;
230 }
231 
232 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
233 /// A worklist that pops elements at a random position. This worklist is for
234 /// testing/debugging purposes only. It can be used to ensure that lowering
235 /// pipelines work correctly regardless of the order in which ops are processed
236 /// by the GreedyPatternRewriteDriver.
237 class RandomizedWorklist : public Worklist {
238 public:
239  RandomizedWorklist() : Worklist() {
240  generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
241  }
242 
243  /// Pop a random non-empty op from the worklist.
244  Operation *pop() {
245  Operation *op = nullptr;
246  do {
247  assert(!list.empty() && "cannot pop from empty worklist");
248  int64_t pos = generator() % list.size();
249  op = list[pos];
250  list.erase(list.begin() + pos);
251  for (int64_t i = pos, e = list.size(); i < e; ++i)
252  map[list[i]] = i;
253  map.erase(op);
254  } while (!op);
255  return op;
256  }
257 
258 private:
259  std::minstd_rand0 generator;
260 };
261 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
262 
263 //===----------------------------------------------------------------------===//
264 // GreedyPatternRewriteDriver
265 //===----------------------------------------------------------------------===//
266 
267 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
268 /// applies the locally optimal patterns.
269 ///
270 /// This abstract class manages the worklist and contains helper methods for
271 /// rewriting ops on the worklist. Derived classes specify how ops are added
272 /// to the worklist in the beginning.
273 class GreedyPatternRewriteDriver : public PatternRewriter,
274  public RewriterBase::Listener {
275 protected:
276  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
277  const FrozenRewritePatternSet &patterns,
278  const GreedyRewriteConfig &config);
279 
280  /// Add the given operation to the worklist.
281  void addSingleOpToWorklist(Operation *op);
282 
283  /// Add the given operation and its ancestors to the worklist.
284  void addToWorklist(Operation *op);
285 
286  /// Notify the driver that the specified operation may have been modified
287  /// in-place. The operation is added to the worklist.
288  void notifyOperationModified(Operation *op) override;
289 
290  /// Notify the driver that the specified operation was inserted. Update the
291  /// worklist as needed: The operation is enqueued depending on scope and
292  /// strict mode.
293  void notifyOperationInserted(Operation *op) override;
294 
295  /// Notify the driver that the specified operation was removed. Update the
296  /// worklist as needed: The operation and its children are removed from the
297  /// worklist.
298  void notifyOperationRemoved(Operation *op) override;
299 
300  /// Notify the driver that the specified operation was replaced. Update the
301  /// worklist as needed: New users are added enqueued.
302  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
303 
304  /// Process ops until the worklist is empty or `config.maxNumRewrites` is
305  /// reached. Return `true` if any IR was changed.
306  bool processWorklist();
307 
308  /// The worklist for this transformation keeps track of the operations that
309  /// need to be (re)visited.
310 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311  RandomizedWorklist worklist;
312 #else
313  Worklist worklist;
314 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
315 
316  /// Non-pattern based folder for operations.
317  OperationFolder folder;
318 
319  /// Configuration information for how to simplify.
320  const GreedyRewriteConfig config;
321 
322  /// The list of ops we are restricting our rewrites to. These include the
323  /// supplied set of ops as well as new ops created while rewriting those ops
324  /// depending on `strictMode`. This set is not maintained when
325  /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
326  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
327 
328 private:
329  /// Look over the provided operands for any defining operations that should
330  /// be re-added to the worklist. This function should be called when an
331  /// operation is modified or removed, as it may trigger further
332  /// simplifications.
333  void addOperandsToWorklist(ValueRange operands);
334 
335  /// Notify the driver that the given block was created.
336  void notifyBlockCreated(Block *block) override;
337 
338  /// For debugging only: Notify the driver of a pattern match failure.
340  notifyMatchFailure(Location loc,
341  function_ref<void(Diagnostic &)> reasonCallback) override;
342 
343 #ifndef NDEBUG
344  /// A logger used to emit information during the application process.
345  llvm::ScopedPrinter logger{llvm::dbgs()};
346 #endif
347 
348  /// The low-level pattern applicator.
349  PatternApplicator matcher;
350 
351 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
352  DebugFingerPrints debugFingerPrints;
353 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
354 };
355 } // namespace
356 
357 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
358  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
359  const GreedyRewriteConfig &config)
360  : PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
361 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
362  // clang-format off
363  , debugFingerPrints(this)
364 // clang-format on
365 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
366 {
367  // Apply a simple cost model based solely on pattern benefit.
368  matcher.applyDefaultCostModel();
369 
370  // Set up listener.
371 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
372  // Send IR notifications to the debug handler. This handler will then forward
373  // all notifications to this GreedyPatternRewriteDriver.
374  setListener(&debugFingerPrints);
375 #else
376  setListener(this);
377 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
378 }
379 
380 bool GreedyPatternRewriteDriver::processWorklist() {
381 #ifndef NDEBUG
382  const char *logLineComment =
383  "//===-------------------------------------------===//\n";
384 
385  /// A utility function to log a process result for the given reason.
386  auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
387  logger.unindent();
388  logger.startLine() << "} -> " << result;
389  if (!msg.isTriviallyEmpty())
390  logger.getOStream() << " : " << msg;
391  logger.getOStream() << "\n";
392  };
393  auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
394  logResult(result, msg);
395  logger.startLine() << logLineComment;
396  };
397 #endif
398 
399  bool changed = false;
400  int64_t numRewrites = 0;
401  while (!worklist.empty() &&
402  (numRewrites < config.maxNumRewrites ||
403  config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
404  auto *op = worklist.pop();
405 
406  LLVM_DEBUG({
407  logger.getOStream() << "\n";
408  logger.startLine() << logLineComment;
409  logger.startLine() << "Processing operation : '" << op->getName() << "'("
410  << op << ") {\n";
411  logger.indent();
412 
413  // If the operation has no regions, just print it here.
414  if (op->getNumRegions() == 0) {
415  op->print(
416  logger.startLine(),
417  OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
418  logger.getOStream() << "\n\n";
419  }
420  });
421 
422  // If the operation is trivially dead - remove it.
423  if (isOpTriviallyDead(op)) {
424  eraseOp(op);
425  changed = true;
426 
427  LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
428  continue;
429  }
430 
431  // Try to fold this op.
432  if (succeeded(folder.tryToFold(op))) {
433  LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
434  changed = true;
435  continue;
436  }
437 
438  // Try to match one of the patterns. The rewriter is automatically
439  // notified of any necessary changes, so there is nothing else to do
440  // here.
441 #ifndef NDEBUG
442  auto canApply = [&](const Pattern &pattern) {
443  LLVM_DEBUG({
444  logger.getOStream() << "\n";
445  logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
446  << op->getName() << " -> (";
447  llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
448  logger.getOStream() << ")' {\n";
449  logger.indent();
450  });
451  return true;
452  };
453  auto onFailure = [&](const Pattern &pattern) {
454  LLVM_DEBUG(logResult("failure", "pattern failed to match"));
455  };
456  auto onSuccess = [&](const Pattern &pattern) {
457  LLVM_DEBUG(logResult("success", "pattern applied successfully"));
458  return success();
459  };
460 #else
461  function_ref<bool(const Pattern &)> canApply = {};
462  function_ref<void(const Pattern &)> onFailure = {};
463  function_ref<LogicalResult(const Pattern &)> onSuccess = {};
464 #endif
465 
466 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
467  debugFingerPrints.computeFingerPrints(
468  /*topLevel=*/config.scope ? config.scope->getParentOp() : op);
469  auto clearFingerprints =
470  llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
471 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
472 
473  LogicalResult matchResult =
474  matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
475 
476  if (succeeded(matchResult)) {
477  LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
478 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
479  debugFingerPrints.notifyRewriteSuccess();
480 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
481  changed = true;
482  ++numRewrites;
483  } else {
484  LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
485 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
486  debugFingerPrints.notifyRewriteFailure();
487 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
488  }
489  }
490 
491  return changed;
492 }
493 
494 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
495  assert(op && "expected valid op");
496  // Gather potential ancestors while looking for a "scope" parent region.
497  SmallVector<Operation *, 8> ancestors;
498  Region *region = nullptr;
499  do {
500  ancestors.push_back(op);
501  region = op->getParentRegion();
502  if (config.scope == region) {
503  // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
504  for (Operation *op : ancestors)
505  addSingleOpToWorklist(op);
506  return;
507  }
508  if (region == nullptr)
509  return;
510  } while ((op = region->getParentOp()));
511 }
512 
513 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
514  if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
515  strictModeFilteredOps.contains(op))
516  worklist.push(op);
517 }
518 
519 void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
520  if (config.listener)
521  config.listener->notifyBlockCreated(block);
522 }
523 
524 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
525  LLVM_DEBUG({
526  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
527  << ")\n";
528  });
529  if (config.listener)
530  config.listener->notifyOperationInserted(op);
531  if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
532  strictModeFilteredOps.insert(op);
533  addToWorklist(op);
534 }
535 
536 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
537  LLVM_DEBUG({
538  logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
539  << ")\n";
540  });
541  if (config.listener)
542  config.listener->notifyOperationModified(op);
543  addToWorklist(op);
544 }
545 
546 void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
547  for (Value operand : operands) {
548  // If the use count of this operand is now < 2, we re-add the defining
549  // operation to the worklist.
550  // TODO: This is based on the fact that zero use operations
551  // may be deleted, and that single use values often have more
552  // canonicalization opportunities.
553  if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
554  continue;
555  if (auto *defOp = operand.getDefiningOp())
556  addToWorklist(defOp);
557  }
558 }
559 
560 void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
561  LLVM_DEBUG({
562  logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
563  << ")\n";
564  });
565  if (config.listener)
566  config.listener->notifyOperationRemoved(op);
567 
568  addOperandsToWorklist(op->getOperands());
569  worklist.remove(op);
570  folder.notifyRemoval(op);
571 
572  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
573  strictModeFilteredOps.erase(op);
574 }
575 
576 void GreedyPatternRewriteDriver::notifyOperationReplaced(
577  Operation *op, ValueRange replacement) {
578  LLVM_DEBUG({
579  logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
580  << ")\n";
581  });
582  if (config.listener)
583  config.listener->notifyOperationReplaced(op, replacement);
584  for (auto result : op->getResults())
585  for (auto *user : result.getUsers())
586  addToWorklist(user);
587 }
588 
589 LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
590  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
591  LLVM_DEBUG({
592  Diagnostic diag(loc, DiagnosticSeverity::Remark);
593  reasonCallback(diag);
594  logger.startLine() << "** Failure : " << diag.str() << "\n";
595  });
596  if (config.listener)
597  return config.listener->notifyMatchFailure(loc, reasonCallback);
598  return failure();
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // RegionPatternRewriteDriver
603 //===----------------------------------------------------------------------===//
604 
605 namespace {
606 /// This driver simplfies all ops in a region.
607 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
608 public:
609  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
610  const FrozenRewritePatternSet &patterns,
611  const GreedyRewriteConfig &config,
612  Region &regions);
613 
614  /// Simplify ops inside `region` and simplify the region itself. Return
615  /// success if the transformation converged.
616  LogicalResult simplify(bool *changed) &&;
617 
618 private:
619  /// The region that is simplified.
620  Region &region;
621 };
622 } // namespace
623 
624 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
625  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
626  const GreedyRewriteConfig &config, Region &region)
627  : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
628  // Populate strict mode ops.
629  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
630  region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
631  }
632 }
633 
634 namespace {
635 class GreedyPatternRewriteIteration
636  : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
637 public:
638  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
639  GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
640  : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
641  iteration(iteration) {}
642  static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
643  void print(raw_ostream &os) const override {
644  os << "GreedyPatternRewriteIteration(" << iteration << ")";
645  }
646 
647 private:
648  int64_t iteration = 0;
649 };
650 } // namespace
651 
652 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
653  auto insertKnownConstant = [&](Operation *op) {
654  // Check for existing constants when populating the worklist. This avoids
655  // accidentally reversing the constant order during processing.
656  Attribute constValue;
657  if (matchPattern(op, m_Constant(&constValue)))
658  if (!folder.insertKnownConstant(op, constValue))
659  return true;
660  return false;
661  };
662 
663  bool continueRewrites = false;
664  int64_t iteration = 0;
665  MLIRContext *ctx = getContext();
666  do {
667  // Check if the iteration limit was reached.
668  if (++iteration > config.maxIterations &&
669  config.maxIterations != GreedyRewriteConfig::kNoLimit)
670  break;
671 
672  worklist.clear();
673 
674  if (!config.useTopDownTraversal) {
675  // Add operations to the worklist in postorder.
676  region.walk([&](Operation *op) {
677  if (!insertKnownConstant(op))
678  addToWorklist(op);
679  });
680  } else {
681  // Add all nested operations to the worklist in preorder.
682  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
683  if (!insertKnownConstant(op)) {
684  addToWorklist(op);
685  return WalkResult::advance();
686  }
687  return WalkResult::skip();
688  });
689 
690  // Reverse the list so our pop-back loop processes them in-order.
691  worklist.reverse();
692  }
693 
694  ctx->executeAction<GreedyPatternRewriteIteration>(
695  [&] {
696  continueRewrites = processWorklist();
697 
698  // After applying patterns, make sure that the CFG of each of the
699  // regions is kept up to date.
700  if (config.enableRegionSimplification)
701  continueRewrites |= succeeded(simplifyRegions(*this, region));
702  },
703  {&region}, iteration);
704  } while (continueRewrites);
705 
706  if (changed)
707  *changed = iteration > 1;
708 
709  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
710  return success(!continueRewrites);
711 }
712 
715  const FrozenRewritePatternSet &patterns,
716  GreedyRewriteConfig config, bool *changed) {
717  // The top-level operation must be known to be isolated from above to
718  // prevent performing canonicalizations on operations defined at or above
719  // the region containing 'op'.
720  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
721  "patterns can only be applied to operations IsolatedFromAbove");
722 
723  // Set scope if not specified.
724  if (!config.scope)
725  config.scope = &region;
726 
727  // Start the pattern driver.
728  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
729  region);
730  LogicalResult converged = std::move(driver).simplify(changed);
731  LLVM_DEBUG(if (failed(converged)) {
732  llvm::dbgs() << "The pattern rewrite did not converge after scanning "
733  << config.maxIterations << " times\n";
734  });
735  return converged;
736 }
737 
738 //===----------------------------------------------------------------------===//
739 // MultiOpPatternRewriteDriver
740 //===----------------------------------------------------------------------===//
741 
742 namespace {
743 /// This driver simplfies a list of ops.
744 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
745 public:
746  explicit MultiOpPatternRewriteDriver(
747  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
748  const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
749  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
750 
751  /// Simplify `ops`. Return `success` if the transformation converged.
752  LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
753 
754 private:
755  void notifyOperationRemoved(Operation *op) override {
756  GreedyPatternRewriteDriver::notifyOperationRemoved(op);
757  if (survivingOps)
758  survivingOps->erase(op);
759  }
760 
761  /// An optional set of ops that survived the rewrite. This set is populated
762  /// at the beginning of `simplifyLocally` with the inititally provided list
763  /// of ops.
764  llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
765 };
766 } // namespace
767 
768 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
769  MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
770  const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
771  llvm::SmallDenseSet<Operation *, 4> *survivingOps)
772  : GreedyPatternRewriteDriver(ctx, patterns, config),
773  survivingOps(survivingOps) {
775  strictModeFilteredOps.insert(ops.begin(), ops.end());
776 
777  if (survivingOps) {
778  survivingOps->clear();
779  survivingOps->insert(ops.begin(), ops.end());
780  }
781 }
782 
783 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
784  bool *changed) && {
785  // Populate the initial worklist.
786  for (Operation *op : ops)
787  addSingleOpToWorklist(op);
788 
789  // Process ops on the worklist.
790  bool result = processWorklist();
791  if (changed)
792  *changed = result;
793 
794  return success(worklist.empty());
795 }
796 
797 /// Find the region that is the closest common ancestor of all given ops.
798 ///
799 /// Note: This function returns `nullptr` if there is a top-level op among the
800 /// given list of ops.
802  assert(!ops.empty() && "expected at least one op");
803  // Fast path in case there is only one op.
804  if (ops.size() == 1)
805  return ops.front()->getParentRegion();
806 
807  Region *region = ops.front()->getParentRegion();
808  ops = ops.drop_front();
809  int sz = ops.size();
810  llvm::BitVector remainingOps(sz, true);
811  while (region) {
812  int pos = -1;
813  // Iterate over all remaining ops.
814  while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
815  // Is this op contained in `region`?
816  if (region->findAncestorOpInRegion(*ops[pos]))
817  remainingOps.reset(pos);
818  }
819  if (remainingOps.none())
820  break;
821  region = region->getParentRegion();
822  }
823  return region;
824 }
825 
827  ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
828  GreedyRewriteConfig config, bool *changed, bool *allErased) {
829  if (ops.empty()) {
830  if (changed)
831  *changed = false;
832  if (allErased)
833  *allErased = true;
834  return success();
835  }
836 
837  // Determine scope of rewrite.
838  if (!config.scope) {
839  // Compute scope if none was provided. The scope will remain `nullptr` if
840  // there is a top-level op among `ops`.
841  config.scope = findCommonAncestor(ops);
842  } else {
843  // If a scope was provided, make sure that all ops are in scope.
844 #ifndef NDEBUG
845  bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
846  return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
847  });
848  assert(allOpsInScope && "ops must be within the specified scope");
849 #endif // NDEBUG
850  }
851 
852  // Start the pattern driver.
853  llvm::SmallDenseSet<Operation *, 4> surviving;
854  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
855  config, ops,
856  allErased ? &surviving : nullptr);
857  LogicalResult converged = std::move(driver).simplify(ops, changed);
858  if (allErased)
859  *allErased = surviving.empty();
860  LLVM_DEBUG(if (failed(converged)) {
861  llvm::dbgs() << "The pattern rewrite did not converge after "
862  << config.maxNumRewrites << " rewrites";
863  });
864  return converged;
865 }
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
Region * scope
Only ops within the scope are added to the worklist.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
bool enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:259
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations.
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition: Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
std::enable_if_t< std::is_same< RetT, void >::value, RetT > walk(FnT &&callback)
Walk the operations in this region.
Definition: Region.h:279
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
CRTP Implementation of an action.
Definition: Action.h:77
virtual void print(raw_ostream &os) const
Definition: Action.h:51
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Applies the specified rewrite patterns on ops while also trying to fold these ops.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Run a set of structural simplifications over the given regions.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ AnyOp
No restrictions wrt. which ops are processed.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual void notifyBlockCreated(Block *block)
Notification handler for when a block is created using the builder.
Definition: Builders.h:294
virtual void notifyOperationInserted(Operation *op)
Notification handler for when an operation is inserted into the builder.
Definition: Builders.h:290
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:446
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:455
void notifyOperationInserted(Operation *op) override
Notification handler for when an operation is inserted into the builder.
Definition: PatternMatch.h:449
void notifyOperationRemoved(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:468
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:406
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that the specified operation is about to be replaced with another operation.
Definition: PatternMatch.h:414
virtual LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
Definition: PatternMatch.h:435
virtual void notifyOperationRemoved(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:427