MLIR  22.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/Iterators.h"
12 #include "llvm/ADT/SmallPtrSet.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // PatternBenefit
18 //===----------------------------------------------------------------------===//
19 
20 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
21  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
22  "This pattern match benefit is too large to represent");
23 }
24 
25 unsigned short PatternBenefit::getBenefit() const {
26  assert(!isImpossibleToMatch() && "Pattern doesn't match");
27  return representation;
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // Pattern
32 //===----------------------------------------------------------------------===//
33 
34 //===----------------------------------------------------------------------===//
35 // OperationName Root Constructors
36 //===----------------------------------------------------------------------===//
37 
38 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
39  MLIRContext *context, ArrayRef<StringRef> generatedNames)
40  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
41  RootKind::OperationName, generatedNames, benefit, context) {}
42 
43 //===----------------------------------------------------------------------===//
44 // MatchAnyOpTypeTag Root Constructors
45 //===----------------------------------------------------------------------===//
46 
48  MLIRContext *context, ArrayRef<StringRef> generatedNames)
49  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
50 
51 //===----------------------------------------------------------------------===//
52 // MatchInterfaceOpTypeTag Root Constructors
53 //===----------------------------------------------------------------------===//
54 
56  PatternBenefit benefit, MLIRContext *context,
57  ArrayRef<StringRef> generatedNames)
58  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
59  generatedNames, benefit, context) {}
60 
61 //===----------------------------------------------------------------------===//
62 // MatchTraitOpTypeTag Root Constructors
63 //===----------------------------------------------------------------------===//
64 
66  PatternBenefit benefit, MLIRContext *context,
67  ArrayRef<StringRef> generatedNames)
68  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
69  benefit, context) {}
70 
71 //===----------------------------------------------------------------------===//
72 // General Constructors
73 //===----------------------------------------------------------------------===//
74 
75 Pattern::Pattern(const void *rootValue, RootKind rootKind,
76  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
77  MLIRContext *context)
78  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
79  contextAndHasBoundedRecursion(context, false) {
80  if (generatedNames.empty())
81  return;
82  generatedOps.reserve(generatedNames.size());
83  std::transform(generatedNames.begin(), generatedNames.end(),
84  std::back_inserter(generatedOps), [context](StringRef name) {
85  return OperationName(name, context);
86  });
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // RewritePattern
91 //===----------------------------------------------------------------------===//
92 
93 /// Out-of-line vtable anchor.
94 void RewritePattern::anchor() {}
95 
96 //===----------------------------------------------------------------------===//
97 // RewriterBase
98 //===----------------------------------------------------------------------===//
99 
102 }
103 
105  // Out of line to provide a vtable anchor for the class.
106 }
107 
109  // Notify the listener that we're about to replace this op.
110  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
111  rewriteListener->notifyOperationReplaced(from, to);
112 
113  replaceAllUsesWith(from->getResults(), to);
114 }
115 
117  // Notify the listener that we're about to replace this op.
118  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
119  rewriteListener->notifyOperationReplaced(from, to);
120 
121  replaceAllUsesWith(from->getResults(), to->getResults());
122 }
123 
124 /// This method replaces the results of the operation with the specified list of
125 /// values. The number of provided values must match the number of results of
126 /// the operation. The replaced op is erased.
128  assert(op->getNumResults() == newValues.size() &&
129  "incorrect # of replacement values");
130 
131  // Replace all result uses. Also notifies the listener of modifications.
132  replaceAllOpUsesWith(op, newValues);
133 
134  // Erase op and notify listener.
135  eraseOp(op);
136 }
137 
138 /// This method replaces the results of the operation with the specified new op
139 /// (replacement). The number of results of the two operations must match. The
140 /// replaced op is erased.
142  assert(op && newOp && "expected non-null op");
143  assert(op->getNumResults() == newOp->getNumResults() &&
144  "ops have different number of results");
145 
146  // Replace all result uses. Also notifies the listener of modifications.
147  replaceAllOpUsesWith(op, newOp->getResults());
148 
149  // Erase op and notify listener.
150  eraseOp(op);
151 }
152 
153 /// This method erases an operation that is known to have no uses. The uses of
154 /// the given operation *must* be known to be dead.
156  assert(op->use_empty() && "expected 'op' to have no uses");
157  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
158 
159  // If the current insertion point is before the erased operation, we adjust
160  // the insertion point to be after the operation.
161  if (getInsertionPoint() == op->getIterator())
163 
164  // Fast path: If no listener is attached, the op can be dropped in one go.
165  if (!rewriteListener) {
166  op->erase();
167  return;
168  }
169 
170  // Helper function that erases a single op.
171  auto eraseSingleOp = [&](Operation *op) {
172 #ifndef NDEBUG
173  // All nested ops should have been erased already.
174  assert(
175  llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
176  "expected empty regions");
177  // All users should have been erased already if the op is in a region with
178  // SSA dominance.
179  if (!op->use_empty() && op->getParentOp())
181  "expected that op has no uses");
182 #endif // NDEBUG
183  rewriteListener->notifyOperationErased(op);
184 
185  // Explicitly drop all uses in case the op is in a graph region.
186  op->dropAllUses();
187  op->erase();
188  };
189 
190  // Nested ops must be erased one-by-one, so that listeners have a consistent
191  // view of the IR every time a notification is triggered. Users must be
192  // erased before definitions. I.e., post-order, reverse dominance.
193  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
194  // Erase nested ops.
195  for (Region &r : llvm::reverse(op->getRegions())) {
196  // Erase all blocks in the right order. Successors should be erased
197  // before predecessors because successor blocks may use values defined
198  // in predecessor blocks. A post-order traversal of blocks within a
199  // region visits successors before predecessors. Repeat the traversal
200  // until the region is empty. (The block graph could be disconnected.)
201  while (!r.empty()) {
202  SmallVector<Block *> erasedBlocks;
203  // Some blocks may have invalid successor, use a set including nullptr
204  // to avoid null pointer.
205  llvm::SmallPtrSet<Block *, 4> visited{nullptr};
206  for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
207  // Visit ops in reverse order.
208  for (Operation &op :
209  llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
210  eraseTree(&op);
211  // Do not erase the block immediately. This is not supprted by the
212  // post_order iterator.
213  erasedBlocks.push_back(b);
214  }
215  for (Block *b : erasedBlocks) {
216  // Explicitly drop all uses in case there is a cycle in the block
217  // graph.
218  for (BlockArgument bbArg : b->getArguments())
219  bbArg.dropAllUses();
220  b->dropAllUses();
221  eraseBlock(b);
222  }
223  }
224  }
225  // Then erase the enclosing op.
226  eraseSingleOp(op);
227  };
228 
229  eraseTree(op);
230 }
231 
233  assert(block->use_empty() && "expected 'block' to have no uses");
234 
235  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
236  assert(op.use_empty() && "expected 'op' to have no uses");
237  eraseOp(&op);
238  }
239 
240  // Notify the listener that the block is about to be removed.
241  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
242  rewriteListener->notifyBlockErased(block);
243 
244  block->erase();
245 }
246 
248  // Notify the listener that the operation was modified.
249  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
250  rewriteListener->notifyOperationModified(op);
251 }
252 
254  Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
255  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
256  Operation *user = use.getOwner();
257  return !preservedUsers.contains(user);
258  });
259 }
260 
262  function_ref<bool(OpOperand &)> functor,
263  bool *allUsesReplaced) {
264  bool allReplaced = true;
265  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
266  bool replace = functor(operand);
267  if (replace)
268  modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
269  allReplaced &= replace;
270  }
271  if (allUsesReplaced)
272  *allUsesReplaced = allReplaced;
273 }
274 
276  function_ref<bool(OpOperand &)> functor,
277  bool *allUsesReplaced) {
278  assert(from.size() == to.size() && "incorrect number of replacements");
279  bool allReplaced = true;
280  for (auto it : llvm::zip_equal(from, to)) {
281  bool r;
282  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
283  /*allUsesReplaced=*/&r);
284  allReplaced &= r;
285  }
286  if (allUsesReplaced)
287  *allUsesReplaced = allReplaced;
288 }
289 
291  Block::iterator before,
292  ValueRange argValues) {
293  assert(argValues.size() == source->getNumArguments() &&
294  "incorrect # of argument replacement values");
295 
296  // The source block will be deleted, so it should not have any users (i.e.,
297  // there should be no predecessors).
298  assert(source->hasNoPredecessors() &&
299  "expected 'source' to have no predecessors");
300 
301  if (dest->end() != before) {
302  // The source block will be inserted in the middle of the dest block, so
303  // the source block should have no successors. Otherwise, the remainder of
304  // the dest block would be unreachable.
305  assert(source->hasNoSuccessors() &&
306  "expected 'source' to have no successors");
307  } else {
308  // The source block will be inserted at the end of the dest block, so the
309  // dest block should have no successors. Otherwise, the inserted operations
310  // will be unreachable.
311  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
312  }
313 
314  // Replace all of the successor arguments with the provided values.
315  for (auto it : llvm::zip(source->getArguments(), argValues))
316  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
317 
318  // Move operations from the source block to the dest block and erase the
319  // source block.
320  if (!listener) {
321  // Fast path: If no listener is attached, move all operations at once.
322  dest->getOperations().splice(before, source->getOperations());
323  } else {
324  while (!source->empty())
325  moveOpBefore(&source->front(), dest, before);
326  }
327 
328  // If the current insertion point is within the source block, adjust the
329  // insertion point to the destination block.
330  if (getInsertionBlock() == source)
332 
333  // Erase the source block.
334  assert(source->empty() && "expected 'source' to be empty");
335  eraseBlock(source);
336 }
337 
339  ValueRange argValues) {
340  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
341 }
342 
344  ValueRange argValues) {
345  inlineBlockBefore(source, dest, dest->end(), argValues);
346 }
347 
348 /// Split the operations starting at "before" (inclusive) out of the given
349 /// block into a new block, and return it.
351  // Fast path: If no listener is attached, split the block directly.
352  if (!listener)
353  return block->splitBlock(before);
354 
355  // `createBlock` sets the insertion point at the beginning of the new block.
356  InsertionGuard g(*this);
357  Block *newBlock =
358  createBlock(block->getParent(), std::next(block->getIterator()));
359 
360  // If `before` points to end of the block, no ops should be moved.
361  if (before == block->end())
362  return newBlock;
363 
364  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
365  // Stop when the operation pointed to by `before` has been moved.
366  while (before->getBlock() != newBlock)
367  moveOpBefore(&block->back(), newBlock, newBlock->begin());
368 
369  return newBlock;
370 }
371 
372 /// Move the blocks that belong to "region" before the given position in
373 /// another region. The two regions must be different. The caller is in
374 /// charge to update create the operation transferring the control flow to the
375 /// region and pass it the correct block arguments.
377  Region::iterator before) {
378  // Fast path: If no listener is attached, move all blocks at once.
379  if (!listener) {
380  parent.getBlocks().splice(before, region.getBlocks());
381  return;
382  }
383 
384  // Move blocks from the beginning of the region one-by-one.
385  while (!region.empty())
386  moveBlockBefore(&region.front(), &parent, before);
387 }
389  inlineRegionBefore(region, *before->getParent(), before->getIterator());
390 }
391 
392 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
393  moveBlockBefore(block, anotherBlock->getParent(),
394  anotherBlock->getIterator());
395 }
396 
398  Region::iterator iterator) {
399  Region *currentRegion = block->getParent();
400  Region::iterator nextIterator = std::next(block->getIterator());
401  block->moveBefore(region, iterator);
402  if (listener)
403  listener->notifyBlockInserted(block, /*previous=*/currentRegion,
404  /*previousIt=*/nextIterator);
405 }
406 
408  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
409 }
410 
412  Block::iterator iterator) {
413  Block *currentBlock = op->getBlock();
414  Block::iterator nextIterator = std::next(op->getIterator());
415  op->moveBefore(block, iterator);
416  if (listener)
418  op, /*previous=*/InsertPoint(currentBlock, nextIterator));
419 }
420 
422  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
423 }
424 
426  Block::iterator iterator) {
427  assert(iterator != block->end() && "cannot move after end of block");
428  moveOpBefore(op, block, std::next(iterator));
429 }
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:248
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:318
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:54
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:245
bool use_empty() const
Returns true if this value has no uses.
Definition: UseDefLists.h:261
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a saved insertion point.
Definition: Builders.h:327
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:610
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:298
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
static constexpr auto makeIterable(RangeT &&range)
Definition: Iterators.h:32
static bool classof(const OpBuilder::Listener *base)