MLIR 22.0.0git
InferStridedMetadataInterface.h
Go to the documentation of this file.
1//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of the strided metadata inference interface
10// defined in `InferStridedMetadataInterface.td`
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
15#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
16
18
19namespace mlir {
20/// A class that represents the strided metadata range information, including
21/// offsets, sizes, and strides as integer ranges.
23public:
24 /// Default constructor creates uninitialized ranges.
26
27 /// Returns a ranked strided metadata range.
32 return StridedMetadataRange(std::move(offsets), std::move(sizes),
33 std::move(strides));
34 }
35
36 /// Returns a strided metadata range with maximum ranges.
37 static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
38 int32_t offsetsRank,
39 int32_t sizeRank,
40 int32_t stridedRank) {
43 offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
45 sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
47 stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
48 }
49
50 static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
51 int32_t rank) {
52 return getMaxRanges(indexBitwidth, 1, rank, rank);
53 }
54
55 /// Returns whether the metadata is uninitialized.
56 bool isUninitialized() const { return !offsets.has_value(); }
57
58 /// Get the offsets range.
60 return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
61 }
65
66 /// Get the sizes ranges.
67 ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
69
70 /// Get the strides ranges.
71 ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
73
74 /// Compare two strided metadata ranges.
75 bool operator==(const StridedMetadataRange &other) const {
76 return offsets == other.offsets && sizes == other.sizes &&
77 strides == other.strides;
78 }
79
80 /// Print the strided metadata range.
81 void print(raw_ostream &os) const;
82
83 /// Join two strided metadata ranges, by taking the element-wise union of the
84 /// metadata.
87 if (lhs.isUninitialized())
88 return rhs;
89 if (rhs.isUninitialized())
90 return lhs;
91
92 // Helper fuction to compute the range union of constant ranges.
93 auto rangeUnion =
94 +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
96 return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
97 };
98
99 // Get the elementwise range union. Note, that `zip_equal` will assert if
100 // sizes are not equal.
101 SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
102 llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
104 llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
105 SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
106 llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
107
108 // Return the joined metadata.
109 return StridedMetadataRange(std::move(offsets), std::move(sizes),
110 std::move(strides));
111 }
112
113private:
114 /// Create a strided metadata range with the given offset, sizes, and strides.
118 : offsets(std::move(offsets)), sizes(std::move(sizes)),
119 strides(std::move(strides)) {}
120
121 /// The offsets range.
122 std::optional<SmallVector<ConstantIntRanges>> offsets;
123
124 /// The sizes ranges.
126
127 /// The strides ranges.
129};
130
131/// Print the strided metadata to `os`.
133 const StridedMetadataRange &range) {
134 range.print(os);
135 return os;
136}
137
138/// Callback function type for setting the strided metadata of a value.
141} // end namespace mlir
142
143#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
144
145#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
lhs
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
A class that represents the strided metadata range information, including offsets,...
StridedMetadataRange()=default
Default constructor creates uninitialized ranges.
ArrayRef< ConstantIntRanges > getSizes() const
Get the sizes ranges.
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, int32_t rank)
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
static StridedMetadataRange join(const StridedMetadataRange &lhs, const StridedMetadataRange &rhs)
Join two strided metadata ranges, by taking the element-wise union of the metadata.
MutableArrayRef< ConstantIntRanges > getSizes()
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
void print(raw_ostream &os) const
Print the strided metadata range.
bool operator==(const StridedMetadataRange &other) const
Compare two strided metadata ranges.
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, int32_t offsetsRank, int32_t sizeRank, int32_t stridedRank)
Returns a strided metadata range with maximum ranges.
MutableArrayRef< ConstantIntRanges > getOffsets()
MutableArrayRef< ConstantIntRanges > getStrides()
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152