MLIR  21.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Iterators.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // PatternBenefit
20 //===----------------------------------------------------------------------===//
21 
22 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
23  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
24  "This pattern match benefit is too large to represent");
25 }
26 
27 unsigned short PatternBenefit::getBenefit() const {
28  assert(!isImpossibleToMatch() && "Pattern doesn't match");
29  return representation;
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Pattern
34 //===----------------------------------------------------------------------===//
35 
36 //===----------------------------------------------------------------------===//
37 // OperationName Root Constructors
38 
39 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40  MLIRContext *context, ArrayRef<StringRef> generatedNames)
41  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
42  RootKind::OperationName, generatedNames, benefit, context) {}
43 
44 //===----------------------------------------------------------------------===//
45 // MatchAnyOpTypeTag Root Constructors
46 
48  MLIRContext *context, ArrayRef<StringRef> generatedNames)
49  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
50 
51 //===----------------------------------------------------------------------===//
52 // MatchInterfaceOpTypeTag Root Constructors
53 
55  PatternBenefit benefit, MLIRContext *context,
56  ArrayRef<StringRef> generatedNames)
57  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
58  generatedNames, benefit, context) {}
59 
60 //===----------------------------------------------------------------------===//
61 // MatchTraitOpTypeTag Root Constructors
62 
64  PatternBenefit benefit, MLIRContext *context,
65  ArrayRef<StringRef> generatedNames)
66  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
67  benefit, context) {}
68 
69 //===----------------------------------------------------------------------===//
70 // General Constructors
71 
72 Pattern::Pattern(const void *rootValue, RootKind rootKind,
73  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
74  MLIRContext *context)
75  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
76  contextAndHasBoundedRecursion(context, false) {
77  if (generatedNames.empty())
78  return;
79  generatedOps.reserve(generatedNames.size());
80  std::transform(generatedNames.begin(), generatedNames.end(),
81  std::back_inserter(generatedOps), [context](StringRef name) {
82  return OperationName(name, context);
83  });
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // RewritePattern
88 //===----------------------------------------------------------------------===//
89 
90 /// Out-of-line vtable anchor.
91 void RewritePattern::anchor() {}
92 
93 //===----------------------------------------------------------------------===//
94 // RewriterBase
95 //===----------------------------------------------------------------------===//
96 
99 }
100 
102  // Out of line to provide a vtable anchor for the class.
103 }
104 
106  // Notify the listener that we're about to replace this op.
107  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
108  rewriteListener->notifyOperationReplaced(from, to);
109 
110  replaceAllUsesWith(from->getResults(), to);
111 }
112 
114  // Notify the listener that we're about to replace this op.
115  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
116  rewriteListener->notifyOperationReplaced(from, to);
117 
118  replaceAllUsesWith(from->getResults(), to->getResults());
119 }
120 
121 /// This method replaces the results of the operation with the specified list of
122 /// values. The number of provided values must match the number of results of
123 /// the operation. The replaced op is erased.
125  assert(op->getNumResults() == newValues.size() &&
126  "incorrect # of replacement values");
127 
128  // Replace all result uses. Also notifies the listener of modifications.
129  replaceAllOpUsesWith(op, newValues);
130 
131  // Erase op and notify listener.
132  eraseOp(op);
133 }
134 
135 /// This method replaces the results of the operation with the specified new op
136 /// (replacement). The number of results of the two operations must match. The
137 /// replaced op is erased.
139  assert(op && newOp && "expected non-null op");
140  assert(op->getNumResults() == newOp->getNumResults() &&
141  "ops have different number of results");
142 
143  // Replace all result uses. Also notifies the listener of modifications.
144  replaceAllOpUsesWith(op, newOp->getResults());
145 
146  // Erase op and notify listener.
147  eraseOp(op);
148 }
149 
150 /// This method erases an operation that is known to have no uses. The uses of
151 /// the given operation *must* be known to be dead.
153  assert(op->use_empty() && "expected 'op' to have no uses");
154  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
155 
156  // Fast path: If no listener is attached, the op can be dropped in one go.
157  if (!rewriteListener) {
158  op->erase();
159  return;
160  }
161 
162  // Helper function that erases a single op.
163  auto eraseSingleOp = [&](Operation *op) {
164 #ifndef NDEBUG
165  // All nested ops should have been erased already.
166  assert(
167  llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
168  "expected empty regions");
169  // All users should have been erased already if the op is in a region with
170  // SSA dominance.
171  if (!op->use_empty() && op->getParentOp())
172  assert(mayBeGraphRegion(*op->getParentRegion()) &&
173  "expected that op has no uses");
174 #endif // NDEBUG
175  rewriteListener->notifyOperationErased(op);
176 
177  // Explicitly drop all uses in case the op is in a graph region.
178  op->dropAllUses();
179  op->erase();
180  };
181 
182  // Nested ops must be erased one-by-one, so that listeners have a consistent
183  // view of the IR every time a notification is triggered. Users must be
184  // erased before definitions. I.e., post-order, reverse dominance.
185  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
186  // Erase nested ops.
187  for (Region &r : llvm::reverse(op->getRegions())) {
188  // Erase all blocks in the right order. Successors should be erased
189  // before predecessors because successor blocks may use values defined
190  // in predecessor blocks. A post-order traversal of blocks within a
191  // region visits successors before predecessors. Repeat the traversal
192  // until the region is empty. (The block graph could be disconnected.)
193  while (!r.empty()) {
194  SmallVector<Block *> erasedBlocks;
195  // Some blocks may have invalid successor, use a set including nullptr
196  // to avoid null pointer.
198  for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
199  // Visit ops in reverse order.
200  for (Operation &op :
201  llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
202  eraseTree(&op);
203  // Do not erase the block immediately. This is not supprted by the
204  // post_order iterator.
205  erasedBlocks.push_back(b);
206  }
207  for (Block *b : erasedBlocks) {
208  // Explicitly drop all uses in case there is a cycle in the block
209  // graph.
210  for (BlockArgument bbArg : b->getArguments())
211  bbArg.dropAllUses();
212  b->dropAllUses();
213  eraseBlock(b);
214  }
215  }
216  }
217  // Then erase the enclosing op.
218  eraseSingleOp(op);
219  };
220 
221  eraseTree(op);
222 }
223 
225  assert(block->use_empty() && "expected 'block' to have no uses");
226 
227  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
228  assert(op.use_empty() && "expected 'op' to have no uses");
229  eraseOp(&op);
230  }
231 
232  // Notify the listener that the block is about to be removed.
233  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
234  rewriteListener->notifyBlockErased(block);
235 
236  block->erase();
237 }
238 
240  // Notify the listener that the operation was modified.
241  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
242  rewriteListener->notifyOperationModified(op);
243 }
244 
246  Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
247  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
248  Operation *user = use.getOwner();
249  return !preservedUsers.contains(user);
250  });
251 }
252 
254  function_ref<bool(OpOperand &)> functor,
255  bool *allUsesReplaced) {
256  bool allReplaced = true;
257  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
258  bool replace = functor(operand);
259  if (replace)
260  modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
261  allReplaced &= replace;
262  }
263  if (allUsesReplaced)
264  *allUsesReplaced = allReplaced;
265 }
266 
268  function_ref<bool(OpOperand &)> functor,
269  bool *allUsesReplaced) {
270  assert(from.size() == to.size() && "incorrect number of replacements");
271  bool allReplaced = true;
272  for (auto it : llvm::zip_equal(from, to)) {
273  bool r;
274  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
275  /*allUsesReplaced=*/&r);
276  allReplaced &= r;
277  }
278  if (allUsesReplaced)
279  *allUsesReplaced = allReplaced;
280 }
281 
283  Block::iterator before,
284  ValueRange argValues) {
285  assert(argValues.size() == source->getNumArguments() &&
286  "incorrect # of argument replacement values");
287 
288  // The source block will be deleted, so it should not have any users (i.e.,
289  // there should be no predecessors).
290  assert(source->hasNoPredecessors() &&
291  "expected 'source' to have no predecessors");
292 
293  if (dest->end() != before) {
294  // The source block will be inserted in the middle of the dest block, so
295  // the source block should have no successors. Otherwise, the remainder of
296  // the dest block would be unreachable.
297  assert(source->hasNoSuccessors() &&
298  "expected 'source' to have no successors");
299  } else {
300  // The source block will be inserted at the end of the dest block, so the
301  // dest block should have no successors. Otherwise, the inserted operations
302  // will be unreachable.
303  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
304  }
305 
306  // Replace all of the successor arguments with the provided values.
307  for (auto it : llvm::zip(source->getArguments(), argValues))
308  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
309 
310  // Move operations from the source block to the dest block and erase the
311  // source block.
312  if (!listener) {
313  // Fast path: If no listener is attached, move all operations at once.
314  dest->getOperations().splice(before, source->getOperations());
315  } else {
316  while (!source->empty())
317  moveOpBefore(&source->front(), dest, before);
318  }
319 
320  // Erase the source block.
321  assert(source->empty() && "expected 'source' to be empty");
322  eraseBlock(source);
323 }
324 
326  ValueRange argValues) {
327  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
328 }
329 
331  ValueRange argValues) {
332  inlineBlockBefore(source, dest, dest->end(), argValues);
333 }
334 
335 /// Split the operations starting at "before" (inclusive) out of the given
336 /// block into a new block, and return it.
338  // Fast path: If no listener is attached, split the block directly.
339  if (!listener)
340  return block->splitBlock(before);
341 
342  // `createBlock` sets the insertion point at the beginning of the new block.
343  InsertionGuard g(*this);
344  Block *newBlock =
345  createBlock(block->getParent(), std::next(block->getIterator()));
346 
347  // If `before` points to end of the block, no ops should be moved.
348  if (before == block->end())
349  return newBlock;
350 
351  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
352  // Stop when the operation pointed to by `before` has been moved.
353  while (before->getBlock() != newBlock)
354  moveOpBefore(&block->back(), newBlock, newBlock->begin());
355 
356  return newBlock;
357 }
358 
359 /// Move the blocks that belong to "region" before the given position in
360 /// another region. The two regions must be different. The caller is in
361 /// charge to update create the operation transferring the control flow to the
362 /// region and pass it the correct block arguments.
364  Region::iterator before) {
365  // Fast path: If no listener is attached, move all blocks at once.
366  if (!listener) {
367  parent.getBlocks().splice(before, region.getBlocks());
368  return;
369  }
370 
371  // Move blocks from the beginning of the region one-by-one.
372  while (!region.empty())
373  moveBlockBefore(&region.front(), &parent, before);
374 }
376  inlineRegionBefore(region, *before->getParent(), before->getIterator());
377 }
378 
379 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
380  moveBlockBefore(block, anotherBlock->getParent(),
381  anotherBlock->getIterator());
382 }
383 
385  Region::iterator iterator) {
386  Region *currentRegion = block->getParent();
387  Region::iterator nextIterator = std::next(block->getIterator());
388  block->moveBefore(region, iterator);
389  if (listener)
390  listener->notifyBlockInserted(block, /*previous=*/currentRegion,
391  /*previousIt=*/nextIterator);
392 }
393 
395  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
396 }
397 
399  Block::iterator iterator) {
400  Block *currentBlock = op->getBlock();
401  Block::iterator nextIterator = std::next(op->getIterator());
402  op->moveBefore(block, iterator);
403  if (listener)
405  op, /*previous=*/InsertPoint(currentBlock, nextIterator));
406 }
407 
409  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
410 }
411 
413  Block::iterator iterator) {
414  assert(iterator != block->end() && "cannot move after end of block");
415  moveOpBefore(op, block, std::next(iterator));
416 }
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:245
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:310
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:56
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:242
bool use_empty() const
Returns true if this value has no uses.
Definition: UseDefLists.h:261
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a saved insertion point.
Definition: Builders.h:325
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:605
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:853
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:656
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:720
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:306
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:296
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
static constexpr auto makeIterable(RangeT &&range)
Definition: Iterators.h:32
static bool classof(const OpBuilder::Listener *base)