MLIR 22.0.0git
StridedMetadataRangeAnalysis.cpp
Go to the documentation of this file.
1//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file defines the dataflow analysis class for integer range inference
11// which is used in transformations over the `arith` dialect such as
12// branch elimination or signed->unsigned rewriting
13//
14//===----------------------------------------------------------------------===//
15
19#include "mlir/IR/Operation.h"
20#include "mlir/IR/Value.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24
25#define DEBUG_TYPE "strided-metadata-range-analysis"
26
27using namespace mlir;
28using namespace mlir::dataflow;
29
30/// Get the entry state for a value. For any value that is not a ranked memref,
31/// this function sets the metadata to a top state with no offsets, sizes, or
32/// strides. For `memref` types, this function will use the metadata in the type
33/// to try to deduce as much informaiton as possible.
34static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
35 // TODO: generalize this method with a type interface.
36 auto mTy = dyn_cast<BaseMemRefType>(v.getType());
37
38 // If not a memref or it's un-ranked, don't infer any metadata.
39 if (!mTy || !mTy.hasRank())
40 return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
41
42 // Get the top state.
43 auto metadata =
44 StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
45
46 // Compute the offset and strides.
47 int64_t offset;
49 if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
50 return metadata;
51
52 // Refine the metadata if we know it from the type.
53 if (!ShapedType::isDynamic(offset)) {
54 metadata.getOffsets()[0] =
55 ConstantIntRanges::constant(APInt(indexBitwidth, offset));
56 }
57 for (auto &&[size, range] :
58 llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
59 if (ShapedType::isDynamic(size))
60 continue;
61 range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
62 }
63 for (auto &&[stride, range] :
64 llvm::zip_equal(strides, metadata.getStrides())) {
65 if (ShapedType::isDynamic(stride))
66 continue;
67 range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
68 }
69
70 return metadata;
71}
72
74 DataFlowSolver &solver, int32_t indexBitwidth)
75 : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
76 assert(indexBitwidth > 0 && "invalid bitwidth");
77}
78
82 lattice->getAnchor(), indexBitwidth)));
83}
84
88 auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
89
90 // Bail if we cannot reason about the op.
91 if (!inferrable) {
92 setAllToEntryStates(results);
93 return success();
94 }
95
96 LDBG() << "Inferring metadata for: "
97 << OpWithFlags(op, OpPrintingFlags().skipRegions());
98
99 // Helper function to retrieve int range values.
100 auto getIntRange = [&](Value value) -> IntegerValueRange {
102 getProgramPointAfter(op), value);
103 return lattice ? lattice->getValue() : IntegerValueRange();
104 };
105
106 // Convert the arguments lattices to a vector.
107 SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
108 operands, [](const StridedMetadataRangeLattice *lattice) {
109 return lattice->getValue();
110 });
111
112 // Callback to set metadata on a result.
113 auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
114 auto result = cast<OpResult>(v);
115 assert(llvm::is_contained(op->getResults(), result));
116 LDBG() << "- Inferred metadata: " << md;
117 StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
118 ChangeResult changed = lattice->join(md);
119 LDBG() << "- Joined metadata: " << lattice->getValue();
120 propagateIfChanged(lattice, changed);
121 };
122
123 // Infer the metadata.
124 inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
125 indexBitwidth);
126 return success();
127}
return success()
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth)
Get the entry state for a value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
ProgramPoint * getProgramPointAfter(Operation *op)
friend class DataFlowSolver
Allow the data-flow solver to access the internals of this class.
const StateT * getOrCreateFor(ProgramPoint *dependent, AnchorT anchor)
Get a read-only analysis state for the given point and create a dependency on dependent.
This lattice value represents the integer range of an SSA value.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:415
A class that represents the strided metadata range information, including offsets,...
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, int32_t offsetsRank, int32_t sizeRank, int32_t stridedRank)
Returns a strided metadata range with maximum ranges.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
ValueT & getValue()
Return the value held by this lattice.
Value getAnchor() const
Return the value this lattice is located at.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
void setAllToEntryStates(ArrayRef< StridedMetadataRangeLattice * > lattices)
void setToEntryState(StridedMetadataRangeLattice *lattice) override
At an entry point, we cannot reason about strided metadata ranges unless the type also encodes the da...
StridedMetadataRangeAnalysis(DataFlowSolver &solver, int32_t indexBitwidth=64)
LogicalResult visitOperation(Operation *op, ArrayRef< const StridedMetadataRangeLattice * > operands, ArrayRef< StridedMetadataRangeLattice * > results) override
Visit an operation.
This lattice element represents the strided metadata of an SSA value.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
ChangeResult
A result type used to indicate if a change happened.