MLIR  19.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Iterators.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // PatternBenefit
20 //===----------------------------------------------------------------------===//
21 
22 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
23  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
24  "This pattern match benefit is too large to represent");
25 }
26 
27 unsigned short PatternBenefit::getBenefit() const {
28  assert(!isImpossibleToMatch() && "Pattern doesn't match");
29  return representation;
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Pattern
34 //===----------------------------------------------------------------------===//
35 
36 //===----------------------------------------------------------------------===//
37 // OperationName Root Constructors
38 
39 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40  MLIRContext *context, ArrayRef<StringRef> generatedNames)
41  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
42  RootKind::OperationName, generatedNames, benefit, context) {}
43 
44 //===----------------------------------------------------------------------===//
45 // MatchAnyOpTypeTag Root Constructors
46 
48  MLIRContext *context, ArrayRef<StringRef> generatedNames)
49  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
50 
51 //===----------------------------------------------------------------------===//
52 // MatchInterfaceOpTypeTag Root Constructors
53 
55  PatternBenefit benefit, MLIRContext *context,
56  ArrayRef<StringRef> generatedNames)
57  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
58  generatedNames, benefit, context) {}
59 
60 //===----------------------------------------------------------------------===//
61 // MatchTraitOpTypeTag Root Constructors
62 
64  PatternBenefit benefit, MLIRContext *context,
65  ArrayRef<StringRef> generatedNames)
66  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
67  benefit, context) {}
68 
69 //===----------------------------------------------------------------------===//
70 // General Constructors
71 
72 Pattern::Pattern(const void *rootValue, RootKind rootKind,
73  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
74  MLIRContext *context)
75  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
76  contextAndHasBoundedRecursion(context, false) {
77  if (generatedNames.empty())
78  return;
79  generatedOps.reserve(generatedNames.size());
80  std::transform(generatedNames.begin(), generatedNames.end(),
81  std::back_inserter(generatedOps), [context](StringRef name) {
82  return OperationName(name, context);
83  });
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // RewritePattern
88 //===----------------------------------------------------------------------===//
89 
91  llvm_unreachable("need to implement either matchAndRewrite or one of the "
92  "rewrite functions!");
93 }
94 
96  llvm_unreachable("need to implement either match or matchAndRewrite!");
97 }
98 
99 /// Out-of-line vtable anchor.
100 void RewritePattern::anchor() {}
101 
102 //===----------------------------------------------------------------------===//
103 // RewriterBase
104 //===----------------------------------------------------------------------===//
105 
108 }
109 
111  // Out of line to provide a vtable anchor for the class.
112 }
113 
115  // Notify the listener that we're about to replace this op.
116  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
117  rewriteListener->notifyOperationReplaced(from, to);
118 
119  replaceAllUsesWith(from->getResults(), to);
120 }
121 
123  // Notify the listener that we're about to replace this op.
124  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
125  rewriteListener->notifyOperationReplaced(from, to);
126 
127  replaceAllUsesWith(from->getResults(), to->getResults());
128 }
129 
130 /// This method replaces the results of the operation with the specified list of
131 /// values. The number of provided values must match the number of results of
132 /// the operation. The replaced op is erased.
134  assert(op->getNumResults() == newValues.size() &&
135  "incorrect # of replacement values");
136 
137  // Replace all result uses. Also notifies the listener of modifications.
138  replaceAllOpUsesWith(op, newValues);
139 
140  // Erase op and notify listener.
141  eraseOp(op);
142 }
143 
144 /// This method replaces the results of the operation with the specified new op
145 /// (replacement). The number of results of the two operations must match. The
146 /// replaced op is erased.
148  assert(op && newOp && "expected non-null op");
149  assert(op->getNumResults() == newOp->getNumResults() &&
150  "ops have different number of results");
151 
152  // Replace all result uses. Also notifies the listener of modifications.
153  replaceAllOpUsesWith(op, newOp->getResults());
154 
155  // Erase op and notify listener.
156  eraseOp(op);
157 }
158 
159 /// This method erases an operation that is known to have no uses. The uses of
160 /// the given operation *must* be known to be dead.
162  assert(op->use_empty() && "expected 'op' to have no uses");
163  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
164 
165  // Fast path: If no listener is attached, the op can be dropped in one go.
166  if (!rewriteListener) {
167  op->erase();
168  return;
169  }
170 
171  // Helper function that erases a single op.
172  auto eraseSingleOp = [&](Operation *op) {
173 #ifndef NDEBUG
174  // All nested ops should have been erased already.
175  assert(
176  llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
177  "expected empty regions");
178  // All users should have been erased already if the op is in a region with
179  // SSA dominance.
180  if (!op->use_empty() && op->getParentOp())
181  assert(mayBeGraphRegion(*op->getParentRegion()) &&
182  "expected that op has no uses");
183 #endif // NDEBUG
184  rewriteListener->notifyOperationErased(op);
185 
186  // Explicitly drop all uses in case the op is in a graph region.
187  op->dropAllUses();
188  op->erase();
189  };
190 
191  // Nested ops must be erased one-by-one, so that listeners have a consistent
192  // view of the IR every time a notification is triggered. Users must be
193  // erased before definitions. I.e., post-order, reverse dominance.
194  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
195  // Erase nested ops.
196  for (Region &r : llvm::reverse(op->getRegions())) {
197  // Erase all blocks in the right order. Successors should be erased
198  // before predecessors because successor blocks may use values defined
199  // in predecessor blocks. A post-order traversal of blocks within a
200  // region visits successors before predecessors. Repeat the traversal
201  // until the region is empty. (The block graph could be disconnected.)
202  while (!r.empty()) {
203  SmallVector<Block *> erasedBlocks;
204  // Some blocks may have invalid successor, use a set including nullptr
205  // to avoid null pointer.
206  llvm::SmallPtrSet<Block *, 4> visited{nullptr};
207  for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
208  // Visit ops in reverse order.
209  for (Operation &op :
210  llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
211  eraseTree(&op);
212  // Do not erase the block immediately. This is not supprted by the
213  // post_order iterator.
214  erasedBlocks.push_back(b);
215  }
216  for (Block *b : erasedBlocks) {
217  // Explicitly drop all uses in case there is a cycle in the block
218  // graph.
219  for (BlockArgument bbArg : b->getArguments())
220  bbArg.dropAllUses();
221  b->dropAllUses();
222  eraseBlock(b);
223  }
224  }
225  }
226  // Then erase the enclosing op.
227  eraseSingleOp(op);
228  };
229 
230  eraseTree(op);
231 }
232 
234  assert(block->use_empty() && "expected 'block' to have no uses");
235 
236  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
237  assert(op.use_empty() && "expected 'op' to have no uses");
238  eraseOp(&op);
239  }
240 
241  // Notify the listener that the block is about to be removed.
242  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
243  rewriteListener->notifyBlockErased(block);
244 
245  block->erase();
246 }
247 
249  // Notify the listener that the operation was modified.
250  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
251  rewriteListener->notifyOperationModified(op);
252 }
253 
255  Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
256  return replaceUsesWithIf(from, to, [&](OpOperand &use) {
257  Operation *user = use.getOwner();
258  return !preservedUsers.contains(user);
259  });
260 }
261 
263  function_ref<bool(OpOperand &)> functor,
264  bool *allUsesReplaced) {
265  bool allReplaced = true;
266  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
267  bool replace = functor(operand);
268  if (replace)
269  modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
270  allReplaced &= replace;
271  }
272  if (allUsesReplaced)
273  *allUsesReplaced = allReplaced;
274 }
275 
277  function_ref<bool(OpOperand &)> functor,
278  bool *allUsesReplaced) {
279  assert(from.size() == to.size() && "incorrect number of replacements");
280  bool allReplaced = true;
281  for (auto it : llvm::zip_equal(from, to)) {
282  bool r;
283  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
284  /*allUsesReplaced=*/&r);
285  allReplaced &= r;
286  }
287  if (allUsesReplaced)
288  *allUsesReplaced = allReplaced;
289 }
290 
292  Block::iterator before,
293  ValueRange argValues) {
294  assert(argValues.size() == source->getNumArguments() &&
295  "incorrect # of argument replacement values");
296 
297  // The source block will be deleted, so it should not have any users (i.e.,
298  // there should be no predecessors).
299  assert(source->hasNoPredecessors() &&
300  "expected 'source' to have no predecessors");
301 
302  if (dest->end() != before) {
303  // The source block will be inserted in the middle of the dest block, so
304  // the source block should have no successors. Otherwise, the remainder of
305  // the dest block would be unreachable.
306  assert(source->hasNoSuccessors() &&
307  "expected 'source' to have no successors");
308  } else {
309  // The source block will be inserted at the end of the dest block, so the
310  // dest block should have no successors. Otherwise, the inserted operations
311  // will be unreachable.
312  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
313  }
314 
315  // Replace all of the successor arguments with the provided values.
316  for (auto it : llvm::zip(source->getArguments(), argValues))
317  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
318 
319  // Move operations from the source block to the dest block and erase the
320  // source block.
321  if (!listener) {
322  // Fast path: If no listener is attached, move all operations at once.
323  dest->getOperations().splice(before, source->getOperations());
324  } else {
325  while (!source->empty())
326  moveOpBefore(&source->front(), dest, before);
327  }
328 
329  // Erase the source block.
330  assert(source->empty() && "expected 'source' to be empty");
331  eraseBlock(source);
332 }
333 
335  ValueRange argValues) {
336  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
337 }
338 
340  ValueRange argValues) {
341  inlineBlockBefore(source, dest, dest->end(), argValues);
342 }
343 
344 /// Split the operations starting at "before" (inclusive) out of the given
345 /// block into a new block, and return it.
347  // Fast path: If no listener is attached, split the block directly.
348  if (!listener)
349  return block->splitBlock(before);
350 
351  // `createBlock` sets the insertion point at the beginning of the new block.
352  InsertionGuard g(*this);
353  Block *newBlock =
354  createBlock(block->getParent(), std::next(block->getIterator()));
355 
356  // If `before` points to end of the block, no ops should be moved.
357  if (before == block->end())
358  return newBlock;
359 
360  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
361  // Stop when the operation pointed to by `before` has been moved.
362  while (before->getBlock() != newBlock)
363  moveOpBefore(&block->back(), newBlock, newBlock->begin());
364 
365  return newBlock;
366 }
367 
368 /// Move the blocks that belong to "region" before the given position in
369 /// another region. The two regions must be different. The caller is in
370 /// charge to update create the operation transferring the control flow to the
371 /// region and pass it the correct block arguments.
373  Region::iterator before) {
374  // Fast path: If no listener is attached, move all blocks at once.
375  if (!listener) {
376  parent.getBlocks().splice(before, region.getBlocks());
377  return;
378  }
379 
380  // Move blocks from the beginning of the region one-by-one.
381  while (!region.empty())
382  moveBlockBefore(&region.front(), &parent, before);
383 }
385  inlineRegionBefore(region, *before->getParent(), before->getIterator());
386 }
387 
388 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
389  moveBlockBefore(block, anotherBlock->getParent(),
390  anotherBlock->getIterator());
391 }
392 
394  Region::iterator iterator) {
395  Region *currentRegion = block->getParent();
396  Region::iterator nextIterator = std::next(block->getIterator());
397  block->moveBefore(region, iterator);
398  if (listener)
399  listener->notifyBlockInserted(block, /*previous=*/currentRegion,
400  /*previousIt=*/nextIterator);
401 }
402 
404  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
405 }
406 
408  Block::iterator iterator) {
409  Block *currentBlock = op->getBlock();
410  Block::iterator nextIterator = std::next(op->getIterator());
411  op->moveBefore(block, iterator);
412  if (listener)
414  op, /*previous=*/InsertPoint(currentBlock, nextIterator));
415 }
416 
418  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
419 }
420 
422  Block::iterator iterator) {
423  assert(iterator != block->end() && "cannot move after end of block");
424  moveOpBefore(op, block, std::next(iterator));
425 }
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
bool empty()
Definition: Block.h:146
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:243
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:65
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:307
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:53
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:240
bool use_empty() const
Returns true if this value has no uses.
Definition: UseDefLists.h:261
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a saved insertion point.
Definition: Builders.h:329
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:609
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:300
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
static constexpr auto makeIterable(RangeT &&range)
Definition: Iterators.h:32
static bool classof(const OpBuilder::Listener *base)