MLIR  14.0.0git
FoldUtils.cpp
Go to the documentation of this file.
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
26  Block *insertionBlock) {
27  while (Region *region = insertionBlock->getParent()) {
28  // Insert in this region for any of the following scenarios:
29  // * The parent is unregistered, or is known to be isolated from above.
30  // * The parent is a top-level operation.
31  auto *parentOp = region->getParentOp();
32  if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33  !parentOp->getBlock())
34  return region;
35 
36  // Otherwise, check if this region is a desired insertion region.
37  auto *interface = interfaces.getInterfaceFor(parentOp);
38  if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39  return region;
40 
41  // Traverse up the parent looking for an insertion region.
42  insertionBlock = parentOp->getBlock();
43  }
44  llvm_unreachable("expected valid insertion region");
45 }
46 
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51  Attribute value, Type type,
52  Location loc) {
53  auto insertPt = builder.getInsertionPoint();
54  (void)insertPt;
55 
56  // Ask the dialect to materialize a constant operation for this value.
57  if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58  assert(insertPt == builder.getInsertionPoint());
59  assert(matchPattern(constOp, m_Constant()));
60  return constOp;
61  }
62 
63  return nullptr;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69 
71  Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
72  function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
73  if (inPlaceUpdate)
74  *inPlaceUpdate = false;
75 
76  // If this is a unique'd constant, return failure as we know that it has
77  // already been folded.
78  if (referencedDialects.count(op))
79  return failure();
80 
81  // Try to fold the operation.
82  SmallVector<Value, 8> results;
83  OpBuilder builder(op);
84  if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
85  return failure();
86 
87  // Check to see if the operation was just updated in place.
88  if (results.empty()) {
89  if (inPlaceUpdate)
90  *inPlaceUpdate = true;
91  return success();
92  }
93 
94  // Constant folding succeeded. We will start replacing this op's uses and
95  // erase this op. Invoke the callback provided by the caller to perform any
96  // pre-replacement action.
97  if (preReplaceAction)
98  preReplaceAction(op);
99 
100  // Replace all of the result values and erase the operation.
101  for (unsigned i = 0, e = results.size(); i != e; ++i)
102  op->getResult(i).replaceAllUsesWith(results[i]);
103  op->erase();
104  return success();
105 }
106 
107 /// Notifies that the given constant `op` should be remove from this
108 /// OperationFolder's internal bookkeeping.
110  // Check to see if this operation is uniqued within the folder.
111  auto it = referencedDialects.find(op);
112  if (it == referencedDialects.end())
113  return;
114 
115  // Get the constant value for this operation, this is the value that was used
116  // to unique the operation internally.
117  Attribute constValue;
118  matchPattern(op, m_Constant(&constValue));
119  assert(constValue);
120 
121  // Get the constant map that this operation was uniqued in.
122  auto &uniquedConstants =
123  foldScopes[getInsertionRegion(interfaces, op->getBlock())];
124 
125  // Erase all of the references to this operation.
126  auto type = op->getResult(0).getType();
127  for (auto *dialect : it->second)
128  uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
129  referencedDialects.erase(it);
130 }
131 
132 /// Clear out any constants cached inside of the folder.
134  foldScopes.clear();
135  referencedDialects.clear();
136 }
137 
138 /// Get or create a constant using the given builder. On success this returns
139 /// the constant operation, nullptr otherwise.
141  Attribute value, Type type,
142  Location loc) {
143  OpBuilder::InsertionGuard foldGuard(builder);
144 
145  // Use the builder insertion block to find an insertion point for the
146  // constant.
147  auto *insertRegion =
148  getInsertionRegion(interfaces, builder.getInsertionBlock());
149  auto &entry = insertRegion->front();
150  builder.setInsertionPoint(&entry, entry.begin());
151 
152  // Get the constant map for the insertion region of this operation.
153  auto &uniquedConstants = foldScopes[insertRegion];
154  Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
155  builder, value, type, loc);
156  return constOp ? constOp->getResult(0) : Value();
157 }
158 
159 /// Tries to perform folding on the given `op`. If successful, populates
160 /// `results` with the results of the folding.
162  OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
163  function_ref<void(Operation *)> processGeneratedConstants) {
164  SmallVector<Attribute, 8> operandConstants;
165  SmallVector<OpFoldResult, 8> foldResults;
166 
167  // If this is a commutative operation, move constants to be trailing operands.
168  if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
169  std::stable_partition(
170  op->getOpOperands().begin(), op->getOpOperands().end(),
171  [&](OpOperand &o) { return !matchPattern(o.get(), m_Constant()); });
172  }
173 
174  // Check to see if any operands to the operation is constant and whether
175  // the operation knows how to constant fold itself.
176  operandConstants.assign(op->getNumOperands(), Attribute());
177  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
178  matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
179 
180  // Attempt to constant fold the operation.
181  if (failed(op->fold(operandConstants, foldResults)))
182  return failure();
183 
184  // Check to see if the operation was just updated in place.
185  if (foldResults.empty())
186  return success();
187  assert(foldResults.size() == op->getNumResults());
188 
189  // Create a builder to insert new operations into the entry block of the
190  // insertion region.
191  auto *insertRegion =
192  getInsertionRegion(interfaces, builder.getInsertionBlock());
193  auto &entry = insertRegion->front();
194  OpBuilder::InsertionGuard foldGuard(builder);
195  builder.setInsertionPoint(&entry, entry.begin());
196 
197  // Get the constant map for the insertion region of this operation.
198  auto &uniquedConstants = foldScopes[insertRegion];
199 
200  // Create the result constants and replace the results.
201  auto *dialect = op->getDialect();
202  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
203  assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
204 
205  // Check if the result was an SSA value.
206  if (auto repl = foldResults[i].dyn_cast<Value>()) {
207  if (repl.getType() != op->getResult(i).getType())
208  return failure();
209  results.emplace_back(repl);
210  continue;
211  }
212 
213  // Check to see if there is a canonicalized version of this constant.
214  auto res = op->getResult(i);
215  Attribute attrRepl = foldResults[i].get<Attribute>();
216  if (auto *constOp =
217  tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
218  res.getType(), op->getLoc())) {
219  // Ensure that this constant dominates the operation we are replacing it
220  // with. This may not automatically happen if the operation being folded
221  // was inserted before the constant within the insertion block.
222  Block *opBlock = op->getBlock();
223  if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
224  constOp->moveBefore(&opBlock->front());
225 
226  results.push_back(constOp->getResult(0));
227  continue;
228  }
229  // If materialization fails, cleanup any operations generated for the
230  // previous results and return failure.
231  for (Operation &op : llvm::make_early_inc_range(
232  llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
233  notifyRemoval(&op);
234  op.erase();
235  }
236  return failure();
237  }
238 
239  // Process any newly generated operations.
240  if (processGeneratedConstants) {
241  for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
242  processGeneratedConstants(&*i);
243  }
244 
245  return success();
246 }
247 
248 /// Try to get or create a new constant entry. On success this returns the
249 /// constant operation value, nullptr otherwise.
250 Operation *OperationFolder::tryGetOrCreateConstant(
251  ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
252  Attribute value, Type type, Location loc) {
253  // Check if an existing mapping already exists.
254  auto constKey = std::make_tuple(dialect, value, type);
255  Operation *&constOp = uniquedConstants[constKey];
256  if (constOp)
257  return constOp;
258 
259  // If one doesn't exist, try to materialize one.
260  if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
261  return nullptr;
262 
263  // Check to see if the generated constant is in the expected dialect.
264  auto *newDialect = constOp->getDialect();
265  if (newDialect == dialect) {
266  referencedDialects[constOp].push_back(dialect);
267  return constOp;
268  }
269 
270  // If it isn't, then we also need to make sure that the mapping for the new
271  // dialect is valid.
272  auto newKey = std::make_tuple(newDialect, value, type);
273 
274  // If an existing operation in the new dialect already exists, delete the
275  // materialized operation in favor of the existing one.
276  if (auto *existingOp = uniquedConstants.lookup(newKey)) {
277  constOp->erase();
278  referencedDialects[existingOp].push_back(dialect);
279  return constOp = existingOp;
280  }
281 
282  // Otherwise, update the new dialect to the materialized operation.
283  referencedDialects[constOp].assign({dialect, newDialect});
284  auto newIt = uniquedConstants.insert({newKey, constOp});
285  return newIt.first->second;
286 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:373
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This class adds property that the operation is commutative.
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block...
Definition: Block.cpp:47
Value getOperand(unsigned idx)
Definition: Operation.h:219
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:215
A collection of dialect interfaces within a context, for a given concrete interface type...
Operation & front()
Definition: Block.h:144
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:161
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:252
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, Attribute value, Type type, Location loc)
Get or create a constant using the given builder.
Definition: FoldUtils.cpp:140
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:42
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
Type getType() const
Return the type of this value.
Definition: Value.h:117
This class provides the API for ops that are known to be isolated from above.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:103
This class represents an operand of an operation.
Definition: Value.h:249
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:496
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
void notifyRemoval(Operation *op)
Notifies that the given constant op should be remove from this OperationFolder&#39;s internal bookkeeping...
Definition: FoldUtils.cpp:109
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:376
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:87
This class helps build Operations.
Definition: Builders.h:177
void clear()
Clear out any constants cached inside of the folder.
Definition: FoldUtils.cpp:133
LogicalResult tryToFold(Operation *op, function_ref< void(Operation *)> processGeneratedConstants=nullptr, function_ref< void(Operation *)> preReplaceAction=nullptr, bool *inPlaceUpdate=nullptr)
Tries to perform folding on the given op, including unifying deduplicated constants.
Definition: FoldUtils.cpp:70
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
static Region * getInsertionRegion(DialectInterfaceCollection< DialectFoldInterface > &interfaces, Block *insertionBlock)
Given an operation, find the parent region that folded constants should be inserted into...
Definition: FoldUtils.cpp:25