MLIR  16.0.0git
Traits.cpp
Go to the documentation of this file.
1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 using namespace mlir;
15 
17  ArrayRef<int64_t> shape2) {
19  extents.emplace_back(shape1.begin(), shape1.end());
20  extents.emplace_back(shape2.begin(), shape2.end());
21  return staticallyKnownBroadcastable(extents);
22 }
23 
26  assert(!shapes.empty() && "Expected at least one shape");
27  size_t maxRank = shapes[0].size();
28  for (size_t i = 1; i != shapes.size(); ++i)
29  maxRank = std::max(maxRank, shapes[i].size());
30 
31  // We look backwards through every column of `shapes`.
32  for (size_t i = 0; i != maxRank; ++i) {
33  bool seenDynamic = false;
34  Optional<int64_t> nonOneDim;
35  for (ArrayRef<int64_t> extent : shapes) {
36  int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
37 
38  if (dim == 1)
39  continue;
40 
41  // Dimensions are compatible when
42  //. 1. One is dynamic, the rest are 1
43  if (ShapedType::isDynamic(dim)) {
44  if (seenDynamic || nonOneDim)
45  return false;
46  seenDynamic = true;
47  }
48 
49  // 2. All are 1 or a specific constant.
50  if (nonOneDim && dim != *nonOneDim)
51  return false;
52 
53  nonOneDim = dim;
54  }
55  }
56  return true;
57 }
58 
60  ArrayRef<int64_t> shape2,
61  SmallVectorImpl<int64_t> &resultShape) {
62  // To compute the result broadcasted shape, we compare operand shapes
63  // element-wise: starting with the trailing dimensions, and working the
64  // way backward. Two dimensions are compatible when
65  // 1. they are equal, or
66  // 2. one of them is 1
67  // The result shape has the maximum among the two inputs at every
68  // dimension index.
69 
70  resultShape.clear();
71  if (shape1.size() > shape2.size()) {
72  std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
73  } else {
74  std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
75  }
76 
77  auto i1 = shape1.rbegin(), e1 = shape1.rend();
78  auto i2 = shape2.rbegin(), e2 = shape2.rend();
79  auto iR = resultShape.rbegin();
80 
81  // Check each dimension is consistent.
82  for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
83  if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
84  // One or both dimensions is unknown. Follow TensorFlow behavior:
85  // - If either dimension is greater than 1, we assume that the program is
86  // correct, and the other dimension will be broadcast to match it.
87  // - If either dimension is 1, the other dimension is the output.
88  if (*i1 > 1) {
89  *iR = *i1;
90  } else if (*i2 > 1) {
91  *iR = *i2;
92  } else if (*i1 == 1) {
93  *iR = *i2;
94  } else if (*i2 == 1) {
95  *iR = *i1;
96  } else {
97  *iR = ShapedType::kDynamic;
98  }
99  } else {
100  if (*i1 == *i2 || *i2 == 1) {
101  *iR = *i1;
102  } else if (*i1 == 1) {
103  *iR = *i2;
104  } else {
105  // This dimension of the two operand types is incompatible.
106  resultShape.clear();
107  return false;
108  }
109  }
110  }
111 
112  return true;
113 }
114 
115 /// Returns the shape of the given type. Scalars will be considered as having a
116 /// shape with zero dimensions.
118  if (auto sType = type.dyn_cast<ShapedType>())
119  return sType.getShape();
120  return {};
121 }
122 
123 /// Returns the result broadcast composition type from the two given types by
124 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
125 /// either of the input types has dynamic shape. Returns null type if the two
126 /// given types are not broadcast-compatible.
127 ///
128 /// elementType, if specified, will be used as the element type of the
129 /// broadcasted result type. Otherwise it is required that the element type of
130 /// type1 and type2 is the same and this element type will be used as the
131 /// resultant element type.
133  Type elementType) {
134  // If the elementType is not specified, then the use the common element type
135  // of the inputs or fail if there is no common element type.
136  if (!elementType) {
137  elementType = getElementTypeOrSelf(type1);
138  if (elementType != getElementTypeOrSelf(type2))
139  return {};
140  }
141 
142  // If one of the types is unranked tensor, then the other type shouldn't be
143  // vector and the result should have unranked tensor type.
144  if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
145  if (type1.isa<VectorType>() || type2.isa<VectorType>())
146  return {};
147  return UnrankedTensorType::get(elementType);
148  }
149 
150  // Returns the type kind if the given type is a vector or ranked tensor type.
151  // Returns std::nullopt otherwise.
152  auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> {
153  if (type.isa<VectorType, RankedTensorType>())
154  return type.getTypeID();
155  return std::nullopt;
156  };
157 
158  // Make sure the composite type, if has, is consistent.
159  Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
160  Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
161  Optional<TypeID> resultCompositeKind;
162 
163  if (compositeKind1 && compositeKind2) {
164  // Disallow mixing vector and tensor.
165  if (compositeKind1 != compositeKind2)
166  return {};
167  resultCompositeKind = compositeKind1;
168  } else if (compositeKind1) {
169  resultCompositeKind = compositeKind1;
170  } else if (compositeKind2) {
171  resultCompositeKind = compositeKind2;
172  }
173 
174  // Get the shape of each type.
175  SmallVector<int64_t, 4> resultShape;
176  if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
177  return {};
178 
179  // Compose the final broadcasted type
180  if (resultCompositeKind == VectorType::getTypeID())
181  return VectorType::get(resultShape, elementType);
182  if (resultCompositeKind == RankedTensorType::getTypeID())
183  return RankedTensorType::get(resultShape, elementType);
184  return elementType;
185 }
186 
187 /// Returns a tuple corresponding to whether range has tensor or vector type.
188 template <typename iterator_range>
189 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
190  return std::make_tuple(
191  llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
192  llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
193 }
194 
196  ArrayRef<int64_t> existing) {
197  auto isCompatible = [](int64_t dim1, int64_t dim2) {
198  // If the inferred and existing dim is the same, or one of them is unknown
199  // then it is compatible, else if the inferred dim is 1 then it is also
200  // compatible. But if the existing dim is 1 and the inferred is greater than
201  // 1 then flag.
202  return dim1 == dim2 || ShapedType::isDynamic(dim1) ||
203  ShapedType::isDynamic(dim2) || dim1 == 1;
204  };
205  if (inferred.size() != existing.size())
206  return false;
207  for (auto p : llvm::zip(inferred, existing))
208  if (!isCompatible(std::get<0>(p), std::get<1>(p)))
209  return false;
210  return true;
211 }
212 
213 static std::string getShapeString(ArrayRef<int64_t> shape) {
214  // TODO: should replace with printing shape more uniformly across here and
215  // when in type.
216  std::string ret;
217  llvm::raw_string_ostream ss(ret);
218  ss << '\'';
219  llvm::interleave(
220  shape, ss,
221  [&](int64_t dim) {
222  if (ShapedType::isDynamic(dim))
223  ss << '?';
224  else
225  ss << dim;
226  },
227  "x");
228  ss << '\'';
229  return ss.str();
230 }
231 
233  // Ensure broadcasting only tensor or only vector types.
234  auto operandsHasTensorVectorType =
236  auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
237  if ((std::get<0>(operandsHasTensorVectorType) ||
238  std::get<0>(resultsHasTensorVectorType)) &&
239  (std::get<1>(operandsHasTensorVectorType) ||
240  std::get<1>(resultsHasTensorVectorType)))
241  return op->emitError("cannot broadcast vector with tensor");
242 
243  auto rankedOperands = make_filter_range(
244  op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
245 
246  // If all operands are unranked, then all result shapes are possible.
247  if (rankedOperands.empty())
248  return success();
249 
250  // Compute broadcasted shape of operands (which requires that operands are
251  // broadcast compatible). The results need to be broadcast compatible with
252  // this result shape.
253  SmallVector<int64_t, 4> resultShape;
254  (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
255  resultShape);
256  for (auto other : make_early_inc_range(rankedOperands)) {
257  SmallVector<int64_t, 4> temp = resultShape;
258  if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
259  return op->emitOpError("operands don't have broadcast-compatible shapes");
260  }
261 
262  auto rankedResults = make_filter_range(
263  op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
264 
265  // If all of the results are unranked then no further verification.
266  if (rankedResults.empty())
267  return success();
268 
269  for (auto type : rankedResults) {
270  ArrayRef<int64_t> actualSuffix =
271  getShape(type).take_back(resultShape.size());
272  if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
273  return op->emitOpError()
274  << "result type " << getShapeString(getShape(type))
275  << " not broadcast compatible with broadcasted operands's shapes "
276  << getShapeString(resultShape);
277  }
278  return success();
279 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::string getShapeString(ArrayRef< int64_t > shape)
Definition: Traits.cpp:213
static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)
Definition: Traits.cpp:195
static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)
Returns a tuple corresponding to whether range has tensor or vector type.
Definition: Traits.cpp:189
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
operand_type_range getOperandTypes()
Definition: Operation.h:314
result_type_range getResultTypes()
Definition: Operation.h:345
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
LogicalResult verifyCompatibleOperandBroadcast(Operation *op)
Definition: Traits.cpp:232
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:24
Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)
Returns the result broadcast composition type from the two given types by following NumPy broadcast s...
Definition: Traits.cpp:132
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26