MLIR  14.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17 
18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20  "This pattern match benefit is too large to represent");
21 }
22 
23 unsigned short PatternBenefit::getBenefit() const {
24  assert(!isImpossibleToMatch() && "Pattern doesn't match");
25  return representation;
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31 
32 //===----------------------------------------------------------------------===//
33 // OperationName Root Constructors
34 
35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
36  MLIRContext *context, ArrayRef<StringRef> generatedNames)
37  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
38  RootKind::OperationName, generatedNames, benefit, context) {}
39 
40 //===----------------------------------------------------------------------===//
41 // MatchAnyOpTypeTag Root Constructors
42 
44  MLIRContext *context, ArrayRef<StringRef> generatedNames)
45  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
46 
47 //===----------------------------------------------------------------------===//
48 // MatchInterfaceOpTypeTag Root Constructors
49 
51  PatternBenefit benefit, MLIRContext *context,
52  ArrayRef<StringRef> generatedNames)
53  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
54  generatedNames, benefit, context) {}
55 
56 //===----------------------------------------------------------------------===//
57 // MatchTraitOpTypeTag Root Constructors
58 
60  PatternBenefit benefit, MLIRContext *context,
61  ArrayRef<StringRef> generatedNames)
62  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
63  benefit, context) {}
64 
65 //===----------------------------------------------------------------------===//
66 // General Constructors
67 
68 Pattern::Pattern(const void *rootValue, RootKind rootKind,
69  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
70  MLIRContext *context)
71  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
72  contextAndHasBoundedRecursion(context, false) {
73  if (generatedNames.empty())
74  return;
75  generatedOps.reserve(generatedNames.size());
76  std::transform(generatedNames.begin(), generatedNames.end(),
77  std::back_inserter(generatedOps), [context](StringRef name) {
78  return OperationName(name, context);
79  });
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // RewritePattern
84 //===----------------------------------------------------------------------===//
85 
87  llvm_unreachable("need to implement either matchAndRewrite or one of the "
88  "rewrite functions!");
89 }
90 
92  llvm_unreachable("need to implement either match or matchAndRewrite!");
93 }
94 
95 /// Out-of-line vtable anchor.
96 void RewritePattern::anchor() {}
97 
98 //===----------------------------------------------------------------------===//
99 // PDLValue
100 //===----------------------------------------------------------------------===//
101 
102 void PDLValue::print(raw_ostream &os) const {
103  if (!value) {
104  os << "<NULL-PDLValue>";
105  return;
106  }
107  switch (kind) {
108  case Kind::Attribute:
109  os << cast<Attribute>();
110  break;
111  case Kind::Operation:
112  os << *cast<Operation *>();
113  break;
114  case Kind::Type:
115  os << cast<Type>();
116  break;
117  case Kind::TypeRange:
118  llvm::interleaveComma(cast<TypeRange>(), os);
119  break;
120  case Kind::Value:
121  os << cast<Value>();
122  break;
123  case Kind::ValueRange:
124  llvm::interleaveComma(cast<ValueRange>(), os);
125  break;
126  }
127 }
128 
129 void PDLValue::print(raw_ostream &os, Kind kind) {
130  switch (kind) {
131  case Kind::Attribute:
132  os << "Attribute";
133  break;
134  case Kind::Operation:
135  os << "Operation";
136  break;
137  case Kind::Type:
138  os << "Type";
139  break;
140  case Kind::TypeRange:
141  os << "TypeRange";
142  break;
143  case Kind::Value:
144  os << "Value";
145  break;
146  case Kind::ValueRange:
147  os << "ValueRange";
148  break;
149  }
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // PDLPatternModule
154 //===----------------------------------------------------------------------===//
155 
157  // Ignore the other module if it has no patterns.
158  if (!other.pdlModule)
159  return;
160 
161  // Steal the functions of the other module.
162  for (auto &it : other.constraintFunctions)
163  registerConstraintFunction(it.first(), std::move(it.second));
164  for (auto &it : other.rewriteFunctions)
165  registerRewriteFunction(it.first(), std::move(it.second));
166 
167  // Steal the other state if we have no patterns.
168  if (!pdlModule) {
169  pdlModule = std::move(other.pdlModule);
170  return;
171  }
172 
173  // Merge the pattern operations from the other module into this one.
174  Block *block = pdlModule->getBody();
175  block->getOperations().splice(block->end(),
176  other.pdlModule->getBody()->getOperations());
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Function Registry
181 
183  StringRef name, PDLConstraintFunction constraintFn) {
184  // TODO: Is it possible to diagnose when `name` is already registered to
185  // a function that is not equivalent to `constraintFn`?
186  // Allow existing mappings in the case multiple patterns depend on the same
187  // constraint.
188  constraintFunctions.try_emplace(name, std::move(constraintFn));
189 }
190 
192  PDLRewriteFunction rewriteFn) {
193  // TODO: Is it possible to diagnose when `name` is already registered to
194  // a function that is not equivalent to `rewriteFn`?
195  // Allow existing mappings in the case multiple patterns depend on the same
196  // rewrite.
197  rewriteFunctions.try_emplace(name, std::move(rewriteFn));
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // RewriterBase
202 //===----------------------------------------------------------------------===//
203 
205  // Out of line to provide a vtable anchor for the class.
206 }
207 
208 /// This method replaces the uses of the results of `op` with the values in
209 /// `newValues` when the provided `functor` returns true for a specific use.
210 /// The number of values in `newValues` is required to match the number of
211 /// results of `op`.
213  Operation *op, ValueRange newValues, bool *allUsesReplaced,
214  llvm::unique_function<bool(OpOperand &) const> functor) {
215  assert(op->getNumResults() == newValues.size() &&
216  "incorrect number of values to replace operation");
217 
218  // Notify the rewriter subclass that we're about to replace this root.
219  notifyRootReplaced(op);
220 
221  // Replace each use of the results when the functor is true.
222  bool replacedAllUses = true;
223  for (auto it : llvm::zip(op->getResults(), newValues)) {
224  std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
225  replacedAllUses &= std::get<0>(it).use_empty();
226  }
227  if (allUsesReplaced)
228  *allUsesReplaced = replacedAllUses;
229 }
230 
231 /// This method replaces the uses of the results of `op` with the values in
232 /// `newValues` when a use is nested within the given `block`. The number of
233 /// values in `newValues` is required to match the number of results of `op`.
234 /// If all uses of this operation are replaced, the operation is erased.
236  Block *block, bool *allUsesReplaced) {
237  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
238  return block->getParentOp()->isProperAncestor(use.getOwner());
239  });
240 }
241 
242 /// This method replaces the results of the operation with the specified list of
243 /// values. The number of provided values must match the number of results of
244 /// the operation.
246  // Notify the rewriter subclass that we're about to replace this root.
247  notifyRootReplaced(op);
248 
249  assert(op->getNumResults() == newValues.size() &&
250  "incorrect # of replacement values");
251  op->replaceAllUsesWith(newValues);
252 
253  notifyOperationRemoved(op);
254  op->erase();
255 }
256 
257 /// This method erases an operation that is known to have no uses. The uses of
258 /// the given operation *must* be known to be dead.
260  assert(op->use_empty() && "expected 'op' to have no uses");
261  notifyOperationRemoved(op);
262  op->erase();
263 }
264 
266  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
267  assert(op.use_empty() && "expected 'op' to have no uses");
268  eraseOp(&op);
269  }
270  block->erase();
271 }
272 
273 /// Merge the operations of block 'source' into the end of block 'dest'.
274 /// 'source's predecessors must be empty or only contain 'dest`.
275 /// 'argValues' is used to replace the block arguments of 'source' after
276 /// merging.
278  ValueRange argValues) {
279  assert(llvm::all_of(source->getPredecessors(),
280  [dest](Block *succ) { return succ == dest; }) &&
281  "expected 'source' to have no predecessors or only 'dest'");
282  assert(argValues.size() == source->getNumArguments() &&
283  "incorrect # of argument replacement values");
284 
285  // Replace all of the successor arguments with the provided values.
286  for (auto it : llvm::zip(source->getArguments(), argValues))
287  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
288 
289  // Splice the operations of the 'source' block into the 'dest' block and erase
290  // it.
291  dest->getOperations().splice(dest->end(), source->getOperations());
292  source->dropAllUses();
293  source->erase();
294 }
295 
296 // Merge the operations of block 'source' before the operation 'op'. Source
297 // block should not have existing predecessors or successors.
299  ValueRange argValues) {
300  assert(source->hasNoPredecessors() &&
301  "expected 'source' to have no predecessors");
302  assert(source->hasNoSuccessors() &&
303  "expected 'source' to have no successors");
304 
305  // Split the block containing 'op' into two, one containing all operations
306  // before 'op' (prologue) and another (epilogue) containing 'op' and all
307  // operations after it.
308  Block *prologue = op->getBlock();
309  Block *epilogue = splitBlock(prologue, op->getIterator());
310 
311  // Merge the source block at the end of the prologue.
312  mergeBlocks(source, prologue, argValues);
313 
314  // Merge the epilogue at the end the prologue.
315  mergeBlocks(epilogue, prologue);
316 }
317 
318 /// Split the operations starting at "before" (inclusive) out of the given
319 /// block into a new block, and return it.
321  return block->splitBlock(before);
322 }
323 
324 /// 'op' and 'newOp' are known to have the same number of results, replace the
325 /// uses of op with uses of newOp
326 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
327  Operation *newOp) {
328  assert(op->getNumResults() == newOp->getNumResults() &&
329  "replacement op doesn't match results of original op");
330  if (op->getNumResults() == 1)
331  return replaceOp(op, newOp->getResult(0));
332  return replaceOp(op, newOp->getResults());
333 }
334 
335 /// Move the blocks that belong to "region" before the given position in
336 /// another region. The two regions must be different. The caller is in
337 /// charge to update create the operation transferring the control flow to the
338 /// region and pass it the correct block arguments.
340  Region::iterator before) {
341  parent.getBlocks().splice(before, region.getBlocks());
342 }
344  inlineRegionBefore(region, *before->getParent(), before->getIterator());
345 }
346 
347 /// Clone the blocks that belong to "region" before the given position in
348 /// another region "parent". The two regions must be different. The caller is
349 /// responsible for creating or updating the operation transferring flow of
350 /// control to the region and passing it the correct block arguments.
352  Region::iterator before,
353  BlockAndValueMapping &mapping) {
354  region.cloneInto(&parent, before, mapping);
355 }
357  Region::iterator before) {
358  BlockAndValueMapping mapping;
359  cloneRegionBefore(region, parent, before, mapping);
360 }
362  cloneRegionBefore(region, *before->getParent(), before->getIterator());
363 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:162
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
BlockListType & getBlocks()
Definition: Region.h:45
Block represents an ordered list of Operations.
Definition: Block.h:29
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function.
OpListType & getOperations()
Definition: Block.h:128
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
BlockListType::iterator iterator
Definition: Region.h:52
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:225
~RewriterBase() override
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function.
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:52
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:174
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
OpListType::iterator iterator
Definition: Block.h:131
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
Definition: PatternMatch.h:610
iterator end()
Definition: Block.h:135
unsigned getNumArguments()
Definition: Block.h:119
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name...
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided &#39;values&#39;.
Definition: Operation.h:154
std::function< void(ArrayRef< PDLValue >, ArrayAttr, PatternRewriter &, PDLResultList &)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:598
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
This class acts as a special tag that makes the desire to match "any" operation type explicit...
Definition: PatternMatch.h:157
PatternBenefit()=default
BlockArgListType getArguments()
Definition: Block.h:76
void cloneInto(Region *dest, BlockAndValueMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:401
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:570
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:233
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:230
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:37
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
This class represents an operand of an operation.
Definition: Value.h:249
std::function< LogicalResult(ArrayRef< PDLValue >, ArrayAttr, PatternRewriter &)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:591
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:167
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
result_range getResults()
Definition: Operation.h:284
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class provides an abstraction over the different types of ranges over Values.
void print(raw_ostream &os) const
Print this value to the provided output stream.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:182
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:289