MLIR  22.0.0git
StridedMetadataRangeAnalysis.cpp
Go to the documentation of this file.
1 //===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the dataflow analysis class for integer range inference
11 // which is used in transformations over the `arith` dialect such as
12 // branch elimination or signed->unsigned rewriting
13 //
14 //===----------------------------------------------------------------------===//
15 
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/Value.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 
25 #define DEBUG_TYPE "strided-metadata-range-analysis"
26 
27 using namespace mlir;
28 using namespace mlir::dataflow;
29 
30 /// Get the entry state for a value. For any value that is not a ranked memref,
31 /// this function sets the metadata to a top state with no offsets, sizes, or
32 /// strides. For `memref` types, this function will use the metadata in the type
33 /// to try to deduce as much informaiton as possible.
34 static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
35  // TODO: generalize this method with a type interface.
36  auto mTy = dyn_cast<BaseMemRefType>(v.getType());
37 
38  // If not a memref or it's un-ranked, don't infer any metadata.
39  if (!mTy || !mTy.hasRank())
40  return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
41 
42  // Get the top state.
43  auto metadata =
44  StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
45 
46  // Compute the offset and strides.
47  int64_t offset;
48  SmallVector<int64_t> strides;
49  if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
50  return metadata;
51 
52  // Refine the metadata if we know it from the type.
53  if (!ShapedType::isDynamic(offset)) {
54  metadata.getOffsets()[0] =
55  ConstantIntRanges::constant(APInt(indexBitwidth, offset));
56  }
57  for (auto &&[size, range] :
58  llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
59  if (ShapedType::isDynamic(size))
60  continue;
61  range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
62  }
63  for (auto &&[stride, range] :
64  llvm::zip_equal(strides, metadata.getStrides())) {
65  if (ShapedType::isDynamic(stride))
66  continue;
67  range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
68  }
69 
70  return metadata;
71 }
72 
74  DataFlowSolver &solver, int32_t indexBitwidth)
75  : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
76  assert(indexBitwidth > 0 && "invalid bitwidth");
77 }
78 
80  StridedMetadataRangeLattice *lattice) {
81  propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
82  lattice->getAnchor(), indexBitwidth)));
83 }
84 
88  auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
89 
90  // Bail if we cannot reason about the op.
91  if (!inferrable) {
92  setAllToEntryStates(results);
93  return success();
94  }
95 
96  LDBG() << "Inferring metadata for: "
97  << OpWithFlags(op, OpPrintingFlags().skipRegions());
98 
99  // Helper function to retrieve int range values.
100  auto getIntRange = [&](Value value) -> IntegerValueRange {
101  auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
102  getProgramPointAfter(op), value);
103  return lattice ? lattice->getValue() : IntegerValueRange();
104  };
105 
106  // Convert the arguments lattices to a vector.
107  SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
108  operands, [](const StridedMetadataRangeLattice *lattice) {
109  return lattice->getValue();
110  });
111 
112  // Callback to set metadata on a result.
113  auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
114  auto result = cast<OpResult>(v);
115  assert(llvm::is_contained(op->getResults(), result));
116  LDBG() << "- Inferred metadata: " << md;
117  StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
118  ChangeResult changed = lattice->join(md);
119  LDBG() << "- Joined metadata: " << lattice->getValue();
120  propagateIfChanged(lattice, changed);
121  };
122 
123  // Infer the metadata.
124  inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
125  indexBitwidth);
126  return success();
127 }
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth)
Get the entry state for a value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
ProgramPoint * getProgramPointAfter(Operation *op)
The general data-flow analysis solver.
This lattice value represents the integer range of an SSA value.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:415
A class that represents the strided metadata range information, including offsets,...
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, int32_t offsetsRank, int32_t sizeRank, int32_t stridedRank)
Returns a strided metadata range with maximum ranges.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value getAnchor() const
Return the value this lattice is located at.
ValueT & getValue()
Return the value held by this lattice.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
A sparse forward data-flow analysis for propagating SSA value lattices across the IR by implementing ...
void setAllToEntryStates(ArrayRef< StridedMetadataRangeLattice * > lattices)
void setToEntryState(StridedMetadataRangeLattice *lattice) override
At an entry point, we cannot reason about strided metadata ranges unless the type also encodes the da...
StridedMetadataRangeAnalysis(DataFlowSolver &solver, int32_t indexBitwidth=64)
LogicalResult visitOperation(Operation *op, ArrayRef< const StridedMetadataRangeLattice * > operands, ArrayRef< StridedMetadataRangeLattice * > results) override
Visit an operation.
This lattice element represents the strided metadata of an SSA value.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
ChangeResult
A result type used to indicate if a change happened.