15 #include "mlir/Config/mlir-config.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
37 #define DEBUG_TYPE "greedy-rewriter"
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60 void computeFingerPrints(
Operation *topLevel) {
61 this->topLevel = topLevel;
62 this->topLevelFingerPrint.emplace(topLevel);
64 fingerprints.try_emplace(op, op,
false);
71 topLevelFingerPrint.reset();
75 void notifyRewriteSuccess() {
80 if (failed(
verify(topLevel)))
81 llvm::report_fatal_error(
"IR failed to verify after pattern application");
85 if (*topLevelFingerPrint == afterFingerPrint) {
87 llvm::report_fatal_error(
88 "pattern returned success but IR did not change");
90 for (
const auto &it : fingerprints) {
92 if (it.first == topLevel)
103 llvm::report_fatal_error(
"operation finger print changed");
108 void notifyRewriteFailure() {
114 if (*topLevelFingerPrint != afterFingerPrint) {
116 llvm::report_fatal_error(
"pattern returned failure but IR did change");
120 void notifyFoldingSuccess() {
125 if (failed(
verify(topLevel)))
126 llvm::report_fatal_error(
"IR failed to verify after folding");
131 void invalidateFingerPrint(
Operation *op) { fingerprints.erase(op); }
133 void notifyBlockErased(
Block *block)
override {
144 void notifyOperationInserted(
Operation *op,
150 void notifyOperationModified(
Operation *op)
override {
152 invalidateFingerPrint(op);
155 void notifyOperationErased(
Operation *op)
override {
157 op->
walk([
this](
Operation *op) { invalidateFingerPrint(op); });
169 std::optional<OperationFingerPrint> topLevelFingerPrint;
181 static void logSuccessfulFolding(
Operation *op) {
182 llvm::dbgs() <<
"// *** IR Dump After Successful Folding ***\n";
184 llvm::dbgs() <<
"\n\n";
224 std::vector<Operation *> list;
230 Worklist::Worklist() { list.reserve(64); }
232 void Worklist::clear() {
237 bool Worklist::empty()
const {
239 return !llvm::any_of(list,
240 [](
Operation *op) {
return static_cast<bool>(op); });
244 assert(op &&
"cannot push nullptr to worklist");
246 if (!map.insert({op, list.size()}).second)
252 assert(!empty() &&
"cannot pop from empty worklist");
260 while (!list.empty() && !list.back())
266 assert(op &&
"cannot remove nullptr from worklist");
267 auto it = map.find(op);
268 if (it != map.end()) {
269 assert(list[it->second] == op &&
"malformed worklist data structure");
270 list[it->second] =
nullptr;
275 void Worklist::reverse() {
276 std::reverse(list.begin(), list.end());
277 for (
size_t i = 0, e = list.size(); i != e; ++i)
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
286 class RandomizedWorklist :
public Worklist {
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
296 assert(!list.empty() &&
"cannot pop from empty worklist");
299 list.
erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
324 explicit GreedyPatternRewriteDriver(
MLIRContext *ctx,
329 void addSingleOpToWorklist(
Operation *op);
336 void notifyOperationModified(
Operation *op)
override;
341 void notifyOperationInserted(
Operation *op,
347 void notifyOperationErased(
Operation *op)
override;
355 bool processWorklist();
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist;
376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
383 void addOperandsToWorklist(
Operation *op);
386 void notifyBlockInserted(
Block *block,
Region *previous,
390 void notifyBlockErased(
Block *block)
override;
399 llvm::ScopedPrinter logger{llvm::dbgs()};
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406 ExpensiveChecks expensiveChecks;
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
424 matcher.applyDefaultCostModel();
427 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
430 rewriter.setListener(&expensiveChecks);
432 rewriter.setListener(
this);
436 bool GreedyPatternRewriteDriver::processWorklist() {
438 const char *logLineComment =
439 "//===-------------------------------------------===//\n";
442 auto logResult = [&](StringRef result,
const llvm::Twine &msg = {}) {
444 logger.startLine() <<
"} -> " << result;
445 if (!msg.isTriviallyEmpty())
446 logger.getOStream() <<
" : " << msg;
447 logger.getOStream() <<
"\n";
449 auto logResultWithLine = [&](StringRef result,
const llvm::Twine &msg = {}) {
450 logResult(result, msg);
451 logger.startLine() << logLineComment;
456 int64_t numRewrites = 0;
457 while (!worklist.empty() &&
458 (numRewrites <
config.maxNumRewrites ||
459 config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
460 auto *op = worklist.pop();
463 logger.getOStream() <<
"\n";
464 logger.startLine() << logLineComment;
465 logger.startLine() <<
"Processing operation : '" << op->getName() <<
"'("
470 if (op->getNumRegions() == 0) {
473 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
474 logger.getOStream() <<
"\n\n";
480 rewriter.eraseOp(op);
483 LLVM_DEBUG(logResultWithLine(
"success",
"operation is trivially dead"));
493 if (succeeded(op->fold(foldResults))) {
494 LLVM_DEBUG(logResultWithLine(
"success",
"operation was folded"));
498 if (foldResults.empty()) {
500 notifyOperationModified(op);
502 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
503 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504 expensiveChecks.notifyFoldingSuccess();
510 assert(foldResults.size() == op->getNumResults() &&
511 "folder produced incorrect number of results");
513 rewriter.setInsertionPoint(op);
515 bool materializationSucceeded =
true;
516 for (
auto [ofr, resultType] :
517 llvm::zip_equal(foldResults, op->getResultTypes())) {
518 if (
auto value = ofr.dyn_cast<
Value>()) {
519 assert(value.getType() == resultType &&
520 "folder produced value of incorrect type");
521 replacements.push_back(value);
526 rewriter, ofr.get<
Attribute>(), resultType, op->getLoc());
531 llvm::SmallDenseSet<Operation *> replacementOps;
532 for (
Value replacement : replacements) {
533 assert(replacement.use_empty() &&
534 "folder reused existing op for one result but constant "
535 "materialization failed for another result");
536 replacementOps.insert(replacement.getDefiningOp());
539 rewriter.eraseOp(op);
542 materializationSucceeded =
false;
547 "materializeConstant produced op that is not a ConstantLike");
549 "materializeConstant produced incorrect result type");
550 replacements.push_back(constOp->
getResult(0));
553 if (materializationSucceeded) {
554 rewriter.replaceOp(op, replacements);
556 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
557 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558 expensiveChecks.notifyFoldingSuccess();
568 auto canApplyCallback = [&](
const Pattern &pattern) {
570 logger.getOStream() <<
"\n";
571 logger.startLine() <<
"* Pattern " << pattern.getDebugName() <<
" : '"
572 << op->getName() <<
" -> (";
573 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
574 logger.getOStream() <<
")' {\n";
578 config.listener->notifyPatternBegin(pattern, op);
582 auto onFailureCallback = [&](
const Pattern &pattern) {
583 LLVM_DEBUG(logResult(
"failure",
"pattern failed to match"));
585 config.listener->notifyPatternEnd(pattern, failure());
588 auto onSuccessCallback = [&](
const Pattern &pattern) {
589 LLVM_DEBUG(logResult(
"success",
"pattern applied successfully"));
591 config.listener->notifyPatternEnd(pattern, success());
606 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
608 expensiveChecks.computeFingerPrints(
config.scope->getParentOp());
610 auto clearFingerprints =
611 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
614 LogicalResult matchResult =
615 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
617 if (succeeded(matchResult)) {
618 LLVM_DEBUG(logResultWithLine(
"success",
"pattern matched"));
619 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620 expensiveChecks.notifyRewriteSuccess();
625 LLVM_DEBUG(logResultWithLine(
"failure",
"pattern failed to match"));
626 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627 expensiveChecks.notifyRewriteFailure();
635 void GreedyPatternRewriteDriver::addToWorklist(
Operation *op) {
636 assert(op &&
"expected valid op");
641 ancestors.push_back(op);
643 if (
config.scope == region) {
646 addSingleOpToWorklist(op);
649 if (region ==
nullptr)
654 void GreedyPatternRewriteDriver::addSingleOpToWorklist(
Operation *op) {
655 if (
config.strictMode == GreedyRewriteStrictness::AnyOp ||
656 strictModeFilteredOps.contains(op))
660 void GreedyPatternRewriteDriver::notifyBlockInserted(
663 config.listener->notifyBlockInserted(block, previous, previousIt);
666 void GreedyPatternRewriteDriver::notifyBlockErased(
Block *block) {
668 config.listener->notifyBlockErased(block);
671 void GreedyPatternRewriteDriver::notifyOperationInserted(
674 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
678 config.listener->notifyOperationInserted(op, previous);
679 if (
config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
680 strictModeFilteredOps.insert(op);
684 void GreedyPatternRewriteDriver::notifyOperationModified(
Operation *op) {
686 logger.startLine() <<
"** Modified: '" << op->
getName() <<
"'(" << op
690 config.listener->notifyOperationModified(op);
694 void GreedyPatternRewriteDriver::addOperandsToWorklist(
Operation *op) {
704 auto *defOp = operand.getDefiningOp();
709 bool hasMoreThanTwoUses =
false;
710 for (
auto user : operand.getUsers()) {
711 if (user == op || user == otherUser)
717 hasMoreThanTwoUses =
true;
720 if (hasMoreThanTwoUses)
723 addToWorklist(defOp);
727 void GreedyPatternRewriteDriver::notifyOperationErased(
Operation *op) {
729 logger.startLine() <<
"** Erase : '" << op->
getName() <<
"'(" << op
741 "scope region must not be erased during greedy pattern rewrite");
745 config.listener->notifyOperationErased(op);
747 addOperandsToWorklist(op);
750 if (
config.strictMode != GreedyRewriteStrictness::AnyOp)
751 strictModeFilteredOps.erase(op);
754 void GreedyPatternRewriteDriver::notifyOperationReplaced(
757 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
761 config.listener->notifyOperationReplaced(op, replacement);
764 void GreedyPatternRewriteDriver::notifyMatchFailure(
768 reasonCallback(
diag);
769 logger.startLine() <<
"** Match Failure : " <<
diag.str() <<
"\n";
772 config.listener->notifyMatchFailure(loc, reasonCallback);
781 class RegionPatternRewriteDriver :
public GreedyPatternRewriteDriver {
783 explicit RegionPatternRewriteDriver(
MLIRContext *ctx,
790 LogicalResult simplify(
bool *
changed) &&;
798 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
801 : GreedyPatternRewriteDriver(ctx,
patterns,
config), region(region) {
803 if (
config.strictMode != GreedyRewriteStrictness::AnyOp) {
804 region.
walk([&](
Operation *op) { strictModeFilteredOps.insert(op); });
809 class GreedyPatternRewriteIteration
814 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
815 iteration(iteration) {}
816 static constexpr StringLiteral tag =
"GreedyPatternRewriteIteration";
817 void print(raw_ostream &os)
const override {
818 os <<
"GreedyPatternRewriteIteration(" << iteration <<
")";
822 int64_t iteration = 0;
826 LogicalResult RegionPatternRewriteDriver::simplify(
bool *
changed) && {
827 bool continueRewrites =
false;
828 int64_t iteration = 0;
832 if (++iteration >
config.maxIterations &&
833 config.maxIterations != GreedyRewriteConfig::kNoLimit)
842 auto insertKnownConstant = [&](
Operation *op) {
847 if (!folder.insertKnownConstant(op, constValue))
852 if (!
config.useTopDownTraversal) {
855 if (!
config.cseConstants || !insertKnownConstant(op))
861 if (!
config.cseConstants || !insertKnownConstant(op)) {
863 return WalkResult::advance();
865 return WalkResult::skip();
874 continueRewrites = processWorklist();
878 if (
config.enableRegionSimplification !=
879 GreedySimplifyRegionLevel::Disabled) {
882 config.enableRegionSimplification ==
883 GreedySimplifyRegionLevel::Aggressive));
886 {®ion}, iteration);
887 }
while (continueRewrites);
893 return success(!continueRewrites);
904 "patterns can only be applied to operations IsolatedFromAbove");
910 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
912 llvm::report_fatal_error(
913 "greedy pattern rewriter input IR failed to verify");
919 LogicalResult converged = std::move(driver).simplify(
changed);
920 LLVM_DEBUG(
if (failed(converged)) {
921 llvm::dbgs() <<
"The pattern rewrite did not converge after scanning "
922 <<
config.maxIterations <<
" times\n";
933 class MultiOpPatternRewriteDriver :
public GreedyPatternRewriteDriver {
935 explicit MultiOpPatternRewriteDriver(
938 llvm::SmallDenseSet<Operation *, 4> *survivingOps =
nullptr);
944 void notifyOperationErased(
Operation *op)
override {
945 GreedyPatternRewriteDriver::notifyOperationErased(op);
947 survivingOps->erase(op);
953 llvm::SmallDenseSet<Operation *, 4> *
const survivingOps =
nullptr;
957 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
960 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
962 survivingOps(survivingOps) {
964 strictModeFilteredOps.insert(ops.begin(), ops.end());
967 survivingOps->clear();
968 survivingOps->insert(ops.begin(), ops.end());
976 addSingleOpToWorklist(op);
979 bool result = processWorklist();
983 return success(worklist.empty());
991 assert(!ops.empty() &&
"expected at least one op");
994 return ops.front()->getParentRegion();
996 Region *region = ops.front()->getParentRegion();
997 ops = ops.drop_front();
999 llvm::BitVector remainingOps(sz,
true);
1003 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1006 remainingOps.reset(pos);
1008 if (remainingOps.none())
1034 bool allOpsInScope = llvm::all_of(ops, [&](
Operation *op) {
1035 return static_cast<bool>(
config.scope->findAncestorOpInRegion(*op));
1037 assert(allOpsInScope &&
"ops must be within the specified scope");
1041 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1043 llvm::report_fatal_error(
1044 "greedy pattern rewriter input IR failed to verify");
1048 llvm::SmallDenseSet<Operation *, 4> surviving;
1049 MultiOpPatternRewriteDriver driver(ops.front()->getContext(),
patterns,
1051 allErased ? &surviving :
nullptr);
1052 LogicalResult converged = std::move(driver).simplify(ops,
changed);
1054 *allErased = surviving.empty();
1055 LLVM_DEBUG(
if (failed(converged)) {
1056 llvm::dbgs() <<
"The pattern rewrite did not converge after "
1057 <<
config.maxNumRewrites <<
" rewrites";
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
BlockListType::iterator iterator
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
CRTP Implementation of an action.
virtual void print(raw_ostream &os) const
RewritePatternSet & patterns
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
A listener that forwards all notifications to another listener.
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.