MLIR  21.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Iterators.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // PatternBenefit
20 //===----------------------------------------------------------------------===//
21 
22 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
23  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
24  "This pattern match benefit is too large to represent");
25 }
26 
27 unsigned short PatternBenefit::getBenefit() const {
28  assert(!isImpossibleToMatch() && "Pattern doesn't match");
29  return representation;
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Pattern
34 //===----------------------------------------------------------------------===//
35 
36 //===----------------------------------------------------------------------===//
37 // OperationName Root Constructors
38 //===----------------------------------------------------------------------===//
39 
40 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
41  MLIRContext *context, ArrayRef<StringRef> generatedNames)
42  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
43  RootKind::OperationName, generatedNames, benefit, context) {}
44 
45 //===----------------------------------------------------------------------===//
46 // MatchAnyOpTypeTag Root Constructors
47 //===----------------------------------------------------------------------===//
48 
50  MLIRContext *context, ArrayRef<StringRef> generatedNames)
51  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
52 
53 //===----------------------------------------------------------------------===//
54 // MatchInterfaceOpTypeTag Root Constructors
55 //===----------------------------------------------------------------------===//
56 
58  PatternBenefit benefit, MLIRContext *context,
59  ArrayRef<StringRef> generatedNames)
60  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
61  generatedNames, benefit, context) {}
62 
63 //===----------------------------------------------------------------------===//
64 // MatchTraitOpTypeTag Root Constructors
65 //===----------------------------------------------------------------------===//
66 
68  PatternBenefit benefit, MLIRContext *context,
69  ArrayRef<StringRef> generatedNames)
70  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
71  benefit, context) {}
72 
73 //===----------------------------------------------------------------------===//
74 // General Constructors
75 //===----------------------------------------------------------------------===//
76 
77 Pattern::Pattern(const void *rootValue, RootKind rootKind,
78  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
79  MLIRContext *context)
80  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
81  contextAndHasBoundedRecursion(context, false) {
82  if (generatedNames.empty())
83  return;
84  generatedOps.reserve(generatedNames.size());
85  std::transform(generatedNames.begin(), generatedNames.end(),
86  std::back_inserter(generatedOps), [context](StringRef name) {
87  return OperationName(name, context);
88  });
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // RewritePattern
93 //===----------------------------------------------------------------------===//
94 
95 /// Out-of-line vtable anchor.
96 void RewritePattern::anchor() {}
97 
98 //===----------------------------------------------------------------------===//
99 // RewriterBase
100 //===----------------------------------------------------------------------===//
101 
104 }
105 
107  // Out of line to provide a vtable anchor for the class.
108 }
109 
111  // Notify the listener that we're about to replace this op.
112  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
113  rewriteListener->notifyOperationReplaced(from, to);
114 
115  replaceAllUsesWith(from->getResults(), to);
116 }
117 
119  // Notify the listener that we're about to replace this op.
120  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
121  rewriteListener->notifyOperationReplaced(from, to);
122 
123  replaceAllUsesWith(from->getResults(), to->getResults());
124 }
125 
126 /// This method replaces the results of the operation with the specified list of
127 /// values. The number of provided values must match the number of results of
128 /// the operation. The replaced op is erased.
130  assert(op->getNumResults() == newValues.size() &&
131  "incorrect # of replacement values");
132 
133  // Replace all result uses. Also notifies the listener of modifications.
134  replaceAllOpUsesWith(op, newValues);
135 
136  // Erase op and notify listener.
137  eraseOp(op);
138 }
139 
140 /// This method replaces the results of the operation with the specified new op
141 /// (replacement). The number of results of the two operations must match. The
142 /// replaced op is erased.
144  assert(op && newOp && "expected non-null op");
145  assert(op->getNumResults() == newOp->getNumResults() &&
146  "ops have different number of results");
147 
148  // Replace all result uses. Also notifies the listener of modifications.
149  replaceAllOpUsesWith(op, newOp->getResults());
150 
151  // Erase op and notify listener.
152  eraseOp(op);
153 }
154 
155 /// This method erases an operation that is known to have no uses. The uses of
156 /// the given operation *must* be known to be dead.
158  assert(op->use_empty() && "expected 'op' to have no uses");
159  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
160 
161  // Fast path: If no listener is attached, the op can be dropped in one go.
162  if (!rewriteListener) {
163  op->erase();
164  return;
165  }
166 
167  // Helper function that erases a single op.
168  auto eraseSingleOp = [&](Operation *op) {
169 #ifndef NDEBUG
170  // All nested ops should have been erased already.
171  assert(
172  llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
173  "expected empty regions");
174  // All users should have been erased already if the op is in a region with
175  // SSA dominance.
176  if (!op->use_empty() && op->getParentOp())
177  assert(mayBeGraphRegion(*op->getParentRegion()) &&
178  "expected that op has no uses");
179 #endif // NDEBUG
180  rewriteListener->notifyOperationErased(op);
181 
182  // Explicitly drop all uses in case the op is in a graph region.
183  op->dropAllUses();
184  op->erase();
185  };
186 
187  // Nested ops must be erased one-by-one, so that listeners have a consistent
188  // view of the IR every time a notification is triggered. Users must be
189  // erased before definitions. I.e., post-order, reverse dominance.
190  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
191  // Erase nested ops.
192  for (Region &r : llvm::reverse(op->getRegions())) {
193  // Erase all blocks in the right order. Successors should be erased
194  // before predecessors because successor blocks may use values defined
195  // in predecessor blocks. A post-order traversal of blocks within a
196  // region visits successors before predecessors. Repeat the traversal
197  // until the region is empty. (The block graph could be disconnected.)
198  while (!r.empty()) {
199  SmallVector<Block *> erasedBlocks;
200  // Some blocks may have invalid successor, use a set including nullptr
201  // to avoid null pointer.
202  llvm::SmallPtrSet<Block *, 4> visited{nullptr};
203  for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
204  // Visit ops in reverse order.
205  for (Operation &op :
206  llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
207  eraseTree(&op);
208  // Do not erase the block immediately. This is not supprted by the
209  // post_order iterator.
210  erasedBlocks.push_back(b);
211  }
212  for (Block *b : erasedBlocks) {
213  // Explicitly drop all uses in case there is a cycle in the block
214  // graph.
215  for (BlockArgument bbArg : b->getArguments())
216  bbArg.dropAllUses();
217  b->dropAllUses();
218  eraseBlock(b);
219  }
220  }
221  }
222  // Then erase the enclosing op.
223  eraseSingleOp(op);
224  };
225 
226  eraseTree(op);
227 }
228 
230  assert(block->use_empty() && "expected 'block' to have no uses");
231 
232  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
233  assert(op.use_empty() && "expected 'op' to have no uses");
234  eraseOp(&op);
235  }
236 
237  // Notify the listener that the block is about to be removed.
238  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
239  rewriteListener->notifyBlockErased(block);
240 
241  block->erase();
242 }
243 
245  // Notify the listener that the operation was modified.
246  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
247  rewriteListener->notifyOperationModified(op);
248 }
249 
251  Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
252  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
253  Operation *user = use.getOwner();
254  return !preservedUsers.contains(user);
255  });
256 }
257 
259  function_ref<bool(OpOperand &)> functor,
260  bool *allUsesReplaced) {
261  bool allReplaced = true;
262  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
263  bool replace = functor(operand);
264  if (replace)
265  modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
266  allReplaced &= replace;
267  }
268  if (allUsesReplaced)
269  *allUsesReplaced = allReplaced;
270 }
271 
273  function_ref<bool(OpOperand &)> functor,
274  bool *allUsesReplaced) {
275  assert(from.size() == to.size() && "incorrect number of replacements");
276  bool allReplaced = true;
277  for (auto it : llvm::zip_equal(from, to)) {
278  bool r;
279  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
280  /*allUsesReplaced=*/&r);
281  allReplaced &= r;
282  }
283  if (allUsesReplaced)
284  *allUsesReplaced = allReplaced;
285 }
286 
288  Block::iterator before,
289  ValueRange argValues) {
290  assert(argValues.size() == source->getNumArguments() &&
291  "incorrect # of argument replacement values");
292 
293  // The source block will be deleted, so it should not have any users (i.e.,
294  // there should be no predecessors).
295  assert(source->hasNoPredecessors() &&
296  "expected 'source' to have no predecessors");
297 
298  if (dest->end() != before) {
299  // The source block will be inserted in the middle of the dest block, so
300  // the source block should have no successors. Otherwise, the remainder of
301  // the dest block would be unreachable.
302  assert(source->hasNoSuccessors() &&
303  "expected 'source' to have no successors");
304  } else {
305  // The source block will be inserted at the end of the dest block, so the
306  // dest block should have no successors. Otherwise, the inserted operations
307  // will be unreachable.
308  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
309  }
310 
311  // Replace all of the successor arguments with the provided values.
312  for (auto it : llvm::zip(source->getArguments(), argValues))
313  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
314 
315  // Move operations from the source block to the dest block and erase the
316  // source block.
317  if (!listener) {
318  // Fast path: If no listener is attached, move all operations at once.
319  dest->getOperations().splice(before, source->getOperations());
320  } else {
321  while (!source->empty())
322  moveOpBefore(&source->front(), dest, before);
323  }
324 
325  // Erase the source block.
326  assert(source->empty() && "expected 'source' to be empty");
327  eraseBlock(source);
328 }
329 
331  ValueRange argValues) {
332  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
333 }
334 
336  ValueRange argValues) {
337  inlineBlockBefore(source, dest, dest->end(), argValues);
338 }
339 
340 /// Split the operations starting at "before" (inclusive) out of the given
341 /// block into a new block, and return it.
343  // Fast path: If no listener is attached, split the block directly.
344  if (!listener)
345  return block->splitBlock(before);
346 
347  // `createBlock` sets the insertion point at the beginning of the new block.
348  InsertionGuard g(*this);
349  Block *newBlock =
350  createBlock(block->getParent(), std::next(block->getIterator()));
351 
352  // If `before` points to end of the block, no ops should be moved.
353  if (before == block->end())
354  return newBlock;
355 
356  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
357  // Stop when the operation pointed to by `before` has been moved.
358  while (before->getBlock() != newBlock)
359  moveOpBefore(&block->back(), newBlock, newBlock->begin());
360 
361  return newBlock;
362 }
363 
364 /// Move the blocks that belong to "region" before the given position in
365 /// another region. The two regions must be different. The caller is in
366 /// charge to update create the operation transferring the control flow to the
367 /// region and pass it the correct block arguments.
369  Region::iterator before) {
370  // Fast path: If no listener is attached, move all blocks at once.
371  if (!listener) {
372  parent.getBlocks().splice(before, region.getBlocks());
373  return;
374  }
375 
376  // Move blocks from the beginning of the region one-by-one.
377  while (!region.empty())
378  moveBlockBefore(&region.front(), &parent, before);
379 }
381  inlineRegionBefore(region, *before->getParent(), before->getIterator());
382 }
383 
384 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
385  moveBlockBefore(block, anotherBlock->getParent(),
386  anotherBlock->getIterator());
387 }
388 
390  Region::iterator iterator) {
391  Region *currentRegion = block->getParent();
392  Region::iterator nextIterator = std::next(block->getIterator());
393  block->moveBefore(region, iterator);
394  if (listener)
395  listener->notifyBlockInserted(block, /*previous=*/currentRegion,
396  /*previousIt=*/nextIterator);
397 }
398 
400  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
401 }
402 
404  Block::iterator iterator) {
405  Block *currentBlock = op->getBlock();
406  Block::iterator nextIterator = std::next(op->getIterator());
407  op->moveBefore(block, iterator);
408  if (listener)
410  op, /*previous=*/InsertPoint(currentBlock, nextIterator));
411 }
412 
414  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
415 }
416 
418  Block::iterator iterator) {
419  assert(iterator != block->end() && "cannot move after end of block");
420  moveOpBefore(op, block, std::next(iterator));
421 }
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:245
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:310
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:56
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:242
bool use_empty() const
Returns true if this value has no uses.
Definition: UseDefLists.h:261
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a saved insertion point.
Definition: Builders.h:324
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:608
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:666
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:282
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:305
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:295
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
static constexpr auto makeIterable(RangeT &&range)
Definition: Iterators.h:32
static bool classof(const OpBuilder::Listener *base)