MLIR  20.0.0git
SCCP.cpp
Go to the documentation of this file.
1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Transforms/Passes.h"
18 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Dialect.h"
24 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCCP
29 #include "mlir/Transforms/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::dataflow;
34 
35 //===----------------------------------------------------------------------===//
36 // SCCP Rewrites
37 //===----------------------------------------------------------------------===//
38 
39 /// Replace the given value with a constant if the corresponding lattice
40 /// represents a constant. Returns success if the value was replaced, failure
41 /// otherwise.
42 static LogicalResult replaceWithConstant(DataFlowSolver &solver,
43  OpBuilder &builder,
44  OperationFolder &folder, Value value) {
45  auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
46  if (!lattice || lattice->getValue().isUninitialized())
47  return failure();
48  const ConstantValue &latticeValue = lattice->getValue();
49  if (!latticeValue.getConstantValue())
50  return failure();
51 
52  // Attempt to materialize a constant for the given value.
53  Dialect *dialect = latticeValue.getConstantDialect();
54  Value constant = folder.getOrCreateConstant(
55  builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
56  value.getType());
57  if (!constant)
58  return failure();
59 
60  value.replaceAllUsesWith(constant);
61  return success();
62 }
63 
64 /// Rewrite the given regions using the computing analysis. This replaces the
65 /// uses of all values that have been computed to be constant, and erases as
66 /// many newly dead operations.
67 static void rewrite(DataFlowSolver &solver, MLIRContext *context,
68  MutableArrayRef<Region> initialRegions) {
69  SmallVector<Block *> worklist;
70  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
71  for (Region &region : regions)
72  for (Block &block : llvm::reverse(region))
73  worklist.push_back(&block);
74  };
75 
76  // An operation folder used to create and unique constants.
77  OperationFolder folder(context);
78  OpBuilder builder(context);
79 
80  addToWorklist(initialRegions);
81  while (!worklist.empty()) {
82  Block *block = worklist.pop_back_val();
83 
84  for (Operation &op : llvm::make_early_inc_range(*block)) {
85  builder.setInsertionPoint(&op);
86 
87  // Replace any result with constants.
88  bool replacedAll = op.getNumResults() != 0;
89  for (Value res : op.getResults())
90  replacedAll &=
91  succeeded(replaceWithConstant(solver, builder, folder, res));
92 
93  // If all of the results of the operation were replaced, try to erase
94  // the operation completely.
95  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
96  assert(op.use_empty() && "expected all uses to be replaced");
97  op.erase();
98  continue;
99  }
100 
101  // Add any the regions of this operation to the worklist.
102  addToWorklist(op.getRegions());
103  }
104 
105  // Replace any block arguments with constants.
106  builder.setInsertionPointToStart(block);
107  for (BlockArgument arg : block->getArguments())
108  (void)replaceWithConstant(solver, builder, folder, arg);
109  }
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // SCCP Pass
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 struct SCCP : public impl::SCCPBase<SCCP> {
118  void runOnOperation() override;
119 };
120 } // namespace
121 
122 void SCCP::runOnOperation() {
123  Operation *op = getOperation();
124 
125  DataFlowSolver solver;
126  solver.load<DeadCodeAnalysis>();
128  if (failed(solver.initializeAndRun(op)))
129  return signalPassFailure();
130  rewrite(solver, op->getContext(), op->getRegions());
131 }
132 
133 std::unique_ptr<Pass> mlir::createSCCPPass() {
134  return std::make_unique<SCCP>();
135 }
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &builder, OperationFolder &folder, Value value)
Replace the given value with a constant if the corresponding lattice represents a constant.
Definition: SCCP.cpp:42
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:451
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, Type type)
Get or create a constant for use in the specified block.
Definition: FoldUtils.cpp:202
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:173
This lattice value represents a known constant value of a lattice.
Attribute getConstantValue() const
Get the constant value. Returns null if no value was determined.
Dialect * getConstantDialect() const
Get the dialect instance that can be used to materialize the constant.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
Include the generated interface declarations.
std::unique_ptr< Pass > createSCCPPass()
Creates a pass which performs sparse conditional constant propagation over nested operations.
Definition: SCCP.cpp:133
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...