MLIR  20.0.0git
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 #include <optional>
17 
19 namespace mlir {
20 
21 class PatternRewriter;
22 
23 //===----------------------------------------------------------------------===//
24 // PatternBenefit class
25 //===----------------------------------------------------------------------===//
26 
27 /// This class represents the benefit of a pattern match in a unitless scheme
28 /// that ranges from 0 (very little benefit) to 65K. The most common unit to
29 /// use here is the "number of operations matched" by the pattern.
30 ///
31 /// This also has a sentinel representation that can be used for patterns that
32 /// fail to match.
33 ///
35  enum { ImpossibleToMatchSentinel = 65535 };
36 
37 public:
38  PatternBenefit() = default;
39  PatternBenefit(unsigned benefit);
40  PatternBenefit(const PatternBenefit &) = default;
42 
44  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
45 
46  /// If the corresponding pattern can match, return its benefit. If the
47  // corresponding pattern isImpossibleToMatch() then this aborts.
48  unsigned short getBenefit() const;
49 
50  bool operator==(const PatternBenefit &rhs) const {
51  return representation == rhs.representation;
52  }
53  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
54  bool operator<(const PatternBenefit &rhs) const {
55  return representation < rhs.representation;
56  }
57  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
58  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
59  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
60 
61 private:
62  unsigned short representation{ImpossibleToMatchSentinel};
63 };
64 
65 //===----------------------------------------------------------------------===//
66 // Pattern
67 //===----------------------------------------------------------------------===//
68 
69 /// This class contains all of the data related to a pattern, but does not
70 /// contain any methods or logic for the actual matching. This class is solely
71 /// used to interface with the metadata of a pattern, such as the benefit or
72 /// root operation.
73 class Pattern {
74  /// This enum represents the kind of value used to select the root operations
75  /// that match this pattern.
76  enum class RootKind {
77  /// The pattern root matches "any" operation.
78  Any,
79  /// The pattern root is matched using a concrete operation name.
81  /// The pattern root is matched using an interface ID.
82  InterfaceID,
83  /// The patter root is matched using a trait ID.
84  TraitID
85  };
86 
87 public:
88  /// Return a list of operations that may be generated when rewriting an
89  /// operation instance with this pattern.
90  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
91 
92  /// Return the root node that this pattern matches. Patterns that can match
93  /// multiple root types return std::nullopt.
94  std::optional<OperationName> getRootKind() const {
95  if (rootKind == RootKind::OperationName)
96  return OperationName::getFromOpaquePointer(rootValue);
97  return std::nullopt;
98  }
99 
100  /// Return the interface ID used to match the root operation of this pattern.
101  /// If the pattern does not use an interface ID for deciding the root match,
102  /// this returns std::nullopt.
103  std::optional<TypeID> getRootInterfaceID() const {
104  if (rootKind == RootKind::InterfaceID)
105  return TypeID::getFromOpaquePointer(rootValue);
106  return std::nullopt;
107  }
108 
109  /// Return the trait ID used to match the root operation of this pattern.
110  /// If the pattern does not use a trait ID for deciding the root match, this
111  /// returns std::nullopt.
112  std::optional<TypeID> getRootTraitID() const {
113  if (rootKind == RootKind::TraitID)
114  return TypeID::getFromOpaquePointer(rootValue);
115  return std::nullopt;
116  }
117 
118  /// Return the benefit (the inverse of "cost") of matching this pattern. The
119  /// benefit of a Pattern is always static - rewrites that may have dynamic
120  /// benefit can be instantiated multiple times (different Pattern instances)
121  /// for each benefit that they may return, and be guarded by different match
122  /// condition predicates.
123  PatternBenefit getBenefit() const { return benefit; }
124 
125  /// Returns true if this pattern is known to result in recursive application,
126  /// i.e. this pattern may generate IR that also matches this pattern, but is
127  /// known to bound the recursion. This signals to a rewrite driver that it is
128  /// safe to apply this pattern recursively to generated IR.
130  return contextAndHasBoundedRecursion.getInt();
131  }
132 
133  /// Return the MLIRContext used to create this pattern.
135  return contextAndHasBoundedRecursion.getPointer();
136  }
137 
138  /// Return a readable name for this pattern. This name should only be used for
139  /// debugging purposes, and may be empty.
140  StringRef getDebugName() const { return debugName; }
141 
142  /// Set the human readable debug name used for this pattern. This name will
143  /// only be used for debugging purposes.
144  void setDebugName(StringRef name) { debugName = name; }
145 
146  /// Return the set of debug labels attached to this pattern.
147  ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
148 
149  /// Add the provided debug labels to this pattern.
151  debugLabels.append(labels.begin(), labels.end());
152  }
153  void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
154 
155 protected:
156  /// This class acts as a special tag that makes the desire to match "any"
157  /// operation type explicit. This helps to avoid unnecessary usages of this
158  /// feature, and ensures that the user is making a conscious decision.
159  struct MatchAnyOpTypeTag {};
160  /// This class acts as a special tag that makes the desire to match any
161  /// operation that implements a given interface explicit. This helps to avoid
162  /// unnecessary usages of this feature, and ensures that the user is making a
163  /// conscious decision.
165  /// This class acts as a special tag that makes the desire to match any
166  /// operation that implements a given trait explicit. This helps to avoid
167  /// unnecessary usages of this feature, and ensures that the user is making a
168  /// conscious decision.
170 
171  /// Construct a pattern with a certain benefit that matches the operation
172  /// with the given root name.
173  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
174  ArrayRef<StringRef> generatedNames = {});
175  /// Construct a pattern that may match any operation type. `generatedNames`
176  /// contains the names of operations that may be generated during a successful
177  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
178  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
179  /// always be supplied here.
180  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
181  ArrayRef<StringRef> generatedNames = {});
182  /// Construct a pattern that may match any operation that implements the
183  /// interface defined by the provided `interfaceID`. `generatedNames` contains
184  /// the names of operations that may be generated during a successful rewrite.
185  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
186  /// interface" behavior is what the user actually desired,
187  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
188  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189  PatternBenefit benefit, MLIRContext *context,
190  ArrayRef<StringRef> generatedNames = {});
191  /// Construct a pattern that may match any operation that implements the
192  /// trait defined by the provided `traitID`. `generatedNames` contains the
193  /// names of operations that may be generated during a successful rewrite.
194  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
195  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
196  /// always be supplied here.
197  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198  MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
199 
200  /// Set the flag detailing if this pattern has bounded rewrite recursion or
201  /// not.
202  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
203  contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
204  }
205 
206 private:
207  Pattern(const void *rootValue, RootKind rootKind,
208  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
209  MLIRContext *context);
210 
211  /// The value used to match the root operation of the pattern.
212  const void *rootValue;
213  RootKind rootKind;
214 
215  /// The expected benefit of matching this pattern.
216  const PatternBenefit benefit;
217 
218  /// The context this pattern was created from, and a boolean flag indicating
219  /// whether this pattern has bounded recursion or not.
220  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
221 
222  /// A list of the potential operations that may be generated when rewriting
223  /// an op with this pattern.
224  SmallVector<OperationName, 2> generatedOps;
225 
226  /// A readable name for this pattern. May be empty.
227  StringRef debugName;
228 
229  /// The set of debug labels attached to this pattern.
230  SmallVector<StringRef, 0> debugLabels;
231 };
232 
233 //===----------------------------------------------------------------------===//
234 // RewritePattern
235 //===----------------------------------------------------------------------===//
236 
237 /// RewritePattern is the common base class for all DAG to DAG replacements.
238 /// There are two possible usages of this class:
239 /// * Multi-step RewritePattern with "match" and "rewrite"
240 /// - By overloading the "match" and "rewrite" functions, the user can
241 /// separate the concerns of matching and rewriting.
242 /// * Single-step RewritePattern with "matchAndRewrite"
243 /// - By overloading the "matchAndRewrite" function, the user can perform
244 /// the rewrite in the same call as the match.
245 ///
246 class RewritePattern : public Pattern {
247 public:
248  virtual ~RewritePattern() = default;
249 
250  /// Rewrite the IR rooted at the specified operation with the result of
251  /// this pattern, generating any new operations with the specified
252  /// builder. If an unexpected error is encountered (an internal
253  /// compiler error), it is emitted through the normal MLIR diagnostic
254  /// hooks and the IR is left in a valid state.
255  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
256 
257  /// Attempt to match against code rooted at the specified operation,
258  /// which is the same operation code as getRootKind().
259  virtual LogicalResult match(Operation *op) const;
260 
261  /// Attempt to match against code rooted at the specified operation,
262  /// which is the same operation code as getRootKind(). If successful, this
263  /// function will automatically perform the rewrite.
264  virtual LogicalResult matchAndRewrite(Operation *op,
265  PatternRewriter &rewriter) const {
266  if (succeeded(match(op))) {
267  rewrite(op, rewriter);
268  return success();
269  }
270  return failure();
271  }
272 
273  /// This method provides a convenient interface for creating and initializing
274  /// derived rewrite patterns of the given type `T`.
275  template <typename T, typename... Args>
276  static std::unique_ptr<T> create(Args &&...args) {
277  std::unique_ptr<T> pattern =
278  std::make_unique<T>(std::forward<Args>(args)...);
279  initializePattern<T>(*pattern);
280 
281  // Set a default debug name if one wasn't provided.
282  if (pattern->getDebugName().empty())
283  pattern->setDebugName(llvm::getTypeName<T>());
284  return pattern;
285  }
286 
287 protected:
288  /// Inherit the base constructors from `Pattern`.
289  using Pattern::Pattern;
290 
291 private:
292  /// Trait to check if T provides a `initialize` method.
293  template <typename T, typename... Args>
294  using has_initialize = decltype(std::declval<T>().initialize());
295  template <typename T>
296  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
297 
298  /// Initialize the derived pattern by calling its `initialize` method.
299  template <typename T>
300  static std::enable_if_t<detect_has_initialize<T>::value>
301  initializePattern(T &pattern) {
302  pattern.initialize();
303  }
304  /// Empty derived pattern initializer for patterns that do not have an
305  /// initialize method.
306  template <typename T>
307  static std::enable_if_t<!detect_has_initialize<T>::value>
308  initializePattern(T &) {}
309 
310  /// An anchor for the virtual table.
311  virtual void anchor();
312 };
313 
314 namespace detail {
315 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
316 /// allows for matching and rewriting against an instance of a derived operation
317 /// class or Interface.
318 template <typename SourceOp>
320  using RewritePattern::RewritePattern;
321 
322  /// Wrappers around the RewritePattern methods that pass the derived op type.
323  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
324  rewrite(cast<SourceOp>(op), rewriter);
325  }
326  LogicalResult match(Operation *op) const final {
327  return match(cast<SourceOp>(op));
328  }
329  LogicalResult matchAndRewrite(Operation *op,
330  PatternRewriter &rewriter) const final {
331  return matchAndRewrite(cast<SourceOp>(op), rewriter);
332  }
333 
334  /// Rewrite and Match methods that operate on the SourceOp type. These must be
335  /// overridden by the derived pattern class.
336  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
337  llvm_unreachable("must override rewrite or matchAndRewrite");
338  }
339  virtual LogicalResult match(SourceOp op) const {
340  llvm_unreachable("must override match or matchAndRewrite");
341  }
342  virtual LogicalResult matchAndRewrite(SourceOp op,
343  PatternRewriter &rewriter) const {
344  if (succeeded(match(op))) {
345  rewrite(op, rewriter);
346  return success();
347  }
348  return failure();
349  }
350 };
351 } // namespace detail
352 
353 /// OpRewritePattern is a wrapper around RewritePattern that allows for
354 /// matching and rewriting against an instance of a derived operation class as
355 /// opposed to a raw Operation.
356 template <typename SourceOp>
358  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
359  /// Patterns must specify the root operation name they match against, and can
360  /// also specify the benefit of the pattern matching and a list of generated
361  /// ops.
363  ArrayRef<StringRef> generatedNames = {})
365  SourceOp::getOperationName(), benefit, context, generatedNames) {}
366 };
367 
368 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
369 /// matching and rewriting against an instance of an operation interface instead
370 /// of a raw Operation.
371 template <typename SourceOp>
373  : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
375  : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
376  Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
377  benefit, context) {}
378 };
379 
380 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
381 /// matching and rewriting against instances of an operation that possess a
382 /// given trait.
383 template <template <typename> class TraitType>
385 public:
388  benefit, context) {}
389 };
390 
391 //===----------------------------------------------------------------------===//
392 // RewriterBase
393 //===----------------------------------------------------------------------===//
394 
395 /// This class coordinates the application of a rewrite on a set of IR,
396 /// providing a way for clients to track mutations and create new operations.
397 /// This class serves as a common API for IR mutation between pattern rewrites
398 /// and non-pattern rewrites, and facilitates the development of shared
399 /// IR transformation utilities.
400 class RewriterBase : public OpBuilder {
401 public:
402  struct Listener : public OpBuilder::Listener {
404  : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
405 
406  /// Notify the listener that the specified block is about to be erased.
407  /// At this point, the block has zero uses.
408  virtual void notifyBlockErased(Block *block) {}
409 
410  /// Notify the listener that the specified operation was modified in-place.
411  virtual void notifyOperationModified(Operation *op) {}
412 
413  /// Notify the listener that all uses of the specified operation's results
414  /// are about to be replaced with the results of another operation. This is
415  /// called before the uses of the old operation have been changed.
416  ///
417  /// By default, this function calls the "operation replaced with values"
418  /// notification.
420  Operation *replacement) {
421  notifyOperationReplaced(op, replacement->getResults());
422  }
423 
424  /// Notify the listener that all uses of the specified operation's results
425  /// are about to be replaced with the a range of values, potentially
426  /// produced by other operations. This is called before the uses of the
427  /// operation have been changed.
429  ValueRange replacement) {}
430 
431  /// Notify the listener that the specified operation is about to be erased.
432  /// At this point, the operation has zero uses.
433  ///
434  /// Note: This notification is not triggered when unlinking an operation.
435  virtual void notifyOperationErased(Operation *op) {}
436 
437  /// Notify the listener that the specified pattern is about to be applied
438  /// at the specified root operation.
439  virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
440 
441  /// Notify the listener that a pattern application finished with the
442  /// specified status. "success" indicates that the pattern was applied
443  /// successfully. "failure" indicates that the pattern could not be
444  /// applied. The pattern may have communicated the reason for the failure
445  /// with `notifyMatchFailure`.
446  virtual void notifyPatternEnd(const Pattern &pattern,
447  LogicalResult status) {}
448 
449  /// Notify the listener that the pattern failed to match, and provide a
450  /// callback to populate a diagnostic with the reason why the failure
451  /// occurred. This method allows for derived listeners to optionally hook
452  /// into the reason why a rewrite failed, and display it to users.
453  virtual void
455  function_ref<void(Diagnostic &)> reasonCallback) {}
456 
457  static bool classof(const OpBuilder::Listener *base);
458  };
459 
460  /// A listener that forwards all notifications to another listener. This
461  /// struct can be used as a base to create listener chains, so that multiple
462  /// listeners can be notified of IR changes.
464  ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
465 
466  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
467  listener->notifyOperationInserted(op, previous);
468  }
469  void notifyBlockInserted(Block *block, Region *previous,
470  Region::iterator previousIt) override {
471  listener->notifyBlockInserted(block, previous, previousIt);
472  }
473  void notifyBlockErased(Block *block) override {
474  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
475  rewriteListener->notifyBlockErased(block);
476  }
477  void notifyOperationModified(Operation *op) override {
478  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
479  rewriteListener->notifyOperationModified(op);
480  }
481  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
482  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
483  rewriteListener->notifyOperationReplaced(op, newOp);
484  }
486  ValueRange replacement) override {
487  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
488  rewriteListener->notifyOperationReplaced(op, replacement);
489  }
490  void notifyOperationErased(Operation *op) override {
491  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
492  rewriteListener->notifyOperationErased(op);
493  }
494  void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
495  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
496  rewriteListener->notifyPatternBegin(pattern, op);
497  }
498  void notifyPatternEnd(const Pattern &pattern,
499  LogicalResult status) override {
500  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
501  rewriteListener->notifyPatternEnd(pattern, status);
502  }
504  Location loc,
505  function_ref<void(Diagnostic &)> reasonCallback) override {
506  if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
507  rewriteListener->notifyMatchFailure(loc, reasonCallback);
508  }
509 
510  private:
511  OpBuilder::Listener *listener;
512  };
513 
514  /// Move the blocks that belong to "region" before the given position in
515  /// another region "parent". The two regions must be different. The caller
516  /// is responsible for creating or updating the operation transferring flow
517  /// of control to the region and passing it the correct block arguments.
518  void inlineRegionBefore(Region &region, Region &parent,
519  Region::iterator before);
520  void inlineRegionBefore(Region &region, Block *before);
521 
522  /// Replace the results of the given (original) operation with the specified
523  /// list of values (replacements). The result types of the given op and the
524  /// replacements must match. The original op is erased.
525  virtual void replaceOp(Operation *op, ValueRange newValues);
526 
527  /// Replace the results of the given (original) operation with the specified
528  /// new op (replacement). The result types of the two ops must match. The
529  /// original op is erased.
530  virtual void replaceOp(Operation *op, Operation *newOp);
531 
532  /// Replace the results of the given (original) op with a new op that is
533  /// created without verification (replacement). The result values of the two
534  /// ops must match. The original op is erased.
535  template <typename OpTy, typename... Args>
536  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
537  auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
538  replaceOp(op, newOp.getOperation());
539  return newOp;
540  }
541 
542  /// This method erases an operation that is known to have no uses.
543  virtual void eraseOp(Operation *op);
544 
545  /// This method erases all operations in a block.
546  virtual void eraseBlock(Block *block);
547 
548  /// Inline the operations of block 'source' into block 'dest' before the given
549  /// position. The source block will be deleted and must have no uses.
550  /// 'argValues' is used to replace the block arguments of 'source'.
551  ///
552  /// If the source block is inserted at the end of the dest block, the dest
553  /// block must have no successors. Similarly, if the source block is inserted
554  /// somewhere in the middle (or beginning) of the dest block, the source block
555  /// must have no successors. Otherwise, the resulting IR would have
556  /// unreachable operations.
557  virtual void inlineBlockBefore(Block *source, Block *dest,
558  Block::iterator before,
559  ValueRange argValues = std::nullopt);
560 
561  /// Inline the operations of block 'source' before the operation 'op'. The
562  /// source block will be deleted and must have no uses. 'argValues' is used to
563  /// replace the block arguments of 'source'
564  ///
565  /// The source block must have no successors. Otherwise, the resulting IR
566  /// would have unreachable operations.
567  void inlineBlockBefore(Block *source, Operation *op,
568  ValueRange argValues = std::nullopt);
569 
570  /// Inline the operations of block 'source' into the end of block 'dest'. The
571  /// source block will be deleted and must have no uses. 'argValues' is used to
572  /// replace the block arguments of 'source'
573  ///
574  /// The dest block must have no successors. Otherwise, the resulting IR would
575  /// have unreachable operation.
576  void mergeBlocks(Block *source, Block *dest,
577  ValueRange argValues = std::nullopt);
578 
579  /// Split the operations starting at "before" (inclusive) out of the given
580  /// block into a new block, and return it.
581  Block *splitBlock(Block *block, Block::iterator before);
582 
583  /// Unlink this operation from its current block and insert it right before
584  /// `existingOp` which may be in the same or another block in the same
585  /// function.
586  void moveOpBefore(Operation *op, Operation *existingOp);
587 
588  /// Unlink this operation from its current block and insert it right before
589  /// `iterator` in the specified block.
590  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
591 
592  /// Unlink this operation from its current block and insert it right after
593  /// `existingOp` which may be in the same or another block in the same
594  /// function.
595  void moveOpAfter(Operation *op, Operation *existingOp);
596 
597  /// Unlink this operation from its current block and insert it right after
598  /// `iterator` in the specified block.
599  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
600 
601  /// Unlink this block and insert it right before `existingBlock`.
602  void moveBlockBefore(Block *block, Block *anotherBlock);
603 
604  /// Unlink this block and insert it right before the location that the given
605  /// iterator points to in the given region.
606  void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
607 
608  /// This method is used to notify the rewriter that an in-place operation
609  /// modification is about to happen. A call to this function *must* be
610  /// followed by a call to either `finalizeOpModification` or
611  /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
612  /// a new operation and removing the old one) but also often allows simpler
613  /// code in the client.
614  virtual void startOpModification(Operation *op) {}
615 
616  /// This method is used to signal the end of an in-place modification of the
617  /// given operation. This can only be called on operations that were provided
618  /// to a call to `startOpModification`.
619  virtual void finalizeOpModification(Operation *op);
620 
621  /// This method cancels a pending in-place modification. This can only be
622  /// called on operations that were provided to a call to
623  /// `startOpModification`.
624  virtual void cancelOpModification(Operation *op) {}
625 
626  /// This method is a utility wrapper around an in-place modification of an
627  /// operation. It wraps calls to `startOpModification` and
628  /// `finalizeOpModification` around the given callable.
629  template <typename CallableT>
630  void modifyOpInPlace(Operation *root, CallableT &&callable) {
631  startOpModification(root);
632  callable();
634  }
635 
636  /// Find uses of `from` and replace them with `to`. Also notify the listener
637  /// about every in-place op modification (for every use that was replaced).
638  void replaceAllUsesWith(Value from, Value to) {
639  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
640  Operation *op = operand.getOwner();
641  modifyOpInPlace(op, [&]() { operand.set(to); });
642  }
643  }
644  void replaceAllUsesWith(Block *from, Block *to) {
645  for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
646  Operation *op = operand.getOwner();
647  modifyOpInPlace(op, [&]() { operand.set(to); });
648  }
649  }
651  assert(from.size() == to.size() && "incorrect number of replacements");
652  for (auto it : llvm::zip(from, to))
653  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
654  }
655 
656  /// Find uses of `from` and replace them with `to`. Also notify the listener
657  /// about every in-place op modification (for every use that was replaced)
658  /// and that the `from` operation is about to be replaced.
659  ///
660  /// Note: This function cannot be called `replaceAllUsesWith` because the
661  /// overload resolution, when called with an op that can be implicitly
662  /// converted to a Value, would be ambiguous.
663  void replaceAllOpUsesWith(Operation *from, ValueRange to);
664  void replaceAllOpUsesWith(Operation *from, Operation *to);
665 
666  /// Find uses of `from` and replace them with `to` if the `functor` returns
667  /// true. Also notify the listener about every in-place op modification (for
668  /// every use that was replaced). The optional `allUsesReplaced` flag is set
669  /// to "true" if all uses were replaced.
670  void replaceUsesWithIf(Value from, Value to,
671  function_ref<bool(OpOperand &)> functor,
672  bool *allUsesReplaced = nullptr);
674  function_ref<bool(OpOperand &)> functor,
675  bool *allUsesReplaced = nullptr);
676  // Note: This function cannot be called `replaceOpUsesWithIf` because the
677  // overload resolution, when called with an op that can be implicitly
678  // converted to a Value, would be ambiguous.
680  function_ref<bool(OpOperand &)> functor,
681  bool *allUsesReplaced = nullptr) {
682  replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
683  }
684 
685  /// Find uses of `from` within `block` and replace them with `to`. Also notify
686  /// the listener about every in-place op modification (for every use that was
687  /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
688  /// uses were replaced.
690  Block *block, bool *allUsesReplaced = nullptr) {
692  op, newValues,
693  [block](OpOperand &use) {
694  return block->getParentOp()->isProperAncestor(use.getOwner());
695  },
696  allUsesReplaced);
697  }
698 
699  /// Find uses of `from` and replace them with `to` except if the user is
700  /// `exceptedUser`. Also notify the listener about every in-place op
701  /// modification (for every use that was replaced).
702  void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
703  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
704  Operation *user = use.getOwner();
705  return user != exceptedUser;
706  });
707  }
708  void replaceAllUsesExcept(Value from, Value to,
709  const SmallPtrSetImpl<Operation *> &preservedUsers);
710 
711  /// Used to notify the listener that the IR failed to be rewritten because of
712  /// a match failure, and provide a callback to populate a diagnostic with the
713  /// reason why the failure occurred. This method allows for derived rewriters
714  /// to optionally hook into the reason why a rewrite failed, and display it to
715  /// users.
716  template <typename CallbackT>
717  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
718  notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
719  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
720  rewriteListener->notifyMatchFailure(
721  loc, function_ref<void(Diagnostic &)>(reasonCallback));
722  return failure();
723  }
724  template <typename CallbackT>
725  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
726  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
727  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
728  rewriteListener->notifyMatchFailure(
729  op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
730  return failure();
731  }
732  template <typename ArgT>
733  LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
734  return notifyMatchFailure(std::forward<ArgT>(arg),
735  [&](Diagnostic &diag) { diag << msg; });
736  }
737  template <typename ArgT>
738  LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
739  return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
740  }
741 
742 protected:
743  /// Initialize the builder.
744  explicit RewriterBase(MLIRContext *ctx,
745  OpBuilder::Listener *listener = nullptr)
746  : OpBuilder(ctx, listener) {}
747  explicit RewriterBase(const OpBuilder &otherBuilder)
748  : OpBuilder(otherBuilder) {}
750  : OpBuilder(op, listener) {}
751  virtual ~RewriterBase();
752 
753 private:
754  void operator=(const RewriterBase &) = delete;
755  RewriterBase(const RewriterBase &) = delete;
756 };
757 
758 //===----------------------------------------------------------------------===//
759 // IRRewriter
760 //===----------------------------------------------------------------------===//
761 
762 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
763 /// providing a way to keep track of the mutations made to the IR. This class
764 /// should only be used in situations where another `RewriterBase` instance,
765 /// such as a `PatternRewriter`, is not available.
766 class IRRewriter : public RewriterBase {
767 public:
769  : RewriterBase(ctx, listener) {}
770  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
772  : RewriterBase(op, listener) {}
773 };
774 
775 //===----------------------------------------------------------------------===//
776 // PatternRewriter
777 //===----------------------------------------------------------------------===//
778 
779 /// A special type of `RewriterBase` that coordinates the application of a
780 /// rewrite pattern on the current IR being matched, providing a way to keep
781 /// track of any mutations made. This class should be used to perform all
782 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
783 /// be tracking various state that would be invalidated when a mutation takes
784 /// place.
786 public:
787  explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
789 
790  /// A hook used to indicate if the pattern rewriter can recover from failure
791  /// during the rewrite stage of a pattern. For example, if the pattern
792  /// rewriter supports rollback, it may progress smoothly even if IR was
793  /// changed during the rewrite.
794  virtual bool canRecoverFromRewriteFailure() const { return false; }
795 };
796 
797 } // namespace mlir
798 
799 // Optionally expose PDL pattern matching methods.
800 #include "PDLPatternMatch.h.inc"
801 
802 namespace mlir {
803 
804 //===----------------------------------------------------------------------===//
805 // RewritePatternSet
806 //===----------------------------------------------------------------------===//
807 
809  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
810 
811 public:
812  RewritePatternSet(MLIRContext *context) : context(context) {}
813 
814  /// Construct a RewritePatternSet populated with the given pattern.
816  std::unique_ptr<RewritePattern> pattern)
817  : context(context) {
818  nativePatterns.emplace_back(std::move(pattern));
819  }
820  RewritePatternSet(PDLPatternModule &&pattern)
821  : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
822 
823  MLIRContext *getContext() const { return context; }
824 
825  /// Return the native patterns held in this list.
826  NativePatternListT &getNativePatterns() { return nativePatterns; }
827 
828  /// Return the PDL patterns held in this list.
829  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
830 
831  /// Clear out all of the held patterns in this list.
832  void clear() {
833  nativePatterns.clear();
834  pdlPatterns.clear();
835  }
836 
837  //===--------------------------------------------------------------------===//
838  // 'add' methods for adding patterns to the set.
839  //===--------------------------------------------------------------------===//
840 
841  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
842  /// the given arguments. Return a reference to `this` for chaining insertions.
843  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
844  template <typename... Ts, typename ConstructorArg,
845  typename... ConstructorArgs,
846  typename = std::enable_if_t<sizeof...(Ts) != 0>>
847  RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
848  // The following expands a call to emplace_back for each of the pattern
849  // types 'Ts'.
850  (addImpl<Ts>(/*debugLabels=*/std::nullopt,
851  std::forward<ConstructorArg>(arg),
852  std::forward<ConstructorArgs>(args)...),
853  ...);
854  return *this;
855  }
856  /// An overload of the above `add` method that allows for attaching a set
857  /// of debug labels to the attached patterns. This is useful for labeling
858  /// groups of patterns that may be shared between multiple different
859  /// passes/users.
860  template <typename... Ts, typename ConstructorArg,
861  typename... ConstructorArgs,
862  typename = std::enable_if_t<sizeof...(Ts) != 0>>
864  ConstructorArg &&arg,
865  ConstructorArgs &&...args) {
866  // The following expands a call to emplace_back for each of the pattern
867  // types 'Ts'.
868  (addImpl<Ts>(debugLabels, arg, args...), ...);
869  return *this;
870  }
871 
872  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
873  /// `this` for chaining insertions.
874  template <typename... Ts>
876  (addImpl<Ts>(), ...);
877  return *this;
878  }
879 
880  /// Add the given native pattern to the pattern list. Return a reference to
881  /// `this` for chaining insertions.
882  RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
883  nativePatterns.emplace_back(std::move(pattern));
884  return *this;
885  }
886 
887  /// Add the given PDL pattern to the pattern list. Return a reference to
888  /// `this` for chaining insertions.
889  RewritePatternSet &add(PDLPatternModule &&pattern) {
890  pdlPatterns.mergeIn(std::move(pattern));
891  return *this;
892  }
893 
894  // Add a matchAndRewrite style pattern represented as a C function pointer.
895  template <typename OpType>
897  add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
898  PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
899  struct FnPattern final : public OpRewritePattern<OpType> {
900  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
901  MLIRContext *context, PatternBenefit benefit,
902  ArrayRef<StringRef> generatedNames)
903  : OpRewritePattern<OpType>(context, benefit, generatedNames),
904  implFn(implFn) {}
905 
906  LogicalResult matchAndRewrite(OpType op,
907  PatternRewriter &rewriter) const override {
908  return implFn(op, rewriter);
909  }
910 
911  private:
912  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
913  };
914  add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
915  generatedNames));
916  return *this;
917  }
918 
919  //===--------------------------------------------------------------------===//
920  // Pattern Insertion
921  //===--------------------------------------------------------------------===//
922 
923  // TODO: These are soft deprecated in favor of the 'add' methods above.
924 
925  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
926  /// the given arguments. Return a reference to `this` for chaining insertions.
927  /// Note: ConstructorArg is necessary here to separate the two variadic lists.
928  template <typename... Ts, typename ConstructorArg,
929  typename... ConstructorArgs,
930  typename = std::enable_if_t<sizeof...(Ts) != 0>>
931  RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
932  // The following expands a call to emplace_back for each of the pattern
933  // types 'Ts'.
934  (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...);
935  return *this;
936  }
937 
938  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
939  /// `this` for chaining insertions.
940  template <typename... Ts>
942  (addImpl<Ts>(), ...);
943  return *this;
944  }
945 
946  /// Add the given native pattern to the pattern list. Return a reference to
947  /// `this` for chaining insertions.
948  RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
949  nativePatterns.emplace_back(std::move(pattern));
950  return *this;
951  }
952 
953  /// Add the given PDL pattern to the pattern list. Return a reference to
954  /// `this` for chaining insertions.
955  RewritePatternSet &insert(PDLPatternModule &&pattern) {
956  pdlPatterns.mergeIn(std::move(pattern));
957  return *this;
958  }
959 
960  // Add a matchAndRewrite style pattern represented as a C function pointer.
961  template <typename OpType>
963  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
964  struct FnPattern final : public OpRewritePattern<OpType> {
965  FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
966  MLIRContext *context)
967  : OpRewritePattern<OpType>(context), implFn(implFn) {
968  this->setDebugName(llvm::getTypeName<FnPattern>());
969  }
970 
971  LogicalResult matchAndRewrite(OpType op,
972  PatternRewriter &rewriter) const override {
973  return implFn(op, rewriter);
974  }
975 
976  private:
977  LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
978  };
979  add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
980  return *this;
981  }
982 
983 private:
984  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
985  /// chaining insertions.
986  template <typename T, typename... Args>
987  std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
988  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
989  std::unique_ptr<T> pattern =
990  RewritePattern::create<T>(std::forward<Args>(args)...);
991  pattern->addDebugLabels(debugLabels);
992  nativePatterns.emplace_back(std::move(pattern));
993  }
994 
995  template <typename T, typename... Args>
996  std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
997  addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
998  // TODO: Add the provided labels to the PDL pattern when PDL supports
999  // labels.
1000  pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1001  }
1002 
1003  MLIRContext *const context;
1004  NativePatternListT nativePatterns;
1005 
1006  // Patterns expressed with PDL. This will compile to a stub class when PDL is
1007  // not enabled.
1008  PDLPatternModule pdlPatterns;
1009 };
1010 
1011 } // namespace mlir
1012 
1013 #endif // MLIR_IR_PATTERNMATCH_H
static std::string diag(const llvm::Value &value)
A block operand represents an operand that holds a reference to a Block, e.g.
Definition: BlockSupport.h:30
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: UseDefLists.h:253
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:768
IRRewriter(const OpBuilder &builder)
Definition: PatternMatch.h:770
IRRewriter(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:771
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a saved insertion point.
Definition: Builders.h:332
This class helps build Operations.
Definition: Builders.h:212
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:612
This class represents an operand of an operation.
Definition: Value.h:267
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:384
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:386
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:54
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:50
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
bool operator>=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:59
bool operator<=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:58
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:53
PatternBenefit()=default
bool operator>(const PatternBenefit &rhs) const
Definition: PatternMatch.h:57
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
PatternRewriter(MLIRContext *ctx)
Definition: PatternMatch.h:787
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:794
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Definition: PatternMatch.h:103
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
Definition: PatternMatch.h:147
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:202
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:123
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
Definition: PatternMatch.h:112
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
Definition: PatternMatch.h:144
void addDebugLabels(StringRef label)
Definition: PatternMatch.h:153
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Definition: PatternMatch.h:150
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:889
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Definition: PatternMatch.h:897
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
Definition: PatternMatch.h:826
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:882
RewritePatternSet(PDLPatternModule &&pattern)
Definition: PatternMatch.h:820
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
MLIRContext * getContext() const
Definition: PatternMatch.h:823
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:832
RewritePatternSet(MLIRContext *context)
Definition: PatternMatch.h:812
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:875
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
Definition: PatternMatch.h:815
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
Definition: PatternMatch.h:955
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
Definition: PatternMatch.h:948
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
Definition: PatternMatch.h:941
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
Definition: PatternMatch.h:863
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
Definition: PatternMatch.h:963
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:276
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:264
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Definition: PatternMatch.h:726
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Definition: PatternMatch.h:679
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Definition: PatternMatch.h:733
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
RewriterBase(const OpBuilder &otherBuilder)
Definition: PatternMatch.h:747
void replaceAllUsesWith(ValueRange from, ValueRange to)
Definition: PatternMatch.h:650
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
RewriterBase(Operation *op, OpBuilder::Listener *listener=nullptr)
Definition: PatternMatch.h:749
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:624
void replaceAllUsesWith(Block *from, Block *to)
Definition: PatternMatch.h:644
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Definition: PatternMatch.h:744
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
Definition: PatternMatch.h:738
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Definition: PatternMatch.h:689
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:132
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Base class for listeners.
Definition: Builders.h:269
Kind
The kind of listener.
Definition: Builders.h:271
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:290
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:313
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:303
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:503
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Definition: PatternMatch.h:466
void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:498
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:477
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:494
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:481
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:490
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:473
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:464
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:485
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notify the listener that the specified block was inserted.
Definition: PatternMatch.h:469
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
Definition: PatternMatch.h:454
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:411
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:435
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:419
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:408
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
Definition: PatternMatch.h:446
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
Definition: PatternMatch.h:439
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:428
static bool classof(const OpBuilder::Listener *base)
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
Definition: PatternMatch.h:319
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: PatternMatch.h:336
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: PatternMatch.h:323
virtual LogicalResult match(SourceOp op) const
Definition: PatternMatch.h:339
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:326
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:342
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329