MLIR 22.0.0git
PatternMatch.h
Go to the documentation of this file.
1//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_PATTERNMATCH_H
10#define MLIR_IR_PATTERNMATCH_H
11
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "llvm/ADT/FunctionExtras.h"
15#include "llvm/Support/TypeName.h"
16#include <optional>
17
19namespace mlir {
20
21class PatternRewriter;
22
23//===----------------------------------------------------------------------===//
24// PatternBenefit class
25//===----------------------------------------------------------------------===//
26
27/// This class represents the benefit of a pattern match in a unitless scheme
28/// that ranges from 0 (very little benefit) to 65K. The most common unit to
29/// use here is the "number of operations matched" by the pattern.
30///
31/// This also has a sentinel representation that can be used for patterns that
32/// fail to match.
33///
35 enum { ImpossibleToMatchSentinel = 65535 };
36
37public:
38 PatternBenefit() = default;
39 PatternBenefit(unsigned benefit);
40 PatternBenefit(const PatternBenefit &) = default;
42
44 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
45
46 /// If the corresponding pattern can match, return its benefit. If the
47 // corresponding pattern isImpossibleToMatch() then this aborts.
48 unsigned short getBenefit() const;
49
50 bool operator==(const PatternBenefit &rhs) const {
51 return representation == rhs.representation;
52 }
53 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
54 bool operator<(const PatternBenefit &rhs) const {
55 return representation < rhs.representation;
56 }
57 bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
58 bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
59 bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
60
61private:
62 unsigned short representation{ImpossibleToMatchSentinel};
63};
64
65//===----------------------------------------------------------------------===//
66// Pattern
67//===----------------------------------------------------------------------===//
68
69/// This class contains all of the data related to a pattern, but does not
70/// contain any methods or logic for the actual matching. This class is solely
71/// used to interface with the metadata of a pattern, such as the benefit or
72/// root operation.
73class Pattern {
74 /// This enum represents the kind of value used to select the root operations
75 /// that match this pattern.
76 enum class RootKind {
77 /// The pattern root matches "any" operation.
78 Any,
79 /// The pattern root is matched using a concrete operation name.
80 OperationName,
81 /// The pattern root is matched using an interface ID.
82 InterfaceID,
83 /// The patter root is matched using a trait ID.
84 TraitID
85 };
86
87public:
88 /// Return a list of operations that may be generated when rewriting an
89 /// operation instance with this pattern.
90 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
91
92 /// Return the root node that this pattern matches. Patterns that can match
93 /// multiple root types return std::nullopt.
94 std::optional<OperationName> getRootKind() const {
95 if (rootKind == RootKind::OperationName)
97 return std::nullopt;
98 }
99
100 /// Return the interface ID used to match the root operation of this pattern.
101 /// If the pattern does not use an interface ID for deciding the root match,
102 /// this returns std::nullopt.
103 std::optional<TypeID> getRootInterfaceID() const {
104 if (rootKind == RootKind::InterfaceID)
105 return TypeID::getFromOpaquePointer(rootValue);
106 return std::nullopt;
107 }
108
109 /// Return the trait ID used to match the root operation of this pattern.
110 /// If the pattern does not use a trait ID for deciding the root match, this
111 /// returns std::nullopt.
112 std::optional<TypeID> getRootTraitID() const {
113 if (rootKind == RootKind::TraitID)
114 return TypeID::getFromOpaquePointer(rootValue);
115 return std::nullopt;
116 }
117
118 /// Return the benefit (the inverse of "cost") of matching this pattern. The
119 /// benefit of a Pattern is always static - rewrites that may have dynamic
120 /// benefit can be instantiated multiple times (different Pattern instances)
121 /// for each benefit that they may return, and be guarded by different match
122 /// condition predicates.
123 PatternBenefit getBenefit() const { return benefit; }
124
125 /// Returns true if this pattern is known to result in recursive application,
126 /// i.e. this pattern may generate IR that also matches this pattern, but is
127 /// known to bound the recursion. This signals to a rewrite driver that it is
128 /// safe to apply this pattern recursively to generated IR.
130 return contextAndHasBoundedRecursion.getInt();
131 }
132
133 /// Return the MLIRContext used to create this pattern.
135 return contextAndHasBoundedRecursion.getPointer();
136 }
137
138 /// Return a readable name for this pattern. This name should only be used for
139 /// debugging purposes, and may be empty.
140 StringRef getDebugName() const { return debugName; }
141
142 /// Set the human readable debug name used for this pattern. This name will
143 /// only be used for debugging purposes.
144 void setDebugName(StringRef name) { debugName = name; }
145
146 /// Return the set of debug labels attached to this pattern.
147 ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
148
149 /// Add the provided debug labels to this pattern.
151 debugLabels.append(labels.begin(), labels.end());
152 }
153 void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
154
155protected:
156 /// This class acts as a special tag that makes the desire to match "any"
157 /// operation type explicit. This helps to avoid unnecessary usages of this
158 /// feature, and ensures that the user is making a conscious decision.
160 /// This class acts as a special tag that makes the desire to match any
161 /// operation that implements a given interface explicit. This helps to avoid
162 /// unnecessary usages of this feature, and ensures that the user is making a
163 /// conscious decision.
165 /// This class acts as a special tag that makes the desire to match any
166 /// operation that implements a given trait explicit. This helps to avoid
167 /// unnecessary usages of this feature, and ensures that the user is making a
168 /// conscious decision.
170
171 /// Construct a pattern with a certain benefit that matches the operation
172 /// with the given root name.
173 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
174 ArrayRef<StringRef> generatedNames = {});
175 /// Construct a pattern that may match any operation type. `generatedNames`
176 /// contains the names of operations that may be generated during a successful
177 /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
178 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
179 /// always be supplied here.
180 Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
181 ArrayRef<StringRef> generatedNames = {});
182 /// Construct a pattern that may match any operation that implements the
183 /// interface defined by the provided `interfaceID`. `generatedNames` contains
184 /// the names of operations that may be generated during a successful rewrite.
185 /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
186 /// interface" behavior is what the user actually desired,
187 /// `MatchInterfaceOpTypeTag()` should always be supplied here.
188 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189 PatternBenefit benefit, MLIRContext *context,
190 ArrayRef<StringRef> generatedNames = {});
191 /// Construct a pattern that may match any operation that implements the
192 /// trait defined by the provided `traitID`. `generatedNames` contains the
193 /// names of operations that may be generated during a successful rewrite.
194 /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
195 /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
196 /// always be supplied here.
197 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198 MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
199
200 /// Set the flag detailing if this pattern has bounded rewrite recursion or
201 /// not.
202 void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
203 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
204 }
205
206private:
207 Pattern(const void *rootValue, RootKind rootKind,
208 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
209 MLIRContext *context);
210
211 /// The value used to match the root operation of the pattern.
212 const void *rootValue;
213 RootKind rootKind;
214
215 /// The expected benefit of matching this pattern.
216 const PatternBenefit benefit;
217
218 /// The context this pattern was created from, and a boolean flag indicating
219 /// whether this pattern has bounded recursion or not.
220 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
221
222 /// A list of the potential operations that may be generated when rewriting
223 /// an op with this pattern.
225
226 /// A readable name for this pattern. May be empty.
227 StringRef debugName;
228
229 /// The set of debug labels attached to this pattern.
230 SmallVector<StringRef, 0> debugLabels;
231};
232
233//===----------------------------------------------------------------------===//
234// RewritePattern
235//===----------------------------------------------------------------------===//
236
237/// RewritePattern is the common base class for all DAG to DAG replacements.
238class RewritePattern : public Pattern {
239public:
240 virtual ~RewritePattern() = default;
241
242 /// Attempt to match against code rooted at the specified operation,
243 /// which is the same operation code as getRootKind(). If successful, perform
244 /// the rewrite.
245 ///
246 /// Note: Implementations must modify the IR if and only if the function
247 /// returns "success".
248 virtual LogicalResult matchAndRewrite(Operation *op,
249 PatternRewriter &rewriter) const = 0;
250
251 /// This method provides a convenient interface for creating and initializing
252 /// derived rewrite patterns of the given type `T`.
253 template <typename T, typename... Args>
254 static std::unique_ptr<T> create(Args &&...args) {
255 std::unique_ptr<T> pattern =
256 std::make_unique<T>(std::forward<Args>(args)...);
257 initializePattern<T>(*pattern);
258
259 // Set a default debug name if one wasn't provided.
260 if (pattern->getDebugName().empty())
261 pattern->setDebugName(llvm::getTypeName<T>());
262 return pattern;
263 }
264
265protected:
266 /// Inherit the base constructors from `Pattern`.
267 using Pattern::Pattern;
268
269private:
270 /// Trait to check if T provides a `initialize` method.
271 template <typename T, typename... Args>
272 using has_initialize = decltype(std::declval<T>().initialize());
273 template <typename T>
274 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
275
276 /// Initialize the derived pattern by calling its `initialize` method if
277 /// available.
278 template <typename T>
279 static void initializePattern(T &pattern) {
280 if constexpr (detect_has_initialize<T>::value)
281 pattern.initialize();
282 }
283
284 /// An anchor for the virtual table.
285 virtual void anchor();
286};
287
288namespace detail {
289/// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
290/// allows for matching and rewriting against an instance of a derived operation
291/// class or Interface.
292template <typename SourceOp>
294 using RewritePattern::RewritePattern;
295
296 /// Wrapper around the RewritePattern method that passes the derived op type.
297 LogicalResult matchAndRewrite(Operation *op,
298 PatternRewriter &rewriter) const final {
299 return matchAndRewrite(cast<SourceOp>(op), rewriter);
300 }
301
302 /// Method that operates on the SourceOp type. Must be overridden by the
303 /// derived pattern class.
304 virtual LogicalResult matchAndRewrite(SourceOp op,
305 PatternRewriter &rewriter) const = 0;
306};
307} // namespace detail
308
309/// OpRewritePattern is a wrapper around RewritePattern that allows for
310/// matching and rewriting against an instance of a derived operation class as
311/// opposed to a raw Operation.
312template <typename SourceOp>
315 /// Type alias to allow derived classes to inherit constructors with
316 /// `using Base::Base;`.
318
319 /// Patterns must specify the root operation name they match against, and can
320 /// also specify the benefit of the pattern matching and a list of generated
321 /// ops.
323 ArrayRef<StringRef> generatedNames = {})
324 : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
325 SourceOp::getOperationName(), benefit, context, generatedNames) {}
326};
327
328/// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
329/// matching and rewriting against an instance of an operation interface instead
330/// of a raw Operation.
331template <typename SourceOp>
334 /// Type alias to allow derived classes to inherit constructors with
335 /// `using Base::Base;`.
337
339 : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
340 Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
341 benefit, context) {}
342};
343
344/// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
345/// matching and rewriting against instances of an operation that possess a
346/// given trait.
347template <template <typename> class TraitType>
349public:
350 /// Type alias to allow derived classes to inherit constructors with
351 /// `using Base::Base;`.
353
356 benefit, context) {}
357};
358
359//===----------------------------------------------------------------------===//
360// RewriterBase
361//===----------------------------------------------------------------------===//
362
363/// This class coordinates the application of a rewrite on a set of IR,
364/// providing a way for clients to track mutations and create new operations.
365/// This class serves as a common API for IR mutation between pattern rewrites
366/// and non-pattern rewrites, and facilitates the development of shared
367/// IR transformation utilities.
368class RewriterBase : public OpBuilder {
369public:
373
374 /// Notify the listener that the specified block is about to be erased.
375 /// At this point, the block has zero uses.
376 virtual void notifyBlockErased(Block *block) {}
377
378 /// Notify the listener that the specified operation was modified in-place.
380
381 /// Notify the listener that all uses of the specified operation's results
382 /// are about to be replaced with the results of another operation. This is
383 /// called before the uses of the old operation have been changed.
384 ///
385 /// By default, this function calls the "operation replaced with values"
386 /// notification.
389 notifyOperationReplaced(op, replacement->getResults());
390 }
391
392 /// Notify the listener that all uses of the specified operation's results
393 /// are about to be replaced with the a range of values, potentially
394 /// produced by other operations. This is called before the uses of the
395 /// operation have been changed.
398
399 /// Notify the listener that the specified operation is about to be erased.
400 /// At this point, the operation has zero uses.
401 ///
402 /// Note: This notification is not triggered when unlinking an operation.
403 virtual void notifyOperationErased(Operation *op) {}
404
405 /// Notify the listener that the specified pattern is about to be applied
406 /// at the specified root operation.
407 virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
408
409 /// Notify the listener that a pattern application finished with the
410 /// specified status. "success" indicates that the pattern was applied
411 /// successfully. "failure" indicates that the pattern could not be
412 /// applied. The pattern may have communicated the reason for the failure
413 /// with `notifyMatchFailure`.
414 virtual void notifyPatternEnd(const Pattern &pattern,
415 LogicalResult status) {}
416
417 /// Notify the listener that the pattern failed to match, and provide a
418 /// callback to populate a diagnostic with the reason why the failure
419 /// occurred. This method allows for derived listeners to optionally hook
420 /// into the reason why a rewrite failed, and display it to users.
421 virtual void
423 function_ref<void(Diagnostic &)> reasonCallback) {}
424
425 static bool classof(const OpBuilder::Listener *base);
426 };
427
428 /// A listener that forwards all notifications to another listener. This
429 /// struct can be used as a base to create listener chains, so that multiple
430 /// listeners can be notified of IR changes.
433 : listener(listener),
434 rewriteListener(
435 dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
436
437 void notifyOperationInserted(Operation *op, InsertPoint previous) override {
438 if (listener)
439 listener->notifyOperationInserted(op, previous);
440 }
441 void notifyBlockInserted(Block *block, Region *previous,
442 Region::iterator previousIt) override {
443 if (listener)
444 listener->notifyBlockInserted(block, previous, previousIt);
445 }
446 void notifyBlockErased(Block *block) override {
447 if (rewriteListener)
448 rewriteListener->notifyBlockErased(block);
449 }
451 if (rewriteListener)
452 rewriteListener->notifyOperationModified(op);
453 }
454 void notifyOperationReplaced(Operation *op, Operation *newOp) override {
455 if (rewriteListener)
456 rewriteListener->notifyOperationReplaced(op, newOp);
457 }
459 ValueRange replacement) override {
460 if (rewriteListener)
461 rewriteListener->notifyOperationReplaced(op, replacement);
462 }
463 void notifyOperationErased(Operation *op) override {
464 if (rewriteListener)
465 rewriteListener->notifyOperationErased(op);
466 }
467 void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
468 if (rewriteListener)
469 rewriteListener->notifyPatternBegin(pattern, op);
470 }
471 void notifyPatternEnd(const Pattern &pattern,
472 LogicalResult status) override {
473 if (rewriteListener)
474 rewriteListener->notifyPatternEnd(pattern, status);
475 }
477 Location loc,
478 function_ref<void(Diagnostic &)> reasonCallback) override {
479 if (rewriteListener)
480 rewriteListener->notifyMatchFailure(loc, reasonCallback);
481 }
482
483 private:
485 RewriterBase::Listener *rewriteListener;
486 };
487
488 /// A listener that logs notification events to llvm::dbgs() before
489 /// forwarding to the base listener.
491 PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
492 : RewriterBase::ForwardingListener(listener), patternName(patternName) {
493 }
494
495 void notifyOperationInserted(Operation *op, InsertPoint previous) override;
496 void notifyOperationModified(Operation *op) override;
497 void notifyOperationReplaced(Operation *op, Operation *newOp) override;
499 ValueRange replacement) override;
500 void notifyOperationErased(Operation *op) override;
501 void notifyPatternBegin(const Pattern &pattern, Operation *op) override;
502
503 private:
504 StringRef patternName;
505 };
506
507 /// Move the blocks that belong to "region" before the given position in
508 /// another region "parent". The two regions must be different. The caller
509 /// is responsible for creating or updating the operation transferring flow
510 /// of control to the region and passing it the correct block arguments.
511 void inlineRegionBefore(Region &region, Region &parent,
512 Region::iterator before);
513 void inlineRegionBefore(Region &region, Block *before);
514
515 /// Replace the results of the given (original) operation with the specified
516 /// list of values (replacements). The result types of the given op and the
517 /// replacements must match. The original op is erased.
518 virtual void replaceOp(Operation *op, ValueRange newValues);
519
520 /// Replace the results of the given (original) operation with the specified
521 /// new op (replacement). The result types of the two ops must match. The
522 /// original op is erased.
523 virtual void replaceOp(Operation *op, Operation *newOp);
524
525 /// Replace the results of the given (original) op with a new op that is
526 /// created without verification (replacement). The result values of the two
527 /// ops must match. The original op is erased.
528 template <typename OpTy, typename... Args>
529 OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
530 auto builder = static_cast<OpBuilder *>(this);
531 auto newOp =
532 OpTy::create(*builder, op->getLoc(), std::forward<Args>(args)...);
533 replaceOp(op, newOp.getOperation());
534 return newOp;
535 }
536
537 /// This method erases an operation that is known to have no uses.
538 ///
539 /// If the current insertion point is before the erased operation, it is
540 /// adjusted to the following operation (or the end of the block). If the
541 /// current insertion point is within the erased operation, the insertion
542 /// point is left in an invalid state.
543 virtual void eraseOp(Operation *op);
544
545 /// This method erases all operations in a block.
546 virtual void eraseBlock(Block *block);
547
548 /// Erase the specified results of the given operation. Results cannot be
549 /// erased directly, so the implementation creates a new replacement
550 /// operation and erases the original operation. The new operation is
551 /// returned.
552 Operation *eraseOpResults(Operation *op, const BitVector &eraseIndices);
553
554 /// Inline the operations of block 'source' into block 'dest' before the given
555 /// position. The source block will be deleted and must have no uses.
556 /// 'argValues' is used to replace the block arguments of 'source'.
557 ///
558 /// If the source block is inserted at the end of the dest block, the dest
559 /// block must have no successors. Similarly, if the source block is inserted
560 /// somewhere in the middle (or beginning) of the dest block, the source block
561 /// must have no successors. Otherwise, the resulting IR would have
562 /// unreachable operations.
563 ///
564 /// If the insertion point is within the source block, it is adjusted to the
565 /// destination block.
566 virtual void inlineBlockBefore(Block *source, Block *dest,
567 Block::iterator before,
568 ValueRange argValues = {});
569
570 /// Inline the operations of block 'source' before the operation 'op'. The
571 /// source block will be deleted and must have no uses. 'argValues' is used to
572 /// replace the block arguments of 'source'
573 ///
574 /// The source block must have no successors. Otherwise, the resulting IR
575 /// would have unreachable operations.
576 ///
577 /// If the insertion point is within the source block, it is adjusted to the
578 /// destination block.
579 void inlineBlockBefore(Block *source, Operation *op,
580 ValueRange argValues = {});
581
582 /// Inline the operations of block 'source' into the end of block 'dest'. The
583 /// source block will be deleted and must have no uses. 'argValues' is used to
584 /// replace the block arguments of 'source'
585 ///
586 /// The dest block must have no successors. Otherwise, the resulting IR would
587 /// have unreachable operation.
588 ///
589 /// If the insertion point is within the source block, it is adjusted to the
590 /// destination block.
591 void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {});
592
593 /// Split the operations starting at "before" (inclusive) out of the given
594 /// block into a new block, and return it.
595 Block *splitBlock(Block *block, Block::iterator before);
596
597 /// Unlink this operation from its current block and insert it right before
598 /// `existingOp` which may be in the same or another block in the same
599 /// function.
600 void moveOpBefore(Operation *op, Operation *existingOp);
601
602 /// Unlink this operation from its current block and insert it right before
603 /// `iterator` in the specified block.
604 void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
605
606 /// Unlink this operation from its current block and insert it right after
607 /// `existingOp` which may be in the same or another block in the same
608 /// function.
609 void moveOpAfter(Operation *op, Operation *existingOp);
610
611 /// Unlink this operation from its current block and insert it right after
612 /// `iterator` in the specified block.
613 void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
614
615 /// Unlink this block and insert it right before `existingBlock`.
616 void moveBlockBefore(Block *block, Block *anotherBlock);
617
618 /// Unlink this block and insert it right before the location that the given
619 /// iterator points to in the given region.
620 void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
621
622 /// This method is used to notify the rewriter that an in-place operation
623 /// modification is about to happen. A call to this function *must* be
624 /// followed by a call to either `finalizeOpModification` or
625 /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
626 /// a new operation and removing the old one) but also often allows simpler
627 /// code in the client.
628 virtual void startOpModification(Operation *op) {}
629
630 /// This method is used to signal the end of an in-place modification of the
631 /// given operation. This can only be called on operations that were provided
632 /// to a call to `startOpModification`.
633 virtual void finalizeOpModification(Operation *op);
634
635 /// This method cancels a pending in-place modification. This can only be
636 /// called on operations that were provided to a call to
637 /// `startOpModification`.
638 virtual void cancelOpModification(Operation *op) {}
639
640 /// This method is a utility wrapper around an in-place modification of an
641 /// operation. It wraps calls to `startOpModification` and
642 /// `finalizeOpModification` around the given callable.
643 template <typename CallableT>
644 void modifyOpInPlace(Operation *root, CallableT &&callable) {
646 callable();
648 }
649
650 /// Find uses of `from` and replace them with `to`. Also notify the listener
651 /// about every in-place op modification (for every use that was replaced).
652 virtual void replaceAllUsesWith(Value from, Value to) {
653 for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
654 Operation *op = operand.getOwner();
655 modifyOpInPlace(op, [&]() { operand.set(to); });
656 }
657 }
658 void replaceAllUsesWith(Block *from, Block *to) {
659 for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
660 Operation *op = operand.getOwner();
661 modifyOpInPlace(op, [&]() { operand.set(to); });
662 }
663 }
665 assert(from.size() == to.size() && "incorrect number of replacements");
666 for (auto it : llvm::zip(from, to))
667 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
668 }
669
670 /// Find uses of `from` and replace them with `to`. Also notify the listener
671 /// about every in-place op modification (for every use that was replaced)
672 /// and that the `from` operation is about to be replaced.
673 ///
674 /// Note: This function cannot be called `replaceAllUsesWith` because the
675 /// overload resolution, when called with an op that can be implicitly
676 /// converted to a Value, would be ambiguous.
679
680 /// Find uses of `from` and replace them with `to` if the `functor` returns
681 /// true. Also notify the listener about every in-place op modification (for
682 /// every use that was replaced). The optional `allUsesReplaced` flag is set
683 /// to "true" if all uses were replaced.
684 virtual void replaceUsesWithIf(Value from, Value to,
685 function_ref<bool(OpOperand &)> functor,
686 bool *allUsesReplaced = nullptr);
688 function_ref<bool(OpOperand &)> functor,
689 bool *allUsesReplaced = nullptr);
690 // Note: This function cannot be called `replaceOpUsesWithIf` because the
691 // overload resolution, when called with an op that can be implicitly
692 // converted to a Value, would be ambiguous.
694 function_ref<bool(OpOperand &)> functor,
695 bool *allUsesReplaced = nullptr) {
696 replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
697 }
698
699 /// Find uses of `from` within `block` and replace them with `to`. Also notify
700 /// the listener about every in-place op modification (for every use that was
701 /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
702 /// uses were replaced.
704 Block *block, bool *allUsesReplaced = nullptr) {
706 op, newValues,
707 [block](OpOperand &use) {
708 return block->getParentOp()->isProperAncestor(use.getOwner());
709 },
710 allUsesReplaced);
711 }
712
713 /// Find uses of `from` and replace them with `to` except if the user is
714 /// `exceptedUser`. Also notify the listener about every in-place op
715 /// modification (for every use that was replaced).
716 void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
717 return replaceUsesWithIf(from, to, [&](OpOperand &use) {
718 Operation *user = use.getOwner();
719 return user != exceptedUser;
720 });
721 }
722 void replaceAllUsesExcept(Value from, Value to,
723 const SmallPtrSetImpl<Operation *> &preservedUsers);
724
725 /// Used to notify the listener that the IR failed to be rewritten because of
726 /// a match failure, and provide a callback to populate a diagnostic with the
727 /// reason why the failure occurred. This method allows for derived rewriters
728 /// to optionally hook into the reason why a rewrite failed, and display it to
729 /// users.
730 template <typename CallbackT>
731 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
732 notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
733 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
734 rewriteListener->notifyMatchFailure(
735 loc, function_ref<void(Diagnostic &)>(reasonCallback));
736 return failure();
737 }
738 template <typename CallbackT>
739 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
740 notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
741 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
742 rewriteListener->notifyMatchFailure(
743 op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
744 return failure();
745 }
746 template <typename ArgT>
747 LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
748 return notifyMatchFailure(std::forward<ArgT>(arg),
749 [&](Diagnostic &diag) { diag << msg; });
750 }
751 template <typename ArgT>
752 LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
753 return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
754 }
755
756protected:
757 /// Initialize the builder.
759 OpBuilder::Listener *listener = nullptr)
760 : OpBuilder(ctx, listener) {}
761 explicit RewriterBase(const OpBuilder &otherBuilder)
762 : OpBuilder(otherBuilder) {}
764 : OpBuilder(op, listener) {}
765 virtual ~RewriterBase();
766
767private:
768 void operator=(const RewriterBase &) = delete;
769 RewriterBase(const RewriterBase &) = delete;
770};
771
772//===----------------------------------------------------------------------===//
773// IRRewriter
774//===----------------------------------------------------------------------===//
775
776/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
777/// providing a way to keep track of the mutations made to the IR. This class
778/// should only be used in situations where another `RewriterBase` instance,
779/// such as a `PatternRewriter`, is not available.
780class IRRewriter : public RewriterBase {
781public:
783 : RewriterBase(ctx, listener) {}
784 explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
786 : RewriterBase(op, listener) {}
787};
788
789//===----------------------------------------------------------------------===//
790// PatternRewriter
791//===----------------------------------------------------------------------===//
792
793/// A special type of `RewriterBase` that coordinates the application of a
794/// rewrite pattern on the current IR being matched, providing a way to keep
795/// track of any mutations made. This class should be used to perform all
796/// necessary IR mutations within a rewrite pattern, as the pattern driver may
797/// be tracking various state that would be invalidated when a mutation takes
798/// place.
800public:
801 explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
803
804 /// A hook used to indicate if the pattern rewriter can recover from failure
805 /// during the rewrite stage of a pattern. For example, if the pattern
806 /// rewriter supports rollback, it may progress smoothly even if IR was
807 /// changed during the rewrite.
808 virtual bool canRecoverFromRewriteFailure() const { return false; }
809};
810
811} // namespace mlir
812
813// Optionally expose PDL pattern matching methods.
814#include "PDLPatternMatch.h.inc"
815
816namespace mlir {
817
818//===----------------------------------------------------------------------===//
819// RewritePatternSet
820//===----------------------------------------------------------------------===//
821
823 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
824
825public:
826 RewritePatternSet(MLIRContext *context) : context(context) {}
827
828 /// Construct a RewritePatternSet populated with the given pattern.
830 std::unique_ptr<RewritePattern> pattern)
831 : context(context) {
832 nativePatterns.emplace_back(std::move(pattern));
833 }
834 RewritePatternSet(PDLPatternModule &&pattern)
835 : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
836
837 MLIRContext *getContext() const { return context; }
838
839 /// Return the native patterns held in this list.
840 NativePatternListT &getNativePatterns() { return nativePatterns; }
841
842 /// Return the PDL patterns held in this list.
843 PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
844
845 /// Clear out all of the held patterns in this list.
846 void clear() {
847 nativePatterns.clear();
848 pdlPatterns.clear();
849 }
850
851 //===--------------------------------------------------------------------===//
852 // 'add' methods for adding patterns to the set.
853 //===--------------------------------------------------------------------===//
854
855 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
856 /// the given arguments. Return a reference to `this` for chaining insertions.
857 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
858 template <typename... Ts, typename ConstructorArg,
859 typename... ConstructorArgs,
860 typename = std::enable_if_t<sizeof...(Ts) != 0>>
861 RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
862 // The following expands a call to emplace_back for each of the pattern
863 // types 'Ts'.
864 (addImpl<Ts>(/*debugLabels=*/{}, std::forward<ConstructorArg>(arg),
865 std::forward<ConstructorArgs>(args)...),
866 ...);
867 return *this;
868 }
869 /// An overload of the above `add` method that allows for attaching a set
870 /// of debug labels to the attached patterns. This is useful for labeling
871 /// groups of patterns that may be shared between multiple different
872 /// passes/users.
873 template <typename... Ts, typename ConstructorArg,
874 typename... ConstructorArgs,
875 typename = std::enable_if_t<sizeof...(Ts) != 0>>
877 ConstructorArg &&arg,
878 ConstructorArgs &&...args) {
879 // The following expands a call to emplace_back for each of the pattern
880 // types 'Ts'.
881 (addImpl<Ts>(debugLabels, std::forward<ConstructorArg>(arg),
882 std::forward<ConstructorArgs>(args)...),
883 ...);
884 return *this;
885 }
886
887 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
888 /// `this` for chaining insertions.
889 template <typename... Ts>
891 (addImpl<Ts>(), ...);
892 return *this;
893 }
894
895 /// Add the given native pattern to the pattern list. Return a reference to
896 /// `this` for chaining insertions.
897 RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
898 nativePatterns.emplace_back(std::move(pattern));
899 return *this;
900 }
901
902 /// Add the given PDL pattern to the pattern list. Return a reference to
903 /// `this` for chaining insertions.
904 RewritePatternSet &add(PDLPatternModule &&pattern) {
905 pdlPatterns.mergeIn(std::move(pattern));
906 return *this;
907 }
908
909 // Add a matchAndRewrite style pattern represented as a C function pointer.
910 template <typename OpType>
912 add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
913 PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
914 struct FnPattern final : public OpRewritePattern<OpType> {
915 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
916 MLIRContext *context, PatternBenefit benefit,
917 ArrayRef<StringRef> generatedNames)
918 : OpRewritePattern<OpType>(context, benefit, generatedNames),
919 implFn(implFn) {}
920
921 LogicalResult matchAndRewrite(OpType op,
922 PatternRewriter &rewriter) const override {
923 return implFn(op, rewriter);
924 }
925
926 private:
927 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
928 };
929 add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
930 generatedNames));
931 return *this;
932 }
933
934 //===--------------------------------------------------------------------===//
935 // Pattern Insertion
936 //===--------------------------------------------------------------------===//
937
938 // TODO: These are soft deprecated in favor of the 'add' methods above.
939
940 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
941 /// the given arguments. Return a reference to `this` for chaining insertions.
942 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
943 template <typename... Ts, typename ConstructorArg,
944 typename... ConstructorArgs,
945 typename = std::enable_if_t<sizeof...(Ts) != 0>>
946 RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
947 // The following expands a call to emplace_back for each of the pattern
948 // types 'Ts'.
949 (addImpl<Ts>(/*debugLabels=*/{}, std::forward<ConstructorArg>(arg),
950 std::forward<ConstructorArgs>(args)...),
951 ...);
952 return *this;
953 }
954
955 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
956 /// `this` for chaining insertions.
957 template <typename... Ts>
959 (addImpl<Ts>(), ...);
960 return *this;
961 }
962
963 /// Add the given native pattern to the pattern list. Return a reference to
964 /// `this` for chaining insertions.
965 RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
966 nativePatterns.emplace_back(std::move(pattern));
967 return *this;
968 }
969
970 /// Add the given PDL pattern to the pattern list. Return a reference to
971 /// `this` for chaining insertions.
972 RewritePatternSet &insert(PDLPatternModule &&pattern) {
973 pdlPatterns.mergeIn(std::move(pattern));
974 return *this;
975 }
976
977 // Add a matchAndRewrite style pattern represented as a C function pointer.
978 template <typename OpType>
980 insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
981 struct FnPattern final : public OpRewritePattern<OpType> {
982 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
983 MLIRContext *context)
984 : OpRewritePattern<OpType>(context), implFn(implFn) {
985 this->setDebugName(llvm::getTypeName<FnPattern>());
986 }
987
988 LogicalResult matchAndRewrite(OpType op,
989 PatternRewriter &rewriter) const override {
990 return implFn(op, rewriter);
991 }
992
993 private:
994 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
995 };
996 add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
997 return *this;
998 }
999
1000private:
1001 /// Add an instance of the pattern type 'T'. Return a reference to `this` for
1002 /// chaining insertions.
1003 template <typename T, typename... Args>
1004 std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
1005 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1006 std::unique_ptr<T> pattern =
1007 RewritePattern::create<T>(std::forward<Args>(args)...);
1008 pattern->addDebugLabels(debugLabels);
1009 nativePatterns.emplace_back(std::move(pattern));
1010 }
1011
1012 template <typename T, typename... Args>
1013 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1014 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1015 // TODO: Add the provided labels to the PDL pattern when PDL supports
1016 // labels.
1017 pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1018 }
1019
1020 MLIRContext *const context;
1021 NativePatternListT nativePatterns;
1022
1023 // Patterns expressed with PDL. This will compile to a stub class when PDL is
1024 // not enabled.
1025 PDLPatternModule pdlPatterns;
1026};
1027
1028} // namespace mlir
1029
1030#endif // MLIR_IR_PATTERNMATCH_H
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
IRRewriter(const OpBuilder &builder)
IRRewriter(Operation *op, OpBuilder::Listener *listener=nullptr)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a saved insertion point.
Definition Builders.h:327
This class helps build Operations.
Definition Builders.h:207
OpBuilder(MLIRContext *ctx, Listener *listener=nullptr)
Create a builder with the given context.
Definition Builders.h:213
Listener * listener
The optional listener for events of this builder.
Definition Builders.h:617
This class represents an operand of an operation.
Definition Value.h:257
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool operator<(const PatternBenefit &rhs) const
bool operator==(const PatternBenefit &rhs) const
static PatternBenefit impossibleToMatch()
bool operator>=(const PatternBenefit &rhs) const
bool operator<=(const PatternBenefit &rhs) const
PatternBenefit & operator=(const PatternBenefit &)=default
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
bool operator!=(const PatternBenefit &rhs) const
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
PatternRewriter(MLIRContext *ctx)
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
void addDebugLabels(StringRef label)
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockListType::iterator iterator
Definition Region.h:52
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet(PDLPatternModule &&pattern)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
void clear()
Clear out all of the held patterns in this list.
RewritePatternSet(MLIRContext *context)
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
MLIRContext * getContext() const
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Inherit the base constructors from Pattern.
virtual ~RewritePattern()=default
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const =0
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
RewriterBase(const OpBuilder &otherBuilder)
void replaceAllUsesWith(ValueRange from, ValueRange to)
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
RewriterBase(Operation *op, OpBuilder::Listener *listener=nullptr)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceAllUsesWith(Block *from, Block *to)
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID getFromOpaquePointer(const void *pointer)
Definition TypeID.h:135
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
AttrTypeReplacer.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
Kind
The kind of listener.
Definition Builders.h:266
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
Definition Builders.h:271
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:285
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This class acts as a special tag that makes the desire to match "any" operation type explicit.
This class acts as a special tag that makes the desire to match any operation that implements a given...
This class acts as a special tag that makes the desire to match any operation that implements a given...
A listener that forwards all notifications to another listener.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override
Notify the listener that a pattern application finished with the specified status.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
ForwardingListener(OpBuilder::Listener *listener)
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notify the listener that the specified block was inserted.
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
static bool classof(const OpBuilder::Listener *base)
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const =0
Method that operates on the SourceOp type.