MLIR  16.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17 
18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20  "This pattern match benefit is too large to represent");
21 }
22 
23 unsigned short PatternBenefit::getBenefit() const {
24  assert(!isImpossibleToMatch() && "Pattern doesn't match");
25  return representation;
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31 
32 //===----------------------------------------------------------------------===//
33 // OperationName Root Constructors
34 
35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
36  MLIRContext *context, ArrayRef<StringRef> generatedNames)
37  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
38  RootKind::OperationName, generatedNames, benefit, context) {}
39 
40 //===----------------------------------------------------------------------===//
41 // MatchAnyOpTypeTag Root Constructors
42 
44  MLIRContext *context, ArrayRef<StringRef> generatedNames)
45  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
46 
47 //===----------------------------------------------------------------------===//
48 // MatchInterfaceOpTypeTag Root Constructors
49 
51  PatternBenefit benefit, MLIRContext *context,
52  ArrayRef<StringRef> generatedNames)
53  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
54  generatedNames, benefit, context) {}
55 
56 //===----------------------------------------------------------------------===//
57 // MatchTraitOpTypeTag Root Constructors
58 
60  PatternBenefit benefit, MLIRContext *context,
61  ArrayRef<StringRef> generatedNames)
62  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
63  benefit, context) {}
64 
65 //===----------------------------------------------------------------------===//
66 // General Constructors
67 
68 Pattern::Pattern(const void *rootValue, RootKind rootKind,
69  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
70  MLIRContext *context)
71  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
72  contextAndHasBoundedRecursion(context, false) {
73  if (generatedNames.empty())
74  return;
75  generatedOps.reserve(generatedNames.size());
76  std::transform(generatedNames.begin(), generatedNames.end(),
77  std::back_inserter(generatedOps), [context](StringRef name) {
78  return OperationName(name, context);
79  });
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // RewritePattern
84 //===----------------------------------------------------------------------===//
85 
87  llvm_unreachable("need to implement either matchAndRewrite or one of the "
88  "rewrite functions!");
89 }
90 
92  llvm_unreachable("need to implement either match or matchAndRewrite!");
93 }
94 
95 /// Out-of-line vtable anchor.
96 void RewritePattern::anchor() {}
97 
98 //===----------------------------------------------------------------------===//
99 // PDLValue
100 //===----------------------------------------------------------------------===//
101 
102 void PDLValue::print(raw_ostream &os) const {
103  if (!value) {
104  os << "<NULL-PDLValue>";
105  return;
106  }
107  switch (kind) {
108  case Kind::Attribute:
109  os << cast<Attribute>();
110  break;
111  case Kind::Operation:
112  os << *cast<Operation *>();
113  break;
114  case Kind::Type:
115  os << cast<Type>();
116  break;
117  case Kind::TypeRange:
118  llvm::interleaveComma(cast<TypeRange>(), os);
119  break;
120  case Kind::Value:
121  os << cast<Value>();
122  break;
123  case Kind::ValueRange:
124  llvm::interleaveComma(cast<ValueRange>(), os);
125  break;
126  }
127 }
128 
129 void PDLValue::print(raw_ostream &os, Kind kind) {
130  switch (kind) {
131  case Kind::Attribute:
132  os << "Attribute";
133  break;
134  case Kind::Operation:
135  os << "Operation";
136  break;
137  case Kind::Type:
138  os << "Type";
139  break;
140  case Kind::TypeRange:
141  os << "TypeRange";
142  break;
143  case Kind::Value:
144  os << "Value";
145  break;
146  case Kind::ValueRange:
147  os << "ValueRange";
148  break;
149  }
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // PDLPatternModule
154 //===----------------------------------------------------------------------===//
155 
157  // Ignore the other module if it has no patterns.
158  if (!other.pdlModule)
159  return;
160 
161  // Steal the functions and config of the other module.
162  for (auto &it : other.constraintFunctions)
163  registerConstraintFunction(it.first(), std::move(it.second));
164  for (auto &it : other.rewriteFunctions)
165  registerRewriteFunction(it.first(), std::move(it.second));
166  for (auto &it : other.configs)
167  configs.emplace_back(std::move(it));
168  for (auto &it : other.configMap)
169  configMap.insert(it);
170 
171  // Steal the other state if we have no patterns.
172  if (!pdlModule) {
173  pdlModule = std::move(other.pdlModule);
174  return;
175  }
176 
177  // Merge the pattern operations from the other module into this one.
178  Block *block = pdlModule->getBody();
179  block->getOperations().splice(block->end(),
180  other.pdlModule->getBody()->getOperations());
181 }
182 
183 void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
184  PDLPatternConfigSet &configSet) {
185  // Attach the configuration to the symbols within the module. We only add
186  // to symbols to avoid hardcoding any specific operation names here (given
187  // that we don't depend on any PDL dialect). We can't use
188  // cast<SymbolOpInterface> here because patterns may be optional symbols.
189  module->walk([&](Operation *op) {
190  if (op->hasTrait<SymbolOpInterface::Trait>())
191  configMap[op] = &configSet;
192  });
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Function Registry
197 
199  StringRef name, PDLConstraintFunction constraintFn) {
200  // TODO: Is it possible to diagnose when `name` is already registered to
201  // a function that is not equivalent to `constraintFn`?
202  // Allow existing mappings in the case multiple patterns depend on the same
203  // constraint.
204  constraintFunctions.try_emplace(name, std::move(constraintFn));
205 }
206 
208  PDLRewriteFunction rewriteFn) {
209  // TODO: Is it possible to diagnose when `name` is already registered to
210  // a function that is not equivalent to `rewriteFn`?
211  // Allow existing mappings in the case multiple patterns depend on the same
212  // rewrite.
213  rewriteFunctions.try_emplace(name, std::move(rewriteFn));
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // RewriterBase
218 //===----------------------------------------------------------------------===//
219 
221  // Out of line to provide a vtable anchor for the class.
222 }
223 
224 /// This method replaces the uses of the results of `op` with the values in
225 /// `newValues` when the provided `functor` returns true for a specific use.
226 /// The number of values in `newValues` is required to match the number of
227 /// results of `op`.
229  Operation *op, ValueRange newValues, bool *allUsesReplaced,
230  llvm::unique_function<bool(OpOperand &) const> functor) {
231  assert(op->getNumResults() == newValues.size() &&
232  "incorrect number of values to replace operation");
233 
234  // Notify the rewriter subclass that we're about to replace this root.
235  notifyRootReplaced(op, newValues);
236 
237  // Replace each use of the results when the functor is true.
238  bool replacedAllUses = true;
239  for (auto it : llvm::zip(op->getResults(), newValues)) {
240  std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
241  replacedAllUses &= std::get<0>(it).use_empty();
242  }
243  if (allUsesReplaced)
244  *allUsesReplaced = replacedAllUses;
245 }
246 
247 /// This method replaces the uses of the results of `op` with the values in
248 /// `newValues` when a use is nested within the given `block`. The number of
249 /// values in `newValues` is required to match the number of results of `op`.
250 /// If all uses of this operation are replaced, the operation is erased.
252  Block *block, bool *allUsesReplaced) {
253  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
254  return block->getParentOp()->isProperAncestor(use.getOwner());
255  });
256 }
257 
258 /// This method replaces the results of the operation with the specified list of
259 /// values. The number of provided values must match the number of results of
260 /// the operation.
262  // Notify the rewriter subclass that we're about to replace this root.
263  notifyRootReplaced(op, newValues);
264 
265  assert(op->getNumResults() == newValues.size() &&
266  "incorrect # of replacement values");
267  op->replaceAllUsesWith(newValues);
268 
270  op->erase();
271 }
272 
273 /// This method erases an operation that is known to have no uses. The uses of
274 /// the given operation *must* be known to be dead.
276  assert(op->use_empty() && "expected 'op' to have no uses");
278  op->erase();
279 }
280 
282  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
283  assert(op.use_empty() && "expected 'op' to have no uses");
284  eraseOp(&op);
285  }
286  block->erase();
287 }
288 
289 /// Merge the operations of block 'source' into the end of block 'dest'.
290 /// 'source's predecessors must be empty or only contain 'dest`.
291 /// 'argValues' is used to replace the block arguments of 'source' after
292 /// merging.
294  ValueRange argValues) {
295  assert(llvm::all_of(source->getPredecessors(),
296  [dest](Block *succ) { return succ == dest; }) &&
297  "expected 'source' to have no predecessors or only 'dest'");
298  assert(argValues.size() == source->getNumArguments() &&
299  "incorrect # of argument replacement values");
300 
301  // Replace all of the successor arguments with the provided values.
302  for (auto it : llvm::zip(source->getArguments(), argValues))
303  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
304 
305  // Splice the operations of the 'source' block into the 'dest' block and erase
306  // it.
307  dest->getOperations().splice(dest->end(), source->getOperations());
308  source->dropAllUses();
309  source->erase();
310 }
311 
312 /// Find uses of `from` and replace it with `to`
314  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
315  Operation *op = operand.getOwner();
316  updateRootInPlace(op, [&]() { operand.set(to); });
317  }
318 }
319 
320 /// Find uses of `from` and replace them with `to` except if the user is
321 /// `exceptedUser`. It also marks every modified uses and notifies the
322 /// rewriter that an in-place operation modification is about to happen.
324  Operation *exceptedUser) {
325  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
326  Operation *user = operand.getOwner();
327  if (user != exceptedUser)
328  updateRootInPlace(user, [&]() { operand.set(to); });
329  }
330 }
331 
332 // Merge the operations of block 'source' before the operation 'op'. Source
333 // block should not have existing predecessors or successors.
335  ValueRange argValues) {
336  assert(source->hasNoPredecessors() &&
337  "expected 'source' to have no predecessors");
338  assert(source->hasNoSuccessors() &&
339  "expected 'source' to have no successors");
340 
341  // Split the block containing 'op' into two, one containing all operations
342  // before 'op' (prologue) and another (epilogue) containing 'op' and all
343  // operations after it.
344  Block *prologue = op->getBlock();
345  Block *epilogue = splitBlock(prologue, op->getIterator());
346 
347  // Merge the source block at the end of the prologue.
348  mergeBlocks(source, prologue, argValues);
349 
350  // Merge the epilogue at the end the prologue.
351  mergeBlocks(epilogue, prologue);
352 }
353 
354 /// Split the operations starting at "before" (inclusive) out of the given
355 /// block into a new block, and return it.
357  return block->splitBlock(before);
358 }
359 
360 /// 'op' and 'newOp' are known to have the same number of results, replace the
361 /// uses of op with uses of newOp
362 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
363  Operation *newOp) {
364  assert(op->getNumResults() == newOp->getNumResults() &&
365  "replacement op doesn't match results of original op");
366  if (op->getNumResults() == 1)
367  return replaceOp(op, newOp->getResult(0));
368  return replaceOp(op, newOp->getResults());
369 }
370 
371 /// Move the blocks that belong to "region" before the given position in
372 /// another region. The two regions must be different. The caller is in
373 /// charge to update create the operation transferring the control flow to the
374 /// region and pass it the correct block arguments.
376  Region::iterator before) {
377  parent.getBlocks().splice(before, region.getBlocks());
378 }
380  inlineRegionBefore(region, *before->getParent(), before->getIterator());
381 }
382 
383 /// Clone the blocks that belong to "region" before the given position in
384 /// another region "parent". The two regions must be different. The caller is
385 /// responsible for creating or updating the operation transferring flow of
386 /// control to the region and passing it the correct block arguments.
388  Region::iterator before,
389  BlockAndValueMapping &mapping) {
390  region.cloneInto(&parent, before, mapping);
391 }
393  Region::iterator before) {
394  BlockAndValueMapping mapping;
395  cloneRegionBefore(region, parent, before, mapping);
396 }
398  cloneRegionBefore(region, *before->getParent(), before->getIterator());
399 }
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:231
unsigned getNumArguments()
Definition: Block.h:117
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:291
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:223
OpListType & getOperations()
Definition: Block.h:126
BlockArgListType getArguments()
Definition: Block.h:76
iterator end()
Definition: Block.h:133
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:228
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:179
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class represents an operand of an operation.
Definition: Value.h:247
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:633
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:532
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:203
result_range getResults()
Definition: Operation.h:332
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:862
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
void print(raw_ostream &os) const
Print this value to the provided output stream.
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:633
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
void cloneInto(Region *dest, BlockAndValueMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
BlockListType::iterator iterator
Definition: Region.h:52
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void notifyRootReplaced(Operation *op, ValueRange replacement)
These are the callback methods that subclasses can choose to implement if they would like to be notif...
Definition: PatternMatch.h:561
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Merge the operations of block 'source' into the end of block 'dest'.
virtual void notifyOperationRemoved(Operation *op)
This is called on an operation that a rewrite is removing, right before the operation is deleted.
Definition: PatternMatch.h:565
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=std::nullopt)
~RewriterBase() override
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:193
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
Include the generated interface declarations.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:924
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:932
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:157
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:162
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:167