MLIR 22.0.0git
PatternMatch.h
Go to the documentation of this file.
1//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_PATTERNMATCH_H
10#define MLIR_IR_PATTERNMATCH_H
11
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "llvm/ADT/FunctionExtras.h"
15#include "llvm/Support/TypeName.h"
16#include <optional>
17
19namespace mlir {
20
21class PatternRewriter;
22
23//===----------------------------------------------------------------------===//
24// PatternBenefit class
25//===----------------------------------------------------------------------===//
26
27/// This class represents the benefit of a pattern match in a unitless scheme
28/// that ranges from 0 (very little benefit) to 65K. The most common unit to
29/// use here is the "number of operations matched" by the pattern.
30///
31/// This also has a sentinel representation that can be used for patterns that
32/// fail to match.
33///
35 enum { ImpossibleToMatchSentinel = 65535 };
36
37public:
38 PatternBenefit() = default;
39 PatternBenefit(unsigned benefit);
40 PatternBenefit(const PatternBenefit &) = default;
42
44 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
45
46 /// If the corresponding pattern can match, return its benefit. If the
47 // corresponding pattern isImpossibleToMatch() then this aborts.
48 unsigned short getBenefit() const;
49
50 bool operator==(const PatternBenefit &rhs) const {
51 return representation == rhs.representation;
52 }
53 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
54 bool operator<(const PatternBenefit &rhs) const {
55 return representation < rhs.representation;
56 }
57 bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
58 bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
59 bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
60
61private:
62 unsigned short representation{ImpossibleToMatchSentinel};
63};
64
65//===----------------------------------------------------------------------===//
66// Pattern
67//===----------------------------------------------------------------------===//
68
69/// This class contains all of the data related to a pattern, but does not
70/// contain any methods or logic for the actual matching. This class is solely
71/// used to interface with the metadata of a pattern, such as the benefit or
72/// root operation.
73class Pattern {
74 /// This enum represents the kind of value used to select the root operations
75 /// that match this pattern.
76 enum class RootKind {
77 /// The pattern root matches "any" operation.
78 Any,
79 /// The pattern root is matched using a concrete operation name.
80 OperationName,
81 /// The pattern root is matched using an interface ID.
82 InterfaceID,
83 /// The patter root is matched using a trait ID.
84 TraitID
85 };
86
87public:
88 /// Return a list of operations that may be generated when rewriting an
89 /// operation instance with this pattern.
90 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
91
92 /// Return the root node that this pattern matches. Patterns that can match
93 /// multiple root types return std::nullopt.
94 std::optional<OperationName> getRootKind() const {
95 if (rootKind == RootKind::OperationName)
97 return std::nullopt;
98 }
99
100 /// Return the interface ID used to match the root operation of this pattern.
101 /// If the pattern does not use an interface ID for deciding the root match,
102 /// this returns std::nullopt.
103 std::optional<TypeID> getRootInterfaceID() const {
104 if (rootKind == RootKind::InterfaceID)
105 return TypeID::getFromOpaquePointer(rootValue);
106 return std::nullopt;
107 }
108
109 /// Return the trait ID used to match the root operation of this pattern.
110 /// If the pattern does not use a trait ID for deciding the root match, this
111 /// returns std::nullopt.
112 std::optional<TypeID> getRootTraitID() const {
113 if (rootKind == RootKind::TraitID)
114 return TypeID::getFromOpaquePointer(rootValue);
115 return std::nullopt;
116 }
117
118 /// Return the benefit (the inverse of "cost") of matching this pattern. The
119 /// benefit of a Pattern is always static - rewrites that may have dynamic
120 /// benefit can be instantiated multiple times (different Pattern instances)
121 /// for each benefit that they may return, and be guarded by different match
122 /// condition predicates.
123 PatternBenefit getBenefit() const { return benefit; }
124
125 /// Returns true if this pattern is known to result in recursive application,
126 /// i.e. this pattern may generate IR that also matches this pattern, but is
127 /// known to bound the recursion. This signals to a rewrite driver that it is
128 /// safe to apply this pattern recursively to generated IR.
130 return contextAndHasBoundedRecursion.getInt();
131 }
132
133 /// Return the MLIRContext used to create this pattern.
135 return contextAndHasBoundedRecursion.getPointer();
136 }
137
138 /// Return a readable name for this pattern. This name should only be used for
139 /// debugging purposes, and may be empty.
140 StringRef getDebugName() const { return debugName; }
141
142 /// Set the human readable debug name used for this pattern. This name will
143 /// only be used for debugging purposes.
144 void setDebugName(StringRef name) { debugName = name; }
145
146 /// Return the set of debug labels attached to this pattern.
147 ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
148
149 /// Add the provided debug labels to this pattern.
151 debugLabels.append(labels.begin(), labels.end());
152 }
153 void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
154
155protected:
156 /// This class acts as a special tag that makes the desire to match "any"
157 /// operation type explicit. This helps to avoid unnecessary usages of this
158 /// feature, and ensures that the user is making a conscious decision.
160 /// This class acts as a special tag that makes the desire to match any
161 /// operation that implements a given interface explicit. This helps to avoid
162 /// unnecessary usages of this feature, and ensures that the user is making a
163 /// conscious decision.
165 /// This class acts as a special tag that makes the desire to match any
166 /// operation that implements a given trait explicit. This helps to avoid
167 /// unnecessary usages of this feature, and ensures that the user is making a
168 /// conscious decision.
170
171 /// Construct a pattern with a certain benefit that matches the operation
172 /// with the given root name.
173 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
174 ArrayRef<StringRef> generatedNames = {});
175 /// Construct a pattern that may match any operation type. `generatedNames`
176 /// contains the names of operations that may be generated during a successful
177 /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
178 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
179 /// always be supplied here.
180 Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
181 ArrayRef<StringRef> generatedNames = {});
182 /// Construct a pattern that may match any operation that implements the
183 /// interface defined by the provided `interfaceID`. `generatedNames` contains
184 /// the names of operations that may be generated during a successful rewrite.
185 /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
186 /// interface" behavior is what the user actually desired,
187 /// `MatchInterfaceOpTypeTag()` should always be supplied here.
188 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189 PatternBenefit benefit, MLIRContext *context,
190 ArrayRef<StringRef> generatedNames = {});
191 /// Construct a pattern that may match any operation that implements the
192 /// trait defined by the provided `traitID`. `generatedNames` contains the
193 /// names of operations that may be generated during a successful rewrite.
194 /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
195 /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
196 /// always be supplied here.
197 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198 MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
199
200 /// Set the flag detailing if this pattern has bounded rewrite recursion or
201 /// not.
202 void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
203 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
204 }
205
206private:
207 Pattern(const void *rootValue, RootKind rootKind,
208 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
209 MLIRContext *context);
210
211 /// The value used to match the root operation of the pattern.
212 const void *rootValue;
213 RootKind rootKind;
214
215 /// The expected benefit of matching this pattern.
216 const PatternBenefit benefit;
217
218 /// The context this pattern was created from, and a boolean flag indicating
219 /// whether this pattern has bounded recursion or not.
220 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
221
222 /// A list of the potential operations that may be generated when rewriting
223 /// an op with this pattern.
225
226 /// A readable name for this pattern. May be empty.
227 StringRef debugName;
228
229 /// The set of debug labels attached to this pattern.
230 SmallVector<StringRef, 0> debugLabels;
231};
232
233//===----------------------------------------------------------------------===//
234// RewritePattern
235//===----------------------------------------------------------------------===//
236
237/// RewritePattern is the common base class for all DAG to DAG replacements.
238class RewritePattern : public Pattern {
239public:
240 virtual ~RewritePattern() = default;
241
242 /// Attempt to match against code rooted at the specified operation,
243 /// which is the same operation code as getRootKind(). If successful, perform
244 /// the rewrite.
245 ///
246 /// Note: Implementations must modify the IR if and only if the function
247 /// returns "success".
248 virtual LogicalResult matchAndRewrite(Operation *op,
249 PatternRewriter &rewriter) const = 0;
250
251 /// This method provides a convenient interface for creating and initializing
252 /// derived rewrite patterns of the given type `T`.
253 template <typename T, typename... Args>
254 static std::unique_ptr<T> create(Args &&...args) {
255 std::unique_ptr<T> pattern =
256 std::make_unique<T>(std::forward<Args>(args)...);
257 initializePattern<T>(*pattern);
258
259 // Set a default debug name if one wasn't provided.
260 if (pattern->getDebugName().empty())
261 pattern->setDebugName(llvm::getTypeName<T>());
262 return pattern;
263 }
264
265protected:
266 /// Inherit the base constructors from `Pattern`.
267 using Pattern::Pattern;
268
269private:
270 /// Trait to check if T provides a `initialize` method.
271 template <typename T, typename... Args>
272 using has_initialize = decltype(std::declval<T>().initialize());
273 template <typename T>
274 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
275
276 /// Initialize the derived pattern by calling its `initialize` method if
277 /// available.
278 template <typename T>
279 static void initializePattern(T &pattern) {
280 if constexpr (detect_has_initialize<T>::value)
281 pattern.initialize();
282 }
283
284 /// An anchor for the virtual table.
285 virtual void anchor();
286};
287
288namespace detail {
289/// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
290/// allows for matching and rewriting against an instance of a derived operation
291/// class or Interface.
292template <typename SourceOp>
294 using RewritePattern::RewritePattern;
295
296 /// Wrapper around the RewritePattern method that passes the derived op type.
297 LogicalResult matchAndRewrite(Operation *op,
298 PatternRewriter &rewriter) const final {
299 return matchAndRewrite(cast<SourceOp>(op), rewriter);
300 }
301
302 /// Method that operates on the SourceOp type. Must be overridden by the
303 /// derived pattern class.
304 virtual LogicalResult matchAndRewrite(SourceOp op,
305 PatternRewriter &rewriter) const = 0;
306};
307} // namespace detail
308
309/// OpRewritePattern is a wrapper around RewritePattern that allows for
310/// matching and rewriting against an instance of a derived operation class as
311/// opposed to a raw Operation.
312template <typename SourceOp>
315 /// Type alias to allow derived classes to inherit constructors with
316 /// `using Base::Base;`.
318
319 /// Patterns must specify the root operation name they match against, and can
320 /// also specify the benefit of the pattern matching and a list of generated
321 /// ops.
323 ArrayRef<StringRef> generatedNames = {})
324 : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
325 SourceOp::getOperationName(), benefit, context, generatedNames) {}
326};
327
328/// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
329/// matching and rewriting against an instance of an operation interface instead
330/// of a raw Operation.
331template <typename SourceOp>
334 /// Type alias to allow derived classes to inherit constructors with
335 /// `using Base::Base;`.
337
339 : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
340 Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
341 benefit, context) {}
342};
343
344/// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
345/// matching and rewriting against instances of an operation that possess a
346/// given trait.
347template <template <typename> class TraitType>
349public:
350 /// Type alias to allow derived classes to inherit constructors with
351 /// `using Base::Base;`.
353
356 benefit, context) {}
357};
358
359//===----------------------------------------------------------------------===//
360// RewriterBase
361//===----------------------------------------------------------------------===//
362
363/// This class coordinates the application of a rewrite on a set of IR,
364/// providing a way for clients to track mutations and create new operations.
365/// This class serves as a common API for IR mutation between pattern rewrites
366/// and non-pattern rewrites, and facilitates the development of shared
367/// IR transformation utilities.
368class RewriterBase : public OpBuilder {
369public:
373
374 /// Notify the listener that the specified block is about to be erased.
375 /// At this point, the block has zero uses.
376 virtual void notifyBlockErased(Block *block) {}
377
378 /// Notify the listener that the specified operation was modified in-place.
380
381 /// Notify the listener that all uses of the specified operation's results
382 /// are about to be replaced with the results of another operation. This is
383 /// called before the uses of the old operation have been changed.
384 ///
385 /// By default, this function calls the "operation replaced with values"
386 /// notification.
389 notifyOperationReplaced(op, replacement->getResults());
390 }
391
392 /// Notify the listener that all uses of the specified operation's results
393 /// are about to be replaced with the a range of values, potentially
394 /// produced by other operations. This is called before the uses of the
395 /// operation have been changed.
398
399 /// Notify the listener that the specified operation is about to be erased.
400 /// At this point, the operation has zero uses.
401 ///
402 /// Note: This notification is not triggered when unlinking an operation.
403 virtual void notifyOperationErased(Operation *op) {}
404
405 /// Notify the listener that the specified pattern is about to be applied
406 /// at the specified root operation.
407 virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
408
409 /// Notify the listener that a pattern application finished with the
410 /// specified status. "success" indicates that the pattern was applied
411 /// successfully. "failure" indicates that the pattern could not be
412 /// applied. The pattern may have communicated the reason for the failure
413 /// with `notifyMatchFailure`.
414 virtual void notifyPatternEnd(const Pattern &pattern,
415 LogicalResult status) {}
416
417 /// Notify the listener that the pattern failed to match, and provide a
418 /// callback to populate a diagnostic with the reason why the failure
419 /// occurred. This method allows for derived listeners to optionally hook
420 /// into the reason why a rewrite failed, and display it to users.
421 virtual void
423 function_ref<void(Diagnostic &)> reasonCallback) {}
424
425 static bool classof(const OpBuilder::Listener *base);
426 };
427
428 /// A listener that forwards all notifications to another listener. This
429 /// struct can be used as a base to create listener chains, so that multiple
430 /// listeners can be notified of IR changes.
433 : listener(listener),
434 rewriteListener(
435 dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
436
437 void notifyOperationInserted(Operation *op, InsertPoint previous) override {
438 if (listener)
439 listener->notifyOperationInserted(op, previous);
440 }
441 void notifyBlockInserted(Block *block, Region *previous,
442 Region::iterator previousIt) override {
443 if (listener)
444 listener->notifyBlockInserted(block, previous, previousIt);
445 }
446 void notifyBlockErased(Block *block) override {
447 if (rewriteListener)
448 rewriteListener->notifyBlockErased(block);
449 }
451 if (rewriteListener)
452 rewriteListener->notifyOperationModified(op);
453 }
454 void notifyOperationReplaced(Operation *op, Operation *newOp) override {
455 if (rewriteListener)
456 rewriteListener->notifyOperationReplaced(op, newOp);
457 }
459 ValueRange replacement) override {
460 if (rewriteListener)
461 rewriteListener->notifyOperationReplaced(op, replacement);
462 }
463 void notifyOperationErased(Operation *op) override {
464 if (rewriteListener)
465 rewriteListener->notifyOperationErased(op);
466 }
467 void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
468 if (rewriteListener)
469 rewriteListener->notifyPatternBegin(pattern, op);
470 }
471 void notifyPatternEnd(const Pattern &pattern,
472 LogicalResult status) override {
473 if (rewriteListener)
474 rewriteListener->notifyPatternEnd(pattern, status);
475 }
477 Location loc,
478 function_ref<void(Diagnostic &)> reasonCallback) override {
479 if (rewriteListener)
480 rewriteListener->notifyMatchFailure(loc, reasonCallback);
481 }
482
483 private:
485 RewriterBase::Listener *rewriteListener;
486 };
487
488 /// A listener that logs notification events to llvm::dbgs() before
489 /// forwarding to the base listener.
491 PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
492 : RewriterBase::ForwardingListener(listener), patternName(patternName) {
493 }
494
495 void notifyOperationInserted(Operation *op, InsertPoint previous) override;
496 void notifyOperationModified(Operation *op) override;
497 void notifyOperationReplaced(Operation *op, Operation *newOp) override;
499 ValueRange replacement) override;
500 void notifyOperationErased(Operation *op) override;
501 void notifyPatternBegin(const Pattern &pattern, Operation *op) override;
502
503 private:
504 StringRef patternName;
505 };
506
507 /// Move the blocks that belong to "region" before the given position in
508 /// another region "parent". The two regions must be different. The caller
509 /// is responsible for creating or updating the operation transferring flow
510 /// of control to the region and passing it the correct block arguments.
511 void inlineRegionBefore(Region &region, Region &parent,
512 Region::iterator before);
513 void inlineRegionBefore(Region &region, Block *before);
514
515 /// Replace the results of the given (original) operation with the specified
516 /// list of values (replacements). The result types of the given op and the
517 /// replacements must match. The original op is erased.
518 virtual void replaceOp(Operation *op, ValueRange newValues);
519
520 /// Replace the results of the given (original) operation with the specified
521 /// new op (replacement). The result types of the two ops must match. The
522 /// original op is erased.
523 virtual void replaceOp(Operation *op, Operation *newOp);
524
525 /// Replace the results of the given (original) op with a new op that is
526 /// created without verification (replacement). The result values of the two
527 /// ops must match. The original op is erased.
528 template <typename OpTy, typename... Args>
529 OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
530 auto builder = static_cast<OpBuilder *>(this);
531 auto newOp =
532 OpTy::create(*builder, op->getLoc(), std::forward<Args>(args)...);
533 replaceOp(op, newOp.getOperation());
534 return newOp;
535 }
536
537 /// This method erases an operation that is known to have no uses.
538 ///
539 /// If the current insertion point is before the erased operation, it is
540 /// adjusted to the following operation (or the end of the block). If the
541 /// current insertion point is within the erased operation, the insertion
542 /// point is left in an invalid state.
543 virtual void eraseOp(Operation *op);
544
545 /// This method erases all operations in a block.
546 virtual void eraseBlock(Block *block);
547
548 /// Inline the operations of block 'source' into block 'dest' before the given
549 /// position. The source block will be deleted and must have no uses.
550 /// 'argValues' is used to replace the block arguments of 'source'.
551 ///
552 /// If the source block is inserted at the end of the dest block, the dest
553 /// block must have no successors. Similarly, if the source block is inserted
554 /// somewhere in the middle (or beginning) of the dest block, the source block
555 /// must have no successors. Otherwise, the resulting IR would have
556 /// unreachable operations.
557 ///
558 /// If the insertion point is within the source block, it is adjusted to the
559 /// destination block.
560 virtual void inlineBlockBefore(Block *source, Block *dest,
561 Block::iterator before,
562 ValueRange argValues = {});
563
564 /// Inline the operations of block 'source' before the operation 'op'. The
565 /// source block will be deleted and must have no uses. 'argValues' is used to
566 /// replace the block arguments of 'source'
567 ///
568 /// The source block must have no successors. Otherwise, the resulting IR
569 /// would have unreachable operations.
570 ///
571 /// If the insertion point is within the source block, it is adjusted to the
572 /// destination block.
573 void inlineBlockBefore(Block *source, Operation *op,
574 ValueRange argValues = {});
575
576 /// Inline the operations of block 'source' into the end of block 'dest'. The
577 /// source block will be deleted and must have no uses. 'argValues' is used to
578 /// replace the block arguments of 'source'
579 ///
580 /// The dest block must have no successors. Otherwise, the resulting IR would
581 /// have unreachable operation.
582 ///
583 /// If the insertion point is within the source block, it is adjusted to the
584 /// destination block.
585 void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {});
586
587 /// Split the operations starting at "before" (inclusive) out of the given
588 /// block into a new block, and return it.
589 Block *splitBlock(Block *block, Block::iterator before);
590
591 /// Unlink this operation from its current block and insert it right before
592 /// `existingOp` which may be in the same or another block in the same
593 /// function.
594 void moveOpBefore(Operation *op, Operation *existingOp);
595
596 /// Unlink this operation from its current block and insert it right before
597 /// `iterator` in the specified block.
598 void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
599
600 /// Unlink this operation from its current block and insert it right after
601 /// `existingOp` which may be in the same or another block in the same
602 /// function.
603 void moveOpAfter(Operation *op, Operation *existingOp);
604
605 /// Unlink this operation from its current block and insert it right after
606 /// `iterator` in the specified block.
607 void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
608
609 /// Unlink this block and insert it right before `existingBlock`.
610 void moveBlockBefore(Block *block, Block *anotherBlock);
611
612 /// Unlink this block and insert it right before the location that the given
613 /// iterator points to in the given region.
614 void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
615
616 /// This method is used to notify the rewriter that an in-place operation
617 /// modification is about to happen. A call to this function *must* be
618 /// followed by a call to either `finalizeOpModification` or
619 /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
620 /// a new operation and removing the old one) but also often allows simpler
621 /// code in the client.
622 virtual void startOpModification(Operation *op) {}
623
624 /// This method is used to signal the end of an in-place modification of the
625 /// given operation. This can only be called on operations that were provided
626 /// to a call to `startOpModification`.
627 virtual void finalizeOpModification(Operation *op);
628
629 /// This method cancels a pending in-place modification. This can only be
630 /// called on operations that were provided to a call to
631 /// `startOpModification`.
632 virtual void cancelOpModification(Operation *op) {}
633
634 /// This method is a utility wrapper around an in-place modification of an
635 /// operation. It wraps calls to `startOpModification` and
636 /// `finalizeOpModification` around the given callable.
637 template <typename CallableT>
638 void modifyOpInPlace(Operation *root, CallableT &&callable) {
640 callable();
642 }
643
644 /// Find uses of `from` and replace them with `to`. Also notify the listener
645 /// about every in-place op modification (for every use that was replaced).
646 virtual void replaceAllUsesWith(Value from, Value to) {
647 for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
648 Operation *op = operand.getOwner();
649 modifyOpInPlace(op, [&]() { operand.set(to); });
650 }
651 }
652 void replaceAllUsesWith(Block *from, Block *to) {
653 for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
654 Operation *op = operand.getOwner();
655 modifyOpInPlace(op, [&]() { operand.set(to); });
656 }
657 }
659 assert(from.size() == to.size() && "incorrect number of replacements");
660 for (auto it : llvm::zip(from, to))
661 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
662 }
663
664 /// Find uses of `from` and replace them with `to`. Also notify the listener
665 /// about every in-place op modification (for every use that was replaced)
666 /// and that the `from` operation is about to be replaced.
667 ///
668 /// Note: This function cannot be called `replaceAllUsesWith` because the
669 /// overload resolution, when called with an op that can be implicitly
670 /// converted to a Value, would be ambiguous.
673
674 /// Find uses of `from` and replace them with `to` if the `functor` returns
675 /// true. Also notify the listener about every in-place op modification (for
676 /// every use that was replaced). The optional `allUsesReplaced` flag is set
677 /// to "true" if all uses were replaced.
678 void replaceUsesWithIf(Value from, Value to,
679 function_ref<bool(OpOperand &)> functor,
680 bool *allUsesReplaced = nullptr);
682 function_ref<bool(OpOperand &)> functor,
683 bool *allUsesReplaced = nullptr);
684 // Note: This function cannot be called `replaceOpUsesWithIf` because the
685 // overload resolution, when called with an op that can be implicitly
686 // converted to a Value, would be ambiguous.
688 function_ref<bool(OpOperand &)> functor,
689 bool *allUsesReplaced = nullptr) {
690 replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
691 }
692
693 /// Find uses of `from` within `block` and replace them with `to`. Also notify
694 /// the listener about every in-place op modification (for every use that was
695 /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
696 /// uses were replaced.
698 Block *block, bool *allUsesReplaced = nullptr) {
700 op, newValues,
701 [block](OpOperand &use) {
702 return block->getParentOp()->isProperAncestor(use.getOwner());
703 },
704 allUsesReplaced);
705 }
706
707 /// Find uses of `from` and replace them with `to` except if the user is
708 /// `exceptedUser`. Also notify the listener about every in-place op
709 /// modification (for every use that was replaced).
710 void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
711 return replaceUsesWithIf(from, to, [&](OpOperand &use) {
712 Operation *user = use.getOwner();
713 return user != exceptedUser;
714 });
715 }
716 void replaceAllUsesExcept(Value from, Value to,
717 const SmallPtrSetImpl<Operation *> &preservedUsers);
718
719 /// Used to notify the listener that the IR failed to be rewritten because of
720 /// a match failure, and provide a callback to populate a diagnostic with the
721 /// reason why the failure occurred. This method allows for derived rewriters
722 /// to optionally hook into the reason why a rewrite failed, and display it to
723 /// users.
724 template <typename CallbackT>
725 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
726 notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
727 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
728 rewriteListener->notifyMatchFailure(
729 loc, function_ref<void(Diagnostic &)>(reasonCallback));
730 return failure();
731 }
732 template <typename CallbackT>
733 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
734 notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
735 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
736 rewriteListener->notifyMatchFailure(
737 op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
738 return failure();
739 }
740 template <typename ArgT>
741 LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
742 return notifyMatchFailure(std::forward<ArgT>(arg),
743 [&](Diagnostic &diag) { diag << msg; });
744 }
745 template <typename ArgT>
746 LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
747 return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
748 }
749
750protected:
751 /// Initialize the builder.
753 OpBuilder::Listener *listener = nullptr)
754 : OpBuilder(ctx, listener) {}
755 explicit RewriterBase(const OpBuilder &otherBuilder)
756 : OpBuilder(otherBuilder) {}
758 : OpBuilder(op, listener) {}
759 virtual ~RewriterBase();
760
761private:
762 void operator=(const RewriterBase &) = delete;
763 RewriterBase(const RewriterBase &) = delete;
764};
765
766//===----------------------------------------------------------------------===//
767// IRRewriter
768//===----------------------------------------------------------------------===//
769
770/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
771/// providing a way to keep track of the mutations made to the IR. This class
772/// should only be used in situations where another `RewriterBase` instance,
773/// such as a `PatternRewriter`, is not available.
774class IRRewriter : public RewriterBase {
775public:
777 : RewriterBase(ctx, listener) {}
778 explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
780 : RewriterBase(op, listener) {}
781};
782
783//===----------------------------------------------------------------------===//
784// PatternRewriter
785//===----------------------------------------------------------------------===//
786
787/// A special type of `RewriterBase` that coordinates the application of a
788/// rewrite pattern on the current IR being matched, providing a way to keep
789/// track of any mutations made. This class should be used to perform all
790/// necessary IR mutations within a rewrite pattern, as the pattern driver may
791/// be tracking various state that would be invalidated when a mutation takes
792/// place.
794public:
795 explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
797
798 /// A hook used to indicate if the pattern rewriter can recover from failure
799 /// during the rewrite stage of a pattern. For example, if the pattern
800 /// rewriter supports rollback, it may progress smoothly even if IR was
801 /// changed during the rewrite.
802 virtual bool canRecoverFromRewriteFailure() const { return false; }
803};
804
805} // namespace mlir
806
807// Optionally expose PDL pattern matching methods.
808#include "PDLPatternMatch.h.inc"
809
810namespace mlir {
811
812//===----------------------------------------------------------------------===//
813// RewritePatternSet
814//===----------------------------------------------------------------------===//
815
817 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
818
819public:
820 RewritePatternSet(MLIRContext *context) : context(context) {}
821
822 /// Construct a RewritePatternSet populated with the given pattern.
824 std::unique_ptr<RewritePattern> pattern)
825 : context(context) {
826 nativePatterns.emplace_back(std::move(pattern));
827 }
828 RewritePatternSet(PDLPatternModule &&pattern)
829 : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
830
831 MLIRContext *getContext() const { return context; }
832
833 /// Return the native patterns held in this list.
834 NativePatternListT &getNativePatterns() { return nativePatterns; }
835
836 /// Return the PDL patterns held in this list.
837 PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
838
839 /// Clear out all of the held patterns in this list.
840 void clear() {
841 nativePatterns.clear();
842 pdlPatterns.clear();
843 }
844
845 //===--------------------------------------------------------------------===//
846 // 'add' methods for adding patterns to the set.
847 //===--------------------------------------------------------------------===//
848
849 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
850 /// the given arguments. Return a reference to `this` for chaining insertions.
851 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
852 template <typename... Ts, typename ConstructorArg,
853 typename... ConstructorArgs,
854 typename = std::enable_if_t<sizeof...(Ts) != 0>>
855 RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
856 // The following expands a call to emplace_back for each of the pattern
857 // types 'Ts'.
858 (addImpl<Ts>(/*debugLabels=*/{}, std::forward<ConstructorArg>(arg),
859 std::forward<ConstructorArgs>(args)...),
860 ...);
861 return *this;
862 }
863 /// An overload of the above `add` method that allows for attaching a set
864 /// of debug labels to the attached patterns. This is useful for labeling
865 /// groups of patterns that may be shared between multiple different
866 /// passes/users.
867 template <typename... Ts, typename ConstructorArg,
868 typename... ConstructorArgs,
869 typename = std::enable_if_t<sizeof...(Ts) != 0>>
871 ConstructorArg &&arg,
872 ConstructorArgs &&...args) {
873 // The following expands a call to emplace_back for each of the pattern
874 // types 'Ts'.
875 (addImpl<Ts>(debugLabels, arg, args...), ...);
876 return *this;
877 }
878
879 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
880 /// `this` for chaining insertions.
881 template <typename... Ts>
883 (addImpl<Ts>(), ...);
884 return *this;
885 }
886
887 /// Add the given native pattern to the pattern list. Return a reference to
888 /// `this` for chaining insertions.
889 RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
890 nativePatterns.emplace_back(std::move(pattern));
891 return *this;
892 }
893
894 /// Add the given PDL pattern to the pattern list. Return a reference to
895 /// `this` for chaining insertions.
896 RewritePatternSet &add(PDLPatternModule &&pattern) {
897 pdlPatterns.mergeIn(std::move(pattern));
898 return *this;
899 }
900
901 // Add a matchAndRewrite style pattern represented as a C function pointer.
902 template <typename OpType>
904 add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
905 PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
906 struct FnPattern final : public OpRewritePattern<OpType> {
907 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
908 MLIRContext *context, PatternBenefit benefit,
909 ArrayRef<StringRef> generatedNames)
910 : OpRewritePattern<OpType>(context, benefit, generatedNames),
911 implFn(implFn) {}
912
913 LogicalResult matchAndRewrite(OpType op,
914 PatternRewriter &rewriter) const override {
915 return implFn(op, rewriter);
916 }
917
918 private:
919 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
920 };
921 add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
922 generatedNames));
923 return *this;
924 }
925
926 //===--------------------------------------------------------------------===//
927 // Pattern Insertion
928 //===--------------------------------------------------------------------===//
929
930 // TODO: These are soft deprecated in favor of the 'add' methods above.
931
932 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
933 /// the given arguments. Return a reference to `this` for chaining insertions.
934 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
935 template <typename... Ts, typename ConstructorArg,
936 typename... ConstructorArgs,
937 typename = std::enable_if_t<sizeof...(Ts) != 0>>
938 RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
939 // The following expands a call to emplace_back for each of the pattern
940 // types 'Ts'.
941 (addImpl<Ts>(/*debugLabels=*/{}, arg, args...), ...);
942 return *this;
943 }
944
945 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
946 /// `this` for chaining insertions.
947 template <typename... Ts>
949 (addImpl<Ts>(), ...);
950 return *this;
951 }
952
953 /// Add the given native pattern to the pattern list. Return a reference to
954 /// `this` for chaining insertions.
955 RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
956 nativePatterns.emplace_back(std::move(pattern));
957 return *this;
958 }
959
960 /// Add the given PDL pattern to the pattern list. Return a reference to
961 /// `this` for chaining insertions.
962 RewritePatternSet &insert(PDLPatternModule &&pattern) {
963 pdlPatterns.mergeIn(std::move(pattern));
964 return *this;
965 }
966
967 // Add a matchAndRewrite style pattern represented as a C function pointer.
968 template <typename OpType>
970 insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
971 struct FnPattern final : public OpRewritePattern<OpType> {
972 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
973 MLIRContext *context)
974 : OpRewritePattern<OpType>(context), implFn(implFn) {
975 this->setDebugName(llvm::getTypeName<FnPattern>());
976 }
977
978 LogicalResult matchAndRewrite(OpType op,
979 PatternRewriter &rewriter) const override {
980 return implFn(op, rewriter);
981 }
982
983 private:
984 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
985 };
986 add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
987 return *this;
988 }
989
990private:
991 /// Add an instance of the pattern type 'T'. Return a reference to `this` for
992 /// chaining insertions.
993 template <typename T, typename... Args>
994 std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
995 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
996 std::unique_ptr<T> pattern =
997 RewritePattern::create<T>(std::forward<Args>(args)...);
998 pattern->addDebugLabels(debugLabels);
999 nativePatterns.emplace_back(std::move(pattern));
1000 }
1001
1002 template <typename T, typename... Args>
1003 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1004 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1005 // TODO: Add the provided labels to the PDL pattern when PDL supports
1006 // labels.
1007 pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1008 }
1009
1010 MLIRContext *const context;
1011 NativePatternListT nativePatterns;
1012
1013 // Patterns expressed with PDL. This will compile to a stub class when PDL is
1014 // not enabled.
1015 PDLPatternModule pdlPatterns;
1016};
1017
1018} // namespace mlir
1019
1020#endif // MLIR_IR_PATTERNMATCH_H
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
IRRewriter(const OpBuilder &builder)
IRRewriter(Operation *op, OpBuilder::Listener *listener=nullptr)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a saved insertion point.
Definition Builders.h:327
This class helps build Operations.
Definition Builders.h:207
OpBuilder(MLIRContext *ctx, Listener *listener=nullptr)
Create a builder with the given context.
Definition Builders.h:213
Listener * listener
The optional listener for events of this builder.
Definition Builders.h:617
This class represents an operand of an operation.
Definition Value.h:257
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool operator<(const PatternBenefit &rhs) const
bool operator==(const PatternBenefit &rhs) const
static PatternBenefit impossibleToMatch()
bool operator>=(const PatternBenefit &rhs) const
bool operator<=(const PatternBenefit &rhs) const
PatternBenefit & operator=(const PatternBenefit &)=default
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
bool operator!=(const PatternBenefit &rhs) const
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
PatternRewriter(MLIRContext *ctx)
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
void addDebugLabels(StringRef label)
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockListType::iterator iterator
Definition Region.h:52
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet(PDLPatternModule &&pattern)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
void clear()
Clear out all of the held patterns in this list.
RewritePatternSet(MLIRContext *context)
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
MLIRContext * getContext() const
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Inherit the base constructors from Pattern.
virtual ~RewritePattern()=default
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const =0
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
RewriterBase(const OpBuilder &otherBuilder)
void replaceAllUsesWith(ValueRange from, ValueRange to)
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
RewriterBase(Operation *op, OpBuilder::Listener *listener=nullptr)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceAllUsesWith(Block *from, Block *to)
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID getFromOpaquePointer(const void *pointer)
Definition TypeID.h:135
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
AttrTypeReplacer.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
Kind
The kind of listener.
Definition Builders.h:266
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
Definition Builders.h:271
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:285
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This class acts as a special tag that makes the desire to match "any" operation type explicit.
This class acts as a special tag that makes the desire to match any operation that implements a given...
This class acts as a special tag that makes the desire to match any operation that implements a given...
A listener that forwards all notifications to another listener.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override
Notify the listener that a pattern application finished with the specified status.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
ForwardingListener(OpBuilder::Listener *listener)
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notify the listener that the specified block was inserted.
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
static bool classof(const OpBuilder::Listener *base)
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const =0
Method that operates on the SourceOp type.