MLIR  22.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/Iterators.h"
12 #include "llvm/ADT/SmallPtrSet.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // PatternBenefit
18 //===----------------------------------------------------------------------===//
19 
20 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
21  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
22  "This pattern match benefit is too large to represent");
23 }
24 
25 unsigned short PatternBenefit::getBenefit() const {
26  assert(!isImpossibleToMatch() && "Pattern doesn't match");
27  return representation;
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // Pattern
32 //===----------------------------------------------------------------------===//
33 
34 //===----------------------------------------------------------------------===//
35 // OperationName Root Constructors
36 //===----------------------------------------------------------------------===//
37 
38 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
39  MLIRContext *context, ArrayRef<StringRef> generatedNames)
40  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
41  RootKind::OperationName, generatedNames, benefit, context) {}
42 
43 //===----------------------------------------------------------------------===//
44 // MatchAnyOpTypeTag Root Constructors
45 //===----------------------------------------------------------------------===//
46 
48  MLIRContext *context, ArrayRef<StringRef> generatedNames)
49  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
50 
51 //===----------------------------------------------------------------------===//
52 // MatchInterfaceOpTypeTag Root Constructors
53 //===----------------------------------------------------------------------===//
54 
56  PatternBenefit benefit, MLIRContext *context,
57  ArrayRef<StringRef> generatedNames)
58  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
59  generatedNames, benefit, context) {}
60 
61 //===----------------------------------------------------------------------===//
62 // MatchTraitOpTypeTag Root Constructors
63 //===----------------------------------------------------------------------===//
64 
66  PatternBenefit benefit, MLIRContext *context,
67  ArrayRef<StringRef> generatedNames)
68  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
69  benefit, context) {}
70 
71 //===----------------------------------------------------------------------===//
72 // General Constructors
73 //===----------------------------------------------------------------------===//
74 
75 Pattern::Pattern(const void *rootValue, RootKind rootKind,
76  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
77  MLIRContext *context)
78  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
79  contextAndHasBoundedRecursion(context, false) {
80  if (generatedNames.empty())
81  return;
82  generatedOps.reserve(generatedNames.size());
83  std::transform(generatedNames.begin(), generatedNames.end(),
84  std::back_inserter(generatedOps), [context](StringRef name) {
85  return OperationName(name, context);
86  });
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // RewritePattern
91 //===----------------------------------------------------------------------===//
92 
93 /// Out-of-line vtable anchor.
94 void RewritePattern::anchor() {}
95 
96 //===----------------------------------------------------------------------===//
97 // RewriterBase
98 //===----------------------------------------------------------------------===//
99 
102 }
103 
105  // Out of line to provide a vtable anchor for the class.
106 }
107 
109  // Notify the listener that we're about to replace this op.
110  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
111  rewriteListener->notifyOperationReplaced(from, to);
112 
113  replaceAllUsesWith(from->getResults(), to);
114 }
115 
117  // Notify the listener that we're about to replace this op.
118  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
119  rewriteListener->notifyOperationReplaced(from, to);
120 
121  replaceAllUsesWith(from->getResults(), to->getResults());
122 }
123 
124 /// This method replaces the results of the operation with the specified list of
125 /// values. The number of provided values must match the number of results of
126 /// the operation. The replaced op is erased.
128  assert(op->getNumResults() == newValues.size() &&
129  "incorrect # of replacement values");
130 
131  // Replace all result uses. Also notifies the listener of modifications.
132  replaceAllOpUsesWith(op, newValues);
133 
134  // Erase op and notify listener.
135  eraseOp(op);
136 }
137 
138 /// This method replaces the results of the operation with the specified new op
139 /// (replacement). The number of results of the two operations must match. The
140 /// replaced op is erased.
142  assert(op && newOp && "expected non-null op");
143  assert(op->getNumResults() == newOp->getNumResults() &&
144  "ops have different number of results");
145 
146  // Replace all result uses. Also notifies the listener of modifications.
147  replaceAllOpUsesWith(op, newOp->getResults());
148 
149  // Erase op and notify listener.
150  eraseOp(op);
151 }
152 
153 /// This method erases an operation that is known to have no uses. The uses of
154 /// the given operation *must* be known to be dead.
156  assert(op->use_empty() && "expected 'op' to have no uses");
157  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
158 
159  // Fast path: If no listener is attached, the op can be dropped in one go.
160  if (!rewriteListener) {
161  op->erase();
162  return;
163  }
164 
165  // Helper function that erases a single op.
166  auto eraseSingleOp = [&](Operation *op) {
167 #ifndef NDEBUG
168  // All nested ops should have been erased already.
169  assert(
170  llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
171  "expected empty regions");
172  // All users should have been erased already if the op is in a region with
173  // SSA dominance.
174  if (!op->use_empty() && op->getParentOp())
175  assert(mayBeGraphRegion(*op->getParentRegion()) &&
176  "expected that op has no uses");
177 #endif // NDEBUG
178  rewriteListener->notifyOperationErased(op);
179 
180  // Explicitly drop all uses in case the op is in a graph region.
181  op->dropAllUses();
182  op->erase();
183  };
184 
185  // Nested ops must be erased one-by-one, so that listeners have a consistent
186  // view of the IR every time a notification is triggered. Users must be
187  // erased before definitions. I.e., post-order, reverse dominance.
188  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
189  // Erase nested ops.
190  for (Region &r : llvm::reverse(op->getRegions())) {
191  // Erase all blocks in the right order. Successors should be erased
192  // before predecessors because successor blocks may use values defined
193  // in predecessor blocks. A post-order traversal of blocks within a
194  // region visits successors before predecessors. Repeat the traversal
195  // until the region is empty. (The block graph could be disconnected.)
196  while (!r.empty()) {
197  SmallVector<Block *> erasedBlocks;
198  // Some blocks may have invalid successor, use a set including nullptr
199  // to avoid null pointer.
200  llvm::SmallPtrSet<Block *, 4> visited{nullptr};
201  for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
202  // Visit ops in reverse order.
203  for (Operation &op :
204  llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
205  eraseTree(&op);
206  // Do not erase the block immediately. This is not supprted by the
207  // post_order iterator.
208  erasedBlocks.push_back(b);
209  }
210  for (Block *b : erasedBlocks) {
211  // Explicitly drop all uses in case there is a cycle in the block
212  // graph.
213  for (BlockArgument bbArg : b->getArguments())
214  bbArg.dropAllUses();
215  b->dropAllUses();
216  eraseBlock(b);
217  }
218  }
219  }
220  // Then erase the enclosing op.
221  eraseSingleOp(op);
222  };
223 
224  eraseTree(op);
225 }
226 
228  assert(block->use_empty() && "expected 'block' to have no uses");
229 
230  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
231  assert(op.use_empty() && "expected 'op' to have no uses");
232  eraseOp(&op);
233  }
234 
235  // Notify the listener that the block is about to be removed.
236  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
237  rewriteListener->notifyBlockErased(block);
238 
239  block->erase();
240 }
241 
243  // Notify the listener that the operation was modified.
244  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
245  rewriteListener->notifyOperationModified(op);
246 }
247 
249  Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
250  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
251  Operation *user = use.getOwner();
252  return !preservedUsers.contains(user);
253  });
254 }
255 
257  function_ref<bool(OpOperand &)> functor,
258  bool *allUsesReplaced) {
259  bool allReplaced = true;
260  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
261  bool replace = functor(operand);
262  if (replace)
263  modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
264  allReplaced &= replace;
265  }
266  if (allUsesReplaced)
267  *allUsesReplaced = allReplaced;
268 }
269 
271  function_ref<bool(OpOperand &)> functor,
272  bool *allUsesReplaced) {
273  assert(from.size() == to.size() && "incorrect number of replacements");
274  bool allReplaced = true;
275  for (auto it : llvm::zip_equal(from, to)) {
276  bool r;
277  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
278  /*allUsesReplaced=*/&r);
279  allReplaced &= r;
280  }
281  if (allUsesReplaced)
282  *allUsesReplaced = allReplaced;
283 }
284 
286  Block::iterator before,
287  ValueRange argValues) {
288  assert(argValues.size() == source->getNumArguments() &&
289  "incorrect # of argument replacement values");
290 
291  // The source block will be deleted, so it should not have any users (i.e.,
292  // there should be no predecessors).
293  assert(source->hasNoPredecessors() &&
294  "expected 'source' to have no predecessors");
295 
296  if (dest->end() != before) {
297  // The source block will be inserted in the middle of the dest block, so
298  // the source block should have no successors. Otherwise, the remainder of
299  // the dest block would be unreachable.
300  assert(source->hasNoSuccessors() &&
301  "expected 'source' to have no successors");
302  } else {
303  // The source block will be inserted at the end of the dest block, so the
304  // dest block should have no successors. Otherwise, the inserted operations
305  // will be unreachable.
306  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
307  }
308 
309  // Replace all of the successor arguments with the provided values.
310  for (auto it : llvm::zip(source->getArguments(), argValues))
311  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
312 
313  // Move operations from the source block to the dest block and erase the
314  // source block.
315  if (!listener) {
316  // Fast path: If no listener is attached, move all operations at once.
317  dest->getOperations().splice(before, source->getOperations());
318  } else {
319  while (!source->empty())
320  moveOpBefore(&source->front(), dest, before);
321  }
322 
323  // Erase the source block.
324  assert(source->empty() && "expected 'source' to be empty");
325  eraseBlock(source);
326 }
327 
329  ValueRange argValues) {
330  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
331 }
332 
334  ValueRange argValues) {
335  inlineBlockBefore(source, dest, dest->end(), argValues);
336 }
337 
338 /// Split the operations starting at "before" (inclusive) out of the given
339 /// block into a new block, and return it.
341  // Fast path: If no listener is attached, split the block directly.
342  if (!listener)
343  return block->splitBlock(before);
344 
345  // `createBlock` sets the insertion point at the beginning of the new block.
346  InsertionGuard g(*this);
347  Block *newBlock =
348  createBlock(block->getParent(), std::next(block->getIterator()));
349 
350  // If `before` points to end of the block, no ops should be moved.
351  if (before == block->end())
352  return newBlock;
353 
354  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
355  // Stop when the operation pointed to by `before` has been moved.
356  while (before->getBlock() != newBlock)
357  moveOpBefore(&block->back(), newBlock, newBlock->begin());
358 
359  return newBlock;
360 }
361 
362 /// Move the blocks that belong to "region" before the given position in
363 /// another region. The two regions must be different. The caller is in
364 /// charge to update create the operation transferring the control flow to the
365 /// region and pass it the correct block arguments.
367  Region::iterator before) {
368  // Fast path: If no listener is attached, move all blocks at once.
369  if (!listener) {
370  parent.getBlocks().splice(before, region.getBlocks());
371  return;
372  }
373 
374  // Move blocks from the beginning of the region one-by-one.
375  while (!region.empty())
376  moveBlockBefore(&region.front(), &parent, before);
377 }
379  inlineRegionBefore(region, *before->getParent(), before->getIterator());
380 }
381 
382 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
383  moveBlockBefore(block, anotherBlock->getParent(),
384  anotherBlock->getIterator());
385 }
386 
388  Region::iterator iterator) {
389  Region *currentRegion = block->getParent();
390  Region::iterator nextIterator = std::next(block->getIterator());
391  block->moveBefore(region, iterator);
392  if (listener)
393  listener->notifyBlockInserted(block, /*previous=*/currentRegion,
394  /*previousIt=*/nextIterator);
395 }
396 
398  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
399 }
400 
402  Block::iterator iterator) {
403  Block *currentBlock = op->getBlock();
404  Block::iterator nextIterator = std::next(op->getIterator());
405  op->moveBefore(block, iterator);
406  if (listener)
408  op, /*previous=*/InsertPoint(currentBlock, nextIterator));
409 }
410 
412  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
413 }
414 
416  Block::iterator iterator) {
417  assert(iterator != block->end() && "cannot move after end of block");
418  moveOpBefore(op, block, std::next(iterator));
419 }
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:245
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:308
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:54
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:242
bool use_empty() const
Returns true if this value has no uses.
Definition: UseDefLists.h:261
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a saved insertion point.
Definition: Builders.h:325
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:608
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:622
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:686
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:306
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:296
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
static constexpr auto makeIterable(RangeT &&range)
Definition: Iterators.h:32
static bool classof(const OpBuilder::Listener *base)