MLIR  22.0.0git
FoldUtils.cpp
Go to the documentation of this file.
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 #include "llvm/Support/DebugLog.h"
20 
21 using namespace mlir;
22 
23 /// Given an operation, find the parent region that folded constants should be
24 /// inserted into.
25 static Region *
27  Block *insertionBlock) {
28  while (Region *region = insertionBlock->getParent()) {
29  // Insert in this region for any of the following scenarios:
30  // * The parent is unregistered, or is known to be isolated from above.
31  // * The parent is a top-level operation.
32  auto *parentOp = region->getParentOp();
33  if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
34  !parentOp->getBlock())
35  return region;
36 
37  // Otherwise, check if this region is a desired insertion region.
38  auto *interface = interfaces.getInterfaceFor(parentOp);
39  if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
40  return region;
41 
42  // Traverse up the parent looking for an insertion region.
43  insertionBlock = parentOp->getBlock();
44  }
45  llvm_unreachable("expected valid insertion region");
46 }
47 
48 /// A utility function used to materialize a constant for a given attribute and
49 /// type. On success, a valid constant value is returned. Otherwise, null is
50 /// returned
51 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
52  Attribute value, Type type,
53  Location loc) {
54  auto insertPt = builder.getInsertionPoint();
55  (void)insertPt;
56 
57  // Ask the dialect to materialize a constant operation for this value.
58  if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
59  assert(insertPt == builder.getInsertionPoint());
60  assert(matchPattern(constOp, m_Constant()));
61  return constOp;
62  }
63 
64  return nullptr;
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // OperationFolder
69 //===----------------------------------------------------------------------===//
70 
71 LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate,
72  int maxIterations) {
73  if (inPlaceUpdate)
74  *inPlaceUpdate = false;
75 
76  // If this is a unique'd constant, return failure as we know that it has
77  // already been folded.
78  if (isFolderOwnedConstant(op)) {
79  // Check to see if we should rehoist, i.e. if a non-constant operation was
80  // inserted before this one.
81  Block *opBlock = op->getBlock();
82  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
83  op->moveBefore(&opBlock->front());
84  op->setLoc(erasedFoldedLocation);
85  }
86  return failure();
87  }
88 
89  // Try to fold the operation.
90  SmallVector<Value, 8> results;
91  if (failed(tryToFold(op, results, maxIterations)))
92  return failure();
93 
94  // Check to see if the operation was just updated in place.
95  if (results.empty()) {
96  if (inPlaceUpdate)
97  *inPlaceUpdate = true;
98  if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
99  rewriter.getListener())) {
100  // Folding API does not notify listeners, so we have to notify manually.
101  rewriteListener->notifyOperationModified(op);
102  }
103  return success();
104  }
105 
106  // Constant folding succeeded. Replace all of the result values and erase the
107  // operation.
108  notifyRemoval(op);
109  rewriter.replaceOp(op, results);
110  return success();
111 }
112 
114  Block *opBlock = op->getBlock();
115 
116  // If this is a constant we unique'd, we don't need to insert, but we can
117  // check to see if we should rehoist it.
118  if (isFolderOwnedConstant(op)) {
119  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
120  op->moveBefore(&opBlock->front());
121  op->setLoc(erasedFoldedLocation);
122  }
123  return true;
124  }
125 
126  // Get the constant value of the op if necessary.
127  if (!constValue) {
128  matchPattern(op, m_Constant(&constValue));
129  assert(constValue && "expected `op` to be a constant");
130  } else {
131  // Ensure that the provided constant was actually correct.
132 #ifndef NDEBUG
133  Attribute expectedValue;
134  matchPattern(op, m_Constant(&expectedValue));
135  assert(
136  expectedValue == constValue &&
137  "provided constant value was not the expected value of the constant");
138 #endif
139  }
140 
141  // Check for an existing constant operation for the attribute value.
142  Region *insertRegion = getInsertionRegion(interfaces, opBlock);
143  auto &uniquedConstants = foldScopes[insertRegion];
144  Operation *&folderConstOp = uniquedConstants[std::make_tuple(
145  op->getDialect(), constValue, *op->result_type_begin())];
146 
147  // If there is an existing constant, replace `op`.
148  if (folderConstOp) {
149  notifyRemoval(op);
150  rewriter.replaceOp(op, folderConstOp->getResults());
151  folderConstOp->setLoc(erasedFoldedLocation);
152  return false;
153  }
154 
155  // Otherwise, we insert `op`. If `op` is in the insertion block and is either
156  // already at the front of the block, or the previous operation is already a
157  // constant we unique'd (i.e. one we inserted), then we don't need to do
158  // anything. Otherwise, we move the constant to the insertion block.
159  // The location info is erased if the constant is moved to a different block.
160  Block *insertBlock = &insertRegion->front();
161  if (opBlock != insertBlock) {
162  op->moveBefore(&insertBlock->front());
163  op->setLoc(erasedFoldedLocation);
164  } else if (&insertBlock->front() != op &&
165  !isFolderOwnedConstant(op->getPrevNode())) {
166  op->moveBefore(&insertBlock->front());
167  }
168 
169  folderConstOp = op;
170  referencedDialects[op].push_back(op->getDialect());
171  return true;
172 }
173 
174 /// Notifies that the given constant `op` should be remove from this
175 /// OperationFolder's internal bookkeeping.
177  // Check to see if this operation is uniqued within the folder.
178  auto it = referencedDialects.find(op);
179  if (it == referencedDialects.end())
180  return;
181 
182  // Get the constant value for this operation, this is the value that was used
183  // to unique the operation internally.
184  Attribute constValue;
185  matchPattern(op, m_Constant(&constValue));
186  assert(constValue);
187 
188  // Get the constant map that this operation was uniqued in.
189  auto &uniquedConstants =
190  foldScopes[getInsertionRegion(interfaces, op->getBlock())];
191 
192  // Erase all of the references to this operation.
193  auto type = op->getResult(0).getType();
194  for (auto *dialect : it->second)
195  uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
196  referencedDialects.erase(it);
197 }
198 
199 /// Clear out any constants cached inside of the folder.
201  foldScopes.clear();
202  referencedDialects.clear();
203 }
204 
205 /// Get or create a constant using the given builder. On success this returns
206 /// the constant operation, nullptr otherwise.
208  Attribute value, Type type) {
209  // Find an insertion point for the constant.
210  auto *insertRegion = getInsertionRegion(interfaces, block);
211  auto &entry = insertRegion->front();
212  rewriter.setInsertionPointToStart(&entry);
213 
214  // Get the constant map for the insertion region of this operation.
215  // Use erased location since the op is being built at the front of block.
216  auto &uniquedConstants = foldScopes[insertRegion];
217  Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
218  type, erasedFoldedLocation);
219  return constOp ? constOp->getResult(0) : Value();
220 }
221 
222 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
223  return referencedDialects.count(op);
224 }
225 
226 /// Tries to perform folding on the given `op`. If successful, populates
227 /// `results` with the results of the folding.
228 LogicalResult OperationFolder::tryToFold(Operation *op,
229  SmallVectorImpl<Value> &results,
230  int maxIterations) {
231  SmallVector<OpFoldResult, 8> foldResults;
232  if (failed(op->fold(foldResults)))
233  return failure();
234  int count = 1;
235  do {
236  LDBG() << "Folded in place #" << count
237  << " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
238  } while (count++ < maxIterations && foldResults.empty() &&
239  succeeded(op->fold(foldResults)));
240 
241  if (failed(processFoldResults(op, results, foldResults)))
242  return failure();
243  return success();
244 }
245 
246 LogicalResult
247 OperationFolder::processFoldResults(Operation *op,
248  SmallVectorImpl<Value> &results,
249  ArrayRef<OpFoldResult> foldResults) {
250  // Check to see if the operation was just updated in place.
251  if (foldResults.empty())
252  return success();
253  assert(foldResults.size() == op->getNumResults());
254 
255  // Create a builder to insert new operations into the entry block of the
256  // insertion region.
257  auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
258  auto &entry = insertRegion->front();
259  rewriter.setInsertionPointToStart(&entry);
260 
261  // Get the constant map for the insertion region of this operation.
262  auto &uniquedConstants = foldScopes[insertRegion];
263 
264  // Create the result constants and replace the results.
265  auto *dialect = op->getDialect();
266  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
267  assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
268 
269  // Check if the result was an SSA value.
270  if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
271  results.emplace_back(repl);
272  continue;
273  }
274 
275  // Check to see if there is a canonicalized version of this constant.
276  auto res = op->getResult(i);
277  Attribute attrRepl = cast<Attribute>(foldResults[i]);
278  if (auto *constOp =
279  tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
280  res.getType(), erasedFoldedLocation)) {
281  // Ensure that this constant dominates the operation we are replacing it
282  // with. This may not automatically happen if the operation being folded
283  // was inserted before the constant within the insertion block.
284  Block *opBlock = op->getBlock();
285  if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
286  constOp->moveBefore(&opBlock->front());
287 
288  results.push_back(constOp->getResult(0));
289  continue;
290  }
291  // If materialization fails, cleanup any operations generated for the
292  // previous results and return failure.
293  for (Operation &op : llvm::make_early_inc_range(
294  llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
295  notifyRemoval(&op);
296  rewriter.eraseOp(&op);
297  }
298 
299  results.clear();
300  return failure();
301  }
302 
303  return success();
304 }
305 
306 /// Try to get or create a new constant entry. On success this returns the
307 /// constant operation value, nullptr otherwise.
308 Operation *
309 OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
310  Dialect *dialect, Attribute value,
311  Type type, Location loc) {
312  // Check if an existing mapping already exists.
313  auto constKey = std::make_tuple(dialect, value, type);
314  Operation *&constOp = uniquedConstants[constKey];
315  if (constOp) {
316  if (loc != constOp->getLoc())
317  constOp->setLoc(erasedFoldedLocation);
318  return constOp;
319  }
320 
321  // If one doesn't exist, try to materialize one.
322  if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
323  return nullptr;
324 
325  // Check to see if the generated constant is in the expected dialect.
326  auto *newDialect = constOp->getDialect();
327  if (newDialect == dialect) {
328  referencedDialects[constOp].push_back(dialect);
329  return constOp;
330  }
331 
332  // If it isn't, then we also need to make sure that the mapping for the new
333  // dialect is valid.
334  auto newKey = std::make_tuple(newDialect, value, type);
335 
336  // If an existing operation in the new dialect already exists, delete the
337  // materialized operation in favor of the existing one.
338  if (auto *existingOp = uniquedConstants.lookup(newKey)) {
339  notifyRemoval(constOp);
340  rewriter.eraseOp(constOp);
341  referencedDialects[existingOp].push_back(dialect);
342  if (loc != existingOp->getLoc())
343  existingOp->setLoc(erasedFoldedLocation);
344  return constOp = existingOp;
345  }
346 
347  // Otherwise, update the new dialect to the materialized operation.
348  referencedDialects[constOp].assign({dialect, newDialect});
349  auto newIt = uniquedConstants.insert({newKey, constOp});
350  return newIt.first->second;
351 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static Region * getInsertionRegion(DialectInterfaceCollection< DialectFoldInterface > &interfaces, Block *insertionBlock)
Given an operation, find the parent region that folded constants should be inserted into.
Definition: FoldUtils.cpp:26
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation & front()
Definition: Block.h:153
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:54
A collection of dialect interfaces within a context, for a given concrete interface type.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, Type type)
Get or create a constant for use in the specified block.
Definition: FoldUtils.cpp:207
void clear()
Clear out any constants cached inside of the folder.
Definition: FoldUtils.cpp:200
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate=nullptr, int maxIterations=INT_MAX)
Tries to perform folding on the given op, including unifying deduplicated constants.
Definition: FoldUtils.cpp:71
void notifyRemoval(Operation *op)
Notifies that the given constant op should be remove from this OperationFolder's internal bookkeeping...
Definition: FoldUtils.cpp:176
bool insertKnownConstant(Operation *op, Attribute constValue={})
Tries to fold a pre-existing constant operation.
Definition: FoldUtils.cpp:113
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:634
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
result_type_iterator result_type_begin()
Definition: Operation.h:426
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369