MLIR  16.0.0git
SCCP.cpp
Go to the documentation of this file.
1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Transforms/Passes.h"
18 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Dialect.h"
24 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCCP
29 #include "mlir/Transforms/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::dataflow;
34 
35 //===----------------------------------------------------------------------===//
36 // SCCP Rewrites
37 //===----------------------------------------------------------------------===//
38 
39 /// Replace the given value with a constant if the corresponding lattice
40 /// represents a constant. Returns success if the value was replaced, failure
41 /// otherwise.
43  OpBuilder &builder,
44  OperationFolder &folder, Value value) {
45  auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
46  if (!lattice || lattice->getValue().isUninitialized())
47  return failure();
48  const ConstantValue &latticeValue = lattice->getValue();
49  if (!latticeValue.getConstantValue())
50  return failure();
51 
52  // Attempt to materialize a constant for the given value.
53  Dialect *dialect = latticeValue.getConstantDialect();
54  Value constant = folder.getOrCreateConstant(builder, dialect,
55  latticeValue.getConstantValue(),
56  value.getType(), value.getLoc());
57  if (!constant)
58  return failure();
59 
60  value.replaceAllUsesWith(constant);
61  return success();
62 }
63 
64 /// Rewrite the given regions using the computing analysis. This replaces the
65 /// uses of all values that have been computed to be constant, and erases as
66 /// many newly dead operations.
67 static void rewrite(DataFlowSolver &solver, MLIRContext *context,
68  MutableArrayRef<Region> initialRegions) {
69  SmallVector<Block *> worklist;
70  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
71  for (Region &region : regions)
72  for (Block &block : llvm::reverse(region))
73  worklist.push_back(&block);
74  };
75 
76  // An operation folder used to create and unique constants.
77  OperationFolder folder(context);
78  OpBuilder builder(context);
79 
80  addToWorklist(initialRegions);
81  while (!worklist.empty()) {
82  Block *block = worklist.pop_back_val();
83 
84  for (Operation &op : llvm::make_early_inc_range(*block)) {
85  builder.setInsertionPoint(&op);
86 
87  // Replace any result with constants.
88  bool replacedAll = op.getNumResults() != 0;
89  for (Value res : op.getResults())
90  replacedAll &=
91  succeeded(replaceWithConstant(solver, builder, folder, res));
92 
93  // If all of the results of the operation were replaced, try to erase
94  // the operation completely.
95  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
96  assert(op.use_empty() && "expected all uses to be replaced");
97  op.erase();
98  continue;
99  }
100 
101  // Add any the regions of this operation to the worklist.
102  addToWorklist(op.getRegions());
103  }
104 
105  // Replace any block arguments with constants.
106  builder.setInsertionPointToStart(block);
107  for (BlockArgument arg : block->getArguments())
108  (void)replaceWithConstant(solver, builder, folder, arg);
109  }
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // SCCP Pass
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 struct SCCP : public impl::SCCPBase<SCCP> {
118  void runOnOperation() override;
119 };
120 } // namespace
121 
122 void SCCP::runOnOperation() {
123  Operation *op = getOperation();
124 
125  DataFlowSolver solver;
126  solver.load<DeadCodeAnalysis>();
128  if (failed(solver.initializeAndRun(op)))
129  return signalPassFailure();
130  rewrite(solver, op->getContext(), op->getRegions());
131 }
132 
133 std::unique_ptr<Pass> mlir::createSCCPPass() {
134  return std::make_unique<SCCP>();
135 }
static constexpr const bool value
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &builder, OperationFolder &folder, Value value)
Replace the given value with a constant if the corresponding lattice represents a constant.
Definition: SCCP.cpp:42
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:76
The general data-flow analysis solver.
const StateT * lookupState(PointT point) const
Lookup an analysis state for the given program point.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
A utility class for folding operations, and unifying duplicated constants generated along the way.
Definition: FoldUtils.h:32
Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, Attribute value, Type type, Location loc)
Get or create a constant using the given builder.
Definition: FoldUtils.cpp:199
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
This lattice value represents a known constant value of a lattice.
Attribute getConstantValue() const
Get the constant value. Returns null if no value was determined.
Dialect * getConstantDialect() const
Get the dialect instance that can be used to materialize the constant.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createSCCPPass()
Creates a pass which performs sparse conditional constant propagation over nested operations.
Definition: SCCP.cpp:133
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26