15 #include "mlir/Config/mlir-config.h"
22 #include "llvm/ADT/BitVector.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/ScopedPrinter.h"
28 #include "llvm/Support/raw_ostream.h"
30 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
36 #define DEBUG_TYPE "greedy-rewriter"
44 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
53 void computeFingerPrints(
Operation *topLevel) {
54 this->topLevel = topLevel;
55 this->topLevelFingerPrint.emplace(topLevel);
56 topLevel->
walk([&](
Operation *op) { fingerprints.try_emplace(op, op); });
62 topLevelFingerPrint.reset();
66 void notifyRewriteSuccess() {
69 if (*topLevelFingerPrint == afterFingerPrint) {
71 llvm::report_fatal_error(
72 "pattern returned success but IR did not change");
74 for (
const auto &it : fingerprints) {
76 if (it.first == topLevel)
86 llvm::report_fatal_error(
"operation finger print changed");
91 void notifyRewriteFailure() {
94 if (*topLevelFingerPrint != afterFingerPrint) {
96 llvm::report_fatal_error(
"pattern returned failure but IR did change");
102 void invalidateFingerPrint(
Operation *op) {
104 while (op && op != topLevel) {
105 fingerprints.erase(op);
110 void notifyOperationInserted(
Operation *op)
override {
115 void notifyOperationModified(
Operation *op)
override {
117 invalidateFingerPrint(op);
120 void notifyOperationRemoved(
Operation *op)
override {
122 op->
walk([
this](
Operation *op) { invalidateFingerPrint(op); });
134 std::optional<OperationFingerPrint> topLevelFingerPrint;
174 std::vector<Operation *> list;
180 Worklist::Worklist() { list.reserve(64); }
182 void Worklist::clear() {
187 bool Worklist::empty()
const {
189 return !llvm::any_of(list,
190 [](
Operation *op) {
return static_cast<bool>(op); });
194 assert(op &&
"cannot push nullptr to worklist");
198 map[op] = list.size();
203 assert(!empty() &&
"cannot pop from empty worklist");
211 while (!list.empty() && !list.back())
217 assert(op &&
"cannot remove nullptr from worklist");
218 auto it = map.find(op);
219 if (it != map.end()) {
220 assert(list[it->second] == op &&
"malformed worklist data structure");
221 list[it->second] =
nullptr;
226 void Worklist::reverse() {
227 std::reverse(list.begin(), list.end());
228 for (
size_t i = 0, e = list.size(); i != e; ++i)
232 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
237 class RandomizedWorklist :
public Worklist {
239 RandomizedWorklist() : Worklist() {
240 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
247 assert(!list.empty() &&
"cannot pop from empty worklist");
250 list.
erase(list.begin() + pos);
251 for (int64_t i = pos, e = list.size(); i < e; ++i)
276 explicit GreedyPatternRewriteDriver(
MLIRContext *ctx,
281 void addSingleOpToWorklist(
Operation *op);
288 void notifyOperationModified(
Operation *op)
override;
293 void notifyOperationInserted(
Operation *op)
override;
298 void notifyOperationRemoved(
Operation *op)
override;
306 bool processWorklist();
310 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311 RandomizedWorklist worklist;
326 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
333 void addOperandsToWorklist(
ValueRange operands);
336 void notifyBlockCreated(
Block *block)
override;
345 llvm::ScopedPrinter logger{llvm::dbgs()};
351 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
352 DebugFingerPrints debugFingerPrints;
357 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
360 :
PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
361 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
363 , debugFingerPrints(this)
368 matcher.applyDefaultCostModel();
371 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
374 setListener(&debugFingerPrints);
380 bool GreedyPatternRewriteDriver::processWorklist() {
382 const char *logLineComment =
383 "//===-------------------------------------------===//\n";
386 auto logResult = [&](StringRef result,
const llvm::Twine &msg = {}) {
388 logger.startLine() <<
"} -> " << result;
389 if (!msg.isTriviallyEmpty())
390 logger.getOStream() <<
" : " << msg;
391 logger.getOStream() <<
"\n";
393 auto logResultWithLine = [&](StringRef result,
const llvm::Twine &msg = {}) {
394 logResult(result, msg);
395 logger.startLine() << logLineComment;
399 bool changed =
false;
400 int64_t numRewrites = 0;
401 while (!worklist.empty() &&
404 auto *op = worklist.pop();
407 logger.getOStream() <<
"\n";
408 logger.startLine() << logLineComment;
409 logger.startLine() <<
"Processing operation : '" << op->
getName() <<
"'("
417 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
418 logger.getOStream() <<
"\n\n";
427 LLVM_DEBUG(logResultWithLine(
"success",
"operation is trivially dead"));
433 LLVM_DEBUG(logResultWithLine(
"success",
"operation was folded"));
442 auto canApply = [&](
const Pattern &pattern) {
444 logger.getOStream() <<
"\n";
445 logger.startLine() <<
"* Pattern " << pattern.getDebugName() <<
" : '"
447 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
448 logger.getOStream() <<
")' {\n";
453 auto onFailure = [&](
const Pattern &pattern) {
454 LLVM_DEBUG(logResult(
"failure",
"pattern failed to match"));
456 auto onSuccess = [&](
const Pattern &pattern) {
457 LLVM_DEBUG(logResult(
"success",
"pattern applied successfully"));
466 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
467 debugFingerPrints.computeFingerPrints(
469 auto clearFingerprints =
470 llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
474 matcher.matchAndRewrite(op, *
this, canApply, onFailure, onSuccess);
477 LLVM_DEBUG(logResultWithLine(
"success",
"pattern matched"));
478 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
479 debugFingerPrints.notifyRewriteSuccess();
484 LLVM_DEBUG(logResultWithLine(
"failure",
"pattern failed to match"));
485 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
486 debugFingerPrints.notifyRewriteFailure();
494 void GreedyPatternRewriteDriver::addToWorklist(
Operation *op) {
495 assert(op &&
"expected valid op");
500 ancestors.push_back(op);
502 if (config.
scope == region) {
505 addSingleOpToWorklist(op);
508 if (region ==
nullptr)
513 void GreedyPatternRewriteDriver::addSingleOpToWorklist(
Operation *op) {
514 if (config.
strictMode == GreedyRewriteStrictness::AnyOp ||
515 strictModeFilteredOps.contains(op))
519 void GreedyPatternRewriteDriver::notifyBlockCreated(
Block *block) {
524 void GreedyPatternRewriteDriver::notifyOperationInserted(
Operation *op) {
526 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
531 if (config.
strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
532 strictModeFilteredOps.insert(op);
536 void GreedyPatternRewriteDriver::notifyOperationModified(
Operation *op) {
538 logger.startLine() <<
"** Modified: '" << op->
getName() <<
"'(" << op
546 void GreedyPatternRewriteDriver::addOperandsToWorklist(
ValueRange operands) {
547 for (
Value operand : operands) {
553 if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
555 if (
auto *defOp = operand.getDefiningOp())
556 addToWorklist(defOp);
560 void GreedyPatternRewriteDriver::notifyOperationRemoved(
Operation *op) {
562 logger.startLine() <<
"** Erase : '" << op->
getName() <<
"'(" << op
570 folder.notifyRemoval(op);
572 if (config.
strictMode != GreedyRewriteStrictness::AnyOp)
573 strictModeFilteredOps.erase(op);
576 void GreedyPatternRewriteDriver::notifyOperationReplaced(
579 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
585 for (
auto *user : result.getUsers())
589 LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
593 reasonCallback(
diag);
594 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
607 class RegionPatternRewriteDriver :
public GreedyPatternRewriteDriver {
609 explicit RegionPatternRewriteDriver(
MLIRContext *ctx,
624 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
627 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
629 if (config.
strictMode != GreedyRewriteStrictness::AnyOp) {
630 region.
walk([&](
Operation *op) { strictModeFilteredOps.insert(op); });
635 class GreedyPatternRewriteIteration
640 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
641 iteration(iteration) {}
642 static constexpr StringLiteral tag =
"GreedyPatternRewriteIteration";
643 void print(raw_ostream &os)
const override {
644 os <<
"GreedyPatternRewriteIteration(" << iteration <<
")";
648 int64_t iteration = 0;
652 LogicalResult RegionPatternRewriteDriver::simplify(
bool *changed) && {
653 auto insertKnownConstant = [&](
Operation *op) {
658 if (!folder.insertKnownConstant(op, constValue))
663 bool continueRewrites =
false;
664 int64_t iteration = 0;
677 if (!insertKnownConstant(op))
683 if (!insertKnownConstant(op)) {
685 return WalkResult::advance();
687 return WalkResult::skip();
696 continueRewrites = processWorklist();
703 {®ion}, iteration);
704 }
while (continueRewrites);
707 *changed = iteration > 1;
710 return success(!continueRewrites);
721 "patterns can only be applied to operations IsolatedFromAbove");
725 config.
scope = ®ion;
728 RegionPatternRewriteDriver driver(region.
getContext(), patterns, config,
730 LogicalResult converged = std::move(driver).simplify(changed);
731 LLVM_DEBUG(
if (
failed(converged)) {
732 llvm::dbgs() <<
"The pattern rewrite did not converge after scanning "
744 class MultiOpPatternRewriteDriver :
public GreedyPatternRewriteDriver {
746 explicit MultiOpPatternRewriteDriver(
749 llvm::SmallDenseSet<Operation *, 4> *survivingOps =
nullptr);
755 void notifyOperationRemoved(
Operation *op)
override {
756 GreedyPatternRewriteDriver::notifyOperationRemoved(op);
758 survivingOps->erase(op);
764 llvm::SmallDenseSet<Operation *, 4> *
const survivingOps =
nullptr;
768 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
771 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
772 : GreedyPatternRewriteDriver(ctx, patterns, config),
773 survivingOps(survivingOps) {
775 strictModeFilteredOps.insert(ops.begin(), ops.end());
778 survivingOps->clear();
779 survivingOps->insert(ops.begin(), ops.end());
787 addSingleOpToWorklist(op);
790 bool result = processWorklist();
794 return success(worklist.empty());
802 assert(!ops.empty() &&
"expected at least one op");
805 return ops.front()->getParentRegion();
807 Region *region = ops.front()->getParentRegion();
808 ops = ops.drop_front();
810 llvm::BitVector remainingOps(sz,
true);
814 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
817 remainingOps.reset(pos);
819 if (remainingOps.none())
845 bool allOpsInScope = llvm::all_of(ops, [&](
Operation *op) {
848 assert(allOpsInScope &&
"ops must be within the specified scope");
853 llvm::SmallDenseSet<Operation *, 4> surviving;
854 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
856 allErased ? &surviving :
nullptr);
857 LogicalResult converged = std::move(driver).simplify(ops, changed);
859 *allErased = surviving.empty();
860 LLVM_DEBUG(
if (
failed(converged)) {
861 llvm::dbgs() <<
"The pattern rewrite did not converge after "
static Region * findCommonAncestor(ArrayRef< Operation * > ops)
Find the region that is the closest common ancestor of all given ops.
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
Region * scope
Only ops within the scope are added to the worklist.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
bool enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
This class provides the API for ops that are known to be isolated from above.
A unique fingerprint for a specific operation, and all of it's internal operations.
A utility class for folding operations, and unifying duplicated constants generated along the way.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Operation * getParentOp()
Return the parent operation this region is attached to.
std::enable_if_t< std::is_same< RetT, void >::value, RetT > walk(FnT &&callback)
Walk the operations in this region.
MLIRContext * getContext()
Return the context this region is inserted in.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
CRTP Implementation of an action.
virtual void print(raw_ostream &os) const
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Applies the specified rewrite patterns on ops while also trying to fold these ops.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Run a set of structural simplifications over the given regions.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ AnyOp
No restrictions wrt. which ops are processed.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
virtual void notifyBlockCreated(Block *block)
Notification handler for when a block is created using the builder.
virtual void notifyOperationInserted(Operation *op)
Notification handler for when an operation is inserted into the builder.
A listener that forwards all notifications to another listener.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationInserted(Operation *op) override
Notification handler for when an operation is inserted into the builder.
void notifyOperationRemoved(Operation *op) override
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that the specified operation is about to be replaced with another operation.
virtual LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
virtual void notifyOperationRemoved(Operation *op)
Notify the listener that the specified operation is about to be erased.