MLIR 22.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Iterators.h"
12#include "llvm/ADT/SmallPtrSet.h"
13
14using namespace mlir;
15
16//===----------------------------------------------------------------------===//
17// PatternBenefit
18//===----------------------------------------------------------------------===//
19
20PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
21 assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
22 "This pattern match benefit is too large to represent");
23}
24
25unsigned short PatternBenefit::getBenefit() const {
26 assert(!isImpossibleToMatch() && "Pattern doesn't match");
27 return representation;
28}
29
30//===----------------------------------------------------------------------===//
31// Pattern
32//===----------------------------------------------------------------------===//
33
34//===----------------------------------------------------------------------===//
35// OperationName Root Constructors
36//===----------------------------------------------------------------------===//
37
38Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
39 MLIRContext *context, ArrayRef<StringRef> generatedNames)
40 : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
41 RootKind::OperationName, generatedNames, benefit, context) {}
42
43//===----------------------------------------------------------------------===//
44// MatchAnyOpTypeTag Root Constructors
45//===----------------------------------------------------------------------===//
46
48 MLIRContext *context, ArrayRef<StringRef> generatedNames)
49 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
50
51//===----------------------------------------------------------------------===//
52// MatchInterfaceOpTypeTag Root Constructors
53//===----------------------------------------------------------------------===//
54
56 PatternBenefit benefit, MLIRContext *context,
57 ArrayRef<StringRef> generatedNames)
58 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
59 generatedNames, benefit, context) {}
60
61//===----------------------------------------------------------------------===//
62// MatchTraitOpTypeTag Root Constructors
63//===----------------------------------------------------------------------===//
64
66 PatternBenefit benefit, MLIRContext *context,
67 ArrayRef<StringRef> generatedNames)
68 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
69 benefit, context) {}
70
71//===----------------------------------------------------------------------===//
72// General Constructors
73//===----------------------------------------------------------------------===//
74
75Pattern::Pattern(const void *rootValue, RootKind rootKind,
76 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
77 MLIRContext *context)
78 : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
79 contextAndHasBoundedRecursion(context, false) {
80 if (generatedNames.empty())
81 return;
82 generatedOps.reserve(generatedNames.size());
83 llvm::append_range(generatedOps,
84 llvm::map_range(generatedNames, [context](StringRef name) {
85 return OperationName(name, context);
86 }));
87}
88
89//===----------------------------------------------------------------------===//
90// RewritePattern
91//===----------------------------------------------------------------------===//
92
93/// Out-of-line vtable anchor.
94void RewritePattern::anchor() {}
95
96//===----------------------------------------------------------------------===//
97// RewriterBase
98//===----------------------------------------------------------------------===//
99
103
105 // Out of line to provide a vtable anchor for the class.
106}
107
109 // Notify the listener that we're about to replace this op.
110 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
111 rewriteListener->notifyOperationReplaced(from, to);
112
113 replaceAllUsesWith(from->getResults(), to);
114}
115
117 // Notify the listener that we're about to replace this op.
118 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
119 rewriteListener->notifyOperationReplaced(from, to);
120
122}
123
124/// This method replaces the results of the operation with the specified list of
125/// values. The number of provided values must match the number of results of
126/// the operation. The replaced op is erased.
128 assert(op->getNumResults() == newValues.size() &&
129 "incorrect # of replacement values");
130
131 // Replace all result uses. Also notifies the listener of modifications.
132 replaceAllOpUsesWith(op, newValues);
133
134 // Erase op and notify listener.
135 eraseOp(op);
136}
137
138/// This method replaces the results of the operation with the specified new op
139/// (replacement). The number of results of the two operations must match. The
140/// replaced op is erased.
142 assert(op && newOp && "expected non-null op");
143 assert(op->getNumResults() == newOp->getNumResults() &&
144 "ops have different number of results");
145
146 // Replace all result uses. Also notifies the listener of modifications.
147 replaceAllOpUsesWith(op, newOp->getResults());
148
149 // Erase op and notify listener.
150 eraseOp(op);
151}
152
153/// This method erases an operation that is known to have no uses. The uses of
154/// the given operation *must* be known to be dead.
156 assert(op->use_empty() && "expected 'op' to have no uses");
157 auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
158
159 // If the current insertion point is before the erased operation, we adjust
160 // the insertion point to be after the operation.
161 if (getInsertionPoint() == op->getIterator())
163
164 // Fast path: If no listener is attached, the op can be dropped in one go.
165 if (!rewriteListener) {
166 op->erase();
167 return;
168 }
169
170 // Helper function that erases a single op.
171 auto eraseSingleOp = [&](Operation *op) {
172#ifndef NDEBUG
173 // All nested ops should have been erased already.
174 assert(
175 llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
176 "expected empty regions");
177 // All users should have been erased already if the op is in a region with
178 // SSA dominance.
179 if (!op->use_empty() && op->getParentOp())
181 "expected that op has no uses");
182#endif // NDEBUG
183 rewriteListener->notifyOperationErased(op);
184
185 // Explicitly drop all uses in case the op is in a graph region.
186 op->dropAllUses();
187 op->erase();
188 };
189
190 // Nested ops must be erased one-by-one, so that listeners have a consistent
191 // view of the IR every time a notification is triggered. Users must be
192 // erased before definitions. I.e., post-order, reverse dominance.
193 std::function<void(Operation *)> eraseTree = [&](Operation *op) {
194 // Erase nested ops.
195 for (Region &r : llvm::reverse(op->getRegions())) {
196 // Erase all blocks in the right order. Successors should be erased
197 // before predecessors because successor blocks may use values defined
198 // in predecessor blocks. A post-order traversal of blocks within a
199 // region visits successors before predecessors. Repeat the traversal
200 // until the region is empty. (The block graph could be disconnected.)
201 while (!r.empty()) {
202 SmallVector<Block *> erasedBlocks;
203 // Some blocks may have invalid successor, use a set including nullptr
204 // to avoid null pointer.
205 llvm::SmallPtrSet<Block *, 4> visited{nullptr};
206 for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
207 // Visit ops in reverse order.
208 for (Operation &op :
209 llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
210 eraseTree(&op);
211 // Do not erase the block immediately. This is not supprted by the
212 // post_order iterator.
213 erasedBlocks.push_back(b);
214 }
215 for (Block *b : erasedBlocks) {
216 // Explicitly drop all uses in case there is a cycle in the block
217 // graph.
218 for (BlockArgument bbArg : b->getArguments())
219 bbArg.dropAllUses();
220 b->dropAllUses();
221 eraseBlock(b);
222 }
223 }
224 }
225 // Then erase the enclosing op.
226 eraseSingleOp(op);
227 };
228
229 eraseTree(op);
230}
231
233 assert(block->use_empty() && "expected 'block' to have no uses");
234
235 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
236 assert(op.use_empty() && "expected 'op' to have no uses");
237 eraseOp(&op);
238 }
239
240 // Notify the listener that the block is about to be removed.
241 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
242 rewriteListener->notifyBlockErased(block);
243
244 block->erase();
245}
246
248 const BitVector &eraseIndices) {
249 assert(op->getNumResults() == eraseIndices.size() &&
250 "number of op results and bitvector size must match");
251
252 // Gather new result types.
253 SmallVector<Type> newResultTypes;
254 newResultTypes.reserve(op->getNumResults() - eraseIndices.count());
255 for (OpResult result : op->getResults())
256 if (!eraseIndices[result.getResultNumber()])
257 newResultTypes.push_back(result.getType());
258
259 // Create a new operation and inline all regions.
260 InsertionGuard g(*this);
262 OperationState state(op->getLoc(), op->getName().getStringRef(),
263 op->getOperands(), newResultTypes, op->getAttrs());
264 for ([[maybe_unused]] auto i : llvm::seq<unsigned>(0, op->getNumRegions()))
265 state.addRegion();
266 Operation *newOp = create(state);
267 for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
268 // Move all blocks of `region` into `newRegion`.
269 Region &newRegion = newOp->getRegion(index);
270 inlineRegionBefore(region, newRegion, newRegion.begin());
271 }
272
273 // Replace the original operation with the new operation.
274 SmallVector<Value> replacements(op->getNumResults(), Value());
275 unsigned nextResultIdx = 0;
276 for (auto i : llvm::seq<unsigned>(0, op->getNumResults()))
277 if (!eraseIndices[i])
278 replacements[i] = newOp->getResult(nextResultIdx++);
279 replaceOp(op, replacements);
280 return newOp;
281}
282
284 // Notify the listener that the operation was modified.
285 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
286 rewriteListener->notifyOperationModified(op);
287}
288
290 Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
291 return replaceUsesWithIf(from, to, [&](OpOperand &use) {
292 Operation *user = use.getOwner();
293 return !preservedUsers.contains(user);
294 });
295}
296
298 function_ref<bool(OpOperand &)> functor,
299 bool *allUsesReplaced) {
300 bool allReplaced = true;
301 for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
302 bool replace = functor(operand);
303 if (replace)
304 modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
305 allReplaced &= replace;
306 }
307 if (allUsesReplaced)
308 *allUsesReplaced = allReplaced;
309}
310
312 function_ref<bool(OpOperand &)> functor,
313 bool *allUsesReplaced) {
314 assert(from.size() == to.size() && "incorrect number of replacements");
315 bool allReplaced = true;
316 for (auto it : llvm::zip_equal(from, to)) {
317 bool r = true;
318 replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
319 /*allUsesReplaced=*/allUsesReplaced ? &r : nullptr);
320 allReplaced &= r;
321 }
322 if (allUsesReplaced)
323 *allUsesReplaced = allReplaced;
324}
325
327 Block::iterator before,
328 ValueRange argValues) {
329 assert(argValues.size() == source->getNumArguments() &&
330 "incorrect # of argument replacement values");
331
332 // The source block will be deleted, so it should not have any users (i.e.,
333 // there should be no predecessors).
334 assert(source->hasNoPredecessors() &&
335 "expected 'source' to have no predecessors");
336
337 if (dest->end() != before) {
338 // The source block will be inserted in the middle of the dest block, so
339 // the source block should have no successors. Otherwise, the remainder of
340 // the dest block would be unreachable.
341 assert(source->hasNoSuccessors() &&
342 "expected 'source' to have no successors");
343 } else {
344 // The source block will be inserted at the end of the dest block, so the
345 // dest block should have no successors. Otherwise, the inserted operations
346 // will be unreachable.
347 assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
348 }
349
350 // Replace all of the successor arguments with the provided values.
351 for (auto it : llvm::zip(source->getArguments(), argValues))
352 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
353
354 // Move operations from the source block to the dest block and erase the
355 // source block.
356 if (!listener) {
357 // Fast path: If no listener is attached, move all operations at once.
358 dest->getOperations().splice(before, source->getOperations());
359 } else {
360 while (!source->empty())
361 moveOpBefore(&source->front(), dest, before);
362 }
363
364 // If the current insertion point is within the source block, adjust the
365 // insertion point to the destination block.
366 if (getInsertionBlock() == source)
368
369 // Erase the source block.
370 assert(source->empty() && "expected 'source' to be empty");
371 eraseBlock(source);
372}
373
375 ValueRange argValues) {
376 inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
377}
378
380 ValueRange argValues) {
381 inlineBlockBefore(source, dest, dest->end(), argValues);
382}
383
384/// Split the operations starting at "before" (inclusive) out of the given
385/// block into a new block, and return it.
387 // Fast path: If no listener is attached, split the block directly.
388 if (!listener)
389 return block->splitBlock(before);
390
391 // `createBlock` sets the insertion point at the beginning of the new block.
392 InsertionGuard g(*this);
393 Block *newBlock =
394 createBlock(block->getParent(), std::next(block->getIterator()));
395
396 // If `before` points to end of the block, no ops should be moved.
397 if (before == block->end())
398 return newBlock;
399
400 // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
401 // Stop when the operation pointed to by `before` has been moved.
402 while (before->getBlock() != newBlock)
403 moveOpBefore(&block->back(), newBlock, newBlock->begin());
404
405 return newBlock;
406}
407
408/// Move the blocks that belong to "region" before the given position in
409/// another region. The two regions must be different. The caller is in
410/// charge to update create the operation transferring the control flow to the
411/// region and pass it the correct block arguments.
413 Region::iterator before) {
414 // Fast path: If no listener is attached, move all blocks at once.
415 if (!listener) {
416 parent.getBlocks().splice(before, region.getBlocks());
417 return;
418 }
419
420 // Move blocks from the beginning of the region one-by-one.
421 while (!region.empty())
422 moveBlockBefore(&region.front(), &parent, before);
423}
425 inlineRegionBefore(region, *before->getParent(), before->getIterator());
426}
427
428void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
429 moveBlockBefore(block, anotherBlock->getParent(),
430 anotherBlock->getIterator());
431}
432
434 Region::iterator iterator) {
435 Region *currentRegion = block->getParent();
436 Region::iterator nextIterator = std::next(block->getIterator());
437 block->moveBefore(region, iterator);
438 if (listener)
439 listener->notifyBlockInserted(block, /*previous=*/currentRegion,
440 /*previousIt=*/nextIterator);
441}
442
444 moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
445}
446
448 Block::iterator iterator) {
449 Block *currentBlock = op->getBlock();
450 Block::iterator nextIterator = std::next(op->getIterator());
451 op->moveBefore(block, iterator);
452 if (listener)
453 listener->notifyOperationInserted(
454 op, /*previous=*/InsertPoint(currentBlock, nextIterator));
455}
456
458 moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
459}
460
462 Block::iterator iterator) {
463 assert(iterator != block->end() && "cannot move after end of block");
464 moveOpBefore(op, block, std::next(iterator));
465}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
false
Parses a map_entries map type from a string format back into its numeric value.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
bool empty()
Definition Block.h:158
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition Block.h:258
unsigned getNumArguments()
Definition Block.h:138
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:323
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition Block.h:255
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a saved insertion point.
Definition Builders.h:327
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
Listener * listener
The optional listener for events of this builder.
Definition Builders.h:617
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
iterator begin()
Definition Region.h:55
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
Definition Builders.h:271
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:285
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.
This class acts as a special tag that makes the desire to match "any" operation type explicit.
This class acts as a special tag that makes the desire to match any operation that implements a given...
This class acts as a special tag that makes the desire to match any operation that implements a given...
static constexpr auto makeIterable(RangeT &&range)
Definition Iterators.h:32
static bool classof(const OpBuilder::Listener *base)