MLIR  15.0.0git
SCCP.cpp
Go to the documentation of this file.
1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Dialect.h"
23 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/Passes.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "sccp"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // SCCP Analysis
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 struct SCCPLatticeValue {
38  SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
39  : constant(constant), constantDialect(dialect) {}
40 
41  /// The pessimistic state of SCCP is non-constant.
42  static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
43  return SCCPLatticeValue();
44  }
45  static SCCPLatticeValue getPessimisticValueState(Value value) {
46  return SCCPLatticeValue();
47  }
48 
49  /// Equivalence for SCCP only accounts for the constant, not the originating
50  /// dialect.
51  bool operator==(const SCCPLatticeValue &rhs) const {
52  return constant == rhs.constant;
53  }
54 
55  /// To join the state of two values, we simply check for equivalence.
56  static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
57  const SCCPLatticeValue &rhs) {
58  return lhs == rhs ? lhs : SCCPLatticeValue();
59  }
60 
61  /// The constant attribute value.
62  Attribute constant;
63 
64  /// The dialect the constant originated from. This is not used as part of the
65  /// key, and is only needed to materialize the held constant if necessary.
66  Dialect *constantDialect;
67 };
68 
69 struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
71  ~SCCPAnalysis() override = default;
72 
74  visitOperation(Operation *op,
75  ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
76 
77  LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n");
78 
79  // Don't try to simulate the results of a region operation as we can't
80  // guarantee that folding will be out-of-place. We don't allow in-place
81  // folds as the desire here is for simulated execution, and not general
82  // folding.
83  if (op->getNumRegions())
84  return markAllPessimisticFixpoint(op->getResults());
85 
86  SmallVector<Attribute> constantOperands(
87  llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
88  return value->getValue().constant;
89  }));
90 
91  // Save the original operands and attributes just in case the operation
92  // folds in-place. The constant passed in may not correspond to the real
93  // runtime value, so in-place updates are not allowed.
94  SmallVector<Value, 8> originalOperands(op->getOperands());
95  DictionaryAttr originalAttrs = op->getAttrDictionary();
96 
97  // Simulate the result of folding this operation to a constant. If folding
98  // fails or was an in-place fold, mark the results as overdefined.
99  SmallVector<OpFoldResult, 8> foldResults;
100  foldResults.reserve(op->getNumResults());
101  if (failed(op->fold(constantOperands, foldResults)))
102  return markAllPessimisticFixpoint(op->getResults());
103 
104  // If the folding was in-place, mark the results as overdefined and reset
105  // the operation. We don't allow in-place folds as the desire here is for
106  // simulated execution, and not general folding.
107  if (foldResults.empty()) {
108  op->setOperands(originalOperands);
109  op->setAttrs(originalAttrs);
110  return markAllPessimisticFixpoint(op->getResults());
111  }
112 
113  // Merge the fold results into the lattice for this operation.
114  assert(foldResults.size() == op->getNumResults() && "invalid result size");
115  Dialect *dialect = op->getDialect();
117  for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
119  getLatticeElement(op->getResult(i));
120 
121  // Merge in the result of the fold, either a constant or a value.
122  OpFoldResult foldResult = foldResults[i];
123  if (Attribute attr = foldResult.dyn_cast<Attribute>())
124  result |= lattice.join(SCCPLatticeValue(attr, dialect));
125  else
126  result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
127  }
128  return result;
129  }
130 
131  /// Implementation of `getSuccessorsForOperands` that uses constant operands
132  /// to potentially remove dead successors.
133  LogicalResult getSuccessorsForOperands(
134  BranchOpInterface branch,
136  SmallVectorImpl<Block *> &successors) final {
137  SmallVector<Attribute> constantOperands(
138  llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
139  return value->getValue().constant;
140  }));
141  if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
142  successors.push_back(singleSucc);
143  return success();
144  }
145  return failure();
146  }
147 
148  /// Implementation of `getSuccessorsForOperands` that uses constant operands
149  /// to potentially remove dead region successors.
150  void getSuccessorsForOperands(
151  RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
153  SmallVectorImpl<RegionSuccessor> &successors) final {
154  SmallVector<Attribute> constantOperands(
155  llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
156  return value->getValue().constant;
157  }));
158  branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
159  }
160 };
161 } // namespace
162 
163 //===----------------------------------------------------------------------===//
164 // SCCP Rewrites
165 //===----------------------------------------------------------------------===//
166 
167 /// Replace the given value with a constant if the corresponding lattice
168 /// represents a constant. Returns success if the value was replaced, failure
169 /// otherwise.
170 static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
171  OpBuilder &builder,
172  OperationFolder &folder, Value value) {
174  analysis.lookupLatticeElement(value);
175  if (!lattice)
176  return failure();
177  SCCPLatticeValue &latticeValue = lattice->getValue();
178  if (!latticeValue.constant)
179  return failure();
180 
181  // Attempt to materialize a constant for the given value.
182  Dialect *dialect = latticeValue.constantDialect;
183  Value constant = folder.getOrCreateConstant(
184  builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
185  if (!constant)
186  return failure();
187 
188  value.replaceAllUsesWith(constant);
189  return success();
190 }
191 
192 /// Rewrite the given regions using the computing analysis. This replaces the
193 /// uses of all values that have been computed to be constant, and erases as
194 /// many newly dead operations.
195 static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
196  MutableArrayRef<Region> initialRegions) {
197  SmallVector<Block *> worklist;
198  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
199  for (Region &region : regions)
200  for (Block &block : llvm::reverse(region))
201  worklist.push_back(&block);
202  };
203 
204  // An operation folder used to create and unique constants.
205  OperationFolder folder(context);
206  OpBuilder builder(context);
207 
208  addToWorklist(initialRegions);
209  while (!worklist.empty()) {
210  Block *block = worklist.pop_back_val();
211 
212  for (Operation &op : llvm::make_early_inc_range(*block)) {
213  builder.setInsertionPoint(&op);
214 
215  // Replace any result with constants.
216  bool replacedAll = op.getNumResults() != 0;
217  for (Value res : op.getResults())
218  replacedAll &=
219  succeeded(replaceWithConstant(analysis, builder, folder, res));
220 
221  // If all of the results of the operation were replaced, try to erase
222  // the operation completely.
223  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
224  assert(op.use_empty() && "expected all uses to be replaced");
225  op.erase();
226  continue;
227  }
228 
229  // Add any the regions of this operation to the worklist.
230  addToWorklist(op.getRegions());
231  }
232 
233  // Replace any block arguments with constants.
234  builder.setInsertionPointToStart(block);
235  for (BlockArgument arg : block->getArguments())
236  (void)replaceWithConstant(analysis, builder, folder, arg);
237  }
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // SCCP Pass
242 //===----------------------------------------------------------------------===//
243 
244 namespace {
245 struct SCCP : public SCCPBase<SCCP> {
246  void runOnOperation() override;
247 };
248 } // namespace
249 
250 void SCCP::runOnOperation() {
251  Operation *op = getOperation();
252 
253  SCCPAnalysis analysis(op->getContext());
254  analysis.run(op);
255  rewrite(analysis, op->getContext(), op->getRegions());
256 }
257 
258 std::unique_ptr<Pass> mlir::createSCCPPass() {
259  return std::make_unique<SCCP>();
260 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:478
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:302
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:475
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in &#39;operands&#39;.
Definition: Operation.cpp:200
Block represents an ordered list of Operations.
Definition: Block.h:29
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:336
This class represents a single result from folding an operation.
Definition: OpDefinition.h:229
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
This class represents a lattice holding a specific value of type ValueT.
static LogicalResult replaceWithConstant(SCCPAnalysis &analysis, OpBuilder &builder, OperationFolder &folder, Value value)
Replace the given value with a constant if the corresponding lattice represents a constant...
Definition: SCCP.cpp:170
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:162
static constexpr const bool value
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:154
std::unique_ptr< Pass > createSCCPPass()
Creates a pass which performs sparse conditional constant propagation over nested operations...
Definition: SCCP.cpp:258
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, Attribute value, Type type, Location loc)
Get or create a constant using the given builder.
Definition: FoldUtils.cpp:199
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:195
BlockArgListType getArguments()
Definition: Block.h:76
This class represents an argument of a Block.
Definition: Value.h:300
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:366
ChangeResult join(const detail::AbstractLatticeElement &rhs) final
Join the information contained in the &#39;rhs&#39; lattice into this lattice.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
A utility class for folding operations, and unifying duplicated constants generated along the way...
Definition: FoldUtils.h:32
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
Type getType() const
Return the type of this value.
Definition: Value.h:118
U dyn_cast() const
Definition: Attributes.h:124
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:158
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:496
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:328
ValueT & getValue()
Return the value held by this lattice.
ChangeResult
A result type used to indicate if a change happened.
void setAttrs(DictionaryAttr newAttrs)
Set the attribute dictionary on this operation.
Definition: Operation.h:369
result_range getResults()
Definition: Operation.h:339
This class helps build Operations.
Definition: Builders.h:184
This class provides a general forward dataflow analysis driver utilizing the lattice classes defined ...
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...