MLIR  19.0.0git
ConstantPropagationAnalysis.cpp
Go to the documentation of this file.
1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/OpDefinition.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/Support/Casting.h"
19 #include "llvm/Support/Debug.h"
20 #include <cassert>
21 
22 #define DEBUG_TYPE "constant-propagation"
23 
24 using namespace mlir;
25 using namespace mlir::dataflow;
26 
27 //===----------------------------------------------------------------------===//
28 // ConstantValue
29 //===----------------------------------------------------------------------===//
30 
31 void ConstantValue::print(raw_ostream &os) const {
32  if (isUninitialized()) {
33  os << "<UNINITIALIZED>";
34  return;
35  }
36  if (getConstantValue() == nullptr) {
37  os << "<UNKNOWN>";
38  return;
39  }
40  return getConstantValue().print(os);
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // SparseConstantPropagation
45 //===----------------------------------------------------------------------===//
46 
48  Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
49  ArrayRef<Lattice<ConstantValue> *> results) {
50  LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
51 
52  // Don't try to simulate the results of a region operation as we can't
53  // guarantee that folding will be out-of-place. We don't allow in-place
54  // folds as the desire here is for simulated execution, and not general
55  // folding.
56  if (op->getNumRegions()) {
57  setAllToEntryStates(results);
58  return;
59  }
60 
61  SmallVector<Attribute, 8> constantOperands;
62  constantOperands.reserve(op->getNumOperands());
63  for (auto *operandLattice : operands) {
64  if (operandLattice->getValue().isUninitialized())
65  return;
66  constantOperands.push_back(operandLattice->getValue().getConstantValue());
67  }
68 
69  // Save the original operands and attributes just in case the operation
70  // folds in-place. The constant passed in may not correspond to the real
71  // runtime value, so in-place updates are not allowed.
72  SmallVector<Value, 8> originalOperands(op->getOperands());
73  DictionaryAttr originalAttrs = op->getAttrDictionary();
74 
75  // Simulate the result of folding this operation to a constant. If folding
76  // fails or was an in-place fold, mark the results as overdefined.
77  SmallVector<OpFoldResult, 8> foldResults;
78  foldResults.reserve(op->getNumResults());
79  if (failed(op->fold(constantOperands, foldResults))) {
80  setAllToEntryStates(results);
81  return;
82  }
83 
84  // If the folding was in-place, mark the results as overdefined and reset
85  // the operation. We don't allow in-place folds as the desire here is for
86  // simulated execution, and not general folding.
87  if (foldResults.empty()) {
88  op->setOperands(originalOperands);
89  op->setAttrs(originalAttrs);
90  setAllToEntryStates(results);
91  return;
92  }
93 
94  // Merge the fold results into the lattice for this operation.
95  assert(foldResults.size() == op->getNumResults() && "invalid result size");
96  for (const auto it : llvm::zip(results, foldResults)) {
97  Lattice<ConstantValue> *lattice = std::get<0>(it);
98 
99  // Merge in the result of the fold, either a constant or a value.
100  OpFoldResult foldResult = std::get<1>(it);
101  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
102  LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
103  propagateIfChanged(lattice,
104  lattice->join(ConstantValue(attr, op->getDialect())));
105  } else {
106  LLVM_DEBUG(llvm::dbgs()
107  << "Folded to value: " << foldResult.get<Value>() << "\n");
109  lattice, *getLatticeElement(foldResult.get<Value>()));
110  }
111  }
112 }
113 
115  Lattice<ConstantValue> *lattice) {
116  propagateIfChanged(lattice,
118 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:632
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
unsigned getNumOperands()
Definition: Operation.h:341
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This lattice value represents a known constant value of a lattice.
Attribute getConstantValue() const
Get the constant value. Returns null if no value was determined.
bool isUninitialized() const
Whether the state is uninitialized.
void print(raw_ostream &os) const
Print the constant value.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
This class represents a lattice holding a specific value of type ValueT.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
void setToEntryState(Lattice< ConstantValue > *lattice) override
Set the given lattice element(s) at control flow entry point(s).
void visitOperation(Operation *op, ArrayRef< const Lattice< ConstantValue > * > operands, ArrayRef< Lattice< ConstantValue > * > results) override
Visit an operation with the lattices of its operands.
Lattice< ConstantValue > * getLatticeElement(Value value) override
Get the lattice element for a value.
void setAllToEntryStates(ArrayRef< Lattice< ConstantValue > * > lattices)
Include the generated interface declarations.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72