MLIR 22.0.0git
TypeUtilities.cpp
Go to the documentation of this file.
1//===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines generic type utilities.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/IR/Attributes.h"
16#include "mlir/IR/Types.h"
17#include "mlir/IR/Value.h"
18#include "llvm/ADT/SmallVectorExtras.h"
19#include <numeric>
20
21using namespace mlir;
22
24 if (auto st = llvm::dyn_cast<ShapedType>(type))
25 return st.getElementType();
26 return type;
27}
28
32
34 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
35 return getElementTypeOrSelf(typedAttr.getType());
36 return {};
37}
38
41 t.getFlattenedTypes(fTypes);
42 return fTypes;
43}
44
45/// Return true if the specified type is an opaque type with the specified
46/// dialect and typeData.
47bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
48 StringRef typeData) {
49 if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type))
50 return opaque.getDialectNamespace() == dialect &&
51 opaque.getTypeData() == typeData;
52 return false;
53}
54
55/// Returns success if the given two shapes are compatible. That is, they have
56/// the same size and each pair of the elements are equal or one of them is
57/// dynamic.
59 ArrayRef<int64_t> shape2) {
60 if (shape1.size() != shape2.size())
61 return failure();
62 for (auto dims : llvm::zip(shape1, shape2)) {
63 int64_t dim1 = std::get<0>(dims);
64 int64_t dim2 = std::get<1>(dims);
65 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) &&
66 dim1 != dim2)
67 return failure();
68 }
69 return success();
70}
71
72/// Returns success if the given two types have compatible shape. That is,
73/// they are both scalars (not shaped), or they are both shaped types and at
74/// least one is unranked or they have compatible dimensions. Dimensions are
75/// compatible if at least one is dynamic or both are equal. The element type
76/// does not matter.
77LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
78 auto sType1 = llvm::dyn_cast<ShapedType>(type1);
79 auto sType2 = llvm::dyn_cast<ShapedType>(type2);
80
81 // Either both or neither type should be shaped.
82 if (!sType1)
83 return success(!sType2);
84 if (!sType2)
85 return failure();
86
87 if (!sType1.hasRank() || !sType2.hasRank())
88 return success();
89
90 return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
91}
92
93/// Returns success if the given two arrays have the same number of elements and
94/// each pair wise entries have compatible shape.
95LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
96 if (types1.size() != types2.size())
97 return failure();
98 for (auto it : llvm::zip_first(types1, types2))
99 if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
100 return failure();
101 return success();
102}
103
105 if (dims.empty())
106 return success();
107 auto staticDim =
108 llvm::accumulate(dims, dims.front(), [](auto fold, auto dim) {
109 return ShapedType::isDynamic(dim) ? fold : dim;
110 });
111 return success(llvm::all_of(dims, [&](auto dim) {
112 return ShapedType::isDynamic(dim) || dim == staticDim;
113 }));
114}
115
116/// Returns success if all given types have compatible shapes. That is, they are
117/// all scalars (not shaped), or they are all shaped types and any ranked shapes
118/// have compatible dimensions. Dimensions are compatible if all non-dynamic
119/// dims are equal. The element type does not matter.
121 auto shapedTypes = llvm::map_to_vector<8>(
122 types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
123 // Return failure if some, but not all are not shaped. Return early if none
124 // are shaped also.
125 if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
126 return success();
127 if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
128 return failure();
129
130 // Return failure if some, but not all, are scalable vectors.
131 bool hasScalableVecTypes = false;
132 bool hasNonScalableVecTypes = false;
133 for (Type t : types) {
134 auto vType = llvm::dyn_cast<VectorType>(t);
135 if (vType && vType.isScalable())
136 hasScalableVecTypes = true;
137 else
138 hasNonScalableVecTypes = true;
139 if (hasScalableVecTypes && hasNonScalableVecTypes)
140 return failure();
141 }
142
143 // Remove all unranked shapes
144 auto shapes = llvm::filter_to_vector<8>(
145 shapedTypes, [](auto shapedType) { return shapedType.hasRank(); });
146 if (shapes.empty())
147 return success();
148
149 // All ranks should be equal
150 auto firstRank = shapes.front().getRank();
151 if (llvm::any_of(shapes,
152 [&](auto shape) { return firstRank != shape.getRank(); }))
153 return failure();
154
155 for (unsigned i = 0; i < firstRank; ++i) {
156 // Retrieve all ranked dimensions
157 auto dims = llvm::map_to_vector<8>(
158 llvm::make_filter_range(
159 shapes, [&](auto shape) { return shape.getRank() >= i; }),
160 [&](auto shape) { return shape.getDimSize(i); });
161 if (verifyCompatibleDims(dims).failed())
162 return failure();
163 }
164
165 return success();
166}
167
169 return llvm::cast<ShapedType>(value.getType()).getElementType();
170}
171
173 return llvm::cast<ShapedType>(value.getType()).getElementType();
174}
175
177 TypeRange newTypes,
178 SmallVectorImpl<Type> &storage) {
179 assert(indices.size() == newTypes.size() &&
180 "mismatch between indice and type count");
181 if (indices.empty())
182 return oldTypes;
183
184 auto fromIt = oldTypes.begin();
185 for (auto it : llvm::zip(indices, newTypes)) {
186 const auto toIt = oldTypes.begin() + std::get<0>(it);
187 storage.append(fromIt, toIt);
188 storage.push_back(std::get<1>(it));
189 fromIt = toIt;
190 }
191 storage.append(fromIt, oldTypes.end());
192 return storage;
193}
194
196 SmallVectorImpl<Type> &storage) {
197 if (indices.none())
198 return types;
199
200 for (unsigned i = 0, e = types.size(); i < e; ++i)
201 if (!indices[i])
202 storage.emplace_back(types[i]);
203 return storage;
204}
return success()
Attributes are known-constant values of operations.
Definition Attributes.h:25
Type mapElement(Value value) const
Map the element to the iterator result type.
Type mapElement(Value value) const
Map the element to the iterator result type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Include the generated interface declarations.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData)
Return true if the specified type is an opaque type with the specified dialect and typeData.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.