MLIR 22.0.0git
GreedyPatternRewriteDriver.cpp
Go to the documentation of this file.
1//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements mlir::applyPatternsGreedily.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "mlir/Config/mlir-config.h"
16#include "mlir/IR/Action.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Operation.h"
20#include "mlir/IR/Verifier.h"
25#include "llvm/ADT/BitVector.h"
26#include "llvm/ADT/DenseMap.h"
27#include "llvm/ADT/ScopeExit.h"
28#include "llvm/Support/DebugLog.h"
29#include "llvm/Support/ScopedPrinter.h"
30#include "llvm/Support/raw_ostream.h"
31
32#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
33#include <random>
34#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
35
36using namespace mlir;
37
38#define DEBUG_TYPE "greedy-rewriter"
39
40namespace {
41
42//===----------------------------------------------------------------------===//
43// Debugging Infrastructure
44//===----------------------------------------------------------------------===//
45
46#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
47/// A helper struct that performs various "expensive checks" to detect broken
48/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
49/// broken if:
50/// * IR does not verify after pattern application / folding.
51/// * Pattern returns "failure" but the IR has changed.
52/// * Pattern returns "success" but the IR has not changed.
53///
54/// This struct stores finger prints of ops to determine whether the IR has
55/// changed or not.
56struct ExpensiveChecks : public RewriterBase::ForwardingListener {
57 ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
58 : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
59
60 /// Compute finger prints of the given op and its nested ops.
61 void computeFingerPrints(Operation *topLevel) {
62 this->topLevel = topLevel;
63 this->topLevelFingerPrint.emplace(topLevel);
64 topLevel->walk([&](Operation *op) {
65 fingerprints.try_emplace(op, op, /*includeNested=*/false);
66 });
67 }
68
69 /// Clear all finger prints.
70 void clear() {
71 topLevel = nullptr;
72 topLevelFingerPrint.reset();
73 fingerprints.clear();
74 }
75
76 void notifyRewriteSuccess() {
77 if (!topLevel)
78 return;
79
80 // Make sure that the IR still verifies.
81 if (failed(verify(topLevel)))
82 llvm::report_fatal_error("IR failed to verify after pattern application");
83
84 // Pattern application success => IR must have changed.
85 OperationFingerPrint afterFingerPrint(topLevel);
86 if (*topLevelFingerPrint == afterFingerPrint) {
87 // Note: Run "mlir-opt -debug" to see which pattern is broken.
88 llvm::report_fatal_error(
89 "pattern returned success but IR did not change");
90 }
91 for (const auto &it : fingerprints) {
92 // Skip top-level op, its finger print is never invalidated.
93 if (it.first == topLevel)
94 continue;
95 // Note: Finger print computation may crash when an op was erased
96 // without notifying the rewriter. (Run with ASAN to see where the op was
97 // erased; the op was probably erased directly, bypassing the rewriter
98 // API.) Finger print computation does may not crash if a new op was
99 // created at the same memory location. (But then the finger print should
100 // have changed.)
101 if (it.second !=
102 OperationFingerPrint(it.first, /*includeNested=*/false)) {
103 // Note: Run "mlir-opt -debug" to see which pattern is broken.
104 llvm::report_fatal_error("operation finger print changed");
105 }
106 }
107 }
108
109 void notifyRewriteFailure() {
110 if (!topLevel)
111 return;
112
113 // Pattern application failure => IR must not have changed.
114 OperationFingerPrint afterFingerPrint(topLevel);
115 if (*topLevelFingerPrint != afterFingerPrint) {
116 // Note: Run "mlir-opt -debug" to see which pattern is broken.
117 llvm::report_fatal_error("pattern returned failure but IR did change");
118 }
119 }
120
121 void notifyFoldingSuccess() {
122 if (!topLevel)
123 return;
124
125 // Make sure that the IR still verifies.
126 if (failed(verify(topLevel)))
127 llvm::report_fatal_error("IR failed to verify after folding");
128 }
129
130protected:
131 /// Invalidate the finger print of the given op, i.e., remove it from the map.
132 void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
133
134 void notifyBlockErased(Block *block) override {
136
137 // The block structure (number of blocks, types of block arguments, etc.)
138 // is part of the fingerprint of the parent op.
139 // TODO: The parent op fingerprint should also be invalidated when modifying
140 // the block arguments of a block, but we do not have a
141 // `notifyBlockModified` callback yet.
142 invalidateFingerPrint(block->getParentOp());
143 }
144
146 OpBuilder::InsertPoint previous) override {
148 invalidateFingerPrint(op->getParentOp());
149 }
150
151 void notifyOperationModified(Operation *op) override {
153 invalidateFingerPrint(op);
154 }
155
156 void notifyOperationErased(Operation *op) override {
158 op->walk([this](Operation *op) { invalidateFingerPrint(op); });
159 }
160
161 /// Operation finger prints to detect invalid pattern API usage. IR is checked
162 /// against these finger prints after pattern application to detect cases
163 /// where IR was modified directly, bypassing the rewriter API.
165
166 /// Top-level operation of the current greedy rewrite.
167 Operation *topLevel = nullptr;
168
169 /// Finger print of the top-level operation.
170 std::optional<OperationFingerPrint> topLevelFingerPrint;
171};
172#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
173
174#ifndef NDEBUG
175static Operation *getDumpRootOp(Operation *op) {
176 // Dump the parent op so that materialized constants are visible. If the op
177 // is a top-level op, dump it directly.
178 if (Operation *parentOp = op->getParentOp())
179 return parentOp;
180 return op;
181}
182static void logSuccessfulFolding(Operation *op) {
183 LDBG() << "// *** IR Dump After Successful Folding ***\n"
184 << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs());
185}
186#endif // NDEBUG
187
188//===----------------------------------------------------------------------===//
189// Worklist
190//===----------------------------------------------------------------------===//
191
192/// A LIFO worklist of operations with efficient removal and set semantics.
193///
194/// This class maintains a vector of operations and a mapping of operations to
195/// positions in the vector, so that operations can be removed efficiently at
196/// random. When an operation is removed, it is replaced with nullptr. Such
197/// nullptr are skipped when pop'ing elements.
198class Worklist {
199public:
200 Worklist();
201
202 /// Clear the worklist.
203 void clear();
204
205 /// Return whether the worklist is empty.
206 bool empty() const;
207
208 /// Push an operation to the end of the worklist, unless the operation is
209 /// already on the worklist.
210 void push(Operation *op);
211
212 /// Pop the an operation from the end of the worklist. Only allowed on
213 /// non-empty worklists.
214 Operation *pop();
215
216 /// Remove an operation from the worklist.
217 void remove(Operation *op);
218
219 /// Reverse the worklist.
220 void reverse();
221
222protected:
223 /// The worklist of operations.
224 std::vector<Operation *> list;
225
226 /// A mapping of operations to positions in `list`.
228};
229
230Worklist::Worklist() { list.reserve(64); }
231
232void Worklist::clear() {
233 list.clear();
234 map.clear();
235}
236
237bool Worklist::empty() const {
238 // Skip all nullptr.
239 return !llvm::any_of(list,
240 [](Operation *op) { return static_cast<bool>(op); });
241}
242
243void Worklist::push(Operation *op) {
244 assert(op && "cannot push nullptr to worklist");
245 // Check to see if the worklist already contains this op.
246 if (!map.insert({op, list.size()}).second)
247 return;
248 list.push_back(op);
249}
250
251Operation *Worklist::pop() {
252 assert(!empty() && "cannot pop from empty worklist");
253 // Skip and remove all trailing nullptr.
254 while (!list.back())
255 list.pop_back();
256 Operation *op = list.back();
257 list.pop_back();
258 map.erase(op);
259 // Cleanup: Remove all trailing nullptr.
260 while (!list.empty() && !list.back())
261 list.pop_back();
262 return op;
263}
264
265void Worklist::remove(Operation *op) {
266 assert(op && "cannot remove nullptr from worklist");
267 auto it = map.find(op);
268 if (it != map.end()) {
269 assert(list[it->second] == op && "malformed worklist data structure");
270 list[it->second] = nullptr;
271 map.erase(it);
272 }
273}
274
275void Worklist::reverse() {
276 std::reverse(list.begin(), list.end());
277 for (size_t i = 0, e = list.size(); i != e; ++i)
278 map[list[i]] = i;
279}
280
281#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282/// A worklist that pops elements at a random position. This worklist is for
283/// testing/debugging purposes only. It can be used to ensure that lowering
284/// pipelines work correctly regardless of the order in which ops are processed
285/// by the GreedyPatternRewriteDriver.
286class RandomizedWorklist : public Worklist {
287public:
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290 }
291
292 /// Pop a random non-empty op from the worklist.
293 Operation *pop() {
294 Operation *op = nullptr;
295 do {
296 assert(!list.empty() && "cannot pop from empty worklist");
297 int64_t pos = generator() % list.size();
298 op = list[pos];
299 list.erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
301 map[list[i]] = i;
302 map.erase(op);
303 } while (!op);
304 return op;
305 }
306
307private:
308 std::minstd_rand0 generator;
309};
310#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311
312//===----------------------------------------------------------------------===//
313// GreedyPatternRewriteDriver
314//===----------------------------------------------------------------------===//
315
316/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317/// applies the locally optimal patterns.
318///
319/// This abstract class manages the worklist and contains helper methods for
320/// rewriting ops on the worklist. Derived classes specify how ops are added
321/// to the worklist in the beginning.
322class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323protected:
324 explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
327
328 /// Add the given operation to the worklist.
329 void addSingleOpToWorklist(Operation *op);
330
331 /// Add the given operation and its ancestors to the worklist.
332 void addToWorklist(Operation *op);
333
334 /// Notify the driver that the specified operation may have been modified
335 /// in-place. The operation is added to the worklist.
336 void notifyOperationModified(Operation *op) override;
337
338 /// Notify the driver that the specified operation was inserted. Update the
339 /// worklist as needed: The operation is enqueued depending on scope and
340 /// strict mode.
341 void notifyOperationInserted(Operation *op,
342 OpBuilder::InsertPoint previous) override;
343
344 /// Notify the driver that the specified operation was removed. Update the
345 /// worklist as needed: The operation and its children are removed from the
346 /// worklist.
347 void notifyOperationErased(Operation *op) override;
348
349 /// Notify the driver that the specified operation was replaced. Update the
350 /// worklist as needed: New users are added enqueued.
351 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352
353 /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354 /// reached. Return `true` if any IR was changed.
355 bool processWorklist();
356
357 /// The pattern rewriter that is used for making IR modifications and is
358 /// passed to rewrite patterns.
359 PatternRewriter rewriter;
360
361 /// The worklist for this transformation keeps track of the operations that
362 /// need to be (re)visited.
363#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist;
365#else
366 Worklist worklist;
367#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368
369 /// Configuration information for how to simplify.
371
372 /// The list of ops we are restricting our rewrites to. These include the
373 /// supplied set of ops as well as new ops created while rewriting those ops
374 /// depending on `strictMode`. This set is not maintained when
375 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377
378private:
379 /// Look over the provided operands for any defining operations that should
380 /// be re-added to the worklist. This function should be called when an
381 /// operation is modified or removed, as it may trigger further
382 /// simplifications.
383 void addOperandsToWorklist(Operation *op);
384
385 /// Notify the driver that the given block was inserted.
386 void notifyBlockInserted(Block *block, Region *previous,
387 Region::iterator previousIt) override;
388
389 /// Notify the driver that the given block is about to be removed.
390 void notifyBlockErased(Block *block) override;
391
392 /// For debugging only: Notify the driver of a pattern match failure.
393 void
394 notifyMatchFailure(Location loc,
395 function_ref<void(Diagnostic &)> reasonCallback) override;
396
397#ifndef NDEBUG
398 /// A raw output stream used to prefix the debug log.
399
400 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
401 llvm::dbgs()};
402 /// A logger used to emit information during the application process.
403 llvm::ScopedPrinter logger{os};
404#endif
405
406 /// The low-level pattern applicator.
407 PatternApplicator matcher;
408
409#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
410 ExpensiveChecks expensiveChecks;
411#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412};
413} // namespace
414
415GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
416 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
417 const GreedyRewriteConfig &config)
418 : rewriter(ctx), config(config), matcher(patterns)
419#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
420 // clang-format off
421 , expensiveChecks(
422 /*driver=*/this,
423 /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
424 : nullptr)
425// clang-format on
426#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
427{
428 // Apply a simple cost model based solely on pattern benefit.
429 matcher.applyDefaultCostModel();
430
431 // Set up listener.
432#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
433 // Send IR notifications to the debug handler. This handler will then forward
434 // all notifications to this GreedyPatternRewriteDriver.
435 rewriter.setListener(&expensiveChecks);
436#else
437 rewriter.setListener(this);
438#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
439}
440
441bool GreedyPatternRewriteDriver::processWorklist() {
442#ifndef NDEBUG
443 const char *logLineComment =
444 "//===-------------------------------------------===//\n";
445
446 /// A utility function to log a process result for the given reason.
447 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
448 logger.unindent();
449 logger.startLine() << "} -> " << result;
450 if (!msg.isTriviallyEmpty())
451 logger.getOStream() << " : " << msg;
452 logger.getOStream() << "\n";
453 };
454 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
455 logResult(result, msg);
456 logger.startLine() << logLineComment;
457 };
458#endif
459
460 bool changed = false;
461 int64_t numRewrites = 0;
462 while (!worklist.empty() &&
463 (numRewrites < config.getMaxNumRewrites() ||
464 config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
465 auto *op = worklist.pop();
466
467 LLVM_DEBUG({
468 logger.getOStream() << "\n";
469 logger.startLine() << logLineComment;
470 logger.startLine() << "Processing operation : '" << op->getName() << "'("
471 << op << ") {\n";
472 logger.indent();
473
474 // If the operation has no regions, just print it here.
475 if (op->getNumRegions() == 0) {
476 op->print(
477 logger.startLine(),
478 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
479 logger.getOStream() << "\n\n";
480 }
481 });
482
483 // If the operation is trivially dead - remove it.
484 if (isOpTriviallyDead(op)) {
485 rewriter.eraseOp(op);
486 changed = true;
487
488 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
489 continue;
490 }
491
492 // Try to fold this op. Do not fold constant ops. That would lead to an
493 // infinite folding loop, as every constant op would be folded to an
494 // Attribute and then immediately be rematerialized as a constant op, which
495 // is then put on the worklist.
496 if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
497 SmallVector<OpFoldResult> foldResults;
498 if (succeeded(op->fold(foldResults))) {
499 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
500#ifndef NDEBUG
501 Operation *dumpRootOp = getDumpRootOp(op);
502#endif // NDEBUG
503 if (foldResults.empty()) {
504 // Op was modified in-place.
505 notifyOperationModified(op);
506 changed = true;
507 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
508#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
509 expensiveChecks.notifyFoldingSuccess();
510#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
511 continue;
512 }
513
514 // Op results can be replaced with `foldResults`.
515 assert(foldResults.size() == op->getNumResults() &&
516 "folder produced incorrect number of results");
517 OpBuilder::InsertionGuard g(rewriter);
518 rewriter.setInsertionPoint(op);
519 SmallVector<Value> replacements;
520 bool materializationSucceeded = true;
521 for (auto [ofr, resultType] :
522 llvm::zip_equal(foldResults, op->getResultTypes())) {
523 if (auto value = dyn_cast<Value>(ofr)) {
524 assert(value.getType() == resultType &&
525 "folder produced value of incorrect type");
526 replacements.push_back(value);
527 continue;
528 }
529 // Materialize Attributes as SSA values.
530 Operation *constOp = op->getDialect()->materializeConstant(
531 rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
532
533 if (!constOp) {
534 // If materialization fails, cleanup any operations generated for
535 // the previous results.
536 llvm::SmallDenseSet<Operation *> replacementOps;
537 for (Value replacement : replacements) {
538 assert(replacement.use_empty() &&
539 "folder reused existing op for one result but constant "
540 "materialization failed for another result");
541 replacementOps.insert(replacement.getDefiningOp());
542 }
543 for (Operation *op : replacementOps) {
544 rewriter.eraseOp(op);
545 }
546
547 materializationSucceeded = false;
548 break;
549 }
550
551 assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
552 "materializeConstant produced op that is not a ConstantLike");
553 assert(constOp->getResultTypes()[0] == resultType &&
554 "materializeConstant produced incorrect result type");
555 replacements.push_back(constOp->getResult(0));
556 }
557
558 if (materializationSucceeded) {
559 rewriter.replaceOp(op, replacements);
560 changed = true;
561 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
562#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
563 expensiveChecks.notifyFoldingSuccess();
564#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
565 continue;
566 }
567 }
568 }
569
570 // Try to match one of the patterns. The rewriter is automatically
571 // notified of any necessary changes, so there is nothing else to do
572 // here.
573 auto canApplyCallback = [&](const Pattern &pattern) {
574 LLVM_DEBUG({
575 logger.getOStream() << "\n";
576 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
577 << op->getName() << " -> (";
578 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
579 logger.getOStream() << ")' {\n";
580 logger.indent();
581 });
582 if (RewriterBase::Listener *listener = config.getListener())
583 listener->notifyPatternBegin(pattern, op);
584 return true;
585 };
586 function_ref<bool(const Pattern &)> canApply = canApplyCallback;
587 auto onFailureCallback = [&](const Pattern &pattern) {
588 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
589 if (RewriterBase::Listener *listener = config.getListener())
590 listener->notifyPatternEnd(pattern, failure());
591 };
592 function_ref<void(const Pattern &)> onFailure = onFailureCallback;
593 auto onSuccessCallback = [&](const Pattern &pattern) {
594 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
595 if (RewriterBase::Listener *listener = config.getListener())
596 listener->notifyPatternEnd(pattern, success());
597 return success();
598 };
599 function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
600
601#ifdef NDEBUG
602 // Optimization: PatternApplicator callbacks are not needed when running in
603 // optimized mode and without a listener.
604 if (!config.getListener()) {
605 canApply = nullptr;
606 onFailure = nullptr;
607 onSuccess = nullptr;
608 }
609#endif // NDEBUG
610
611#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
612 if (config.getScope()) {
613 expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
614 }
615 llvm::scope_exit clearFingerprints([&]() { expensiveChecks.clear(); });
616#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
617
618 LogicalResult matchResult =
619 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
620
621 if (succeeded(matchResult)) {
622 LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched"));
623#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
624 expensiveChecks.notifyRewriteSuccess();
625#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
626 changed = true;
627 ++numRewrites;
628 } else {
629 LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match"));
630#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
631 expensiveChecks.notifyRewriteFailure();
632#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
633 }
634 }
635
636 return changed;
637}
638
639void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
640 assert(op && "expected valid op");
641 // Gather potential ancestors while looking for a "scope" parent region.
642 SmallVector<Operation *, 8> ancestors;
643 Region *region = nullptr;
644 do {
645 ancestors.push_back(op);
646 region = op->getParentRegion();
647 if (config.getScope() == region) {
648 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
649 for (Operation *op : ancestors)
650 addSingleOpToWorklist(op);
651 return;
652 }
653 if (region == nullptr)
654 return;
655 } while ((op = region->getParentOp()));
656}
657
658void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
659 if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
660 strictModeFilteredOps.contains(op))
661 worklist.push(op);
662}
663
664void GreedyPatternRewriteDriver::notifyBlockInserted(
665 Block *block, Region *previous, Region::iterator previousIt) {
666 if (RewriterBase::Listener *listener = config.getListener())
667 listener->notifyBlockInserted(block, previous, previousIt);
668}
669
670void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
671 if (RewriterBase::Listener *listener = config.getListener())
672 listener->notifyBlockErased(block);
673}
674
675void GreedyPatternRewriteDriver::notifyOperationInserted(
676 Operation *op, OpBuilder::InsertPoint previous) {
677 LLVM_DEBUG({
678 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
679 << ")\n";
680 });
681 if (RewriterBase::Listener *listener = config.getListener())
682 listener->notifyOperationInserted(op, previous);
683 if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
684 strictModeFilteredOps.insert(op);
685 addToWorklist(op);
686}
687
688void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
689 LLVM_DEBUG({
690 logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
691 << ")\n";
692 });
693 if (RewriterBase::Listener *listener = config.getListener())
694 listener->notifyOperationModified(op);
695 addToWorklist(op);
696}
697
698void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
699 for (Value operand : op->getOperands()) {
700 // If this operand currently has at most 2 users, add its defining op to the
701 // worklist. Indeed, after the op is deleted, then the operand will have at
702 // most 1 user left. If it has 0 users left, it can be deleted too,
703 // and if it has 1 user left, there may be further canonicalization
704 // opportunities.
705 if (!operand)
706 continue;
707
708 auto *defOp = operand.getDefiningOp();
709 if (!defOp)
710 continue;
711
712 Operation *otherUser = nullptr;
713 bool hasMoreThanTwoUses = false;
714 for (auto *user : operand.getUsers()) {
715 if (user == op || user == otherUser)
716 continue;
717 if (!otherUser) {
718 otherUser = user;
719 continue;
720 }
721 hasMoreThanTwoUses = true;
722 break;
723 }
724 if (hasMoreThanTwoUses)
725 continue;
726
727 addToWorklist(defOp);
728 }
729}
730
731void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
732 LLVM_DEBUG({
733 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
734 << ")\n";
735 });
736
737#ifndef NDEBUG
738 // Only ops that are within the configured scope are added to the worklist of
739 // the greedy pattern rewriter. Moreover, the parent op of the scope region is
740 // the part of the IR that is taken into account for the "expensive checks".
741 // A greedy pattern rewrite is not allowed to erase the parent op of the scope
742 // region, as that would break the worklist handling and the expensive checks.
743 if (Region *scope = config.getScope(); scope->getParentOp() == op)
744 llvm_unreachable(
745 "scope region must not be erased during greedy pattern rewrite");
746#endif // NDEBUG
747
748 if (RewriterBase::Listener *listener = config.getListener())
749 listener->notifyOperationErased(op);
750
751 addOperandsToWorklist(op);
752 worklist.remove(op);
753
754 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
755 strictModeFilteredOps.erase(op);
756}
757
758void GreedyPatternRewriteDriver::notifyOperationReplaced(
759 Operation *op, ValueRange replacement) {
760 LLVM_DEBUG({
761 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
762 << ")\n";
763 });
764 if (RewriterBase::Listener *listener = config.getListener())
765 listener->notifyOperationReplaced(op, replacement);
766}
767
768void GreedyPatternRewriteDriver::notifyMatchFailure(
769 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
770 LLVM_DEBUG({
771 Diagnostic diag(loc, DiagnosticSeverity::Remark);
772 reasonCallback(diag);
773 logger.startLine() << "** Match Failure : " << diag.str() << "\n";
774 });
775 if (RewriterBase::Listener *listener = config.getListener())
776 listener->notifyMatchFailure(loc, reasonCallback);
777}
778
779//===----------------------------------------------------------------------===//
780// RegionPatternRewriteDriver
781//===----------------------------------------------------------------------===//
782
783namespace {
784/// This driver simplfies all ops in a region.
785class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
786public:
787 explicit RegionPatternRewriteDriver(MLIRContext *ctx,
788 const FrozenRewritePatternSet &patterns,
789 const GreedyRewriteConfig &config,
790 Region &regions);
791
792 /// Simplify ops inside `region` and simplify the region itself. Return
793 /// success if the transformation converged.
794 LogicalResult simplify(bool *changed) &&;
795
796private:
797 /// The region that is simplified.
798 Region &region;
799};
800} // namespace
801
802RegionPatternRewriteDriver::RegionPatternRewriteDriver(
803 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
804 const GreedyRewriteConfig &config, Region &region)
805 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
806 // Populate strict mode ops.
807 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
808 region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
809 }
810}
811
812namespace {
813class GreedyPatternRewriteIteration
814 : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
815public:
816 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
817 GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
818 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
819 iteration(iteration) {}
820 static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
821 void print(raw_ostream &os) const override {
822 os << "GreedyPatternRewriteIteration(" << iteration << ")";
823 }
824
825private:
826 int64_t iteration = 0;
827};
828} // namespace
829
830LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
831 bool continueRewrites = false;
832 int64_t iteration = 0;
833 MLIRContext *ctx = rewriter.getContext();
834 do {
835 // Check if the iteration limit was reached.
836 if (++iteration > config.getMaxIterations() &&
837 config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
838 break;
839
840 // New iteration: start with an empty worklist.
841 worklist.clear();
842
843 // `OperationFolder` CSE's constant ops (and may move them into parents
844 // regions to enable more aggressive CSE'ing).
845 OperationFolder folder(ctx, this);
846 auto insertKnownConstant = [&](Operation *op) {
847 // Check for existing constants when populating the worklist. This avoids
848 // accidentally reversing the constant order during processing.
849 Attribute constValue;
850 if (matchPattern(op, m_Constant(&constValue)))
851 if (!folder.insertKnownConstant(op, constValue))
852 return true;
853 return false;
854 };
855
856 if (!config.getUseTopDownTraversal()) {
857 // Add operations to the worklist in postorder.
858 region.walk([&](Operation *op) {
859 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
860 addToWorklist(op);
861 });
862 } else {
863 // Add all nested operations to the worklist in preorder.
864 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
865 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
866 addToWorklist(op);
867 return WalkResult::advance();
868 }
869 return WalkResult::skip();
870 });
871
872 // Reverse the list so our pop-back loop processes them in-order.
873 worklist.reverse();
874 }
875
876 ctx->executeAction<GreedyPatternRewriteIteration>(
877 [&] {
878 continueRewrites = false;
879
880 // Erase unreachable blocks
881 // Operations like:
882 // %add = arith.addi %add, %add : i64
883 // are legal in unreachable code. Unfortunately many patterns would be
884 // unsafe to apply on such IR and can lead to crashes or infinite
885 // loops.
886 continueRewrites |=
887 succeeded(eraseUnreachableBlocks(rewriter, region));
888
889 continueRewrites |= processWorklist();
890
891 // After applying patterns, make sure that the CFG of each of the
892 // regions is kept up to date.
893 if (config.getRegionSimplificationLevel() !=
895 continueRewrites |= succeeded(simplifyRegions(
896 rewriter, region,
897 /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
899 }
900 },
901 {&region}, iteration);
902 } while (continueRewrites);
903
904 if (changed)
905 *changed = iteration > 1;
906
907 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
908 return success(!continueRewrites);
909}
910
911LogicalResult
915 // The top-level operation must be known to be isolated from above to
916 // prevent performing canonicalizations on operations defined at or above
917 // the region containing 'op'.
919 "patterns can only be applied to operations IsolatedFromAbove");
920
921 // Set scope if not specified.
922 if (!config.getScope())
923 config.setScope(&region);
924
925#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
926 if (failed(verify(config.getScope()->getParentOp())))
927 llvm::report_fatal_error(
928 "greedy pattern rewriter input IR failed to verify");
929#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
930
931 // Start the pattern driver.
932 RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
933 region);
934 LogicalResult converged = std::move(driver).simplify(changed);
935 if (failed(converged))
936 LDBG() << "The pattern rewrite did not converge after scanning "
937 << config.getMaxIterations() << " times";
938 return converged;
939}
940
941//===----------------------------------------------------------------------===//
942// MultiOpPatternRewriteDriver
943//===----------------------------------------------------------------------===//
944
945namespace {
946/// This driver simplfies a list of ops.
947class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
948public:
949 explicit MultiOpPatternRewriteDriver(
952 llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
953
954 /// Simplify `ops`. Return `success` if the transformation converged.
955 LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
956
957private:
958 void notifyOperationErased(Operation *op) override {
959 GreedyPatternRewriteDriver::notifyOperationErased(op);
960 if (survivingOps)
961 survivingOps->erase(op);
962 }
963
964 /// An optional set of ops that survived the rewrite. This set is populated
965 /// at the beginning of `simplifyLocally` with the inititally provided list
966 /// of ops.
967 llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
968};
969} // namespace
970
971MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
974 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
975 : GreedyPatternRewriteDriver(ctx, patterns, config),
976 survivingOps(survivingOps) {
977 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
978 strictModeFilteredOps.insert_range(ops);
979
980 if (survivingOps) {
981 survivingOps->clear();
982 survivingOps->insert_range(ops);
983 }
984}
985
986LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
987 bool *changed) && {
988 // Populate the initial worklist.
989 for (Operation *op : ops)
990 addSingleOpToWorklist(op);
991
992 // Process ops on the worklist.
993 bool result = processWorklist();
994 if (changed)
995 *changed = result;
996
997 return success(worklist.empty());
998}
999
1000/// Find the region that is the closest common ancestor of all given ops.
1001///
1002/// Note: This function returns `nullptr` if there is a top-level op among the
1003/// given list of ops.
1005 assert(!ops.empty() && "expected at least one op");
1006 // Fast path in case there is only one op.
1007 if (ops.size() == 1)
1008 return ops.front()->getParentRegion();
1009
1010 Region *region = ops.front()->getParentRegion();
1011 ops = ops.drop_front();
1012 int sz = ops.size();
1013 llvm::BitVector remainingOps(sz, true);
1014 while (region) {
1015 int pos = -1;
1016 // Iterate over all remaining ops.
1017 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1018 // Is this op contained in `region`?
1019 if (region->findAncestorOpInRegion(*ops[pos]))
1020 remainingOps.reset(pos);
1021 }
1022 if (remainingOps.none())
1023 break;
1024 region = region->getParentRegion();
1025 }
1026 return region;
1027}
1028
1031 GreedyRewriteConfig config, bool *changed, bool *allErased) {
1032 if (ops.empty()) {
1033 if (changed)
1034 *changed = false;
1035 if (allErased)
1036 *allErased = true;
1037 return success();
1038 }
1039
1040 // Determine scope of rewrite.
1041 if (!config.getScope()) {
1042 // Compute scope if none was provided. The scope will remain `nullptr` if
1043 // there is a top-level op among `ops`.
1044 config.setScope(findCommonAncestor(ops));
1045 } else {
1046 // If a scope was provided, make sure that all ops are in scope.
1047#ifndef NDEBUG
1048 bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1049 return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
1050 });
1051 assert(allOpsInScope && "ops must be within the specified scope");
1052#endif // NDEBUG
1053 }
1054
1055#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1056 if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
1057 llvm::report_fatal_error(
1058 "greedy pattern rewriter input IR failed to verify");
1059#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1060
1061 // Start the pattern driver.
1062 llvm::SmallDenseSet<Operation *, 4> surviving;
1063 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1064 config, ops,
1065 allErased ? &surviving : nullptr);
1066 LogicalResult converged = std::move(driver).simplify(ops, changed);
1067 if (allErased)
1068 *allErased = surviving.empty();
1069 if (failed(converged))
1070 LDBG() << "The pattern rewrite did not converge after "
1071 << config.getMaxNumRewrites() << " rewrites";
1072 return converged;
1073}
return success()
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
#define DEBUG_TYPE
if(!isCopyOut)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
values clear()
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
This class represents a saved insertion point.
Definition Builders.h:327
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition FoldUtils.h:33
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition Region.cpp:168
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
MLIRContext * getContext()
Return the context this region is inserted in.
Definition Region.cpp:24
BlockListType::iterator iterator
Definition Region.h:52
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
CRTP Implementation of an action.
Definition Action.h:76
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
@ Aggressive
Run extra simplificiations (e.g.
@ Disabled
Disable region control-flow simplification.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
A listener that forwards all notifications to another listener.
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.