MLIR  19.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Iterators.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // PatternBenefit
19 //===----------------------------------------------------------------------===//
20 
21 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
22  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
23  "This pattern match benefit is too large to represent");
24 }
25 
26 unsigned short PatternBenefit::getBenefit() const {
27  assert(!isImpossibleToMatch() && "Pattern doesn't match");
28  return representation;
29 }
30 
31 //===----------------------------------------------------------------------===//
32 // Pattern
33 //===----------------------------------------------------------------------===//
34 
35 //===----------------------------------------------------------------------===//
36 // OperationName Root Constructors
37 
38 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
39  MLIRContext *context, ArrayRef<StringRef> generatedNames)
40  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
41  RootKind::OperationName, generatedNames, benefit, context) {}
42 
43 //===----------------------------------------------------------------------===//
44 // MatchAnyOpTypeTag Root Constructors
45 
47  MLIRContext *context, ArrayRef<StringRef> generatedNames)
48  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
49 
50 //===----------------------------------------------------------------------===//
51 // MatchInterfaceOpTypeTag Root Constructors
52 
54  PatternBenefit benefit, MLIRContext *context,
55  ArrayRef<StringRef> generatedNames)
56  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
57  generatedNames, benefit, context) {}
58 
59 //===----------------------------------------------------------------------===//
60 // MatchTraitOpTypeTag Root Constructors
61 
63  PatternBenefit benefit, MLIRContext *context,
64  ArrayRef<StringRef> generatedNames)
65  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
66  benefit, context) {}
67 
68 //===----------------------------------------------------------------------===//
69 // General Constructors
70 
71 Pattern::Pattern(const void *rootValue, RootKind rootKind,
72  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
73  MLIRContext *context)
74  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
75  contextAndHasBoundedRecursion(context, false) {
76  if (generatedNames.empty())
77  return;
78  generatedOps.reserve(generatedNames.size());
79  std::transform(generatedNames.begin(), generatedNames.end(),
80  std::back_inserter(generatedOps), [context](StringRef name) {
81  return OperationName(name, context);
82  });
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // RewritePattern
87 //===----------------------------------------------------------------------===//
88 
90  llvm_unreachable("need to implement either matchAndRewrite or one of the "
91  "rewrite functions!");
92 }
93 
95  llvm_unreachable("need to implement either match or matchAndRewrite!");
96 }
97 
98 /// Out-of-line vtable anchor.
99 void RewritePattern::anchor() {}
100 
101 //===----------------------------------------------------------------------===//
102 // RewriterBase
103 //===----------------------------------------------------------------------===//
104 
107 }
108 
110  // Out of line to provide a vtable anchor for the class.
111 }
112 
113 /// This method replaces the results of the operation with the specified list of
114 /// values. The number of provided values must match the number of results of
115 /// the operation. The replaced op is erased.
117  assert(op->getNumResults() == newValues.size() &&
118  "incorrect # of replacement values");
119 
120  // Notify the listener that we're about to replace this op.
121  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
122  rewriteListener->notifyOperationReplaced(op, newValues);
123 
124  // Replace all result uses. Also notifies the listener of modifications.
125  replaceAllOpUsesWith(op, newValues);
126 
127  // Erase op and notify listener.
128  eraseOp(op);
129 }
130 
131 /// This method replaces the results of the operation with the specified new op
132 /// (replacement). The number of results of the two operations must match. The
133 /// replaced op is erased.
135  assert(op && newOp && "expected non-null op");
136  assert(op->getNumResults() == newOp->getNumResults() &&
137  "ops have different number of results");
138 
139  // Notify the listener that we're about to replace this op.
140  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
141  rewriteListener->notifyOperationReplaced(op, newOp);
142 
143  // Replace all result uses. Also notifies the listener of modifications.
144  replaceAllOpUsesWith(op, newOp->getResults());
145 
146  // Erase op and notify listener.
147  eraseOp(op);
148 }
149 
150 /// This method erases an operation that is known to have no uses. The uses of
151 /// the given operation *must* be known to be dead.
153  assert(op->use_empty() && "expected 'op' to have no uses");
154  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
155 
156  // Fast path: If no listener is attached, the op can be dropped in one go.
157  if (!rewriteListener) {
158  op->erase();
159  return;
160  }
161 
162  // Helper function that erases a single op.
163  auto eraseSingleOp = [&](Operation *op) {
164 #ifndef NDEBUG
165  // All nested ops should have been erased already.
166  assert(
167  llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
168  "expected empty regions");
169  // All users should have been erased already if the op is in a region with
170  // SSA dominance.
171  if (!op->use_empty() && op->getParentOp())
173  "expected that op has no uses");
174 #endif // NDEBUG
175  rewriteListener->notifyOperationErased(op);
176 
177  // Explicitly drop all uses in case the op is in a graph region.
178  op->dropAllUses();
179  op->erase();
180  };
181 
182  // Nested ops must be erased one-by-one, so that listeners have a consistent
183  // view of the IR every time a notification is triggered. Users must be
184  // erased before definitions. I.e., post-order, reverse dominance.
185  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
186  // Erase nested ops.
187  for (Region &r : llvm::reverse(op->getRegions())) {
188  // Erase all blocks in the right order. Successors should be erased
189  // before predecessors because successor blocks may use values defined
190  // in predecessor blocks. A post-order traversal of blocks within a
191  // region visits successors before predecessors. Repeat the traversal
192  // until the region is empty. (The block graph could be disconnected.)
193  while (!r.empty()) {
194  SmallVector<Block *> erasedBlocks;
195  // Some blocks may have invalid successor, use a set including nullptr
196  // to avoid null pointer.
197  llvm::SmallPtrSet<Block *, 4> visited{nullptr};
198  for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
199  // Visit ops in reverse order.
200  for (Operation &op :
201  llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
202  eraseTree(&op);
203  // Do not erase the block immediately. This is not supprted by the
204  // post_order iterator.
205  erasedBlocks.push_back(b);
206  }
207  for (Block *b : erasedBlocks) {
208  // Explicitly drop all uses in case there is a cycle in the block
209  // graph.
210  for (BlockArgument bbArg : b->getArguments())
211  bbArg.dropAllUses();
212  b->dropAllUses();
213  eraseBlock(b);
214  }
215  }
216  }
217  // Then erase the enclosing op.
218  eraseSingleOp(op);
219  };
220 
221  eraseTree(op);
222 }
223 
225  assert(block->use_empty() && "expected 'block' to have no uses");
226 
227  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
228  assert(op.use_empty() && "expected 'op' to have no uses");
229  eraseOp(&op);
230  }
231 
232  // Notify the listener that the block is about to be removed.
233  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
234  rewriteListener->notifyBlockErased(block);
235 
236  block->erase();
237 }
238 
240  // Notify the listener that the operation was modified.
241  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
242  rewriteListener->notifyOperationModified(op);
243 }
244 
246  function_ref<bool(OpOperand &)> functor,
247  bool *allUsesReplaced) {
248  bool allReplaced = true;
249  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
250  bool replace = functor(operand);
251  if (replace)
252  modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
253  allReplaced &= replace;
254  }
255  if (allUsesReplaced)
256  *allUsesReplaced = allReplaced;
257 }
258 
260  function_ref<bool(OpOperand &)> functor,
261  bool *allUsesReplaced) {
262  assert(from.size() == to.size() && "incorrect number of replacements");
263  bool allReplaced = true;
264  for (auto it : llvm::zip_equal(from, to)) {
265  bool r;
266  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
267  /*allUsesReplaced=*/&r);
268  allReplaced &= r;
269  }
270  if (allUsesReplaced)
271  *allUsesReplaced = allReplaced;
272 }
273 
275  Block::iterator before,
276  ValueRange argValues) {
277  assert(argValues.size() == source->getNumArguments() &&
278  "incorrect # of argument replacement values");
279 
280  // The source block will be deleted, so it should not have any users (i.e.,
281  // there should be no predecessors).
282  assert(source->hasNoPredecessors() &&
283  "expected 'source' to have no predecessors");
284 
285  if (dest->end() != before) {
286  // The source block will be inserted in the middle of the dest block, so
287  // the source block should have no successors. Otherwise, the remainder of
288  // the dest block would be unreachable.
289  assert(source->hasNoSuccessors() &&
290  "expected 'source' to have no successors");
291  } else {
292  // The source block will be inserted at the end of the dest block, so the
293  // dest block should have no successors. Otherwise, the inserted operations
294  // will be unreachable.
295  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
296  }
297 
298  // Replace all of the successor arguments with the provided values.
299  for (auto it : llvm::zip(source->getArguments(), argValues))
300  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
301 
302  // Move operations from the source block to the dest block and erase the
303  // source block.
304  if (!listener) {
305  // Fast path: If no listener is attached, move all operations at once.
306  dest->getOperations().splice(before, source->getOperations());
307  } else {
308  while (!source->empty())
309  moveOpBefore(&source->front(), dest, before);
310  }
311 
312  // Erase the source block.
313  assert(source->empty() && "expected 'source' to be empty");
314  eraseBlock(source);
315 }
316 
318  ValueRange argValues) {
319  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
320 }
321 
323  ValueRange argValues) {
324  inlineBlockBefore(source, dest, dest->end(), argValues);
325 }
326 
327 /// Split the operations starting at "before" (inclusive) out of the given
328 /// block into a new block, and return it.
330  // Fast path: If no listener is attached, split the block directly.
331  if (!listener)
332  return block->splitBlock(before);
333 
334  // `createBlock` sets the insertion point at the beginning of the new block.
335  InsertionGuard g(*this);
336  Block *newBlock =
337  createBlock(block->getParent(), std::next(block->getIterator()));
338 
339  // If `before` points to end of the block, no ops should be moved.
340  if (before == block->end())
341  return newBlock;
342 
343  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
344  // Stop when the operation pointed to by `before` has been moved.
345  while (before->getBlock() != newBlock)
346  moveOpBefore(&block->back(), newBlock, newBlock->begin());
347 
348  return newBlock;
349 }
350 
351 /// Move the blocks that belong to "region" before the given position in
352 /// another region. The two regions must be different. The caller is in
353 /// charge to update create the operation transferring the control flow to the
354 /// region and pass it the correct block arguments.
356  Region::iterator before) {
357  // Fast path: If no listener is attached, move all blocks at once.
358  if (!listener) {
359  parent.getBlocks().splice(before, region.getBlocks());
360  return;
361  }
362 
363  // Move blocks from the beginning of the region one-by-one.
364  while (!region.empty())
365  moveBlockBefore(&region.front(), &parent, before);
366 }
368  inlineRegionBefore(region, *before->getParent(), before->getIterator());
369 }
370 
371 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
372  moveBlockBefore(block, anotherBlock->getParent(),
373  anotherBlock->getIterator());
374 }
375 
377  Region::iterator iterator) {
378  Region *currentRegion = block->getParent();
379  Region::iterator nextIterator = std::next(block->getIterator());
380  block->moveBefore(region, iterator);
381  if (listener)
382  listener->notifyBlockInserted(block, /*previous=*/currentRegion,
383  /*previousIt=*/nextIterator);
384 }
385 
387  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
388 }
389 
391  Block::iterator iterator) {
392  Block *currentBlock = op->getBlock();
393  Block::iterator nextIterator = std::next(op->getIterator());
394  op->moveBefore(block, iterator);
395  if (listener)
397  op, /*previous=*/InsertPoint(currentBlock, nextIterator));
398 }
399 
401  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
402 }
403 
405  Block::iterator iterator) {
406  assert(iterator != block->end() && "cannot move after end of block");
407  moveOpBefore(op, block, std::next(iterator));
408 }
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
bool empty()
Definition: Block.h:145
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:242
unsigned getNumArguments()
Definition: Block.h:125
Operation & back()
Definition: Block.h:149
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:65
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:307
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator end()
Definition: Block.h:141
iterator begin()
Definition: Block.h:140
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:53
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:239
bool use_empty() const
Returns true if this value has no uses.
Definition: UseDefLists.h:261
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a saved insertion point.
Definition: Builders.h:329
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:601
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
bool isImpossibleToMatch() const
Definition: PatternMatch.h:43
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Definition: PatternMatch.h:654
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:300
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:158
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:168
static constexpr auto makeIterable(RangeT &&range)
Definition: Iterators.h:32
static bool classof(const OpBuilder::Listener *base)