MLIR  19.0.0git
SCCP.cpp
Go to the documentation of this file.
1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Transforms/Passes.h"
18 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Dialect.h"
24 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCCP
29 #include "mlir/Transforms/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::dataflow;
34 
35 //===----------------------------------------------------------------------===//
36 // SCCP Rewrites
37 //===----------------------------------------------------------------------===//
38 
39 /// Replace the given value with a constant if the corresponding lattice
40 /// represents a constant. Returns success if the value was replaced, failure
41 /// otherwise.
43  OpBuilder &builder,
44  OperationFolder &folder, Value value) {
45  auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
46  if (!lattice || lattice->getValue().isUninitialized())
47  return failure();
48  const ConstantValue &latticeValue = lattice->getValue();
49  if (!latticeValue.getConstantValue())
50  return failure();
51 
52  // Attempt to materialize a constant for the given value.
53  Dialect *dialect = latticeValue.getConstantDialect();
54  Value constant = folder.getOrCreateConstant(
55  builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
56  value.getType());
57  if (!constant)
58  return failure();
59 
60  value.replaceAllUsesWith(constant);
61  return success();
62 }
63 
64 /// Rewrite the given regions using the computing analysis. This replaces the
65 /// uses of all values that have been computed to be constant, and erases as
66 /// many newly dead operations.
67 static void rewrite(DataFlowSolver &solver, MLIRContext *context,
68  MutableArrayRef<Region> initialRegions) {
69  SmallVector<Block *> worklist;
70  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
71  for (Region &region : regions)
72  for (Block &block : llvm::reverse(region))
73  worklist.push_back(&block);
74  };
75 
76  // An operation folder used to create and unique constants.
77  OperationFolder folder(context);
78  OpBuilder builder(context);
79 
80  addToWorklist(initialRegions);
81  while (!worklist.empty()) {
82  Block *block = worklist.pop_back_val();
83 
84  for (Operation &op : llvm::make_early_inc_range(*block)) {
85  builder.setInsertionPoint(&op);
86 
87  // Replace any result with constants.
88  bool replacedAll = op.getNumResults() != 0;
89  for (Value res : op.getResults())
90  replacedAll &=
91  succeeded(replaceWithConstant(solver, builder, folder, res));
92 
93  // If all of the results of the operation were replaced, try to erase
94  // the operation completely.
95  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
96  assert(op.use_empty() && "expected all uses to be replaced");
97  op.erase();
98  continue;
99  }
100 
101  // Add any the regions of this operation to the worklist.
102  addToWorklist(op.getRegions());
103  }
104 
105  // Replace any block arguments with constants.
106  builder.setInsertionPointToStart(block);
107  for (BlockArgument arg : block->getArguments())
108  (void)replaceWithConstant(solver, builder, folder, arg);
109  }
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // SCCP Pass
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 struct SCCP : public impl::SCCPBase<SCCP> {
118  void runOnOperation() override;
119 };
120 } // namespace
121 
122 void SCCP::runOnOperation() {
123  Operation *op = getOperation();
124 
125  DataFlowSolver solver;
126  solver.load<DeadCodeAnalysis>();
128  if (failed(solver.initializeAndRun(op)))
129  return signalPassFailure();
130  rewrite(solver, op->getContext(), op->getRegions());
131 }
132 
133 std::unique_ptr<Pass> mlir::createSCCPPass() {
134  return std::make_unique<SCCP>();
135 }
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &builder, OperationFolder &folder, Value value)
Replace the given value with a constant if the corresponding lattice represents a constant.
Definition: SCCP.cpp:42
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:84
The general data-flow analysis solver.
const StateT * lookupState(PointT point) const
Lookup an analysis state for the given program point.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:33
Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, Type type)
Get or create a constant for use in the specified block.
Definition: FoldUtils.cpp:202
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:169
This lattice value represents a known constant value of a lattice.
Attribute getConstantValue() const
Get the constant value. Returns null if no value was determined.
Dialect * getConstantDialect() const
Get the dialect instance that can be used to materialize the constant.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createSCCPPass()
Creates a pass which performs sparse conditional constant propagation over nested operations.
Definition: SCCP.cpp:133
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26