MLIR  18.0.0git
FoldUtils.cpp
Go to the documentation of this file.
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
26  Block *insertionBlock) {
27  while (Region *region = insertionBlock->getParent()) {
28  // Insert in this region for any of the following scenarios:
29  // * The parent is unregistered, or is known to be isolated from above.
30  // * The parent is a top-level operation.
31  auto *parentOp = region->getParentOp();
32  if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33  !parentOp->getBlock())
34  return region;
35 
36  // Otherwise, check if this region is a desired insertion region.
37  auto *interface = interfaces.getInterfaceFor(parentOp);
38  if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39  return region;
40 
41  // Traverse up the parent looking for an insertion region.
42  insertionBlock = parentOp->getBlock();
43  }
44  llvm_unreachable("expected valid insertion region");
45 }
46 
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51  Attribute value, Type type,
52  Location loc) {
53  auto insertPt = builder.getInsertionPoint();
54  (void)insertPt;
55 
56  // Ask the dialect to materialize a constant operation for this value.
57  if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58  assert(insertPt == builder.getInsertionPoint());
59  assert(matchPattern(constOp, m_Constant()));
60  return constOp;
61  }
62 
63  return nullptr;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69 
71  if (inPlaceUpdate)
72  *inPlaceUpdate = false;
73 
74  // If this is a unique'd constant, return failure as we know that it has
75  // already been folded.
76  if (isFolderOwnedConstant(op)) {
77  // Check to see if we should rehoist, i.e. if a non-constant operation was
78  // inserted before this one.
79  Block *opBlock = op->getBlock();
80  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
81  op->moveBefore(&opBlock->front());
82  return failure();
83  }
84 
85  // Try to fold the operation.
86  SmallVector<Value, 8> results;
87  if (failed(tryToFold(op, results)))
88  return failure();
89 
90  // Check to see if the operation was just updated in place.
91  if (results.empty()) {
92  if (inPlaceUpdate)
93  *inPlaceUpdate = true;
94  if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
95  rewriter.getListener())) {
96  // Folding API does not notify listeners, so we have to notify manually.
97  rewriteListener->notifyOperationModified(op);
98  }
99  return success();
100  }
101 
102  // Constant folding succeeded. Replace all of the result values and erase the
103  // operation.
104  notifyRemoval(op);
105  rewriter.replaceOp(op, results);
106  return success();
107 }
108 
110  Block *opBlock = op->getBlock();
111 
112  // If this is a constant we unique'd, we don't need to insert, but we can
113  // check to see if we should rehoist it.
114  if (isFolderOwnedConstant(op)) {
115  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
116  op->moveBefore(&opBlock->front());
117  return true;
118  }
119 
120  // Get the constant value of the op if necessary.
121  if (!constValue) {
122  matchPattern(op, m_Constant(&constValue));
123  assert(constValue && "expected `op` to be a constant");
124  } else {
125  // Ensure that the provided constant was actually correct.
126 #ifndef NDEBUG
127  Attribute expectedValue;
128  matchPattern(op, m_Constant(&expectedValue));
129  assert(
130  expectedValue == constValue &&
131  "provided constant value was not the expected value of the constant");
132 #endif
133  }
134 
135  // Check for an existing constant operation for the attribute value.
136  Region *insertRegion = getInsertionRegion(interfaces, opBlock);
137  auto &uniquedConstants = foldScopes[insertRegion];
138  Operation *&folderConstOp = uniquedConstants[std::make_tuple(
139  op->getDialect(), constValue, *op->result_type_begin())];
140 
141  // If there is an existing constant, replace `op`.
142  if (folderConstOp) {
143  notifyRemoval(op);
144  rewriter.replaceOp(op, folderConstOp->getResults());
145  return false;
146  }
147 
148  // Otherwise, we insert `op`. If `op` is in the insertion block and is either
149  // already at the front of the block, or the previous operation is already a
150  // constant we unique'd (i.e. one we inserted), then we don't need to do
151  // anything. Otherwise, we move the constant to the insertion block.
152  Block *insertBlock = &insertRegion->front();
153  if (opBlock != insertBlock || (&insertBlock->front() != op &&
154  !isFolderOwnedConstant(op->getPrevNode())))
155  op->moveBefore(&insertBlock->front());
156 
157  folderConstOp = op;
158  referencedDialects[op].push_back(op->getDialect());
159  return true;
160 }
161 
162 /// Notifies that the given constant `op` should be remove from this
163 /// OperationFolder's internal bookkeeping.
165  // Check to see if this operation is uniqued within the folder.
166  auto it = referencedDialects.find(op);
167  if (it == referencedDialects.end())
168  return;
169 
170  // Get the constant value for this operation, this is the value that was used
171  // to unique the operation internally.
172  Attribute constValue;
173  matchPattern(op, m_Constant(&constValue));
174  assert(constValue);
175 
176  // Get the constant map that this operation was uniqued in.
177  auto &uniquedConstants =
178  foldScopes[getInsertionRegion(interfaces, op->getBlock())];
179 
180  // Erase all of the references to this operation.
181  auto type = op->getResult(0).getType();
182  for (auto *dialect : it->second)
183  uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
184  referencedDialects.erase(it);
185 }
186 
187 /// Clear out any constants cached inside of the folder.
189  foldScopes.clear();
190  referencedDialects.clear();
191 }
192 
193 /// Get or create a constant using the given builder. On success this returns
194 /// the constant operation, nullptr otherwise.
196  Attribute value, Type type,
197  Location loc) {
198  // Find an insertion point for the constant.
199  auto *insertRegion = getInsertionRegion(interfaces, block);
200  auto &entry = insertRegion->front();
201  rewriter.setInsertionPoint(&entry, entry.begin());
202 
203  // Get the constant map for the insertion region of this operation.
204  auto &uniquedConstants = foldScopes[insertRegion];
205  Operation *constOp =
206  tryGetOrCreateConstant(uniquedConstants, dialect, value, type, loc);
207  return constOp ? constOp->getResult(0) : Value();
208 }
209 
210 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
211  return referencedDialects.count(op);
212 }
213 
214 /// Tries to perform folding on the given `op`. If successful, populates
215 /// `results` with the results of the folding.
217  SmallVectorImpl<Value> &results) {
218  SmallVector<OpFoldResult, 8> foldResults;
219  if (failed(op->fold(foldResults)) ||
220  failed(processFoldResults(op, results, foldResults)))
221  return failure();
222  return success();
223 }
224 
226 OperationFolder::processFoldResults(Operation *op,
227  SmallVectorImpl<Value> &results,
228  ArrayRef<OpFoldResult> foldResults) {
229  // Check to see if the operation was just updated in place.
230  if (foldResults.empty())
231  return success();
232  assert(foldResults.size() == op->getNumResults());
233 
234  // Create a builder to insert new operations into the entry block of the
235  // insertion region.
236  auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
237  auto &entry = insertRegion->front();
238  rewriter.setInsertionPoint(&entry, entry.begin());
239 
240  // Get the constant map for the insertion region of this operation.
241  auto &uniquedConstants = foldScopes[insertRegion];
242 
243  // Create the result constants and replace the results.
244  auto *dialect = op->getDialect();
245  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
246  assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
247 
248  // Check if the result was an SSA value.
249  if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
250  if (repl.getType() != op->getResult(i).getType()) {
251  results.clear();
252  return failure();
253  }
254  results.emplace_back(repl);
255  continue;
256  }
257 
258  // Check to see if there is a canonicalized version of this constant.
259  auto res = op->getResult(i);
260  Attribute attrRepl = foldResults[i].get<Attribute>();
261  if (auto *constOp = tryGetOrCreateConstant(
262  uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
263  // Ensure that this constant dominates the operation we are replacing it
264  // with. This may not automatically happen if the operation being folded
265  // was inserted before the constant within the insertion block.
266  Block *opBlock = op->getBlock();
267  if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
268  constOp->moveBefore(&opBlock->front());
269 
270  results.push_back(constOp->getResult(0));
271  continue;
272  }
273  // If materialization fails, cleanup any operations generated for the
274  // previous results and return failure.
275  for (Operation &op : llvm::make_early_inc_range(
276  llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
277  notifyRemoval(&op);
278  rewriter.eraseOp(&op);
279  }
280 
281  results.clear();
282  return failure();
283  }
284 
285  return success();
286 }
287 
288 /// Try to get or create a new constant entry. On success this returns the
289 /// constant operation value, nullptr otherwise.
290 Operation *
291 OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
292  Dialect *dialect, Attribute value,
293  Type type, Location loc) {
294  // Check if an existing mapping already exists.
295  auto constKey = std::make_tuple(dialect, value, type);
296  Operation *&constOp = uniquedConstants[constKey];
297  if (constOp)
298  return constOp;
299 
300  // If one doesn't exist, try to materialize one.
301  if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
302  return nullptr;
303 
304  // Check to see if the generated constant is in the expected dialect.
305  auto *newDialect = constOp->getDialect();
306  if (newDialect == dialect) {
307  referencedDialects[constOp].push_back(dialect);
308  return constOp;
309  }
310 
311  // If it isn't, then we also need to make sure that the mapping for the new
312  // dialect is valid.
313  auto newKey = std::make_tuple(newDialect, value, type);
314 
315  // If an existing operation in the new dialect already exists, delete the
316  // materialized operation in favor of the existing one.
317  if (auto *existingOp = uniquedConstants.lookup(newKey)) {
318  notifyRemoval(constOp);
319  rewriter.eraseOp(constOp);
320  referencedDialects[existingOp].push_back(dialect);
321  return constOp = existingOp;
322  }
323 
324  // Otherwise, update the new dialect to the materialized operation.
325  referencedDialects[constOp].assign({dialect, newDialect});
326  auto newIt = uniquedConstants.insert({newKey, constOp});
327  return newIt.first->second;
328 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static Region * getInsertionRegion(DialectInterfaceCollection< DialectFoldInterface > &interfaces, Block *insertionBlock)
Given an operation, find the parent region that folded constants should be inserted into.
Definition: FoldUtils.cpp:25
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation & front()
Definition: Block.h:146
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:53
A collection of dialect interfaces within a context, for a given concrete interface type.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:305
This class provides the API for ops that are known to be isolated from above.
Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, Type type, Location loc)
Get or create a constant for use in the specified block.
Definition: FoldUtils.cpp:195
void clear()
Clear out any constants cached inside of the folder.
Definition: FoldUtils.cpp:188
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate=nullptr)
Tries to perform folding on the given op, including unifying deduplicated constants.
Definition: FoldUtils.cpp:70
void notifyRemoval(Operation *op)
Notifies that the given constant op should be remove from this OperationFolder's internal bookkeeping...
Definition: FoldUtils.cpp:164
bool insertKnownConstant(Operation *op, Attribute constValue={})
Tries to fold a pre-existing constant operation.
Definition: FoldUtils.cpp:109
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:610
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
result_type_iterator result_type_begin()
Definition: Operation.h:421
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26