MLIR  17.0.0git
PatternMatch.cpp
Go to the documentation of this file.
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/IRMapping.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17 
18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20  "This pattern match benefit is too large to represent");
21 }
22 
23 unsigned short PatternBenefit::getBenefit() const {
24  assert(!isImpossibleToMatch() && "Pattern doesn't match");
25  return representation;
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31 
32 //===----------------------------------------------------------------------===//
33 // OperationName Root Constructors
34 
35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
36  MLIRContext *context, ArrayRef<StringRef> generatedNames)
37  : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
38  RootKind::OperationName, generatedNames, benefit, context) {}
39 
40 //===----------------------------------------------------------------------===//
41 // MatchAnyOpTypeTag Root Constructors
42 
44  MLIRContext *context, ArrayRef<StringRef> generatedNames)
45  : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
46 
47 //===----------------------------------------------------------------------===//
48 // MatchInterfaceOpTypeTag Root Constructors
49 
51  PatternBenefit benefit, MLIRContext *context,
52  ArrayRef<StringRef> generatedNames)
53  : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
54  generatedNames, benefit, context) {}
55 
56 //===----------------------------------------------------------------------===//
57 // MatchTraitOpTypeTag Root Constructors
58 
60  PatternBenefit benefit, MLIRContext *context,
61  ArrayRef<StringRef> generatedNames)
62  : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
63  benefit, context) {}
64 
65 //===----------------------------------------------------------------------===//
66 // General Constructors
67 
68 Pattern::Pattern(const void *rootValue, RootKind rootKind,
69  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
70  MLIRContext *context)
71  : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
72  contextAndHasBoundedRecursion(context, false) {
73  if (generatedNames.empty())
74  return;
75  generatedOps.reserve(generatedNames.size());
76  std::transform(generatedNames.begin(), generatedNames.end(),
77  std::back_inserter(generatedOps), [context](StringRef name) {
78  return OperationName(name, context);
79  });
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // RewritePattern
84 //===----------------------------------------------------------------------===//
85 
87  llvm_unreachable("need to implement either matchAndRewrite or one of the "
88  "rewrite functions!");
89 }
90 
92  llvm_unreachable("need to implement either match or matchAndRewrite!");
93 }
94 
95 /// Out-of-line vtable anchor.
96 void RewritePattern::anchor() {}
97 
98 //===----------------------------------------------------------------------===//
99 // PDLValue
100 //===----------------------------------------------------------------------===//
101 
102 void PDLValue::print(raw_ostream &os) const {
103  if (!value) {
104  os << "<NULL-PDLValue>";
105  return;
106  }
107  switch (kind) {
108  case Kind::Attribute:
109  os << cast<Attribute>();
110  break;
111  case Kind::Operation:
112  os << *cast<Operation *>();
113  break;
114  case Kind::Type:
115  os << cast<Type>();
116  break;
117  case Kind::TypeRange:
118  llvm::interleaveComma(cast<TypeRange>(), os);
119  break;
120  case Kind::Value:
121  os << cast<Value>();
122  break;
123  case Kind::ValueRange:
124  llvm::interleaveComma(cast<ValueRange>(), os);
125  break;
126  }
127 }
128 
129 void PDLValue::print(raw_ostream &os, Kind kind) {
130  switch (kind) {
131  case Kind::Attribute:
132  os << "Attribute";
133  break;
134  case Kind::Operation:
135  os << "Operation";
136  break;
137  case Kind::Type:
138  os << "Type";
139  break;
140  case Kind::TypeRange:
141  os << "TypeRange";
142  break;
143  case Kind::Value:
144  os << "Value";
145  break;
146  case Kind::ValueRange:
147  os << "ValueRange";
148  break;
149  }
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // PDLPatternModule
154 //===----------------------------------------------------------------------===//
155 
157  // Ignore the other module if it has no patterns.
158  if (!other.pdlModule)
159  return;
160 
161  // Steal the functions and config of the other module.
162  for (auto &it : other.constraintFunctions)
163  registerConstraintFunction(it.first(), std::move(it.second));
164  for (auto &it : other.rewriteFunctions)
165  registerRewriteFunction(it.first(), std::move(it.second));
166  for (auto &it : other.configs)
167  configs.emplace_back(std::move(it));
168  for (auto &it : other.configMap)
169  configMap.insert(it);
170 
171  // Steal the other state if we have no patterns.
172  if (!pdlModule) {
173  pdlModule = std::move(other.pdlModule);
174  return;
175  }
176 
177  // Merge the pattern operations from the other module into this one.
178  Block *block = pdlModule->getBody();
179  block->getOperations().splice(block->end(),
180  other.pdlModule->getBody()->getOperations());
181 }
182 
183 void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
184  PDLPatternConfigSet &configSet) {
185  // Attach the configuration to the symbols within the module. We only add
186  // to symbols to avoid hardcoding any specific operation names here (given
187  // that we don't depend on any PDL dialect). We can't use
188  // cast<SymbolOpInterface> here because patterns may be optional symbols.
189  module->walk([&](Operation *op) {
190  if (op->hasTrait<SymbolOpInterface::Trait>())
191  configMap[op] = &configSet;
192  });
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Function Registry
197 
199  StringRef name, PDLConstraintFunction constraintFn) {
200  // TODO: Is it possible to diagnose when `name` is already registered to
201  // a function that is not equivalent to `constraintFn`?
202  // Allow existing mappings in the case multiple patterns depend on the same
203  // constraint.
204  constraintFunctions.try_emplace(name, std::move(constraintFn));
205 }
206 
208  PDLRewriteFunction rewriteFn) {
209  // TODO: Is it possible to diagnose when `name` is already registered to
210  // a function that is not equivalent to `rewriteFn`?
211  // Allow existing mappings in the case multiple patterns depend on the same
212  // rewrite.
213  rewriteFunctions.try_emplace(name, std::move(rewriteFn));
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // RewriterBase
218 //===----------------------------------------------------------------------===//
219 
222 }
223 
225  // Out of line to provide a vtable anchor for the class.
226 }
227 
228 /// This method replaces the uses of the results of `op` with the values in
229 /// `newValues` when the provided `functor` returns true for a specific use.
230 /// The number of values in `newValues` is required to match the number of
231 /// results of `op`.
233  Operation *op, ValueRange newValues, bool *allUsesReplaced,
234  llvm::unique_function<bool(OpOperand &) const> functor) {
235  assert(op->getNumResults() == newValues.size() &&
236  "incorrect number of values to replace operation");
237 
238  // Notify the listener that we're about to replace this op.
239  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
240  rewriteListener->notifyOperationReplaced(op, newValues);
241 
242  // Replace each use of the results when the functor is true.
243  bool replacedAllUses = true;
244  for (auto it : llvm::zip(op->getResults(), newValues)) {
245  replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
246  replacedAllUses &= std::get<0>(it).use_empty();
247  }
248  if (allUsesReplaced)
249  *allUsesReplaced = replacedAllUses;
250 }
251 
252 /// This method replaces the uses of the results of `op` with the values in
253 /// `newValues` when a use is nested within the given `block`. The number of
254 /// values in `newValues` is required to match the number of results of `op`.
255 /// If all uses of this operation are replaced, the operation is erased.
257  Block *block, bool *allUsesReplaced) {
258  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
259  return block->getParentOp()->isProperAncestor(use.getOwner());
260  });
261 }
262 
263 /// This method replaces the results of the operation with the specified list of
264 /// values. The number of provided values must match the number of results of
265 /// the operation.
267  assert(op->getNumResults() == newValues.size() &&
268  "incorrect # of replacement values");
269 
270  // Notify the listener that we're about to remove this op.
271  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
272  rewriteListener->notifyOperationReplaced(op, newValues);
273 
274  // Replace results one-by-one. Also notifies the listener of modifications.
275  for (auto it : llvm::zip(op->getResults(), newValues))
276  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
277 
278  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
279  rewriteListener->notifyOperationRemoved(op);
280  op->erase();
281 }
282 
283 /// This method erases an operation that is known to have no uses. The uses of
284 /// the given operation *must* be known to be dead.
286  assert(op->use_empty() && "expected 'op' to have no uses");
287  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
288  rewriteListener->notifyOperationRemoved(op);
289  op->erase();
290 }
291 
293  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
294  assert(op.use_empty() && "expected 'op' to have no uses");
295  eraseOp(&op);
296  }
297  block->erase();
298 }
299 
301  // Notify the listener that the operation was modified.
302  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
303  rewriteListener->notifyOperationModified(op);
304 }
305 
306 /// Find uses of `from` and replace them with `to` if the `functor` returns
307 /// true. It also marks every modified uses and notifies the rewriter that an
308 /// in-place operation modification is about to happen.
310  function_ref<bool(OpOperand &)> functor) {
311  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
312  if (functor(operand))
313  updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); });
314  }
315 }
316 
318  Block::iterator before,
319  ValueRange argValues) {
320  assert(argValues.size() == source->getNumArguments() &&
321  "incorrect # of argument replacement values");
322 
323  // The source block will be deleted, so it should not have any users (i.e.,
324  // there should be no predecessors).
325  assert(source->hasNoPredecessors() &&
326  "expected 'source' to have no predecessors");
327 
328  if (dest->end() != before) {
329  // The source block will be inserted in the middle of the dest block, so
330  // the source block should have no successors. Otherwise, the remainder of
331  // the dest block would be unreachable.
332  assert(source->hasNoSuccessors() &&
333  "expected 'source' to have no successors");
334  } else {
335  // The source block will be inserted at the end of the dest block, so the
336  // dest block should have no successors. Otherwise, the inserted operations
337  // will be unreachable.
338  assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
339  }
340 
341  // Replace all of the successor arguments with the provided values.
342  for (auto it : llvm::zip(source->getArguments(), argValues))
343  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
344 
345  // Move operations from the source block to the dest block and erase the
346  // source block.
347  dest->getOperations().splice(before, source->getOperations());
348  source->erase();
349 }
350 
352  ValueRange argValues) {
353  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
354 }
355 
357  ValueRange argValues) {
358  inlineBlockBefore(source, dest, dest->end(), argValues);
359 }
360 
361 /// Split the operations starting at "before" (inclusive) out of the given
362 /// block into a new block, and return it.
364  return block->splitBlock(before);
365 }
366 
367 /// 'op' and 'newOp' are known to have the same number of results, replace the
368 /// uses of op with uses of newOp
369 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
370  Operation *newOp) {
371  assert(op->getNumResults() == newOp->getNumResults() &&
372  "replacement op doesn't match results of original op");
373  if (op->getNumResults() == 1)
374  return replaceOp(op, newOp->getResult(0));
375  return replaceOp(op, newOp->getResults());
376 }
377 
378 /// Move the blocks that belong to "region" before the given position in
379 /// another region. The two regions must be different. The caller is in
380 /// charge to update create the operation transferring the control flow to the
381 /// region and pass it the correct block arguments.
383  Region::iterator before) {
384  parent.getBlocks().splice(before, region.getBlocks());
385 }
387  inlineRegionBefore(region, *before->getParent(), before->getIterator());
388 }
389 
390 /// Clone the blocks that belong to "region" before the given position in
391 /// another region "parent". The two regions must be different. The caller is
392 /// responsible for creating or updating the operation transferring flow of
393 /// control to the region and passing it the correct block arguments.
395  Region::iterator before,
396  IRMapping &mapping) {
397  region.cloneInto(&parent, before, mapping);
398 }
400  Region::iterator before) {
401  IRMapping mapping;
402  cloneRegionBefore(region, parent, before, mapping);
403 }
405  cloneRegionBefore(region, *before->getParent(), before->getIterator());
406 }
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:231
unsigned getNumArguments()
Definition: Block.h:117
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:291
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:126
BlockArgListType getArguments()
Definition: Block.h:76
iterator end()
Definition: Block.h:133
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:228
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:568
This class represents an operand of an operation.
Definition: Value.h:255
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:695
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:592
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:197
result_range getResults()
Definition: Operation.h:394
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:427
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:920
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
void print(raw_ostream &os) const
Print this value to the provided output stream.
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:691
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
bool isImpossibleToMatch() const
Definition: PatternMatch.h:43
PatternBenefit()=default
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
BlockListType & getBlocks()
Definition: Region.h:45
BlockListType::iterator iterator
Definition: Region.h:52
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:558
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:201
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
This header declares functions that assit transformations in the MemRef dialect.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:982
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:990
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
@ RewriterBaseListener
RewriterBase::Listener or user-derived class.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:279
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:158
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:168
static bool classof(const OpBuilder::Listener *base)