MLIR  19.0.0git
FoldUtils.cpp
Go to the documentation of this file.
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
26  Block *insertionBlock) {
27  while (Region *region = insertionBlock->getParent()) {
28  // Insert in this region for any of the following scenarios:
29  // * The parent is unregistered, or is known to be isolated from above.
30  // * The parent is a top-level operation.
31  auto *parentOp = region->getParentOp();
32  if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33  !parentOp->getBlock())
34  return region;
35 
36  // Otherwise, check if this region is a desired insertion region.
37  auto *interface = interfaces.getInterfaceFor(parentOp);
38  if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39  return region;
40 
41  // Traverse up the parent looking for an insertion region.
42  insertionBlock = parentOp->getBlock();
43  }
44  llvm_unreachable("expected valid insertion region");
45 }
46 
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51  Attribute value, Type type,
52  Location loc) {
53  auto insertPt = builder.getInsertionPoint();
54  (void)insertPt;
55 
56  // Ask the dialect to materialize a constant operation for this value.
57  if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58  assert(insertPt == builder.getInsertionPoint());
59  assert(matchPattern(constOp, m_Constant()));
60  return constOp;
61  }
62 
63  return nullptr;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69 
71  if (inPlaceUpdate)
72  *inPlaceUpdate = false;
73 
74  // If this is a unique'd constant, return failure as we know that it has
75  // already been folded.
76  if (isFolderOwnedConstant(op)) {
77  // Check to see if we should rehoist, i.e. if a non-constant operation was
78  // inserted before this one.
79  Block *opBlock = op->getBlock();
80  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
81  op->moveBefore(&opBlock->front());
82  op->setLoc(erasedFoldedLocation);
83  }
84  return failure();
85  }
86 
87  // Try to fold the operation.
88  SmallVector<Value, 8> results;
89  if (failed(tryToFold(op, results)))
90  return failure();
91 
92  // Check to see if the operation was just updated in place.
93  if (results.empty()) {
94  if (inPlaceUpdate)
95  *inPlaceUpdate = true;
96  if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
97  rewriter.getListener())) {
98  // Folding API does not notify listeners, so we have to notify manually.
99  rewriteListener->notifyOperationModified(op);
100  }
101  return success();
102  }
103 
104  // Constant folding succeeded. Replace all of the result values and erase the
105  // operation.
106  notifyRemoval(op);
107  rewriter.replaceOp(op, results);
108  return success();
109 }
110 
112  Block *opBlock = op->getBlock();
113 
114  // If this is a constant we unique'd, we don't need to insert, but we can
115  // check to see if we should rehoist it.
116  if (isFolderOwnedConstant(op)) {
117  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
118  op->moveBefore(&opBlock->front());
119  op->setLoc(erasedFoldedLocation);
120  }
121  return true;
122  }
123 
124  // Get the constant value of the op if necessary.
125  if (!constValue) {
126  matchPattern(op, m_Constant(&constValue));
127  assert(constValue && "expected `op` to be a constant");
128  } else {
129  // Ensure that the provided constant was actually correct.
130 #ifndef NDEBUG
131  Attribute expectedValue;
132  matchPattern(op, m_Constant(&expectedValue));
133  assert(
134  expectedValue == constValue &&
135  "provided constant value was not the expected value of the constant");
136 #endif
137  }
138 
139  // Check for an existing constant operation for the attribute value.
140  Region *insertRegion = getInsertionRegion(interfaces, opBlock);
141  auto &uniquedConstants = foldScopes[insertRegion];
142  Operation *&folderConstOp = uniquedConstants[std::make_tuple(
143  op->getDialect(), constValue, *op->result_type_begin())];
144 
145  // If there is an existing constant, replace `op`.
146  if (folderConstOp) {
147  notifyRemoval(op);
148  rewriter.replaceOp(op, folderConstOp->getResults());
149  folderConstOp->setLoc(erasedFoldedLocation);
150  return false;
151  }
152 
153  // Otherwise, we insert `op`. If `op` is in the insertion block and is either
154  // already at the front of the block, or the previous operation is already a
155  // constant we unique'd (i.e. one we inserted), then we don't need to do
156  // anything. Otherwise, we move the constant to the insertion block.
157  Block *insertBlock = &insertRegion->front();
158  if (opBlock != insertBlock || (&insertBlock->front() != op &&
159  !isFolderOwnedConstant(op->getPrevNode()))) {
160  op->moveBefore(&insertBlock->front());
161  op->setLoc(erasedFoldedLocation);
162  }
163 
164  folderConstOp = op;
165  referencedDialects[op].push_back(op->getDialect());
166  return true;
167 }
168 
169 /// Notifies that the given constant `op` should be remove from this
170 /// OperationFolder's internal bookkeeping.
172  // Check to see if this operation is uniqued within the folder.
173  auto it = referencedDialects.find(op);
174  if (it == referencedDialects.end())
175  return;
176 
177  // Get the constant value for this operation, this is the value that was used
178  // to unique the operation internally.
179  Attribute constValue;
180  matchPattern(op, m_Constant(&constValue));
181  assert(constValue);
182 
183  // Get the constant map that this operation was uniqued in.
184  auto &uniquedConstants =
185  foldScopes[getInsertionRegion(interfaces, op->getBlock())];
186 
187  // Erase all of the references to this operation.
188  auto type = op->getResult(0).getType();
189  for (auto *dialect : it->second)
190  uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
191  referencedDialects.erase(it);
192 }
193 
194 /// Clear out any constants cached inside of the folder.
196  foldScopes.clear();
197  referencedDialects.clear();
198 }
199 
200 /// Get or create a constant using the given builder. On success this returns
201 /// the constant operation, nullptr otherwise.
203  Attribute value, Type type) {
204  // Find an insertion point for the constant.
205  auto *insertRegion = getInsertionRegion(interfaces, block);
206  auto &entry = insertRegion->front();
207  rewriter.setInsertionPoint(&entry, entry.begin());
208 
209  // Get the constant map for the insertion region of this operation.
210  // Use erased location since the op is being built at the front of block.
211  auto &uniquedConstants = foldScopes[insertRegion];
212  Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
213  type, erasedFoldedLocation);
214  return constOp ? constOp->getResult(0) : Value();
215 }
216 
217 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
218  return referencedDialects.count(op);
219 }
220 
221 /// Tries to perform folding on the given `op`. If successful, populates
222 /// `results` with the results of the folding.
224  SmallVectorImpl<Value> &results) {
225  SmallVector<OpFoldResult, 8> foldResults;
226  if (failed(op->fold(foldResults)) ||
227  failed(processFoldResults(op, results, foldResults)))
228  return failure();
229  return success();
230 }
231 
233 OperationFolder::processFoldResults(Operation *op,
234  SmallVectorImpl<Value> &results,
235  ArrayRef<OpFoldResult> foldResults) {
236  // Check to see if the operation was just updated in place.
237  if (foldResults.empty())
238  return success();
239  assert(foldResults.size() == op->getNumResults());
240 
241  // Create a builder to insert new operations into the entry block of the
242  // insertion region.
243  auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
244  auto &entry = insertRegion->front();
245  rewriter.setInsertionPoint(&entry, entry.begin());
246 
247  // Get the constant map for the insertion region of this operation.
248  auto &uniquedConstants = foldScopes[insertRegion];
249 
250  // Create the result constants and replace the results.
251  auto *dialect = op->getDialect();
252  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
253  assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
254 
255  // Check if the result was an SSA value.
256  if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
257  results.emplace_back(repl);
258  continue;
259  }
260 
261  // Check to see if there is a canonicalized version of this constant.
262  auto res = op->getResult(i);
263  Attribute attrRepl = foldResults[i].get<Attribute>();
264  if (auto *constOp =
265  tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
266  res.getType(), erasedFoldedLocation)) {
267  // Ensure that this constant dominates the operation we are replacing it
268  // with. This may not automatically happen if the operation being folded
269  // was inserted before the constant within the insertion block.
270  Block *opBlock = op->getBlock();
271  if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
272  constOp->moveBefore(&opBlock->front());
273 
274  results.push_back(constOp->getResult(0));
275  continue;
276  }
277  // If materialization fails, cleanup any operations generated for the
278  // previous results and return failure.
279  for (Operation &op : llvm::make_early_inc_range(
280  llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
281  notifyRemoval(&op);
282  rewriter.eraseOp(&op);
283  }
284 
285  results.clear();
286  return failure();
287  }
288 
289  return success();
290 }
291 
292 /// Try to get or create a new constant entry. On success this returns the
293 /// constant operation value, nullptr otherwise.
294 Operation *
295 OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
296  Dialect *dialect, Attribute value,
297  Type type, Location loc) {
298  // Check if an existing mapping already exists.
299  auto constKey = std::make_tuple(dialect, value, type);
300  Operation *&constOp = uniquedConstants[constKey];
301  if (constOp) {
302  if (loc != constOp->getLoc())
303  constOp->setLoc(erasedFoldedLocation);
304  return constOp;
305  }
306 
307  // If one doesn't exist, try to materialize one.
308  if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
309  return nullptr;
310 
311  // Check to see if the generated constant is in the expected dialect.
312  auto *newDialect = constOp->getDialect();
313  if (newDialect == dialect) {
314  referencedDialects[constOp].push_back(dialect);
315  return constOp;
316  }
317 
318  // If it isn't, then we also need to make sure that the mapping for the new
319  // dialect is valid.
320  auto newKey = std::make_tuple(newDialect, value, type);
321 
322  // If an existing operation in the new dialect already exists, delete the
323  // materialized operation in favor of the existing one.
324  if (auto *existingOp = uniquedConstants.lookup(newKey)) {
325  notifyRemoval(constOp);
326  rewriter.eraseOp(constOp);
327  referencedDialects[existingOp].push_back(dialect);
328  if (loc != existingOp->getLoc())
329  existingOp->setLoc(erasedFoldedLocation);
330  return constOp = existingOp;
331  }
332 
333  // Otherwise, update the new dialect to the materialized operation.
334  referencedDialects[constOp].assign({dialect, newDialect});
335  auto newIt = uniquedConstants.insert({newKey, constOp});
336  return newIt.first->second;
337 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static Region * getInsertionRegion(DialectInterfaceCollection< DialectFoldInterface > &interfaces, Block *insertionBlock)
Given an operation, find the parent region that folded constants should be inserted into.
Definition: FoldUtils.cpp:25
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation & front()
Definition: Block.h:150
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:53
A collection of dialect interfaces within a context, for a given concrete interface type.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
This class provides the API for ops that are known to be isolated from above.
Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, Type type)
Get or create a constant for use in the specified block.
Definition: FoldUtils.cpp:202
void clear()
Clear out any constants cached inside of the folder.
Definition: FoldUtils.cpp:195
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate=nullptr)
Tries to perform folding on the given op, including unifying deduplicated constants.
Definition: FoldUtils.cpp:70
void notifyRemoval(Operation *op)
Notifies that the given constant op should be remove from this OperationFolder's internal bookkeeping...
Definition: FoldUtils.cpp:171
bool insertKnownConstant(Operation *op, Attribute constValue={})
Tries to fold a pre-existing constant operation.
Definition: FoldUtils.cpp:111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:632
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
result_type_iterator result_type_begin()
Definition: Operation.h:421
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26