MLIR  22.0.0git
InferStridedMetadataInterface.h
Go to the documentation of this file.
1 //===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of the strided metadata inference interface
10 // defined in `InferStridedMetadataInterface.td`
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
15 #define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
16 
18 
19 namespace mlir {
20 /// A class that represents the strided metadata range information, including
21 /// offsets, sizes, and strides as integer ranges.
23 public:
24  /// Default constructor creates uninitialized ranges.
25  StridedMetadataRange() = default;
26 
27  /// Returns a ranked strided metadata range.
32  return StridedMetadataRange(std::move(offsets), std::move(sizes),
33  std::move(strides));
34  }
35 
36  /// Returns a strided metadata range with maximum ranges.
37  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
38  int32_t offsetsRank,
39  int32_t sizeRank,
40  int32_t stridedRank) {
41  return StridedMetadataRange(
43  offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
45  sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
47  stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
48  }
49 
50  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
51  int32_t rank) {
52  return getMaxRanges(indexBitwidth, 1, rank, rank);
53  }
54 
55  /// Returns whether the metadata is uninitialized.
56  bool isUninitialized() const { return !offsets.has_value(); }
57 
58  /// Get the offsets range.
60  return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
61  }
63  return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
64  }
65 
66  /// Get the sizes ranges.
67  ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
69 
70  /// Get the strides ranges.
71  ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
73 
74  /// Compare two strided metadata ranges.
75  bool operator==(const StridedMetadataRange &other) const {
76  return offsets == other.offsets && sizes == other.sizes &&
77  strides == other.strides;
78  }
79 
80  /// Print the strided metadata range.
81  void print(raw_ostream &os) const;
82 
83  /// Join two strided metadata ranges, by taking the element-wise union of the
84  /// metadata.
86  const StridedMetadataRange &rhs) {
87  if (lhs.isUninitialized())
88  return rhs;
89  if (rhs.isUninitialized())
90  return lhs;
91 
92  // Helper fuction to compute the range union of constant ranges.
93  auto rangeUnion =
94  +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
96  return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
97  };
98 
99  // Get the elementwise range union. Note, that `zip_equal` will assert if
100  // sizes are not equal.
101  SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
102  llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
104  llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
105  SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
106  llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
107 
108  // Return the joined metadata.
109  return StridedMetadataRange(std::move(offsets), std::move(sizes),
110  std::move(strides));
111  }
112 
113 private:
114  /// Create a strided metadata range with the given offset, sizes, and strides.
118  : offsets(std::move(offsets)), sizes(std::move(sizes)),
119  strides(std::move(strides)) {}
120 
121  /// The offsets range.
122  std::optional<SmallVector<ConstantIntRanges>> offsets;
123 
124  /// The sizes ranges.
125  SmallVector<ConstantIntRanges> sizes;
126 
127  /// The strides ranges.
128  SmallVector<ConstantIntRanges> strides;
129 };
130 
131 /// Print the strided metadata to `os`.
132 inline raw_ostream &operator<<(raw_ostream &os,
133  const StridedMetadataRange &range) {
134  range.print(os);
135  return os;
136 }
137 
138 /// Callback function type for setting the strided metadata of a value.
140  function_ref<void(Value, const StridedMetadataRange &)>;
141 } // end namespace mlir
142 
143 #include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
144 
145 #endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
A class that represents the strided metadata range information, including offsets,...
StridedMetadataRange()=default
Default constructor creates uninitialized ranges.
MutableArrayRef< ConstantIntRanges > getStrides()
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, int32_t rank)
MutableArrayRef< ConstantIntRanges > getSizes()
ArrayRef< ConstantIntRanges > getSizes() const
Get the sizes ranges.
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
MutableArrayRef< ConstantIntRanges > getOffsets()
static StridedMetadataRange join(const StridedMetadataRange &lhs, const StridedMetadataRange &rhs)
Join two strided metadata ranges, by taking the element-wise union of the metadata.
void print(raw_ostream &os) const
Print the strided metadata range.
bool operator==(const StridedMetadataRange &other) const
Compare two strided metadata ranges.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, int32_t offsetsRank, int32_t sizeRank, int32_t stridedRank)
Returns a strided metadata range with maximum ranges.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78