MLIR 22.0.0git
FoldUtils.cpp
Go to the documentation of this file.
1//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines various operation fold utilities. These utilities are
10// intended to be used by passes to unify and simply their logic.
11//
12//===----------------------------------------------------------------------===//
13
15
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Operation.h"
19#include "llvm/Support/DebugLog.h"
20
21using namespace mlir;
22
23/// Given an operation, find the parent region that folded constants should be
24/// inserted into.
25static Region *
27 Block *insertionBlock) {
28 while (Region *region = insertionBlock->getParent()) {
29 // Insert in this region for any of the following scenarios:
30 // * The parent is unregistered, or is known to be isolated from above.
31 // * The parent is a top-level operation.
32 auto *parentOp = region->getParentOp();
33 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
34 !parentOp->getBlock())
35 return region;
36
37 // Otherwise, check if this region is a desired insertion region.
38 auto *interface = interfaces.getInterfaceFor(parentOp);
39 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
40 return region;
41
42 // Traverse up the parent looking for an insertion region.
43 insertionBlock = parentOp->getBlock();
44 }
45 llvm_unreachable("expected valid insertion region");
46}
47
48/// A utility function used to materialize a constant for a given attribute and
49/// type. On success, a valid constant value is returned. Otherwise, null is
50/// returned
52 Attribute value, Type type,
53 Location loc) {
54 auto insertPt = builder.getInsertionPoint();
55 (void)insertPt;
56
57 // Ask the dialect to materialize a constant operation for this value.
58 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
59 assert(insertPt == builder.getInsertionPoint());
60 assert(matchPattern(constOp, m_Constant()));
61 return constOp;
62 }
63
64 return nullptr;
65}
66
67//===----------------------------------------------------------------------===//
68// OperationFolder
69//===----------------------------------------------------------------------===//
70
71LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate,
72 int maxIterations) {
73 if (inPlaceUpdate)
74 *inPlaceUpdate = false;
75
76 // If this is a unique'd constant, return failure as we know that it has
77 // already been folded.
78 if (isFolderOwnedConstant(op)) {
79 // Check to see if we should rehoist, i.e. if a non-constant operation was
80 // inserted before this one.
81 Block *opBlock = op->getBlock();
82 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
83 op->moveBefore(&opBlock->front());
84 op->setLoc(erasedFoldedLocation);
85 }
86 return failure();
87 }
88
89 // Try to fold the operation.
91 if (failed(tryToFold(op, results, maxIterations)))
92 return failure();
93
94 // Check to see if the operation was just updated in place.
95 if (results.empty()) {
96 if (inPlaceUpdate)
97 *inPlaceUpdate = true;
98 if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
99 rewriter.getListener())) {
100 // Folding API does not notify listeners, so we have to notify manually.
101 rewriteListener->notifyOperationModified(op);
102 }
103 return success();
104 }
105
106 // Constant folding succeeded. Replace all of the result values and erase the
107 // operation.
108 notifyRemoval(op);
109 rewriter.replaceOp(op, results);
110 return success();
111}
112
114 Block *opBlock = op->getBlock();
115
116 // If this is a constant we unique'd, we don't need to insert, but we can
117 // check to see if we should rehoist it.
118 if (isFolderOwnedConstant(op)) {
119 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
120 op->moveBefore(&opBlock->front());
121 op->setLoc(erasedFoldedLocation);
122 }
123 return true;
124 }
125
126 // Get the constant value of the op if necessary.
127 if (!constValue) {
128 matchPattern(op, m_Constant(&constValue));
129 assert(constValue && "expected `op` to be a constant");
130 } else {
131 // Ensure that the provided constant was actually correct.
132#ifndef NDEBUG
133 Attribute expectedValue;
134 matchPattern(op, m_Constant(&expectedValue));
135 assert(
136 expectedValue == constValue &&
137 "provided constant value was not the expected value of the constant");
138#endif
139 }
140
141 // Check for an existing constant operation for the attribute value.
142 Region *insertRegion = getInsertionRegion(interfaces, opBlock);
143 auto &uniquedConstants = foldScopes[insertRegion];
144 Operation *&folderConstOp = uniquedConstants[std::make_tuple(
145 op->getDialect(), constValue, *op->result_type_begin())];
146
147 // If there is an existing constant, replace `op`.
148 if (folderConstOp) {
149 notifyRemoval(op);
150 rewriter.replaceOp(op, folderConstOp->getResults());
151 folderConstOp->setLoc(erasedFoldedLocation);
152 return false;
153 }
154
155 // Otherwise, we insert `op`. If `op` is in the insertion block and is either
156 // already at the front of the block, or the previous operation is already a
157 // constant we unique'd (i.e. one we inserted), then we don't need to do
158 // anything. Otherwise, we move the constant to the insertion block.
159 // The location info is erased if the constant is moved to a different block.
160 Block *insertBlock = &insertRegion->front();
161 if (opBlock != insertBlock) {
162 op->moveBefore(&insertBlock->front());
163 op->setLoc(erasedFoldedLocation);
164 } else if (&insertBlock->front() != op &&
165 !isFolderOwnedConstant(op->getPrevNode())) {
166 op->moveBefore(&insertBlock->front());
167 }
168
169 folderConstOp = op;
170 referencedDialects[op].push_back(op->getDialect());
171 return true;
172}
173
174/// Notifies that the given constant `op` should be remove from this
175/// OperationFolder's internal bookkeeping.
177 // Check to see if this operation is uniqued within the folder.
178 auto it = referencedDialects.find(op);
179 if (it == referencedDialects.end())
180 return;
181
182 // Get the constant value for this operation, this is the value that was used
183 // to unique the operation internally.
184 Attribute constValue;
185 matchPattern(op, m_Constant(&constValue));
186 assert(constValue);
187
188 // Get the constant map that this operation was uniqued in.
189 auto &uniquedConstants =
190 foldScopes[getInsertionRegion(interfaces, op->getBlock())];
191
192 // Erase all of the references to this operation.
193 auto type = op->getResult(0).getType();
194 for (auto *dialect : it->second)
195 uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
196 referencedDialects.erase(it);
197}
198
199/// Clear out any constants cached inside of the folder.
201 foldScopes.clear();
202 referencedDialects.clear();
203}
204
205/// Get or create a constant using the given builder. On success this returns
206/// the constant operation, nullptr otherwise.
208 Attribute value, Type type) {
209 // Find an insertion point for the constant.
210 auto *insertRegion = getInsertionRegion(interfaces, block);
211 auto &entry = insertRegion->front();
212 rewriter.setInsertionPointToStart(&entry);
213
214 // Get the constant map for the insertion region of this operation.
215 // Use erased location since the op is being built at the front of block.
216 auto &uniquedConstants = foldScopes[insertRegion];
217 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
218 type, erasedFoldedLocation);
219 return constOp ? constOp->getResult(0) : Value();
220}
221
222bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
223 return referencedDialects.count(op);
224}
225
226/// Tries to perform folding on the given `op`. If successful, populates
227/// `results` with the results of the folding.
228LogicalResult OperationFolder::tryToFold(Operation *op,
229 SmallVectorImpl<Value> &results,
230 int maxIterations) {
232 if (failed(op->fold(foldResults)))
233 return failure();
234 int count = 1;
235 do {
236 LDBG() << "Folded in place #" << count
237 << " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
238 } while (count++ < maxIterations && foldResults.empty() &&
239 succeeded(op->fold(foldResults)));
240
241 if (failed(processFoldResults(op, results, foldResults)))
242 return failure();
243 return success();
244}
245
246LogicalResult
247OperationFolder::processFoldResults(Operation *op,
248 SmallVectorImpl<Value> &results,
249 ArrayRef<OpFoldResult> foldResults) {
250 // Check to see if the operation was just updated in place.
251 if (foldResults.empty())
252 return success();
253 assert(foldResults.size() == op->getNumResults());
254
255 // Create a builder to insert new operations into the entry block of the
256 // insertion region.
257 auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
258 auto &entry = insertRegion->front();
259 rewriter.setInsertionPointToStart(&entry);
260
261 // Get the constant map for the insertion region of this operation.
262 auto &uniquedConstants = foldScopes[insertRegion];
263
264 // Create the result constants and replace the results.
265 auto *dialect = op->getDialect();
266 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
267 assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
268
269 // Check if the result was an SSA value.
270 if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
271 results.emplace_back(repl);
272 continue;
273 }
274
275 // Check to see if there is a canonicalized version of this constant.
276 auto res = op->getResult(i);
277 Attribute attrRepl = cast<Attribute>(foldResults[i]);
278 if (auto *constOp =
279 tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
280 res.getType(), erasedFoldedLocation)) {
281 // Ensure that this constant dominates the operation we are replacing it
282 // with. This may not automatically happen if the operation being folded
283 // was inserted before the constant within the insertion block.
284 Block *opBlock = op->getBlock();
285 if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
286 constOp->moveBefore(&opBlock->front());
287
288 results.push_back(constOp->getResult(0));
289 continue;
290 }
291 // If materialization fails, cleanup any operations generated for the
292 // previous results and return failure.
293 for (Operation &op : llvm::make_early_inc_range(
294 llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
295 notifyRemoval(&op);
296 rewriter.eraseOp(&op);
297 }
298
299 results.clear();
300 return failure();
301 }
302
303 return success();
304}
305
306/// Try to get or create a new constant entry. On success this returns the
307/// constant operation value, nullptr otherwise.
308Operation *
309OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
310 Dialect *dialect, Attribute value,
311 Type type, Location loc) {
312 // Check if an existing mapping already exists.
313 auto constKey = std::make_tuple(dialect, value, type);
314 Operation *&constOp = uniquedConstants[constKey];
315 if (constOp) {
316 if (loc != constOp->getLoc())
317 constOp->setLoc(erasedFoldedLocation);
318 return constOp;
319 }
320
321 // If one doesn't exist, try to materialize one.
322 if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
323 return nullptr;
324
325 // Check to see if the generated constant is in the expected dialect.
326 auto *newDialect = constOp->getDialect();
327 if (newDialect == dialect) {
328 referencedDialects[constOp].push_back(dialect);
329 return constOp;
330 }
331
332 // If it isn't, then we also need to make sure that the mapping for the new
333 // dialect is valid.
334 auto newKey = std::make_tuple(newDialect, value, type);
335
336 // If an existing operation in the new dialect already exists, delete the
337 // materialized operation in favor of the existing one.
338 if (auto *existingOp = uniquedConstants.lookup(newKey)) {
339 notifyRemoval(constOp);
340 rewriter.eraseOp(constOp);
341 referencedDialects[existingOp].push_back(dialect);
342 if (loc != existingOp->getLoc())
343 existingOp->setLoc(erasedFoldedLocation);
344 return constOp = existingOp;
345 }
346
347 // Otherwise, update the new dialect to the materialized operation.
348 referencedDialects[constOp].assign({dialect, newDialect});
349 auto newIt = uniquedConstants.insert({newKey, constOp});
350 return newIt.first->second;
351}
return success()
static Region * getInsertionRegion(DialectInterfaceCollection< DialectFoldInterface > &interfaces, Block *insertionBlock)
Given an operation, find the parent region that folded constants should be inserted into.
Definition FoldUtils.cpp:26
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition FoldUtils.cpp:51
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition Block.cpp:54
A collection of dialect interfaces within a context, for a given concrete interface type.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:445
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, Type type)
Get or create a constant for use in the specified block.
void clear()
Clear out any constants cached inside of the folder.
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate=nullptr, int maxIterations=INT_MAX)
Tries to perform folding on the given op, including unifying deduplicated constants.
Definition FoldUtils.cpp:71
void notifyRemoval(Operation *op)
Notifies that the given constant op should be remove from this OperationFolder's internal bookkeeping...
bool insertKnownConstant(Operation *op, Attribute constValue={})
Tries to fold a pre-existing constant operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_iterator result_type_begin()
Definition Operation.h:426
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369