15 #include "mlir/Config/mlir-config.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/Support/DebugLog.h"
29 #include "llvm/Support/ScopedPrinter.h"
30 #include "llvm/Support/raw_ostream.h"
32 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
38 #define DEBUG_TYPE "greedy-rewriter"
46 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
61 void computeFingerPrints(
Operation *topLevel) {
62 this->topLevel = topLevel;
63 this->topLevelFingerPrint.emplace(topLevel);
65 fingerprints.try_emplace(op, op,
false);
72 topLevelFingerPrint.reset();
76 void notifyRewriteSuccess() {
82 llvm::report_fatal_error(
"IR failed to verify after pattern application");
86 if (*topLevelFingerPrint == afterFingerPrint) {
88 llvm::report_fatal_error(
89 "pattern returned success but IR did not change");
91 for (
const auto &it : fingerprints) {
93 if (it.first == topLevel)
104 llvm::report_fatal_error(
"operation finger print changed");
109 void notifyRewriteFailure() {
115 if (*topLevelFingerPrint != afterFingerPrint) {
117 llvm::report_fatal_error(
"pattern returned failure but IR did change");
121 void notifyFoldingSuccess() {
127 llvm::report_fatal_error(
"IR failed to verify after folding");
132 void invalidateFingerPrint(
Operation *op) { fingerprints.erase(op); }
134 void notifyBlockErased(
Block *block)
override {
145 void notifyOperationInserted(
Operation *op,
151 void notifyOperationModified(
Operation *op)
override {
153 invalidateFingerPrint(op);
156 void notifyOperationErased(
Operation *op)
override {
158 op->
walk([
this](
Operation *op) { invalidateFingerPrint(op); });
170 std::optional<OperationFingerPrint> topLevelFingerPrint;
182 static void logSuccessfulFolding(
Operation *op) {
183 LDBG() <<
"// *** IR Dump After Successful Folding ***\n"
224 std::vector<Operation *> list;
230 Worklist::Worklist() { list.reserve(64); }
232 void Worklist::clear() {
237 bool Worklist::empty()
const {
239 return !llvm::any_of(list,
240 [](
Operation *op) {
return static_cast<bool>(op); });
244 assert(op &&
"cannot push nullptr to worklist");
246 if (!map.insert({op, list.size()}).second)
252 assert(!empty() &&
"cannot pop from empty worklist");
260 while (!list.empty() && !list.back())
266 assert(op &&
"cannot remove nullptr from worklist");
267 auto it = map.find(op);
268 if (it != map.end()) {
269 assert(list[it->second] == op &&
"malformed worklist data structure");
270 list[it->second] =
nullptr;
275 void Worklist::reverse() {
276 std::reverse(list.begin(), list.end());
277 for (
size_t i = 0, e = list.size(); i != e; ++i)
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
286 class RandomizedWorklist :
public Worklist {
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
296 assert(!list.empty() &&
"cannot pop from empty worklist");
299 list.
erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
324 explicit GreedyPatternRewriteDriver(
MLIRContext *ctx,
329 void addSingleOpToWorklist(
Operation *op);
336 void notifyOperationModified(
Operation *op)
override;
341 void notifyOperationInserted(
Operation *op,
347 void notifyOperationErased(
Operation *op)
override;
355 bool processWorklist();
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist;
376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
383 void addOperandsToWorklist(
Operation *op);
386 void notifyBlockInserted(
Block *block,
Region *previous,
390 void notifyBlockErased(
Block *block)
override;
400 llvm::impl::raw_ldbg_ostream os{(Twine(
"[") +
DEBUG_TYPE +
":1] ").str(),
403 llvm::ScopedPrinter logger{os};
409 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
410 ExpensiveChecks expensiveChecks;
415 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
419 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
429 matcher.applyDefaultCostModel();
432 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
435 rewriter.setListener(&expensiveChecks);
437 rewriter.setListener(
this);
441 bool GreedyPatternRewriteDriver::processWorklist() {
443 const char *logLineComment =
444 "//===-------------------------------------------===//\n";
447 auto logResult = [&](StringRef result,
const llvm::Twine &msg = {}) {
449 logger.startLine() <<
"} -> " << result;
450 if (!msg.isTriviallyEmpty())
451 logger.getOStream() <<
" : " << msg;
452 logger.getOStream() <<
"\n";
454 auto logResultWithLine = [&](StringRef result,
const llvm::Twine &msg = {}) {
455 logResult(result, msg);
456 logger.startLine() << logLineComment;
461 int64_t numRewrites = 0;
462 while (!worklist.empty() &&
463 (numRewrites <
config.getMaxNumRewrites() ||
464 config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
465 auto *op = worklist.pop();
468 logger.getOStream() <<
"\n";
469 logger.startLine() << logLineComment;
470 logger.startLine() <<
"Processing operation : '" << op->getName() <<
"'("
475 if (op->getNumRegions() == 0) {
478 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
479 logger.getOStream() <<
"\n\n";
485 rewriter.eraseOp(op);
488 LLVM_DEBUG(logResultWithLine(
"success",
"operation is trivially dead"));
498 if (succeeded(op->fold(foldResults))) {
499 LLVM_DEBUG(logResultWithLine(
"success",
"operation was folded"));
503 if (foldResults.empty()) {
505 notifyOperationModified(op);
507 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
508 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
509 expensiveChecks.notifyFoldingSuccess();
515 assert(foldResults.size() == op->getNumResults() &&
516 "folder produced incorrect number of results");
518 rewriter.setInsertionPoint(op);
520 bool materializationSucceeded =
true;
521 for (
auto [ofr, resultType] :
522 llvm::zip_equal(foldResults, op->getResultTypes())) {
523 if (
auto value = dyn_cast<Value>(ofr)) {
524 assert(value.getType() == resultType &&
525 "folder produced value of incorrect type");
526 replacements.push_back(value);
531 rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
536 llvm::SmallDenseSet<Operation *> replacementOps;
537 for (
Value replacement : replacements) {
538 assert(replacement.use_empty() &&
539 "folder reused existing op for one result but constant "
540 "materialization failed for another result");
541 replacementOps.insert(replacement.getDefiningOp());
544 rewriter.eraseOp(op);
547 materializationSucceeded =
false;
552 "materializeConstant produced op that is not a ConstantLike");
554 "materializeConstant produced incorrect result type");
555 replacements.push_back(constOp->
getResult(0));
558 if (materializationSucceeded) {
559 rewriter.replaceOp(op, replacements);
561 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
562 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
563 expensiveChecks.notifyFoldingSuccess();
573 auto canApplyCallback = [&](
const Pattern &pattern) {
575 logger.getOStream() <<
"\n";
576 logger.startLine() <<
"* Pattern " << pattern.getDebugName() <<
" : '"
577 << op->getName() <<
" -> (";
578 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
579 logger.getOStream() <<
")' {\n";
583 listener->notifyPatternBegin(pattern, op);
587 auto onFailureCallback = [&](
const Pattern &pattern) {
588 LLVM_DEBUG(logResult(
"failure",
"pattern failed to match"));
590 listener->notifyPatternEnd(pattern, failure());
593 auto onSuccessCallback = [&](
const Pattern &pattern) {
594 LLVM_DEBUG(logResult(
"success",
"pattern applied successfully"));
596 listener->notifyPatternEnd(pattern, success());
604 if (!
config.getListener()) {
611 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
613 expensiveChecks.computeFingerPrints(
config.getScope()->getParentOp());
615 auto clearFingerprints =
616 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
619 LogicalResult matchResult =
620 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
622 if (succeeded(matchResult)) {
623 LLVM_DEBUG(logResultWithLine(
"success",
"at least one pattern matched"));
624 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
625 expensiveChecks.notifyRewriteSuccess();
630 LLVM_DEBUG(logResultWithLine(
"failure",
"all patterns failed to match"));
631 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
632 expensiveChecks.notifyRewriteFailure();
640 void GreedyPatternRewriteDriver::addToWorklist(
Operation *op) {
641 assert(op &&
"expected valid op");
646 ancestors.push_back(op);
648 if (
config.getScope() == region) {
651 addSingleOpToWorklist(op);
654 if (region ==
nullptr)
659 void GreedyPatternRewriteDriver::addSingleOpToWorklist(
Operation *op) {
660 if (
config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
661 strictModeFilteredOps.contains(op))
665 void GreedyPatternRewriteDriver::notifyBlockInserted(
668 listener->notifyBlockInserted(block, previous, previousIt);
671 void GreedyPatternRewriteDriver::notifyBlockErased(
Block *block) {
673 listener->notifyBlockErased(block);
676 void GreedyPatternRewriteDriver::notifyOperationInserted(
679 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
683 listener->notifyOperationInserted(op, previous);
684 if (
config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
685 strictModeFilteredOps.insert(op);
689 void GreedyPatternRewriteDriver::notifyOperationModified(
Operation *op) {
691 logger.startLine() <<
"** Modified: '" << op->
getName() <<
"'(" << op
695 listener->notifyOperationModified(op);
699 void GreedyPatternRewriteDriver::addOperandsToWorklist(
Operation *op) {
709 auto *defOp = operand.getDefiningOp();
714 bool hasMoreThanTwoUses =
false;
715 for (
auto user : operand.getUsers()) {
716 if (user == op || user == otherUser)
722 hasMoreThanTwoUses =
true;
725 if (hasMoreThanTwoUses)
728 addToWorklist(defOp);
732 void GreedyPatternRewriteDriver::notifyOperationErased(
Operation *op) {
734 logger.startLine() <<
"** Erase : '" << op->
getName() <<
"'(" << op
744 if (
Region *scope =
config.getScope(); scope->getParentOp() == op)
746 "scope region must not be erased during greedy pattern rewrite");
750 listener->notifyOperationErased(op);
752 addOperandsToWorklist(op);
755 if (
config.getStrictness() != GreedyRewriteStrictness::AnyOp)
756 strictModeFilteredOps.erase(op);
759 void GreedyPatternRewriteDriver::notifyOperationReplaced(
762 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
766 listener->notifyOperationReplaced(op, replacement);
769 void GreedyPatternRewriteDriver::notifyMatchFailure(
773 reasonCallback(
diag);
774 logger.startLine() <<
"** Match Failure : " <<
diag.str() <<
"\n";
777 listener->notifyMatchFailure(loc, reasonCallback);
786 class RegionPatternRewriteDriver :
public GreedyPatternRewriteDriver {
788 explicit RegionPatternRewriteDriver(
MLIRContext *ctx,
795 LogicalResult simplify(
bool *
changed) &&;
803 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
806 : GreedyPatternRewriteDriver(ctx,
patterns,
config), region(region) {
808 if (
config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
809 region.
walk([&](
Operation *op) { strictModeFilteredOps.insert(op); });
814 class GreedyPatternRewriteIteration
819 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
820 iteration(iteration) {}
821 static constexpr StringLiteral tag =
"GreedyPatternRewriteIteration";
822 void print(raw_ostream &os)
const override {
823 os <<
"GreedyPatternRewriteIteration(" << iteration <<
")";
827 int64_t iteration = 0;
831 LogicalResult RegionPatternRewriteDriver::simplify(
bool *
changed) && {
832 bool continueRewrites =
false;
833 int64_t iteration = 0;
837 if (++iteration >
config.getMaxIterations() &&
838 config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
847 auto insertKnownConstant = [&](
Operation *op) {
852 if (!folder.insertKnownConstant(op, constValue))
857 if (!
config.getUseTopDownTraversal()) {
860 if (!
config.isConstantCSEEnabled() || !insertKnownConstant(op))
866 if (!
config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
868 return WalkResult::advance();
870 return WalkResult::skip();
879 continueRewrites =
false;
890 continueRewrites |= processWorklist();
894 if (
config.getRegionSimplificationLevel() !=
895 GreedySimplifyRegionLevel::Disabled) {
898 config.getRegionSimplificationLevel() ==
899 GreedySimplifyRegionLevel::Aggressive));
902 {®ion}, iteration);
903 }
while (continueRewrites);
909 return success(!continueRewrites);
920 "patterns can only be applied to operations IsolatedFromAbove");
926 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
928 llvm::report_fatal_error(
929 "greedy pattern rewriter input IR failed to verify");
935 LogicalResult converged = std::move(driver).simplify(
changed);
937 LDBG() <<
"The pattern rewrite did not converge after scanning "
938 <<
config.getMaxIterations() <<
" times";
948 class MultiOpPatternRewriteDriver :
public GreedyPatternRewriteDriver {
950 explicit MultiOpPatternRewriteDriver(
953 llvm::SmallDenseSet<Operation *, 4> *survivingOps =
nullptr);
959 void notifyOperationErased(
Operation *op)
override {
960 GreedyPatternRewriteDriver::notifyOperationErased(op);
962 survivingOps->erase(op);
968 llvm::SmallDenseSet<Operation *, 4> *
const survivingOps =
nullptr;
972 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
975 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
977 survivingOps(survivingOps) {
979 strictModeFilteredOps.insert_range(ops);
982 survivingOps->clear();
983 survivingOps->insert_range(ops);
991 addSingleOpToWorklist(op);
994 bool result = processWorklist();
998 return success(worklist.empty());
1006 assert(!ops.empty() &&
"expected at least one op");
1008 if (ops.size() == 1)
1009 return ops.front()->getParentRegion();
1011 Region *region = ops.front()->getParentRegion();
1012 ops = ops.drop_front();
1013 int sz = ops.size();
1014 llvm::BitVector remainingOps(sz,
true);
1018 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1021 remainingOps.reset(pos);
1023 if (remainingOps.none())
1042 if (!
config.getScope()) {
1049 bool allOpsInScope = llvm::all_of(ops, [&](
Operation *op) {
1052 assert(allOpsInScope &&
"ops must be within the specified scope");
1056 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1058 llvm::report_fatal_error(
1059 "greedy pattern rewriter input IR failed to verify");
1063 llvm::SmallDenseSet<Operation *, 4> surviving;
1064 MultiOpPatternRewriteDriver driver(ops.front()->getContext(),
patterns,
1066 allErased ? &surviving :
nullptr);
1067 LogicalResult converged = std::move(driver).simplify(ops,
changed);
1069 *allErased = surviving.empty();
1071 LDBG() <<
"The pattern rewrite did not converge after "
1072 <<
config.getMaxNumRewrites() <<
" rewrites";
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
BlockListType::iterator iterator
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
CRTP Implementation of an action.
virtual void print(raw_ostream &os) const
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
A listener that forwards all notifications to another listener.
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.