MLIR  18.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/IRMapping.h"
11 #include "mlir/IR/Iterators.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // PatternBenefit
18 //===----------------------------------------------------------------------===//
19 
20 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
21  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
22  "This pattern match benefit is too large to represent");
23 }
24 
25 unsigned short PatternBenefit::getBenefit() const {
26  assert(!isImpossibleToMatch() && "Pattern doesn't match");
27  return representation;
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // Pattern
32 //===----------------------------------------------------------------------===//
33 
34 //===----------------------------------------------------------------------===//
35 // OperationName Root Constructors
36 
37 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
38  MLIRContext *context, ArrayRef<StringRef> generatedNames)
39  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
40  RootKind::OperationName, generatedNames, benefit, context) {}
41 
42 //===----------------------------------------------------------------------===//
43 // MatchAnyOpTypeTag Root Constructors
44 
46  MLIRContext *context, ArrayRef<StringRef> generatedNames)
47  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
48 
49 //===----------------------------------------------------------------------===//
50 // MatchInterfaceOpTypeTag Root Constructors
51 
53  PatternBenefit benefit, MLIRContext *context,
54  ArrayRef<StringRef> generatedNames)
55  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
56  generatedNames, benefit, context) {}
57 
58 //===----------------------------------------------------------------------===//
59 // MatchTraitOpTypeTag Root Constructors
60 
62  PatternBenefit benefit, MLIRContext *context,
63  ArrayRef<StringRef> generatedNames)
64  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
65  benefit, context) {}
66 
67 //===----------------------------------------------------------------------===//
68 // General Constructors
69 
70 Pattern::Pattern(const void *rootValue, RootKind rootKind,
71  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
72  MLIRContext *context)
73  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
74  contextAndHasBoundedRecursion(context, false) {
75  if (generatedNames.empty())
76  return;
77  generatedOps.reserve(generatedNames.size());
78  std::transform(generatedNames.begin(), generatedNames.end(),
79  std::back_inserter(generatedOps), [context](StringRef name) {
80  return OperationName(name, context);
81  });
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // RewritePattern
86 //===----------------------------------------------------------------------===//
87 
89  llvm_unreachable("need to implement either matchAndRewrite or one of the "
90  "rewrite functions!");
91 }
92 
94  llvm_unreachable("need to implement either match or matchAndRewrite!");
95 }
96 
97 /// Out-of-line vtable anchor.
98 void RewritePattern::anchor() {}
99 
100 //===----------------------------------------------------------------------===//
101 // PDLValue
102 //===----------------------------------------------------------------------===//
103 
104 void PDLValue::print(raw_ostream &os) const {
105  if (!value) {
106  os << "<NULL-PDLValue>";
107  return;
108  }
109  switch (kind) {
110  case Kind::Attribute:
111  os << cast<Attribute>();
112  break;
113  case Kind::Operation:
114  os << *cast<Operation *>();
115  break;
116  case Kind::Type:
117  os << cast<Type>();
118  break;
119  case Kind::TypeRange:
120  llvm::interleaveComma(cast<TypeRange>(), os);
121  break;
122  case Kind::Value:
123  os << cast<Value>();
124  break;
125  case Kind::ValueRange:
126  llvm::interleaveComma(cast<ValueRange>(), os);
127  break;
128  }
129 }
130 
131 void PDLValue::print(raw_ostream &os, Kind kind) {
132  switch (kind) {
133  case Kind::Attribute:
134  os << "Attribute";
135  break;
136  case Kind::Operation:
137  os << "Operation";
138  break;
139  case Kind::Type:
140  os << "Type";
141  break;
142  case Kind::TypeRange:
143  os << "TypeRange";
144  break;
145  case Kind::Value:
146  os << "Value";
147  break;
148  case Kind::ValueRange:
149  os << "ValueRange";
150  break;
151  }
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // PDLPatternModule
156 //===----------------------------------------------------------------------===//
157 
159  // Ignore the other module if it has no patterns.
160  if (!other.pdlModule)
161  return;
162 
163  // Steal the functions and config of the other module.
164  for (auto &it : other.constraintFunctions)
165  registerConstraintFunction(it.first(), std::move(it.second));
166  for (auto &it : other.rewriteFunctions)
167  registerRewriteFunction(it.first(), std::move(it.second));
168  for (auto &it : other.configs)
169  configs.emplace_back(std::move(it));
170  for (auto &it : other.configMap)
171  configMap.insert(it);
172 
173  // Steal the other state if we have no patterns.
174  if (!pdlModule) {
175  pdlModule = std::move(other.pdlModule);
176  return;
177  }
178 
179  // Merge the pattern operations from the other module into this one.
180  Block *block = pdlModule->getBody();
181  block->getOperations().splice(block->end(),
182  other.pdlModule->getBody()->getOperations());
183 }
184 
185 void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
186  PDLPatternConfigSet &configSet) {
187  // Attach the configuration to the symbols within the module. We only add
188  // to symbols to avoid hardcoding any specific operation names here (given
189  // that we don't depend on any PDL dialect). We can't use
190  // cast<SymbolOpInterface> here because patterns may be optional symbols.
191  module->walk([&](Operation *op) {
192  if (op->hasTrait<SymbolOpInterface::Trait>())
193  configMap[op] = &configSet;
194  });
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // Function Registry
199 
201  StringRef name, PDLConstraintFunction constraintFn) {
202  // TODO: Is it possible to diagnose when `name` is already registered to
203  // a function that is not equivalent to `constraintFn`?
204  // Allow existing mappings in the case multiple patterns depend on the same
205  // constraint.
206  constraintFunctions.try_emplace(name, std::move(constraintFn));
207 }
208 
210  PDLRewriteFunction rewriteFn) {
211  // TODO: Is it possible to diagnose when `name` is already registered to
212  // a function that is not equivalent to `rewriteFn`?
213  // Allow existing mappings in the case multiple patterns depend on the same
214  // rewrite.
215  rewriteFunctions.try_emplace(name, std::move(rewriteFn));
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // RewriterBase
220 //===----------------------------------------------------------------------===//
221 
224 }
225 
227  // Out of line to provide a vtable anchor for the class.
228 }
229 
230 /// This method replaces the uses of the results of `op` with the values in
231 /// `newValues` when the provided `functor` returns true for a specific use.
232 /// The number of values in `newValues` is required to match the number of
233 /// results of `op`.
235  Operation *op, ValueRange newValues, bool *allUsesReplaced,
236  llvm::unique_function<bool(OpOperand &) const> functor) {
237  assert(op->getNumResults() == newValues.size() &&
238  "incorrect number of values to replace operation");
239 
240  // Notify the listener that we're about to replace this op.
241  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
242  rewriteListener->notifyOperationReplaced(op, newValues);
243 
244  // Replace each use of the results when the functor is true.
245  bool replacedAllUses = true;
246  for (auto it : llvm::zip(op->getResults(), newValues)) {
247  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
248  replacedAllUses &= std::get<0>(it).use_empty();
249  }
250  if (allUsesReplaced)
251  *allUsesReplaced = replacedAllUses;
252 }
253 
254 /// This method replaces the uses of the results of `op` with the values in
255 /// `newValues` when a use is nested within the given `block`. The number of
256 /// values in `newValues` is required to match the number of results of `op`.
257 /// If all uses of this operation are replaced, the operation is erased.
259  Block *block, bool *allUsesReplaced) {
260  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
261  return block->getParentOp()->isProperAncestor(use.getOwner());
262  });
263 }
264 
265 /// This method replaces the results of the operation with the specified list of
266 /// values. The number of provided values must match the number of results of
267 /// the operation. The replaced op is erased.
269  assert(op->getNumResults() == newValues.size() &&
270  "incorrect # of replacement values");
271 
272  // Notify the listener that we're about to replace this op.
273  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
274  rewriteListener->notifyOperationReplaced(op, newValues);
275 
276  // Replace results one-by-one. Also notifies the listener of modifications.
277  for (auto it : llvm::zip(op->getResults(), newValues))
278  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
279 
280  // Erase op and notify listener.
281  eraseOp(op);
282 }
283 
284 /// This method replaces the results of the operation with the specified new op
285 /// (replacement). The number of results of the two operations must match. The
286 /// replaced op is erased.
288  assert(op && newOp && "expected non-null op");
289  assert(op->getNumResults() == newOp->getNumResults() &&
290  "ops have different number of results");
291 
292  // Notify the listener that we're about to replace this op.
293  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
294  rewriteListener->notifyOperationReplaced(op, newOp);
295 
296  // Replace results one-by-one. Also notifies the listener of modifications.
297  for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
298  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
299 
300  // Erase op and notify listener.
301  eraseOp(op);
302 }
303 
304 /// This method erases an operation that is known to have no uses. The uses of
305 /// the given operation *must* be known to be dead.
307  assert(op->use_empty() && "expected 'op' to have no uses");
308  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
309 
310  // Fast path: If no listener is attached, the op can be dropped in one go.
311  if (!rewriteListener) {
312  op->erase();
313  return;
314  }
315 
316  // Helper function that erases a single op.
317  auto eraseSingleOp = [&](Operation *op) {
318 #ifndef NDEBUG
319  // All nested ops should have been erased already.
320  assert(
321  llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
322  "expected empty regions");
323  // All users should have been erased already if the op is in a region with
324  // SSA dominance.
325  if (!op->use_empty() && op->getParentOp())
326  assert(mayBeGraphRegion(*op->getParentRegion()) &&
327  "expected that op has no uses");
328 #endif // NDEBUG
329  rewriteListener->notifyOperationRemoved(op);
330 
331  // Explicitly drop all uses in case the op is in a graph region.
332  op->dropAllUses();
333  op->erase();
334  };
335 
336  // Nested ops must be erased one-by-one, so that listeners have a consistent
337  // view of the IR every time a notification is triggered. Users must be
338  // erased before definitions. I.e., post-order, reverse dominance.
339  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
340  // Erase nested ops.
341  for (Region &r : llvm::reverse(op->getRegions())) {
342  // Erase all blocks in the right order. Successors should be erased
343  // before predecessors because successor blocks may use values defined
344  // in predecessor blocks. A post-order traversal of blocks within a
345  // region visits successors before predecessors. Repeat the traversal
346  // until the region is empty. (The block graph could be disconnected.)
347  while (!r.empty()) {
348  SmallVector<Block *> erasedBlocks;
349  for (Block *b : llvm::post_order(&r.front())) {
350  // Visit ops in reverse order.
351  for (Operation &op :
352  llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
353  eraseTree(&op);
354  // Do not erase the block immediately. This is not supprted by the
355  // post_order iterator.
356  erasedBlocks.push_back(b);
357  }
358  for (Block *b : erasedBlocks) {
359  // Explicitly drop all uses in case there is a cycle in the block
360  // graph.
361  for (BlockArgument bbArg : b->getArguments())
362  bbArg.dropAllUses();
363  b->dropAllUses();
364  b->erase();
365  }
366  }
367  }
368  // Then erase the enclosing op.
369  eraseSingleOp(op);
370  };
371 
372  eraseTree(op);
373 }
374 
376  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
377  assert(op.use_empty() && "expected 'op' to have no uses");
378  eraseOp(&op);
379  }
380  block->erase();
381 }
382 
384  // Notify the listener that the operation was modified.
385  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
386  rewriteListener->notifyOperationModified(op);
387 }
388 
389 /// Find uses of `from` and replace them with `to` if the `functor` returns
390 /// true. It also marks every modified uses and notifies the rewriter that an
391 /// in-place operation modification is about to happen.
393  function_ref<bool(OpOperand &)> functor) {
394  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
395  if (functor(operand))
396  updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); });
397  }
398 }
399 
401  Block::iterator before,
402  ValueRange argValues) {
403  assert(argValues.size() == source->getNumArguments() &&
404  "incorrect # of argument replacement values");
405 
406  // The source block will be deleted, so it should not have any users (i.e.,
407  // there should be no predecessors).
408  assert(source->hasNoPredecessors() &&
409  "expected 'source' to have no predecessors");
410 
411  if (dest->end() != before) {
412  // The source block will be inserted in the middle of the dest block, so
413  // the source block should have no successors. Otherwise, the remainder of
414  // the dest block would be unreachable.
415  assert(source->hasNoSuccessors() &&
416  "expected 'source' to have no successors");
417  } else {
418  // The source block will be inserted at the end of the dest block, so the
419  // dest block should have no successors. Otherwise, the inserted operations
420  // will be unreachable.
421  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
422  }
423 
424  // Replace all of the successor arguments with the provided values.
425  for (auto it : llvm::zip(source->getArguments(), argValues))
426  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
427 
428  // Move operations from the source block to the dest block and erase the
429  // source block.
430  dest->getOperations().splice(before, source->getOperations());
431  source->erase();
432 }
433 
435  ValueRange argValues) {
436  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
437 }
438 
440  ValueRange argValues) {
441  inlineBlockBefore(source, dest, dest->end(), argValues);
442 }
443 
444 /// Split the operations starting at "before" (inclusive) out of the given
445 /// block into a new block, and return it.
447  return block->splitBlock(before);
448 }
449 
450 /// Move the blocks that belong to "region" before the given position in
451 /// another region. The two regions must be different. The caller is in
452 /// charge to update create the operation transferring the control flow to the
453 /// region and pass it the correct block arguments.
455  Region::iterator before) {
456  parent.getBlocks().splice(before, region.getBlocks());
457 }
459  inlineRegionBefore(region, *before->getParent(), before->getIterator());
460 }
461 
462 /// Clone the blocks that belong to "region" before the given position in
463 /// another region "parent". The two regions must be different. The caller is
464 /// responsible for creating or updating the operation transferring flow of
465 /// control to the region and passing it the correct block arguments.
467  Region::iterator before,
468  IRMapping &mapping) {
469  region.cloneInto(&parent, before, mapping);
470 }
472  Region::iterator before) {
473  IRMapping mapping;
474  cloneRegionBefore(region, parent, before, mapping);
475 }
477  cloneRegionBefore(region, *before->getParent(), before->getIterator());
478 }
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:238
unsigned getNumArguments()
Definition: Block.h:121
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:60
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:302
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:130
BlockArgListType getArguments()
Definition: Block.h:80
iterator end()
Definition: Block.h:137
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:235
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:574
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:831
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:813
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:979
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
void print(raw_ostream &os) const
Print this value to the provided output stream.
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:750
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
bool isImpossibleToMatch() const
Definition: PatternMatch.h:43
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
BlockListType & getBlocks()
Definition: Region.h:45
BlockListType::iterator iterator
Definition: Region.h:52
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Include the generated interface declarations.
bool mayBeGraphRegion(Region &region)
Return "true" if the given region may be a graph region without SSA dominance.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:158
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:168
static constexpr auto makeIterable(RangeT &&range)
Definition: Iterators.h:32
static bool classof(const OpBuilder::Listener *base)