MLIR  16.0.0git
FoldUtils.cpp
Go to the documentation of this file.
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
26  Block *insertionBlock) {
27  while (Region *region = insertionBlock->getParent()) {
28  // Insert in this region for any of the following scenarios:
29  // * The parent is unregistered, or is known to be isolated from above.
30  // * The parent is a top-level operation.
31  auto *parentOp = region->getParentOp();
32  if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33  !parentOp->getBlock())
34  return region;
35 
36  // Otherwise, check if this region is a desired insertion region.
37  auto *interface = interfaces.getInterfaceFor(parentOp);
38  if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39  return region;
40 
41  // Traverse up the parent looking for an insertion region.
42  insertionBlock = parentOp->getBlock();
43  }
44  llvm_unreachable("expected valid insertion region");
45 }
46 
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51  Attribute value, Type type,
52  Location loc) {
53  auto insertPt = builder.getInsertionPoint();
54  (void)insertPt;
55 
56  // Ask the dialect to materialize a constant operation for this value.
57  if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58  assert(insertPt == builder.getInsertionPoint());
59  assert(matchPattern(constOp, m_Constant()));
60  return constOp;
61  }
62 
63  return nullptr;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69 
71  Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
72  function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
73  if (inPlaceUpdate)
74  *inPlaceUpdate = false;
75 
76  // If this is a unique'd constant, return failure as we know that it has
77  // already been folded.
78  if (isFolderOwnedConstant(op)) {
79  // Check to see if we should rehoist, i.e. if a non-constant operation was
80  // inserted before this one.
81  Block *opBlock = op->getBlock();
82  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
83  op->moveBefore(&opBlock->front());
84  return failure();
85  }
86 
87  // Try to fold the operation.
88  SmallVector<Value, 8> results;
89  OpBuilder builder(op);
90  if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
91  return failure();
92 
93  // Check to see if the operation was just updated in place.
94  if (results.empty()) {
95  if (inPlaceUpdate)
96  *inPlaceUpdate = true;
97  return success();
98  }
99 
100  // Constant folding succeeded. We will start replacing this op's uses and
101  // erase this op. Invoke the callback provided by the caller to perform any
102  // pre-replacement action.
103  if (preReplaceAction)
104  preReplaceAction(op);
105 
106  // Replace all of the result values and erase the operation.
107  for (unsigned i = 0, e = results.size(); i != e; ++i)
108  op->getResult(i).replaceAllUsesWith(results[i]);
109  op->erase();
110  return success();
111 }
112 
114  Block *opBlock = op->getBlock();
115 
116  // If this is a constant we unique'd, we don't need to insert, but we can
117  // check to see if we should rehoist it.
118  if (isFolderOwnedConstant(op)) {
119  if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
120  op->moveBefore(&opBlock->front());
121  return true;
122  }
123 
124  // Get the constant value of the op if necessary.
125  if (!constValue) {
126  matchPattern(op, m_Constant(&constValue));
127  assert(constValue && "expected `op` to be a constant");
128  } else {
129  // Ensure that the provided constant was actually correct.
130 #ifndef NDEBUG
131  Attribute expectedValue;
132  matchPattern(op, m_Constant(&expectedValue));
133  assert(
134  expectedValue == constValue &&
135  "provided constant value was not the expected value of the constant");
136 #endif
137  }
138 
139  // Check for an existing constant operation for the attribute value.
140  Region *insertRegion = getInsertionRegion(interfaces, opBlock);
141  auto &uniquedConstants = foldScopes[insertRegion];
142  Operation *&folderConstOp = uniquedConstants[std::make_tuple(
143  op->getDialect(), constValue, *op->result_type_begin())];
144 
145  // If there is an existing constant, replace `op`.
146  if (folderConstOp) {
147  op->replaceAllUsesWith(folderConstOp);
148  op->erase();
149  return false;
150  }
151 
152  // Otherwise, we insert `op`. If `op` is in the insertion block and is either
153  // already at the front of the block, or the previous operation is already a
154  // constant we unique'd (i.e. one we inserted), then we don't need to do
155  // anything. Otherwise, we move the constant to the insertion block.
156  Block *insertBlock = &insertRegion->front();
157  if (opBlock != insertBlock || (&insertBlock->front() != op &&
158  !isFolderOwnedConstant(op->getPrevNode())))
159  op->moveBefore(&insertBlock->front());
160 
161  folderConstOp = op;
162  referencedDialects[op].push_back(op->getDialect());
163  return true;
164 }
165 
166 /// Notifies that the given constant `op` should be remove from this
167 /// OperationFolder's internal bookkeeping.
169  // Check to see if this operation is uniqued within the folder.
170  auto it = referencedDialects.find(op);
171  if (it == referencedDialects.end())
172  return;
173 
174  // Get the constant value for this operation, this is the value that was used
175  // to unique the operation internally.
176  Attribute constValue;
177  matchPattern(op, m_Constant(&constValue));
178  assert(constValue);
179 
180  // Get the constant map that this operation was uniqued in.
181  auto &uniquedConstants =
182  foldScopes[getInsertionRegion(interfaces, op->getBlock())];
183 
184  // Erase all of the references to this operation.
185  auto type = op->getResult(0).getType();
186  for (auto *dialect : it->second)
187  uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
188  referencedDialects.erase(it);
189 }
190 
191 /// Clear out any constants cached inside of the folder.
193  foldScopes.clear();
194  referencedDialects.clear();
195 }
196 
197 /// Get or create a constant using the given builder. On success this returns
198 /// the constant operation, nullptr otherwise.
200  Attribute value, Type type,
201  Location loc) {
202  OpBuilder::InsertionGuard foldGuard(builder);
203 
204  // Use the builder insertion block to find an insertion point for the
205  // constant.
206  auto *insertRegion =
207  getInsertionRegion(interfaces, builder.getInsertionBlock());
208  auto &entry = insertRegion->front();
209  builder.setInsertionPoint(&entry, entry.begin());
210 
211  // Get the constant map for the insertion region of this operation.
212  auto &uniquedConstants = foldScopes[insertRegion];
213  Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
214  builder, value, type, loc);
215  return constOp ? constOp->getResult(0) : Value();
216 }
217 
218 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
219  return referencedDialects.count(op);
220 }
221 
222 /// Tries to perform folding on the given `op`. If successful, populates
223 /// `results` with the results of the folding.
225  OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
226  function_ref<void(Operation *)> processGeneratedConstants) {
227  SmallVector<Attribute, 8> operandConstants;
228 
229  // If this is a commutative operation, move constants to be trailing operands.
230  bool updatedOpOperands = false;
231  if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
232  auto isNonConstant = [&](OpOperand &o) {
233  return !matchPattern(o.get(), m_Constant());
234  };
235  auto *firstConstantIt =
236  llvm::find_if_not(op->getOpOperands(), isNonConstant);
237  auto *newConstantIt = std::stable_partition(
238  firstConstantIt, op->getOpOperands().end(), isNonConstant);
239 
240  // Remember if we actually moved anything.
241  updatedOpOperands = firstConstantIt != newConstantIt;
242  }
243 
244  // Check to see if any operands to the operation is constant and whether
245  // the operation knows how to constant fold itself.
246  operandConstants.assign(op->getNumOperands(), Attribute());
247  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
248  matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
249 
250  // Attempt to constant fold the operation. If we failed, check to see if we at
251  // least updated the operands of the operation. We treat this as an in-place
252  // fold.
253  SmallVector<OpFoldResult, 8> foldResults;
254  if (failed(op->fold(operandConstants, foldResults)) ||
255  failed(processFoldResults(builder, op, results, foldResults,
256  processGeneratedConstants)))
257  return success(updatedOpOperands);
258  return success();
259 }
260 
261 LogicalResult OperationFolder::processFoldResults(
262  OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
263  ArrayRef<OpFoldResult> foldResults,
264  function_ref<void(Operation *)> processGeneratedConstants) {
265  // Check to see if the operation was just updated in place.
266  if (foldResults.empty())
267  return success();
268  assert(foldResults.size() == op->getNumResults());
269 
270  // Create a builder to insert new operations into the entry block of the
271  // insertion region.
272  auto *insertRegion =
273  getInsertionRegion(interfaces, builder.getInsertionBlock());
274  auto &entry = insertRegion->front();
275  OpBuilder::InsertionGuard foldGuard(builder);
276  builder.setInsertionPoint(&entry, entry.begin());
277 
278  // Get the constant map for the insertion region of this operation.
279  auto &uniquedConstants = foldScopes[insertRegion];
280 
281  // Create the result constants and replace the results.
282  auto *dialect = op->getDialect();
283  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
284  assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
285 
286  // Check if the result was an SSA value.
287  if (auto repl = foldResults[i].dyn_cast<Value>()) {
288  if (repl.getType() != op->getResult(i).getType()) {
289  results.clear();
290  return failure();
291  }
292  results.emplace_back(repl);
293  continue;
294  }
295 
296  // Check to see if there is a canonicalized version of this constant.
297  auto res = op->getResult(i);
298  Attribute attrRepl = foldResults[i].get<Attribute>();
299  if (auto *constOp =
300  tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
301  res.getType(), op->getLoc())) {
302  // Ensure that this constant dominates the operation we are replacing it
303  // with. This may not automatically happen if the operation being folded
304  // was inserted before the constant within the insertion block.
305  Block *opBlock = op->getBlock();
306  if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
307  constOp->moveBefore(&opBlock->front());
308 
309  results.push_back(constOp->getResult(0));
310  continue;
311  }
312  // If materialization fails, cleanup any operations generated for the
313  // previous results and return failure.
314  for (Operation &op : llvm::make_early_inc_range(
315  llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
316  notifyRemoval(&op);
317  op.erase();
318  }
319  results.clear();
320  return failure();
321  }
322 
323  // Process any newly generated operations.
324  if (processGeneratedConstants) {
325  for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
326  processGeneratedConstants(&*i);
327  }
328 
329  return success();
330 }
331 
332 /// Try to get or create a new constant entry. On success this returns the
333 /// constant operation value, nullptr otherwise.
334 Operation *OperationFolder::tryGetOrCreateConstant(
335  ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
336  Attribute value, Type type, Location loc) {
337  // Check if an existing mapping already exists.
338  auto constKey = std::make_tuple(dialect, value, type);
339  Operation *&constOp = uniquedConstants[constKey];
340  if (constOp)
341  return constOp;
342 
343  // If one doesn't exist, try to materialize one.
344  if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
345  return nullptr;
346 
347  // Check to see if the generated constant is in the expected dialect.
348  auto *newDialect = constOp->getDialect();
349  if (newDialect == dialect) {
350  referencedDialects[constOp].push_back(dialect);
351  return constOp;
352  }
353 
354  // If it isn't, then we also need to make sure that the mapping for the new
355  // dialect is valid.
356  auto newKey = std::make_tuple(newDialect, value, type);
357 
358  // If an existing operation in the new dialect already exists, delete the
359  // materialized operation in favor of the existing one.
360  if (auto *existingOp = uniquedConstants.lookup(newKey)) {
361  constOp->erase();
362  referencedDialects[existingOp].push_back(dialect);
363  return constOp = existingOp;
364  }
365 
366  // Otherwise, update the new dialect to the materialized operation.
367  referencedDialects[constOp].assign({dialect, newDialect});
368  auto newIt = uniquedConstants.insert({newKey, constOp});
369  return newIt.first->second;
370 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static Region * getInsertionRegion(DialectInterfaceCollection< DialectFoldInterface > &interfaces, Block *insertionBlock)
Given an operation, find the parent region that folded constants should be inserted into.
Definition: FoldUtils.cpp:25
static constexpr const bool value
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation & front()
Definition: Block.h:142
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:47
A collection of dialect interfaces within a context, for a given concrete interface type.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:397
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:394
This class represents an operand of an operation.
Definition: Value.h:247
This class adds property that the operation is commutative.
This class provides the API for ops that are known to be isolated from above.
LogicalResult tryToFold(Operation *op, function_ref< void(Operation *)> processGeneratedConstants=nullptr, function_ref< void(Operation *)> preReplaceAction=nullptr, bool *inPlaceUpdate=nullptr)
Tries to perform folding on the given op, including unifying deduplicated constants.
Definition: FoldUtils.cpp:70
Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, Attribute value, Type type, Location loc)
Get or create a constant using the given builder.
Definition: FoldUtils.cpp:199
void clear()
Clear out any constants cached inside of the folder.
Definition: FoldUtils.cpp:192
void notifyRemoval(Operation *op)
Notifies that the given constant op should be remove from this OperationFolder's internal bookkeeping...
Definition: FoldUtils.cpp:168
bool insertKnownConstant(Operation *op, Attribute constValue={})
Tries to fold a pre-existing constant operation.
Definition: FoldUtils.cpp:113
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:490
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:532
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:151
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
unsigned getNumOperands()
Definition: Operation.h:263
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
result_type_iterator result_type_begin()
Definition: Operation.h:343
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:434
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:203
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:158
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:255
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26