MLIR 22.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements mlir::applyPatternsGreedily.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "mlir/Config/mlir-config.h"
16#include "mlir/IR/Action.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Operation.h"
20#include "mlir/IR/Verifier.h"
25#include "llvm/ADT/BitVector.h"
26#include "llvm/ADT/DenseMap.h"
27#include "llvm/ADT/ScopeExit.h"
28#include "llvm/Support/DebugLog.h"
29#include "llvm/Support/ScopedPrinter.h"
30#include "llvm/Support/raw_ostream.h"
31
32#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
33#include <random>
34#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
35
36using namespace mlir;
37
38#define DEBUG_TYPE "greedy-rewriter"
39
40namespace {
41
42//===----------------------------------------------------------------------===//
43// Debugging Infrastructure
44//===----------------------------------------------------------------------===//
45
46#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
47/// A helper struct that performs various "expensive checks" to detect broken
48/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
49/// broken if:
50/// * IR does not verify after pattern application / folding.
51/// * Pattern returns "failure" but the IR has changed.
52/// * Pattern returns "success" but the IR has not changed.
53///
54/// This struct stores finger prints of ops to determine whether the IR has
55/// changed or not.
56struct ExpensiveChecks : public RewriterBase::ForwardingListener {
57 ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
58 : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
59
60 /// Compute finger prints of the given op and its nested ops.
61 void computeFingerPrints(Operation *topLevel) {
62 this->topLevel = topLevel;
63 this->topLevelFingerPrint.emplace(topLevel);
64 topLevel->walk([&](Operation *op) {
65 fingerprints.try_emplace(op, op, /*includeNested=*/false);
66 });
67 }
68
69 /// Clear all finger prints.
70 void clear() {
71 topLevel = nullptr;
72 topLevelFingerPrint.reset();
73 fingerprints.clear();
74 }
75
76 void notifyRewriteSuccess() {
77 if (!topLevel)
78 return;
79
80 // Make sure that the IR still verifies.
81 if (failed(verify(topLevel)))
82 llvm::report_fatal_error("IR failed to verify after pattern application");
83
84 // Pattern application success => IR must have changed.
85 OperationFingerPrint afterFingerPrint(topLevel);
86 if (*topLevelFingerPrint == afterFingerPrint) {
87 // Note: Run "mlir-opt -debug" to see which pattern is broken.
88 llvm::report_fatal_error(
89 "pattern returned success but IR did not change");
90 }
91 for (const auto &it : fingerprints) {
92 // Skip top-level op, its finger print is never invalidated.
93 if (it.first == topLevel)
94 continue;
95 // Note: Finger print computation may crash when an op was erased
96 // without notifying the rewriter. (Run with ASAN to see where the op was
97 // erased; the op was probably erased directly, bypassing the rewriter
98 // API.) Finger print computation does may not crash if a new op was
99 // created at the same memory location. (But then the finger print should
100 // have changed.)
101 if (it.second !=
102 OperationFingerPrint(it.first, /*includeNested=*/false)) {
103 // Note: Run "mlir-opt -debug" to see which pattern is broken.
104 llvm::report_fatal_error("operation finger print changed");
105 }
106 }
107 }
108
109 void notifyRewriteFailure() {
110 if (!topLevel)
111 return;
112
113 // Pattern application failure => IR must not have changed.
114 OperationFingerPrint afterFingerPrint(topLevel);
115 if (*topLevelFingerPrint != afterFingerPrint) {
116 // Note: Run "mlir-opt -debug" to see which pattern is broken.
117 llvm::report_fatal_error("pattern returned failure but IR did change");
118 }
119 }
120
121 void notifyFoldingSuccess() {
122 if (!topLevel)
123 return;
124
125 // Make sure that the IR still verifies.
126 if (failed(verify(topLevel)))
127 llvm::report_fatal_error("IR failed to verify after folding");
128 }
129
130protected:
131 /// Invalidate the finger print of the given op, i.e., remove it from the map.
132 void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
133
134 void notifyBlockErased(Block *block) override {
136
137 // The block structure (number of blocks, types of block arguments, etc.)
138 // is part of the fingerprint of the parent op.
139 // TODO: The parent op fingerprint should also be invalidated when modifying
140 // the block arguments of a block, but we do not have a
141 // `notifyBlockModified` callback yet.
142 invalidateFingerPrint(block->getParentOp());
143 }
144
146 OpBuilder::InsertPoint previous) override {
148 invalidateFingerPrint(op->getParentOp());
149 }
150
151 void notifyOperationModified(Operation *op) override {
153 invalidateFingerPrint(op);
154 }
155
156 void notifyOperationErased(Operation *op) override {
158 op->walk([this](Operation *op) { invalidateFingerPrint(op); });
159 }
160
161 /// Operation finger prints to detect invalid pattern API usage. IR is checked
162 /// against these finger prints after pattern application to detect cases
163 /// where IR was modified directly, bypassing the rewriter API.
165
166 /// Top-level operation of the current greedy rewrite.
167 Operation *topLevel = nullptr;
168
169 /// Finger print of the top-level operation.
170 std::optional<OperationFingerPrint> topLevelFingerPrint;
171};
172#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
173
174#ifndef NDEBUG
175static Operation *getDumpRootOp(Operation *op) {
176 // Dump the parent op so that materialized constants are visible. If the op
177 // is a top-level op, dump it directly.
178 if (Operation *parentOp = op->getParentOp())
179 return parentOp;
180 return op;
181}
182static void logSuccessfulFolding(Operation *op) {
183 LDBG() << "// *** IR Dump After Successful Folding ***\n"
184 << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs());
185}
186#endif // NDEBUG
187
188//===----------------------------------------------------------------------===//
189// Worklist
190//===----------------------------------------------------------------------===//
191
192/// A LIFO worklist of operations with efficient removal and set semantics.
193///
194/// This class maintains a vector of operations and a mapping of operations to
195/// positions in the vector, so that operations can be removed efficiently at
196/// random. When an operation is removed, it is replaced with nullptr. Such
197/// nullptr are skipped when pop'ing elements.
198class Worklist {
199public:
200 Worklist();
201
202 /// Clear the worklist.
203 void clear();
204
205 /// Return whether the worklist is empty.
206 bool empty() const;
207
208 /// Push an operation to the end of the worklist, unless the operation is
209 /// already on the worklist.
210 void push(Operation *op);
211
212 /// Pop the an operation from the end of the worklist. Only allowed on
213 /// non-empty worklists.
214 Operation *pop();
215
216 /// Remove an operation from the worklist.
217 void remove(Operation *op);
218
219 /// Reverse the worklist.
220 void reverse();
221
222protected:
223 /// The worklist of operations.
224 std::vector<Operation *> list;
225
226 /// A mapping of operations to positions in `list`.
228};
229
230Worklist::Worklist() { list.reserve(64); }
231
232void Worklist::clear() {
233 list.clear();
234 map.clear();
235}
236
237bool Worklist::empty() const {
238 // Skip all nullptr.
239 return !llvm::any_of(list,
240 [](Operation *op) { return static_cast<bool>(op); });
241}
242
243void Worklist::push(Operation *op) {
244 assert(op && "cannot push nullptr to worklist");
245 // Check to see if the worklist already contains this op.
246 if (!map.insert({op, list.size()}).second)
247 return;
248 list.push_back(op);
249}
250
251Operation *Worklist::pop() {
252 assert(!empty() && "cannot pop from empty worklist");
253 // Skip and remove all trailing nullptr.
254 while (!list.back())
255 list.pop_back();
256 Operation *op = list.back();
257 list.pop_back();
258 map.erase(op);
259 // Cleanup: Remove all trailing nullptr.
260 while (!list.empty() && !list.back())
261 list.pop_back();
262 return op;
263}
264
265void Worklist::remove(Operation *op) {
266 assert(op && "cannot remove nullptr from worklist");
267 auto it = map.find(op);
268 if (it != map.end()) {
269 assert(list[it->second] == op && "malformed worklist data structure");
270 list[it->second] = nullptr;
271 map.erase(it);
272 }
273}
274
275void Worklist::reverse() {
276 std::reverse(list.begin(), list.end());
277 for (size_t i = 0, e = list.size(); i != e; ++i)
278 map[list[i]] = i;
279}
280
281#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282/// A worklist that pops elements at a random position. This worklist is for
283/// testing/debugging purposes only. It can be used to ensure that lowering
284/// pipelines work correctly regardless of the order in which ops are processed
285/// by the GreedyPatternRewriteDriver.
286class RandomizedWorklist : public Worklist {
287public:
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290 }
291
292 /// Pop a random non-empty op from the worklist.
293 Operation *pop() {
294 Operation *op = nullptr;
295 do {
296 assert(!list.empty() && "cannot pop from empty worklist");
297 int64_t pos = generator() % list.size();
298 op = list[pos];
299 list.erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
301 map[list[i]] = i;
302 map.erase(op);
303 } while (!op);
304 return op;
305 }
306
307private:
308 std::minstd_rand0 generator;
309};
310#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311
312//===----------------------------------------------------------------------===//
313// GreedyPatternRewriteDriver
314//===----------------------------------------------------------------------===//
315
316/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317/// applies the locally optimal patterns.
318///
319/// This abstract class manages the worklist and contains helper methods for
320/// rewriting ops on the worklist. Derived classes specify how ops are added
321/// to the worklist in the beginning.
322class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323protected:
324 explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
327
328 /// Add the given operation to the worklist.
329 void addSingleOpToWorklist(Operation *op);
330
331 /// Add the given operation and its ancestors to the worklist.
332 void addToWorklist(Operation *op);
333
334 /// Notify the driver that the specified operation may have been modified
335 /// in-place. The operation is added to the worklist.
336 void notifyOperationModified(Operation *op) override;
337
338 /// Notify the driver that the specified operation was inserted. Update the
339 /// worklist as needed: The operation is enqueued depending on scope and
340 /// strict mode.
341 void notifyOperationInserted(Operation *op,
342 OpBuilder::InsertPoint previous) override;
343
344 /// Notify the driver that the specified operation was removed. Update the
345 /// worklist as needed: The operation and its children are removed from the
346 /// worklist.
347 void notifyOperationErased(Operation *op) override;
348
349 /// Notify the driver that the specified operation was replaced. Update the
350 /// worklist as needed: New users are added enqueued.
351 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352
353 /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354 /// reached. Return `true` if any IR was changed.
355 bool processWorklist();
356
357 /// The pattern rewriter that is used for making IR modifications and is
358 /// passed to rewrite patterns.
359 PatternRewriter rewriter;
360
361 /// The worklist for this transformation keeps track of the operations that
362 /// need to be (re)visited.
363#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist;
365#else
366 Worklist worklist;
367#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368
369 /// Configuration information for how to simplify.
371
372 /// The list of ops we are restricting our rewrites to. These include the
373 /// supplied set of ops as well as new ops created while rewriting those ops
374 /// depending on `strictMode`. This set is not maintained when
375 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377
378private:
379 /// Look over the provided operands for any defining operations that should
380 /// be re-added to the worklist. This function should be called when an
381 /// operation is modified or removed, as it may trigger further
382 /// simplifications.
383 void addOperandsToWorklist(Operation *op);
384
385 /// Notify the driver that the given block was inserted.
386 void notifyBlockInserted(Block *block, Region *previous,
387 Region::iterator previousIt) override;
388
389 /// Notify the driver that the given block is about to be removed.
390 void notifyBlockErased(Block *block) override;
391
392 /// For debugging only: Notify the driver of a pattern match failure.
393 void
394 notifyMatchFailure(Location loc,
395 function_ref<void(Diagnostic &)> reasonCallback) override;
396
397#ifndef NDEBUG
398 /// A raw output stream used to prefix the debug log.
399
400 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
401 llvm::dbgs()};
402 /// A logger used to emit information during the application process.
403 llvm::ScopedPrinter logger{os};
404#endif
405
406 /// The low-level pattern applicator.
407 PatternApplicator matcher;
408
409#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
410 ExpensiveChecks expensiveChecks;
411#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412};
413} // namespace
414
415GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
416 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
417 const GreedyRewriteConfig &config)
418 : rewriter(ctx), config(config), matcher(patterns)
419#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
420 // clang-format off
421 , expensiveChecks(
422 /*driver=*/this,
423 /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
424 : nullptr)
425// clang-format on
426#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
427{
428 // Apply a simple cost model based solely on pattern benefit.
429 matcher.applyDefaultCostModel();
430
431 // Set up listener.
432#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
433 // Send IR notifications to the debug handler. This handler will then forward
434 // all notifications to this GreedyPatternRewriteDriver.
435 rewriter.setListener(&expensiveChecks);
436#else
437 rewriter.setListener(this);
438#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
439}
440
441bool GreedyPatternRewriteDriver::processWorklist() {
442#ifndef NDEBUG
443 const char *logLineComment =
444 "//===-------------------------------------------===//\n";
445
446 /// A utility function to log a process result for the given reason.
447 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
448 logger.unindent();
449 logger.startLine() << "} -> " << result;
450 if (!msg.isTriviallyEmpty())
451 logger.getOStream() << " : " << msg;
452 logger.getOStream() << "\n";
453 };
454 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
455 logResult(result, msg);
456 logger.startLine() << logLineComment;
457 };
458#endif
459
460 bool changed = false;
461 int64_t numRewrites = 0;
462 while (!worklist.empty() &&
463 (numRewrites < config.getMaxNumRewrites() ||
464 config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
465 auto *op = worklist.pop();
466
467 LLVM_DEBUG({
468 logger.getOStream() << "\n";
469 logger.startLine() << logLineComment;
470 logger.startLine() << "Processing operation : '" << op->getName() << "'("
471 << op << ") {\n";
472 logger.indent();
473
474 // If the operation has no regions, just print it here.
475 if (op->getNumRegions() == 0) {
476 op->print(
477 logger.startLine(),
478 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
479 logger.getOStream() << "\n\n";
480 }
481 });
482
483 // If the operation is trivially dead - remove it.
484 if (isOpTriviallyDead(op)) {
485 rewriter.eraseOp(op);
486 changed = true;
487
488 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
489 continue;
490 }
491
492 // Try to fold this op. Do not fold constant ops. That would lead to an
493 // infinite folding loop, as every constant op would be folded to an
494 // Attribute and then immediately be rematerialized as a constant op, which
495 // is then put on the worklist.
496 if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
497 SmallVector<OpFoldResult> foldResults;
498 if (succeeded(op->fold(foldResults))) {
499 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
500#ifndef NDEBUG
501 Operation *dumpRootOp = getDumpRootOp(op);
502#endif // NDEBUG
503 if (foldResults.empty()) {
504 // Op was modified in-place.
505 notifyOperationModified(op);
506 changed = true;
507 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
508#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
509 expensiveChecks.notifyFoldingSuccess();
510#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
511 continue;
512 }
513
514 // Op results can be replaced with `foldResults`.
515 assert(foldResults.size() == op->getNumResults() &&
516 "folder produced incorrect number of results");
517 OpBuilder::InsertionGuard g(rewriter);
518 rewriter.setInsertionPoint(op);
519 SmallVector<Value> replacements;
520 bool materializationSucceeded = true;
521 for (auto [ofr, resultType] :
522 llvm::zip_equal(foldResults, op->getResultTypes())) {
523 if (auto value = dyn_cast<Value>(ofr)) {
524 assert(value.getType() == resultType &&
525 "folder produced value of incorrect type");
526 replacements.push_back(value);
527 continue;
528 }
529 // Materialize Attributes as SSA values.
530 Operation *constOp = op->getDialect()->materializeConstant(
531 rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
532
533 if (!constOp) {
534 // If materialization fails, cleanup any operations generated for
535 // the previous results.
536 llvm::SmallDenseSet<Operation *> replacementOps;
537 for (Value replacement : replacements) {
538 assert(replacement.use_empty() &&
539 "folder reused existing op for one result but constant "
540 "materialization failed for another result");
541 replacementOps.insert(replacement.getDefiningOp());
542 }
543 for (Operation *op : replacementOps) {
544 rewriter.eraseOp(op);
545 }
546
547 materializationSucceeded = false;
548 break;
549 }
550
551 assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
552 "materializeConstant produced op that is not a ConstantLike");
553 assert(constOp->getResultTypes()[0] == resultType &&
554 "materializeConstant produced incorrect result type");
555 replacements.push_back(constOp->getResult(0));
556 }
557
558 if (materializationSucceeded) {
559 rewriter.replaceOp(op, replacements);
560 changed = true;
561 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
562#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
563 expensiveChecks.notifyFoldingSuccess();
564#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
565 continue;
566 }
567 }
568 }
569
570 // Try to match one of the patterns. The rewriter is automatically
571 // notified of any necessary changes, so there is nothing else to do
572 // here.
573 auto canApplyCallback = [&](const Pattern &pattern) {
574 LLVM_DEBUG({
575 logger.getOStream() << "\n";
576 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
577 << op->getName() << " -> (";
578 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
579 logger.getOStream() << ")' {\n";
580 logger.indent();
581 });
582 if (RewriterBase::Listener *listener = config.getListener())
583 listener->notifyPatternBegin(pattern, op);
584 return true;
585 };
586 function_ref<bool(const Pattern &)> canApply = canApplyCallback;
587 auto onFailureCallback = [&](const Pattern &pattern) {
588 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
589 if (RewriterBase::Listener *listener = config.getListener())
590 listener->notifyPatternEnd(pattern, failure());
591 };
592 function_ref<void(const Pattern &)> onFailure = onFailureCallback;
593 auto onSuccessCallback = [&](const Pattern &pattern) {
594 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
595 if (RewriterBase::Listener *listener = config.getListener())
596 listener->notifyPatternEnd(pattern, success());
597 return success();
598 };
599 function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
600
601#ifdef NDEBUG
602 // Optimization: PatternApplicator callbacks are not needed when running in
603 // optimized mode and without a listener.
604 if (!config.getListener()) {
605 canApply = nullptr;
606 onFailure = nullptr;
607 onSuccess = nullptr;
608 }
609#endif // NDEBUG
610
611#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
612 if (config.getScope()) {
613 expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
614 }
615 auto clearFingerprints =
616 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
617#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
618
619 LogicalResult matchResult =
620 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
621
622 if (succeeded(matchResult)) {
623 LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched"));
624#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
625 expensiveChecks.notifyRewriteSuccess();
626#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627 changed = true;
628 ++numRewrites;
629 } else {
630 LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match"));
631#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
632 expensiveChecks.notifyRewriteFailure();
633#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
634 }
635 }
636
637 return changed;
638}
639
640void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
641 assert(op && "expected valid op");
642 // Gather potential ancestors while looking for a "scope" parent region.
643 SmallVector<Operation *, 8> ancestors;
644 Region *region = nullptr;
645 do {
646 ancestors.push_back(op);
647 region = op->getParentRegion();
648 if (config.getScope() == region) {
649 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
650 for (Operation *op : ancestors)
651 addSingleOpToWorklist(op);
652 return;
653 }
654 if (region == nullptr)
655 return;
656 } while ((op = region->getParentOp()));
657}
658
659void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
660 if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
661 strictModeFilteredOps.contains(op))
662 worklist.push(op);
663}
664
665void GreedyPatternRewriteDriver::notifyBlockInserted(
666 Block *block, Region *previous, Region::iterator previousIt) {
667 if (RewriterBase::Listener *listener = config.getListener())
668 listener->notifyBlockInserted(block, previous, previousIt);
669}
670
671void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
672 if (RewriterBase::Listener *listener = config.getListener())
673 listener->notifyBlockErased(block);
674}
675
676void GreedyPatternRewriteDriver::notifyOperationInserted(
677 Operation *op, OpBuilder::InsertPoint previous) {
678 LLVM_DEBUG({
679 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
680 << ")\n";
681 });
682 if (RewriterBase::Listener *listener = config.getListener())
683 listener->notifyOperationInserted(op, previous);
684 if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
685 strictModeFilteredOps.insert(op);
686 addToWorklist(op);
687}
688
689void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
690 LLVM_DEBUG({
691 logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
692 << ")\n";
693 });
694 if (RewriterBase::Listener *listener = config.getListener())
695 listener->notifyOperationModified(op);
696 addToWorklist(op);
697}
698
699void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
700 for (Value operand : op->getOperands()) {
701 // If this operand currently has at most 2 users, add its defining op to the
702 // worklist. Indeed, after the op is deleted, then the operand will have at
703 // most 1 user left. If it has 0 users left, it can be deleted too,
704 // and if it has 1 user left, there may be further canonicalization
705 // opportunities.
706 if (!operand)
707 continue;
708
709 auto *defOp = operand.getDefiningOp();
710 if (!defOp)
711 continue;
712
713 Operation *otherUser = nullptr;
714 bool hasMoreThanTwoUses = false;
715 for (auto *user : operand.getUsers()) {
716 if (user == op || user == otherUser)
717 continue;
718 if (!otherUser) {
719 otherUser = user;
720 continue;
721 }
722 hasMoreThanTwoUses = true;
723 break;
724 }
725 if (hasMoreThanTwoUses)
726 continue;
727
728 addToWorklist(defOp);
729 }
730}
731
732void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
733 LLVM_DEBUG({
734 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
735 << ")\n";
736 });
737
738#ifndef NDEBUG
739 // Only ops that are within the configured scope are added to the worklist of
740 // the greedy pattern rewriter. Moreover, the parent op of the scope region is
741 // the part of the IR that is taken into account for the "expensive checks".
742 // A greedy pattern rewrite is not allowed to erase the parent op of the scope
743 // region, as that would break the worklist handling and the expensive checks.
744 if (Region *scope = config.getScope(); scope->getParentOp() == op)
745 llvm_unreachable(
746 "scope region must not be erased during greedy pattern rewrite");
747#endif // NDEBUG
748
749 if (RewriterBase::Listener *listener = config.getListener())
750 listener->notifyOperationErased(op);
751
752 addOperandsToWorklist(op);
753 worklist.remove(op);
754
755 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
756 strictModeFilteredOps.erase(op);
757}
758
759void GreedyPatternRewriteDriver::notifyOperationReplaced(
760 Operation *op, ValueRange replacement) {
761 LLVM_DEBUG({
762 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
763 << ")\n";
764 });
765 if (RewriterBase::Listener *listener = config.getListener())
766 listener->notifyOperationReplaced(op, replacement);
767}
768
769void GreedyPatternRewriteDriver::notifyMatchFailure(
770 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
771 LLVM_DEBUG({
772 Diagnostic diag(loc, DiagnosticSeverity::Remark);
773 reasonCallback(diag);
774 logger.startLine() << "** Match Failure : " << diag.str() << "\n";
775 });
776 if (RewriterBase::Listener *listener = config.getListener())
777 listener->notifyMatchFailure(loc, reasonCallback);
778}
779
780//===----------------------------------------------------------------------===//
781// RegionPatternRewriteDriver
782//===----------------------------------------------------------------------===//
783
784namespace {
785/// This driver simplfies all ops in a region.
786class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
787public:
788 explicit RegionPatternRewriteDriver(MLIRContext *ctx,
789 const FrozenRewritePatternSet &patterns,
790 const GreedyRewriteConfig &config,
791 Region &regions);
792
793 /// Simplify ops inside `region` and simplify the region itself. Return
794 /// success if the transformation converged.
795 LogicalResult simplify(bool *changed) &&;
796
797private:
798 /// The region that is simplified.
799 Region &region;
800};
801} // namespace
802
803RegionPatternRewriteDriver::RegionPatternRewriteDriver(
804 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
805 const GreedyRewriteConfig &config, Region &region)
806 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
807 // Populate strict mode ops.
808 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
809 region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
810 }
811}
812
813namespace {
814class GreedyPatternRewriteIteration
815 : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
816public:
817 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
818 GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
819 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
820 iteration(iteration) {}
821 static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
822 void print(raw_ostream &os) const override {
823 os << "GreedyPatternRewriteIteration(" << iteration << ")";
824 }
825
826private:
827 int64_t iteration = 0;
828};
829} // namespace
830
831LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
832 bool continueRewrites = false;
833 int64_t iteration = 0;
834 MLIRContext *ctx = rewriter.getContext();
835 do {
836 // Check if the iteration limit was reached.
837 if (++iteration > config.getMaxIterations() &&
838 config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
839 break;
840
841 // New iteration: start with an empty worklist.
842 worklist.clear();
843
844 // `OperationFolder` CSE's constant ops (and may move them into parents
845 // regions to enable more aggressive CSE'ing).
846 OperationFolder folder(ctx, this);
847 auto insertKnownConstant = [&](Operation *op) {
848 // Check for existing constants when populating the worklist. This avoids
849 // accidentally reversing the constant order during processing.
850 Attribute constValue;
851 if (matchPattern(op, m_Constant(&constValue)))
852 if (!folder.insertKnownConstant(op, constValue))
853 return true;
854 return false;
855 };
856
857 if (!config.getUseTopDownTraversal()) {
858 // Add operations to the worklist in postorder.
859 region.walk([&](Operation *op) {
860 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
861 addToWorklist(op);
862 });
863 } else {
864 // Add all nested operations to the worklist in preorder.
865 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
866 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
867 addToWorklist(op);
868 return WalkResult::advance();
869 }
870 return WalkResult::skip();
871 });
872
873 // Reverse the list so our pop-back loop processes them in-order.
874 worklist.reverse();
875 }
876
877 ctx->executeAction<GreedyPatternRewriteIteration>(
878 [&] {
879 continueRewrites = false;
880
881 // Erase unreachable blocks
882 // Operations like:
883 // %add = arith.addi %add, %add : i64
884 // are legal in unreachable code. Unfortunately many patterns would be
885 // unsafe to apply on such IR and can lead to crashes or infinite
886 // loops.
887 continueRewrites |=
888 succeeded(eraseUnreachableBlocks(rewriter, region));
889
890 continueRewrites |= processWorklist();
891
892 // After applying patterns, make sure that the CFG of each of the
893 // regions is kept up to date.
894 if (config.getRegionSimplificationLevel() !=
896 continueRewrites |= succeeded(simplifyRegions(
897 rewriter, region,
898 /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
900 }
901 },
902 {&region}, iteration);
903 } while (continueRewrites);
904
905 if (changed)
906 *changed = iteration > 1;
907
908 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
909 return success(!continueRewrites);
910}
911
912LogicalResult
916 // The top-level operation must be known to be isolated from above to
917 // prevent performing canonicalizations on operations defined at or above
918 // the region containing 'op'.
920 "patterns can only be applied to operations IsolatedFromAbove");
921
922 // Set scope if not specified.
923 if (!config.getScope())
924 config.setScope(&region);
925
926#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
927 if (failed(verify(config.getScope()->getParentOp())))
928 llvm::report_fatal_error(
929 "greedy pattern rewriter input IR failed to verify");
930#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
931
932 // Start the pattern driver.
933 RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
934 region);
935 LogicalResult converged = std::move(driver).simplify(changed);
936 if (failed(converged))
937 LDBG() << "The pattern rewrite did not converge after scanning "
938 << config.getMaxIterations() << " times";
939 return converged;
940}
941
942//===----------------------------------------------------------------------===//
943// MultiOpPatternRewriteDriver
944//===----------------------------------------------------------------------===//
945
946namespace {
947/// This driver simplfies a list of ops.
948class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
949public:
950 explicit MultiOpPatternRewriteDriver(
953 llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
954
955 /// Simplify `ops`. Return `success` if the transformation converged.
956 LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
957
958private:
959 void notifyOperationErased(Operation *op) override {
960 GreedyPatternRewriteDriver::notifyOperationErased(op);
961 if (survivingOps)
962 survivingOps->erase(op);
963 }
964
965 /// An optional set of ops that survived the rewrite. This set is populated
966 /// at the beginning of `simplifyLocally` with the inititally provided list
967 /// of ops.
968 llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
969};
970} // namespace
971
972MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
975 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
976 : GreedyPatternRewriteDriver(ctx, patterns, config),
977 survivingOps(survivingOps) {
978 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
979 strictModeFilteredOps.insert_range(ops);
980
981 if (survivingOps) {
982 survivingOps->clear();
983 survivingOps->insert_range(ops);
984 }
985}
986
987LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
988 bool *changed) && {
989 // Populate the initial worklist.
990 for (Operation *op : ops)
991 addSingleOpToWorklist(op);
992
993 // Process ops on the worklist.
994 bool result = processWorklist();
995 if (changed)
996 *changed = result;
997
998 return success(worklist.empty());
999}
1000
1001/// Find the region that is the closest common ancestor of all given ops.
1002///
1003/// Note: This function returns `nullptr` if there is a top-level op among the
1004/// given list of ops.
1006 assert(!ops.empty() && "expected at least one op");
1007 // Fast path in case there is only one op.
1008 if (ops.size() == 1)
1009 return ops.front()->getParentRegion();
1010
1011 Region *region = ops.front()->getParentRegion();
1012 ops = ops.drop_front();
1013 int sz = ops.size();
1014 llvm::BitVector remainingOps(sz, true);
1015 while (region) {
1016 int pos = -1;
1017 // Iterate over all remaining ops.
1018 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1019 // Is this op contained in `region`?
1020 if (region->findAncestorOpInRegion(*ops[pos]))
1021 remainingOps.reset(pos);
1022 }
1023 if (remainingOps.none())
1024 break;
1025 region = region->getParentRegion();
1026 }
1027 return region;
1028}
1029
1032 GreedyRewriteConfig config, bool *changed, bool *allErased) {
1033 if (ops.empty()) {
1034 if (changed)
1035 *changed = false;
1036 if (allErased)
1037 *allErased = true;
1038 return success();
1039 }
1040
1041 // Determine scope of rewrite.
1042 if (!config.getScope()) {
1043 // Compute scope if none was provided. The scope will remain `nullptr` if
1044 // there is a top-level op among `ops`.
1045 config.setScope(findCommonAncestor(ops));
1046 } else {
1047 // If a scope was provided, make sure that all ops are in scope.
1048#ifndef NDEBUG
1049 bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1050 return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
1051 });
1052 assert(allOpsInScope && "ops must be within the specified scope");
1053#endif // NDEBUG
1054 }
1055
1056#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1057 if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
1058 llvm::report_fatal_error(
1059 "greedy pattern rewriter input IR failed to verify");
1060#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1061
1062 // Start the pattern driver.
1063 llvm::SmallDenseSet<Operation *, 4> surviving;
1064 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1065 config, ops,
1066 allErased ? &surviving : nullptr);
1067 LogicalResult converged = std::move(driver).simplify(ops, changed);
1068 if (allErased)
1069 *allErased = surviving.empty();
1070 if (failed(converged))
1071 LDBG() << "The pattern rewrite did not converge after "
1072 << config.getMaxNumRewrites() << " rewrites";
1073 return converged;
1074}
return success()
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
#define DEBUG_TYPE
if(!isCopyOut)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
values clear()
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
This class represents a saved insertion point.
Definition Builders.h:327
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
MLIRContext * getContext()
Return the context this region is inserted in.
Definition Region.cpp:24
BlockListType::iterator iterator
Definition Region.h:52
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
CRTP Implementation of an action.
Definition Action.h:76
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
@ Aggressive
Run extra simplificiations (e.g.
@ Disabled
Disable region control-flow simplification.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
A listener that forwards all notifications to another listener.
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.