15 #include "mlir/Config/mlir-config.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
37 #define DEBUG_TYPE "greedy-rewriter"
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60 void computeFingerPrints(
Operation *topLevel) {
61 this->topLevel = topLevel;
62 this->topLevelFingerPrint.emplace(topLevel);
64 fingerprints.try_emplace(op, op,
false);
71 topLevelFingerPrint.reset();
75 void notifyRewriteSuccess() {
80 if (failed(
verify(topLevel)))
81 llvm::report_fatal_error(
"IR failed to verify after pattern application");
85 if (*topLevelFingerPrint == afterFingerPrint) {
87 llvm::report_fatal_error(
88 "pattern returned success but IR did not change");
90 for (
const auto &it : fingerprints) {
92 if (it.first == topLevel)
103 llvm::report_fatal_error(
"operation finger print changed");
108 void notifyRewriteFailure() {
114 if (*topLevelFingerPrint != afterFingerPrint) {
116 llvm::report_fatal_error(
"pattern returned failure but IR did change");
120 void notifyFoldingSuccess() {
125 if (failed(
verify(topLevel)))
126 llvm::report_fatal_error(
"IR failed to verify after folding");
131 void invalidateFingerPrint(
Operation *op) { fingerprints.erase(op); }
133 void notifyBlockErased(
Block *block)
override {
144 void notifyOperationInserted(
Operation *op,
150 void notifyOperationModified(
Operation *op)
override {
152 invalidateFingerPrint(op);
155 void notifyOperationErased(
Operation *op)
override {
157 op->
walk([
this](
Operation *op) { invalidateFingerPrint(op); });
169 std::optional<OperationFingerPrint> topLevelFingerPrint;
181 static void logSuccessfulFolding(
Operation *op) {
182 llvm::dbgs() <<
"// *** IR Dump After Successful Folding ***\n";
184 llvm::dbgs() <<
"\n\n";
224 std::vector<Operation *> list;
230 Worklist::Worklist() { list.reserve(64); }
232 void Worklist::clear() {
237 bool Worklist::empty()
const {
239 return !llvm::any_of(list,
240 [](
Operation *op) {
return static_cast<bool>(op); });
244 assert(op &&
"cannot push nullptr to worklist");
246 if (!map.insert({op, list.size()}).second)
252 assert(!empty() &&
"cannot pop from empty worklist");
260 while (!list.empty() && !list.back())
266 assert(op &&
"cannot remove nullptr from worklist");
267 auto it = map.find(op);
268 if (it != map.end()) {
269 assert(list[it->second] == op &&
"malformed worklist data structure");
270 list[it->second] =
nullptr;
275 void Worklist::reverse() {
276 std::reverse(list.begin(), list.end());
277 for (
size_t i = 0, e = list.size(); i != e; ++i)
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
286 class RandomizedWorklist :
public Worklist {
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
296 assert(!list.empty() &&
"cannot pop from empty worklist");
299 list.
erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
324 explicit GreedyPatternRewriteDriver(
MLIRContext *ctx,
329 void addSingleOpToWorklist(
Operation *op);
336 void notifyOperationModified(
Operation *op)
override;
341 void notifyOperationInserted(
Operation *op,
347 void notifyOperationErased(
Operation *op)
override;
355 bool processWorklist();
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist;
376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
383 void addOperandsToWorklist(
Operation *op);
386 void notifyBlockInserted(
Block *block,
Region *previous,
390 void notifyBlockErased(
Block *block)
override;
399 llvm::ScopedPrinter logger{llvm::dbgs()};
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406 ExpensiveChecks expensiveChecks;
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
425 matcher.applyDefaultCostModel();
428 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
431 rewriter.setListener(&expensiveChecks);
433 rewriter.setListener(
this);
437 bool GreedyPatternRewriteDriver::processWorklist() {
439 const char *logLineComment =
440 "//===-------------------------------------------===//\n";
443 auto logResult = [&](StringRef result,
const llvm::Twine &msg = {}) {
445 logger.startLine() <<
"} -> " << result;
446 if (!msg.isTriviallyEmpty())
447 logger.getOStream() <<
" : " << msg;
448 logger.getOStream() <<
"\n";
450 auto logResultWithLine = [&](StringRef result,
const llvm::Twine &msg = {}) {
451 logResult(result, msg);
452 logger.startLine() << logLineComment;
457 int64_t numRewrites = 0;
458 while (!worklist.empty() &&
459 (numRewrites <
config.getMaxNumRewrites() ||
460 config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
461 auto *op = worklist.pop();
464 logger.getOStream() <<
"\n";
465 logger.startLine() << logLineComment;
466 logger.startLine() <<
"Processing operation : '" << op->getName() <<
"'("
471 if (op->getNumRegions() == 0) {
474 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
475 logger.getOStream() <<
"\n\n";
481 rewriter.eraseOp(op);
484 LLVM_DEBUG(logResultWithLine(
"success",
"operation is trivially dead"));
494 if (succeeded(op->fold(foldResults))) {
495 LLVM_DEBUG(logResultWithLine(
"success",
"operation was folded"));
499 if (foldResults.empty()) {
501 notifyOperationModified(op);
503 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
504 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
505 expensiveChecks.notifyFoldingSuccess();
511 assert(foldResults.size() == op->getNumResults() &&
512 "folder produced incorrect number of results");
514 rewriter.setInsertionPoint(op);
516 bool materializationSucceeded =
true;
517 for (
auto [ofr, resultType] :
518 llvm::zip_equal(foldResults, op->getResultTypes())) {
519 if (
auto value = dyn_cast<Value>(ofr)) {
520 assert(value.getType() == resultType &&
521 "folder produced value of incorrect type");
522 replacements.push_back(value);
527 rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
532 llvm::SmallDenseSet<Operation *> replacementOps;
533 for (
Value replacement : replacements) {
534 assert(replacement.use_empty() &&
535 "folder reused existing op for one result but constant "
536 "materialization failed for another result");
537 replacementOps.insert(replacement.getDefiningOp());
540 rewriter.eraseOp(op);
543 materializationSucceeded =
false;
548 "materializeConstant produced op that is not a ConstantLike");
550 "materializeConstant produced incorrect result type");
551 replacements.push_back(constOp->
getResult(0));
554 if (materializationSucceeded) {
555 rewriter.replaceOp(op, replacements);
557 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
558 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
559 expensiveChecks.notifyFoldingSuccess();
569 auto canApplyCallback = [&](
const Pattern &pattern) {
571 logger.getOStream() <<
"\n";
572 logger.startLine() <<
"* Pattern " << pattern.getDebugName() <<
" : '"
573 << op->getName() <<
" -> (";
574 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
575 logger.getOStream() <<
")' {\n";
579 listener->notifyPatternBegin(pattern, op);
583 auto onFailureCallback = [&](
const Pattern &pattern) {
584 LLVM_DEBUG(logResult(
"failure",
"pattern failed to match"));
586 listener->notifyPatternEnd(pattern, failure());
589 auto onSuccessCallback = [&](
const Pattern &pattern) {
590 LLVM_DEBUG(logResult(
"success",
"pattern applied successfully"));
592 listener->notifyPatternEnd(pattern, success());
600 if (!
config.getListener()) {
607 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
609 expensiveChecks.computeFingerPrints(
config.getScope()->getParentOp());
611 auto clearFingerprints =
612 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
615 LogicalResult matchResult =
616 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
618 if (succeeded(matchResult)) {
619 LLVM_DEBUG(logResultWithLine(
"success",
"at least one pattern matched"));
620 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
621 expensiveChecks.notifyRewriteSuccess();
626 LLVM_DEBUG(logResultWithLine(
"failure",
"all patterns failed to match"));
627 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
628 expensiveChecks.notifyRewriteFailure();
636 void GreedyPatternRewriteDriver::addToWorklist(
Operation *op) {
637 assert(op &&
"expected valid op");
642 ancestors.push_back(op);
644 if (
config.getScope() == region) {
647 addSingleOpToWorklist(op);
650 if (region ==
nullptr)
655 void GreedyPatternRewriteDriver::addSingleOpToWorklist(
Operation *op) {
656 if (
config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
657 strictModeFilteredOps.contains(op))
661 void GreedyPatternRewriteDriver::notifyBlockInserted(
664 listener->notifyBlockInserted(block, previous, previousIt);
667 void GreedyPatternRewriteDriver::notifyBlockErased(
Block *block) {
669 listener->notifyBlockErased(block);
672 void GreedyPatternRewriteDriver::notifyOperationInserted(
675 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
679 listener->notifyOperationInserted(op, previous);
680 if (
config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
681 strictModeFilteredOps.insert(op);
685 void GreedyPatternRewriteDriver::notifyOperationModified(
Operation *op) {
687 logger.startLine() <<
"** Modified: '" << op->
getName() <<
"'(" << op
691 listener->notifyOperationModified(op);
695 void GreedyPatternRewriteDriver::addOperandsToWorklist(
Operation *op) {
705 auto *defOp = operand.getDefiningOp();
710 bool hasMoreThanTwoUses =
false;
711 for (
auto user : operand.getUsers()) {
712 if (user == op || user == otherUser)
718 hasMoreThanTwoUses =
true;
721 if (hasMoreThanTwoUses)
724 addToWorklist(defOp);
728 void GreedyPatternRewriteDriver::notifyOperationErased(
Operation *op) {
730 logger.startLine() <<
"** Erase : '" << op->
getName() <<
"'(" << op
740 if (
Region *scope =
config.getScope(); scope->getParentOp() == op)
742 "scope region must not be erased during greedy pattern rewrite");
746 listener->notifyOperationErased(op);
748 addOperandsToWorklist(op);
751 if (
config.getStrictness() != GreedyRewriteStrictness::AnyOp)
752 strictModeFilteredOps.erase(op);
755 void GreedyPatternRewriteDriver::notifyOperationReplaced(
758 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
762 listener->notifyOperationReplaced(op, replacement);
765 void GreedyPatternRewriteDriver::notifyMatchFailure(
769 reasonCallback(
diag);
770 logger.startLine() <<
"** Match Failure : " <<
diag.str() <<
"\n";
773 listener->notifyMatchFailure(loc, reasonCallback);
782 class RegionPatternRewriteDriver :
public GreedyPatternRewriteDriver {
784 explicit RegionPatternRewriteDriver(
MLIRContext *ctx,
791 LogicalResult simplify(
bool *
changed) &&;
799 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
802 : GreedyPatternRewriteDriver(ctx,
patterns,
config), region(region) {
804 if (
config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
805 region.
walk([&](
Operation *op) { strictModeFilteredOps.insert(op); });
810 class GreedyPatternRewriteIteration
815 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
816 iteration(iteration) {}
817 static constexpr StringLiteral tag =
"GreedyPatternRewriteIteration";
818 void print(raw_ostream &os)
const override {
819 os <<
"GreedyPatternRewriteIteration(" << iteration <<
")";
823 int64_t iteration = 0;
827 LogicalResult RegionPatternRewriteDriver::simplify(
bool *
changed) && {
828 bool continueRewrites =
false;
829 int64_t iteration = 0;
833 if (++iteration >
config.getMaxIterations() &&
834 config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
843 auto insertKnownConstant = [&](
Operation *op) {
848 if (!folder.insertKnownConstant(op, constValue))
853 if (!
config.getUseTopDownTraversal()) {
856 if (!
config.isConstantCSEEnabled() || !insertKnownConstant(op))
862 if (!
config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
864 return WalkResult::advance();
866 return WalkResult::skip();
875 continueRewrites = processWorklist();
879 if (
config.getRegionSimplificationLevel() !=
880 GreedySimplifyRegionLevel::Disabled) {
883 config.getRegionSimplificationLevel() ==
884 GreedySimplifyRegionLevel::Aggressive));
887 {®ion}, iteration);
888 }
while (continueRewrites);
894 return success(!continueRewrites);
905 "patterns can only be applied to operations IsolatedFromAbove");
911 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
913 llvm::report_fatal_error(
914 "greedy pattern rewriter input IR failed to verify");
920 LogicalResult converged = std::move(driver).simplify(
changed);
921 LLVM_DEBUG(
if (failed(converged)) {
922 llvm::dbgs() <<
"The pattern rewrite did not converge after scanning "
923 <<
config.getMaxIterations() <<
" times\n";
934 class MultiOpPatternRewriteDriver :
public GreedyPatternRewriteDriver {
936 explicit MultiOpPatternRewriteDriver(
939 llvm::SmallDenseSet<Operation *, 4> *survivingOps =
nullptr);
945 void notifyOperationErased(
Operation *op)
override {
946 GreedyPatternRewriteDriver::notifyOperationErased(op);
948 survivingOps->erase(op);
954 llvm::SmallDenseSet<Operation *, 4> *
const survivingOps =
nullptr;
958 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
961 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
963 survivingOps(survivingOps) {
965 strictModeFilteredOps.insert_range(ops);
968 survivingOps->clear();
969 survivingOps->insert_range(ops);
977 addSingleOpToWorklist(op);
980 bool result = processWorklist();
984 return success(worklist.empty());
992 assert(!ops.empty() &&
"expected at least one op");
995 return ops.front()->getParentRegion();
997 Region *region = ops.front()->getParentRegion();
998 ops = ops.drop_front();
1000 llvm::BitVector remainingOps(sz,
true);
1004 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1007 remainingOps.reset(pos);
1009 if (remainingOps.none())
1028 if (!
config.getScope()) {
1035 bool allOpsInScope = llvm::all_of(ops, [&](
Operation *op) {
1038 assert(allOpsInScope &&
"ops must be within the specified scope");
1042 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1044 llvm::report_fatal_error(
1045 "greedy pattern rewriter input IR failed to verify");
1049 llvm::SmallDenseSet<Operation *, 4> surviving;
1050 MultiOpPatternRewriteDriver driver(ops.front()->getContext(),
patterns,
1052 allErased ? &surviving :
nullptr);
1053 LogicalResult converged = std::move(driver).simplify(ops,
changed);
1055 *allErased = surviving.empty();
1056 LLVM_DEBUG(
if (failed(converged)) {
1057 llvm::dbgs() <<
"The pattern rewrite did not converge after "
1058 <<
config.getMaxNumRewrites() <<
" rewrites";
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
BlockListType::iterator iterator
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
CRTP Implementation of an action.
virtual void print(raw_ostream &os) const
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
A listener that forwards all notifications to another listener.
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.