15 #include "mlir/Config/mlir-config.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
37 #define DEBUG_TYPE "greedy-rewriter"
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60 void computeFingerPrints(
Operation *topLevel) {
61 this->topLevel = topLevel;
62 this->topLevelFingerPrint.emplace(topLevel);
64 fingerprints.try_emplace(op, op,
false);
71 topLevelFingerPrint.reset();
75 void notifyRewriteSuccess() {
81 llvm::report_fatal_error(
"IR failed to verify after pattern application");
85 if (*topLevelFingerPrint == afterFingerPrint) {
87 llvm::report_fatal_error(
88 "pattern returned success but IR did not change");
90 for (
const auto &it : fingerprints) {
92 if (it.first == topLevel)
103 llvm::report_fatal_error(
"operation finger print changed");
108 void notifyRewriteFailure() {
114 if (*topLevelFingerPrint != afterFingerPrint) {
116 llvm::report_fatal_error(
"pattern returned failure but IR did change");
120 void notifyFoldingSuccess() {
126 llvm::report_fatal_error(
"IR failed to verify after folding");
131 void invalidateFingerPrint(
Operation *op) { fingerprints.erase(op); }
133 void notifyBlockErased(
Block *block)
override {
144 void notifyOperationInserted(
Operation *op,
150 void notifyOperationModified(
Operation *op)
override {
152 invalidateFingerPrint(op);
155 void notifyOperationErased(
Operation *op)
override {
157 op->
walk([
this](
Operation *op) { invalidateFingerPrint(op); });
169 std::optional<OperationFingerPrint> topLevelFingerPrint;
181 static void logSuccessfulFolding(
Operation *op) {
182 llvm::dbgs() <<
"// *** IR Dump After Successful Folding ***\n";
184 llvm::dbgs() <<
"\n\n";
224 std::vector<Operation *> list;
230 Worklist::Worklist() { list.reserve(64); }
232 void Worklist::clear() {
237 bool Worklist::empty()
const {
239 return !llvm::any_of(list,
240 [](
Operation *op) {
return static_cast<bool>(op); });
244 assert(op &&
"cannot push nullptr to worklist");
246 if (!map.insert({op, list.size()}).second)
252 assert(!empty() &&
"cannot pop from empty worklist");
260 while (!list.empty() && !list.back())
266 assert(op &&
"cannot remove nullptr from worklist");
267 auto it = map.find(op);
268 if (it != map.end()) {
269 assert(list[it->second] == op &&
"malformed worklist data structure");
270 list[it->second] =
nullptr;
275 void Worklist::reverse() {
276 std::reverse(list.begin(), list.end());
277 for (
size_t i = 0, e = list.size(); i != e; ++i)
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
286 class RandomizedWorklist :
public Worklist {
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
296 assert(!list.empty() &&
"cannot pop from empty worklist");
299 list.
erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
325 explicit GreedyPatternRewriteDriver(
MLIRContext *ctx,
330 void addSingleOpToWorklist(
Operation *op);
337 void notifyOperationModified(
Operation *op)
override;
342 void notifyOperationInserted(
Operation *op, InsertPoint previous)
override;
347 void notifyOperationErased(
Operation *op)
override;
355 bool processWorklist();
359 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
360 RandomizedWorklist worklist;
372 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
379 void addOperandsToWorklist(
Operation *op);
382 void notifyBlockInserted(
Block *block,
Region *previous,
386 void notifyBlockErased(
Block *block)
override;
395 llvm::ScopedPrinter logger{llvm::dbgs()};
401 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
402 ExpensiveChecks expensiveChecks;
407 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
411 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
415 config.scope ? config.scope->getParentOp() : nullptr)
420 matcher.applyDefaultCostModel();
423 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
426 setListener(&expensiveChecks);
432 bool GreedyPatternRewriteDriver::processWorklist() {
434 const char *logLineComment =
435 "//===-------------------------------------------===//\n";
438 auto logResult = [&](StringRef result,
const llvm::Twine &msg = {}) {
440 logger.startLine() <<
"} -> " << result;
441 if (!msg.isTriviallyEmpty())
442 logger.getOStream() <<
" : " << msg;
443 logger.getOStream() <<
"\n";
445 auto logResultWithLine = [&](StringRef result,
const llvm::Twine &msg = {}) {
446 logResult(result, msg);
447 logger.startLine() << logLineComment;
451 bool changed =
false;
452 int64_t numRewrites = 0;
453 while (!worklist.empty() &&
456 auto *op = worklist.pop();
459 logger.getOStream() <<
"\n";
460 logger.startLine() << logLineComment;
461 logger.startLine() <<
"Processing operation : '" << op->
getName() <<
"'("
469 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
470 logger.getOStream() <<
"\n\n";
479 LLVM_DEBUG(logResultWithLine(
"success",
"operation is trivially dead"));
490 LLVM_DEBUG(logResultWithLine(
"success",
"operation was folded"));
494 if (foldResults.empty()) {
496 notifyOperationModified(op);
498 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
499 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
500 expensiveChecks.notifyFoldingSuccess();
507 "folder produced incorrect number of results");
509 setInsertionPoint(op);
511 bool materializationSucceeded =
true;
512 for (
auto [ofr, resultType] :
514 if (
auto value = ofr.dyn_cast<
Value>()) {
515 assert(value.getType() == resultType &&
516 "folder produced value of incorrect type");
517 replacements.push_back(value);
527 llvm::SmallDenseSet<Operation *> replacementOps;
528 for (
Value replacement : replacements) {
529 assert(replacement.use_empty() &&
530 "folder reused existing op for one result but constant "
531 "materialization failed for another result");
532 replacementOps.insert(replacement.getDefiningOp());
538 materializationSucceeded =
false;
543 "materializeConstant produced op that is not a ConstantLike");
545 "materializeConstant produced incorrect result type");
546 replacements.push_back(constOp->
getResult(0));
549 if (materializationSucceeded) {
550 replaceOp(op, replacements);
552 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
553 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
554 expensiveChecks.notifyFoldingSuccess();
564 auto canApplyCallback = [&](
const Pattern &pattern) {
566 logger.getOStream() <<
"\n";
567 logger.startLine() <<
"* Pattern " << pattern.getDebugName() <<
" : '"
569 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
570 logger.getOStream() <<
")' {\n";
578 auto onFailureCallback = [&](
const Pattern &pattern) {
579 LLVM_DEBUG(logResult(
"failure",
"pattern failed to match"));
584 auto onSuccessCallback = [&](
const Pattern &pattern) {
585 LLVM_DEBUG(logResult(
"success",
"pattern applied successfully"));
602 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
606 auto clearFingerprints =
607 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
611 matcher.matchAndRewrite(op, *
this, canApply, onFailure, onSuccess);
614 LLVM_DEBUG(logResultWithLine(
"success",
"pattern matched"));
615 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
616 expensiveChecks.notifyRewriteSuccess();
621 LLVM_DEBUG(logResultWithLine(
"failure",
"pattern failed to match"));
622 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
623 expensiveChecks.notifyRewriteFailure();
631 void GreedyPatternRewriteDriver::addToWorklist(
Operation *op) {
632 assert(op &&
"expected valid op");
637 ancestors.push_back(op);
639 if (config.
scope == region) {
642 addSingleOpToWorklist(op);
645 if (region ==
nullptr)
650 void GreedyPatternRewriteDriver::addSingleOpToWorklist(
Operation *op) {
651 if (config.
strictMode == GreedyRewriteStrictness::AnyOp ||
652 strictModeFilteredOps.contains(op))
656 void GreedyPatternRewriteDriver::notifyBlockInserted(
662 void GreedyPatternRewriteDriver::notifyBlockErased(
Block *block) {
667 void GreedyPatternRewriteDriver::notifyOperationInserted(
Operation *op,
668 InsertPoint previous) {
670 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
675 if (config.
strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
676 strictModeFilteredOps.insert(op);
680 void GreedyPatternRewriteDriver::notifyOperationModified(
Operation *op) {
682 logger.startLine() <<
"** Modified: '" << op->
getName() <<
"'(" << op
690 void GreedyPatternRewriteDriver::addOperandsToWorklist(
Operation *op) {
700 auto *defOp = operand.getDefiningOp();
705 bool hasMoreThanTwoUses =
false;
706 for (
auto user : operand.getUsers()) {
707 if (user == op || user == otherUser)
713 hasMoreThanTwoUses =
true;
716 if (hasMoreThanTwoUses)
719 addToWorklist(defOp);
723 void GreedyPatternRewriteDriver::notifyOperationErased(
Operation *op) {
725 logger.startLine() <<
"** Erase : '" << op->
getName() <<
"'(" << op
737 "scope region must not be erased during greedy pattern rewrite");
743 addOperandsToWorklist(op);
746 if (config.
strictMode != GreedyRewriteStrictness::AnyOp)
747 strictModeFilteredOps.erase(op);
750 void GreedyPatternRewriteDriver::notifyOperationReplaced(
753 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
760 void GreedyPatternRewriteDriver::notifyMatchFailure(
764 reasonCallback(
diag);
765 logger.startLine() <<
"** Match Failure : " <<
diag.str() <<
"\n";
777 class RegionPatternRewriteDriver :
public GreedyPatternRewriteDriver {
779 explicit RegionPatternRewriteDriver(
MLIRContext *ctx,
794 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
797 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
799 if (config.
strictMode != GreedyRewriteStrictness::AnyOp) {
800 region.
walk([&](
Operation *op) { strictModeFilteredOps.insert(op); });
805 class GreedyPatternRewriteIteration
810 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
811 iteration(iteration) {}
812 static constexpr StringLiteral tag =
"GreedyPatternRewriteIteration";
813 void print(raw_ostream &os)
const override {
814 os <<
"GreedyPatternRewriteIteration(" << iteration <<
")";
818 int64_t iteration = 0;
822 LogicalResult RegionPatternRewriteDriver::simplify(
bool *changed) && {
823 bool continueRewrites =
false;
824 int64_t iteration = 0;
838 auto insertKnownConstant = [&](
Operation *op) {
843 if (!folder.insertKnownConstant(op, constValue))
851 if (!insertKnownConstant(op))
857 if (!insertKnownConstant(op)) {
859 return WalkResult::advance();
861 return WalkResult::skip();
870 continueRewrites = processWorklist();
877 {®ion}, iteration);
878 }
while (continueRewrites);
881 *changed = iteration > 1;
884 return success(!continueRewrites);
895 "patterns can only be applied to operations IsolatedFromAbove");
899 config.
scope = ®ion;
901 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
903 llvm::report_fatal_error(
904 "greedy pattern rewriter input IR failed to verify");
908 RegionPatternRewriteDriver driver(region.
getContext(), patterns, config,
910 LogicalResult converged = std::move(driver).simplify(changed);
911 LLVM_DEBUG(
if (
failed(converged)) {
912 llvm::dbgs() <<
"The pattern rewrite did not converge after scanning "
924 class MultiOpPatternRewriteDriver :
public GreedyPatternRewriteDriver {
926 explicit MultiOpPatternRewriteDriver(
929 llvm::SmallDenseSet<Operation *, 4> *survivingOps =
nullptr);
935 void notifyOperationErased(
Operation *op)
override {
936 GreedyPatternRewriteDriver::notifyOperationErased(op);
938 survivingOps->erase(op);
944 llvm::SmallDenseSet<Operation *, 4> *
const survivingOps =
nullptr;
948 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
951 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
952 : GreedyPatternRewriteDriver(ctx, patterns, config),
953 survivingOps(survivingOps) {
955 strictModeFilteredOps.insert(ops.begin(), ops.end());
958 survivingOps->clear();
959 survivingOps->insert(ops.begin(), ops.end());
967 addSingleOpToWorklist(op);
970 bool result = processWorklist();
974 return success(worklist.empty());
982 assert(!ops.empty() &&
"expected at least one op");
985 return ops.front()->getParentRegion();
987 Region *region = ops.front()->getParentRegion();
988 ops = ops.drop_front();
990 llvm::BitVector remainingOps(sz,
true);
994 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
997 remainingOps.reset(pos);
999 if (remainingOps.none())
1018 if (!config.
scope) {
1025 bool allOpsInScope = llvm::all_of(ops, [&](
Operation *op) {
1028 assert(allOpsInScope &&
"ops must be within the specified scope");
1032 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1034 llvm::report_fatal_error(
1035 "greedy pattern rewriter input IR failed to verify");
1039 llvm::SmallDenseSet<Operation *, 4> surviving;
1040 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1042 allErased ? &surviving :
nullptr);
1043 LogicalResult converged = std::move(driver).simplify(ops, changed);
1045 *allErased = surviving.empty();
1046 LLVM_DEBUG(
if (
failed(converged)) {
1047 llvm::dbgs() <<
"The pattern rewrite did not converge after "
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
Region * scope
Only ops within the scope are added to the worklist.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
bool enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
A utility class for folding operations, and unifying duplicated constants generated along the way.
Operation is the basic unit of execution within MLIR.
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
BlockListType::iterator iterator
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
CRTP Implementation of an action.
virtual void print(raw_ostream &os) const
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Run a set of structural simplifications over the given regions.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ AnyOp
No restrictions wrt. which ops are processed.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
A listener that forwards all notifications to another listener.
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...