MLIR 22.0.0git
Traits.cpp
Go to the documentation of this file.
1//===- Traits.cpp - Common op traits shared by dialects -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include <optional>
13
14using namespace mlir;
15
17 ArrayRef<int64_t> shape2) {
19 extents.emplace_back(shape1.begin(), shape1.end());
20 extents.emplace_back(shape2.begin(), shape2.end());
21 return staticallyKnownBroadcastable(extents);
22}
23
26 assert(!shapes.empty() && "Expected at least one shape");
27 size_t maxRank = shapes[0].size();
28 for (size_t i = 1; i != shapes.size(); ++i)
29 maxRank = std::max(maxRank, shapes[i].size());
30
31 // We look backwards through every column of `shapes`.
32 for (size_t i = 0; i != maxRank; ++i) {
33 bool seenDynamic = false;
34 std::optional<int64_t> nonOneDim;
35 for (ArrayRef<int64_t> extent : shapes) {
36 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
37
38 if (dim == 1)
39 continue;
40
41 // Dimensions are compatible when
42 //. 1. One is dynamic, the rest are 1
43 if (ShapedType::isDynamic(dim)) {
44 if (seenDynamic || nonOneDim)
45 return false;
46 seenDynamic = true;
47 }
48
49 // 2. All are 1 or a specific constant.
50 if (nonOneDim && dim != *nonOneDim)
51 return false;
52
53 nonOneDim = dim;
54 }
55 }
56 return true;
57}
58
60 ArrayRef<int64_t> shape2,
61 SmallVectorImpl<int64_t> &resultShape) {
62 // To compute the result broadcasted shape, we compare operand shapes
63 // element-wise: starting with the trailing dimensions, and working the
64 // way backward. Two dimensions are compatible when
65 // 1. they are equal, or
66 // 2. one of them is 1
67 // The result shape has the maximum among the two inputs at every
68 // dimension index.
69
70 resultShape.clear();
71 if (shape1.size() > shape2.size()) {
72 llvm::append_range(resultShape, shape1);
73 } else {
74 llvm::append_range(resultShape, shape2);
75 }
76
77 auto i1 = shape1.rbegin(), e1 = shape1.rend();
78 auto i2 = shape2.rbegin(), e2 = shape2.rend();
79 auto iR = resultShape.rbegin();
80
81 // Check each dimension is consistent.
82 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
83 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
84 // One or both dimensions is unknown. Follow TensorFlow behavior:
85 // - If either dimension is greater than 1, we assume that the program is
86 // correct, and the other dimension will be broadcasted to match it.
87 // - If either dimension is 1, the other dimension is the output.
88 if (*i1 > 1) {
89 *iR = *i1;
90 } else if (*i2 > 1) {
91 *iR = *i2;
92 } else if (*i1 == 1) {
93 *iR = *i2;
94 } else if (*i2 == 1) {
95 *iR = *i1;
96 } else {
97 *iR = ShapedType::kDynamic;
98 }
99 } else {
100 if (*i1 == *i2 || *i2 == 1) {
101 *iR = *i1;
102 } else if (*i1 == 1) {
103 *iR = *i2;
104 } else {
105 // This dimension of the two operand types is incompatible.
106 resultShape.clear();
107 return false;
108 }
109 }
110 }
111
112 return true;
113}
114
115/// Returns the shape of the given type. Scalars will be considered as having a
116/// shape with zero dimensions.
118 if (auto sType = dyn_cast<ShapedType>(type))
119 return sType.getShape();
120 return {};
121}
122
123/// Returns the result broadcast composition type from the two given types by
124/// following NumPy broadcast semantics. Returned type may have dynamic shape if
125/// either of the input types has dynamic shape. Returns null type if the two
126/// given types are not broadcast-compatible.
127///
128/// elementType, if specified, will be used as the element type of the
129/// broadcasted result type. Otherwise it is required that the element type of
130/// type1 and type2 is the same and this element type will be used as the
131/// resultant element type.
133 Type elementType) {
134 // If the elementType is not specified, then the use the common element type
135 // of the inputs or fail if there is no common element type.
136 if (!elementType) {
137 elementType = getElementTypeOrSelf(type1);
138 if (elementType != getElementTypeOrSelf(type2))
139 return {};
140 }
141
142 // If one of the types is unranked tensor, then the other type shouldn't be
143 // vector and the result should have unranked tensor type.
144 if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
145 if (isa<VectorType>(type1) || isa<VectorType>(type2))
146 return {};
147 return UnrankedTensorType::get(elementType);
148 }
149
150 // Returns the type kind if the given type is a vector or ranked tensor type.
151 // Returns std::nullopt otherwise.
152 auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
153 if (isa<VectorType, RankedTensorType>(type))
154 return type.getTypeID();
155 return std::nullopt;
156 };
157
158 // Make sure the composite type, if has, is consistent.
159 std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
160 std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
161 std::optional<TypeID> resultCompositeKind;
162
163 if (compositeKind1 && compositeKind2) {
164 // Disallow mixing vector and tensor.
165 if (compositeKind1 != compositeKind2)
166 return {};
167 resultCompositeKind = compositeKind1;
168 } else if (compositeKind1) {
169 resultCompositeKind = compositeKind1;
170 } else if (compositeKind2) {
171 resultCompositeKind = compositeKind2;
172 }
173
174 // Get the shape of each type.
175 SmallVector<int64_t, 4> resultShape;
176 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
177 return {};
178
179 // Compose the final broadcasted type
180 if (resultCompositeKind == VectorType::getTypeID())
181 return VectorType::get(resultShape, elementType);
182 if (resultCompositeKind == RankedTensorType::getTypeID())
183 return RankedTensorType::get(resultShape, elementType);
184 return elementType;
185}
186
187/// Returns a tuple corresponding to whether range has tensor or vector type.
188template <typename iterator_range>
189static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
190 return {llvm::any_of(types, llvm::IsaPred<TensorType>),
191 llvm::any_of(types, llvm::IsaPred<VectorType>)};
192}
193
195 ArrayRef<int64_t> existing) {
196 // If both interred and existing dimensions are static, they must be equal.
197 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
198 return ShapedType::isDynamic(existingDim) ||
199 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
200 };
201 if (inferred.size() != existing.size())
202 return false;
203 for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
204 if (!isCompatible(inferredDim, existingDim))
205 return false;
206 return true;
207}
208
210 // TODO: should replace with printing shape more uniformly across here and
211 // when in type.
212 std::string ret;
213 llvm::raw_string_ostream ss(ret);
214 ss << '\'';
215 llvm::interleave(
216 shape, ss,
217 [&](int64_t dim) {
218 if (ShapedType::isDynamic(dim))
219 ss << '?';
220 else
221 ss << dim;
222 },
223 "x");
224 ss << '\'';
225 return ret;
226}
227
229 // Ensure broadcasting only tensor or only vector types.
230 auto operandsHasTensorVectorType =
232 auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
233 if ((std::get<0>(operandsHasTensorVectorType) ||
234 std::get<0>(resultsHasTensorVectorType)) &&
235 (std::get<1>(operandsHasTensorVectorType) ||
236 std::get<1>(resultsHasTensorVectorType)))
237 return op->emitError("cannot broadcast vector with tensor");
238
239 auto rankedOperands =
240 make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
241
242 // If all operands are unranked, then all result shapes are possible.
243 if (rankedOperands.empty())
244 return success();
245
246 // Compute broadcasted shape of operands (which requires that operands are
247 // broadcast compatible). The results need to be broadcast compatible with
248 // this result shape.
249 SmallVector<int64_t, 4> resultShape;
250 (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
251 resultShape);
252 for (auto other : make_early_inc_range(rankedOperands)) {
253 SmallVector<int64_t, 4> temp = resultShape;
254 if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
255 return op->emitOpError("operands don't have broadcast-compatible shapes");
256 }
257
258 auto rankedResults =
259 make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
260
261 // If all of the results are unranked then no further verification.
262 if (rankedResults.empty())
263 return success();
264
265 for (auto type : rankedResults) {
266 ArrayRef<int64_t> actualSuffix =
267 getShape(type).take_back(resultShape.size());
268 if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
269 return op->emitOpError()
270 << "result type " << getShapeString(getShape(type))
271 << " not broadcast compatible with broadcasted operands's shapes "
272 << getShapeString(resultShape);
273 }
274 return success();
275}
return success()
static std::string getShapeString(ArrayRef< int64_t > shape)
Definition Traits.cpp:209
static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)
Definition Traits.cpp:194
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)
Returns a tuple corresponding to whether range has tensor or vector type.
Definition Traits.cpp:189
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
LogicalResult verifyCompatibleOperandBroadcast(Operation *op)
Definition Traits.cpp:228
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)
Returns the result broadcast composition type from the two given types by following NumPy broadcast s...
Definition Traits.cpp:132
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.