MLIR 22.0.0git
ConstantPropagationAnalysis.cpp
Go to the documentation of this file.
1//===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Support/Casting.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/DebugLog.h"
20#include <cassert>
21
22#define DEBUG_TYPE "constant-propagation"
23
24using namespace mlir;
25using namespace mlir::dataflow;
26
27//===----------------------------------------------------------------------===//
28// ConstantValue
29//===----------------------------------------------------------------------===//
30
32 if (isUninitialized()) {
33 os << "<UNINITIALIZED>";
34 return;
35 }
36 if (getConstantValue() == nullptr) {
37 os << "<UNKNOWN>";
38 return;
39 }
40 return getConstantValue().print(os);
41}
42
43//===----------------------------------------------------------------------===//
44// SparseConstantPropagation
45//===----------------------------------------------------------------------===//
46
48 Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
50 LDBG() << "SCP: Visiting operation: " << *op;
51
52 // Don't try to simulate the results of a region operation as we can't
53 // guarantee that folding will be out-of-place. We don't allow in-place
54 // folds as the desire here is for simulated execution, and not general
55 // folding.
56 if (op->getNumRegions()) {
57 setAllToEntryStates(results);
58 return success();
59 }
60
61 SmallVector<Attribute, 8> constantOperands;
62 constantOperands.reserve(op->getNumOperands());
63 for (auto *operandLattice : operands) {
64 if (operandLattice->getValue().isUninitialized())
65 return success();
66 constantOperands.push_back(operandLattice->getValue().getConstantValue());
67 }
68
69 // Save the original operands and attributes just in case the operation
70 // folds in-place. The constant passed in may not correspond to the real
71 // runtime value, so in-place updates are not allowed.
72 SmallVector<Value, 8> originalOperands(op->getOperands());
73 DictionaryAttr originalAttrs = op->getAttrDictionary();
74
75 // Simulate the result of folding this operation to a constant. If folding
76 // fails or was an in-place fold, mark the results as overdefined.
78 foldResults.reserve(op->getNumResults());
79 if (failed(op->fold(constantOperands, foldResults))) {
80 setAllToEntryStates(results);
81 return success();
82 }
83
84 // If the folding was in-place, mark the results as overdefined and reset
85 // the operation. We don't allow in-place folds as the desire here is for
86 // simulated execution, and not general folding.
87 if (foldResults.empty()) {
88 op->setOperands(originalOperands);
89 op->setAttrs(originalAttrs);
90 setAllToEntryStates(results);
91 return success();
92 }
93
94 // Merge the fold results into the lattice for this operation.
95 assert(foldResults.size() == op->getNumResults() && "invalid result size");
96 for (const auto it : llvm::zip(results, foldResults)) {
97 Lattice<ConstantValue> *lattice = std::get<0>(it);
98
99 // Merge in the result of the fold, either a constant or a value.
100 OpFoldResult foldResult = std::get<1>(it);
101 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
102 LDBG() << "Folded to constant: " << attr;
103 propagateIfChanged(lattice,
104 lattice->join(ConstantValue(attr, op->getDialect())));
105 } else {
106 LDBG() << "Folded to value: " << cast<Value>(foldResult);
108 lattice, *getLatticeElement(cast<Value>(foldResult)));
109 }
110 }
111 return success();
112}
113
return success()
Attributes are known-constant values of operations.
Definition Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
unsigned getNumOperands()
Definition Operation.h:346
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This lattice value represents a known constant value of a lattice.
Attribute getConstantValue() const
Get the constant value. Returns null if no value was determined.
bool isUninitialized() const
Whether the state is uninitialized.
void print(raw_ostream &os) const
Print the constant value.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
This class represents a lattice holding a specific value of type ValueT.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
void setToEntryState(Lattice< ConstantValue > *lattice) override
Set the given lattice element(s) at control flow entry point(s).
LogicalResult visitOperation(Operation *op, ArrayRef< const Lattice< ConstantValue > * > operands, ArrayRef< Lattice< ConstantValue > * > results) override
Visit an operation with the lattices of its operands.
Lattice< ConstantValue > * getLatticeElement(Value value) override
void setAllToEntryStates(ArrayRef< Lattice< ConstantValue > * > lattices)
Include the generated interface declarations.