MLIR  20.0.0git
ConstantPropagationAnalysis.cpp
Go to the documentation of this file.
1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/OpDefinition.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Casting.h"
18 #include "llvm/Support/Debug.h"
19 #include <cassert>
20 
21 #define DEBUG_TYPE "constant-propagation"
22 
23 using namespace mlir;
24 using namespace mlir::dataflow;
25 
26 //===----------------------------------------------------------------------===//
27 // ConstantValue
28 //===----------------------------------------------------------------------===//
29 
30 void ConstantValue::print(raw_ostream &os) const {
31  if (isUninitialized()) {
32  os << "<UNINITIALIZED>";
33  return;
34  }
35  if (getConstantValue() == nullptr) {
36  os << "<UNKNOWN>";
37  return;
38  }
39  return getConstantValue().print(os);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // SparseConstantPropagation
44 //===----------------------------------------------------------------------===//
45 
47  Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
48  ArrayRef<Lattice<ConstantValue> *> results) {
49  LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
50 
51  // Don't try to simulate the results of a region operation as we can't
52  // guarantee that folding will be out-of-place. We don't allow in-place
53  // folds as the desire here is for simulated execution, and not general
54  // folding.
55  if (op->getNumRegions()) {
56  setAllToEntryStates(results);
57  return success();
58  }
59 
60  SmallVector<Attribute, 8> constantOperands;
61  constantOperands.reserve(op->getNumOperands());
62  for (auto *operandLattice : operands) {
63  if (operandLattice->getValue().isUninitialized())
64  return success();
65  constantOperands.push_back(operandLattice->getValue().getConstantValue());
66  }
67 
68  // Save the original operands and attributes just in case the operation
69  // folds in-place. The constant passed in may not correspond to the real
70  // runtime value, so in-place updates are not allowed.
71  SmallVector<Value, 8> originalOperands(op->getOperands());
72  DictionaryAttr originalAttrs = op->getAttrDictionary();
73 
74  // Simulate the result of folding this operation to a constant. If folding
75  // fails or was an in-place fold, mark the results as overdefined.
76  SmallVector<OpFoldResult, 8> foldResults;
77  foldResults.reserve(op->getNumResults());
78  if (failed(op->fold(constantOperands, foldResults))) {
79  setAllToEntryStates(results);
80  return success();
81  }
82 
83  // If the folding was in-place, mark the results as overdefined and reset
84  // the operation. We don't allow in-place folds as the desire here is for
85  // simulated execution, and not general folding.
86  if (foldResults.empty()) {
87  op->setOperands(originalOperands);
88  op->setAttrs(originalAttrs);
89  setAllToEntryStates(results);
90  return success();
91  }
92 
93  // Merge the fold results into the lattice for this operation.
94  assert(foldResults.size() == op->getNumResults() && "invalid result size");
95  for (const auto it : llvm::zip(results, foldResults)) {
96  Lattice<ConstantValue> *lattice = std::get<0>(it);
97 
98  // Merge in the result of the fold, either a constant or a value.
99  OpFoldResult foldResult = std::get<1>(it);
100  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
101  LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
102  propagateIfChanged(lattice,
103  lattice->join(ConstantValue(attr, op->getDialect())));
104  } else {
105  LLVM_DEBUG(llvm::dbgs()
106  << "Folded to value: " << cast<Value>(foldResult) << "\n");
108  lattice, *getLatticeElement(cast<Value>(foldResult)));
109  }
110  }
111  return success();
112 }
113 
115  Lattice<ConstantValue> *lattice) {
116  propagateIfChanged(lattice,
118 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:632
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This lattice value represents a known constant value of a lattice.
Attribute getConstantValue() const
Get the constant value. Returns null if no value was determined.
bool isUninitialized() const
Whether the state is uninitialized.
void print(raw_ostream &os) const
Print the constant value.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
This class represents a lattice holding a specific value of type ValueT.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
void setToEntryState(Lattice< ConstantValue > *lattice) override
Set the given lattice element(s) at control flow entry point(s).
LogicalResult visitOperation(Operation *op, ArrayRef< const Lattice< ConstantValue > * > operands, ArrayRef< Lattice< ConstantValue > * > results) override
Visit an operation with the lattices of its operands.
Lattice< ConstantValue > * getLatticeElement(Value value) override
Get the lattice element for a value.
void setAllToEntryStates(ArrayRef< Lattice< ConstantValue > * > lattices)
Include the generated interface declarations.