MLIR  19.0.0git
Traits.cpp
Go to the documentation of this file.
1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/Support/FormatVariadic.h"
13 #include <optional>
14 
15 using namespace mlir;
16 
18  ArrayRef<int64_t> shape2) {
20  extents.emplace_back(shape1.begin(), shape1.end());
21  extents.emplace_back(shape2.begin(), shape2.end());
22  return staticallyKnownBroadcastable(extents);
23 }
24 
27  assert(!shapes.empty() && "Expected at least one shape");
28  size_t maxRank = shapes[0].size();
29  for (size_t i = 1; i != shapes.size(); ++i)
30  maxRank = std::max(maxRank, shapes[i].size());
31 
32  // We look backwards through every column of `shapes`.
33  for (size_t i = 0; i != maxRank; ++i) {
34  bool seenDynamic = false;
35  std::optional<int64_t> nonOneDim;
36  for (ArrayRef<int64_t> extent : shapes) {
37  int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
38 
39  if (dim == 1)
40  continue;
41 
42  // Dimensions are compatible when
43  //. 1. One is dynamic, the rest are 1
44  if (ShapedType::isDynamic(dim)) {
45  if (seenDynamic || nonOneDim)
46  return false;
47  seenDynamic = true;
48  }
49 
50  // 2. All are 1 or a specific constant.
51  if (nonOneDim && dim != *nonOneDim)
52  return false;
53 
54  nonOneDim = dim;
55  }
56  }
57  return true;
58 }
59 
61  ArrayRef<int64_t> shape2,
62  SmallVectorImpl<int64_t> &resultShape) {
63  // To compute the result broadcasted shape, we compare operand shapes
64  // element-wise: starting with the trailing dimensions, and working the
65  // way backward. Two dimensions are compatible when
66  // 1. they are equal, or
67  // 2. one of them is 1
68  // The result shape has the maximum among the two inputs at every
69  // dimension index.
70 
71  resultShape.clear();
72  if (shape1.size() > shape2.size()) {
73  std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
74  } else {
75  std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
76  }
77 
78  auto i1 = shape1.rbegin(), e1 = shape1.rend();
79  auto i2 = shape2.rbegin(), e2 = shape2.rend();
80  auto iR = resultShape.rbegin();
81 
82  // Check each dimension is consistent.
83  for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
84  if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
85  // One or both dimensions is unknown. Follow TensorFlow behavior:
86  // - If either dimension is greater than 1, we assume that the program is
87  // correct, and the other dimension will be broadcast to match it.
88  // - If either dimension is 1, the other dimension is the output.
89  if (*i1 > 1) {
90  *iR = *i1;
91  } else if (*i2 > 1) {
92  *iR = *i2;
93  } else if (*i1 == 1) {
94  *iR = *i2;
95  } else if (*i2 == 1) {
96  *iR = *i1;
97  } else {
98  *iR = ShapedType::kDynamic;
99  }
100  } else {
101  if (*i1 == *i2 || *i2 == 1) {
102  *iR = *i1;
103  } else if (*i1 == 1) {
104  *iR = *i2;
105  } else {
106  // This dimension of the two operand types is incompatible.
107  resultShape.clear();
108  return false;
109  }
110  }
111  }
112 
113  return true;
114 }
115 
116 /// Returns the shape of the given type. Scalars will be considered as having a
117 /// shape with zero dimensions.
119  if (auto sType = dyn_cast<ShapedType>(type))
120  return sType.getShape();
121  return {};
122 }
123 
124 /// Returns the result broadcast composition type from the two given types by
125 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
126 /// either of the input types has dynamic shape. Returns null type if the two
127 /// given types are not broadcast-compatible.
128 ///
129 /// elementType, if specified, will be used as the element type of the
130 /// broadcasted result type. Otherwise it is required that the element type of
131 /// type1 and type2 is the same and this element type will be used as the
132 /// resultant element type.
134  Type elementType) {
135  // If the elementType is not specified, then the use the common element type
136  // of the inputs or fail if there is no common element type.
137  if (!elementType) {
138  elementType = getElementTypeOrSelf(type1);
139  if (elementType != getElementTypeOrSelf(type2))
140  return {};
141  }
142 
143  // If one of the types is unranked tensor, then the other type shouldn't be
144  // vector and the result should have unranked tensor type.
145  if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
146  if (isa<VectorType>(type1) || isa<VectorType>(type2))
147  return {};
148  return UnrankedTensorType::get(elementType);
149  }
150 
151  // Returns the type kind if the given type is a vector or ranked tensor type.
152  // Returns std::nullopt otherwise.
153  auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
154  if (isa<VectorType, RankedTensorType>(type))
155  return type.getTypeID();
156  return std::nullopt;
157  };
158 
159  // Make sure the composite type, if has, is consistent.
160  std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
161  std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
162  std::optional<TypeID> resultCompositeKind;
163 
164  if (compositeKind1 && compositeKind2) {
165  // Disallow mixing vector and tensor.
166  if (compositeKind1 != compositeKind2)
167  return {};
168  resultCompositeKind = compositeKind1;
169  } else if (compositeKind1) {
170  resultCompositeKind = compositeKind1;
171  } else if (compositeKind2) {
172  resultCompositeKind = compositeKind2;
173  }
174 
175  // Get the shape of each type.
176  SmallVector<int64_t, 4> resultShape;
177  if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
178  return {};
179 
180  // Compose the final broadcasted type
181  if (resultCompositeKind == VectorType::getTypeID())
182  return VectorType::get(resultShape, elementType);
183  if (resultCompositeKind == RankedTensorType::getTypeID())
184  return RankedTensorType::get(resultShape, elementType);
185  return elementType;
186 }
187 
188 /// Returns a tuple corresponding to whether range has tensor or vector type.
189 template <typename iterator_range>
190 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
191  return {llvm::any_of(types, llvm::IsaPred<TensorType>),
192  llvm::any_of(types, llvm::IsaPred<VectorType>)};
193 }
194 
196  ArrayRef<int64_t> existing) {
197  // If both interred and existing dimensions are static, they must be equal.
198  auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
199  return ShapedType::isDynamic(existingDim) ||
200  ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
201  };
202  if (inferred.size() != existing.size())
203  return false;
204  for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
205  if (!isCompatible(inferredDim, existingDim))
206  return false;
207  return true;
208 }
209 
210 static std::string getShapeString(ArrayRef<int64_t> shape) {
211  // TODO: should replace with printing shape more uniformly across here and
212  // when in type.
213  std::string ret;
214  llvm::raw_string_ostream ss(ret);
215  ss << '\'';
216  llvm::interleave(
217  shape, ss,
218  [&](int64_t dim) {
219  if (ShapedType::isDynamic(dim))
220  ss << '?';
221  else
222  ss << dim;
223  },
224  "x");
225  ss << '\'';
226  return ss.str();
227 }
228 
230  // Ensure broadcasting only tensor or only vector types.
231  auto operandsHasTensorVectorType =
233  auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
234  if ((std::get<0>(operandsHasTensorVectorType) ||
235  std::get<0>(resultsHasTensorVectorType)) &&
236  (std::get<1>(operandsHasTensorVectorType) ||
237  std::get<1>(resultsHasTensorVectorType)))
238  return op->emitError("cannot broadcast vector with tensor");
239 
240  auto rankedOperands =
241  make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
242 
243  // If all operands are unranked, then all result shapes are possible.
244  if (rankedOperands.empty())
245  return success();
246 
247  // Compute broadcasted shape of operands (which requires that operands are
248  // broadcast compatible). The results need to be broadcast compatible with
249  // this result shape.
250  SmallVector<int64_t, 4> resultShape;
251  (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
252  resultShape);
253  for (auto other : make_early_inc_range(rankedOperands)) {
254  SmallVector<int64_t, 4> temp = resultShape;
255  if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
256  return op->emitOpError("operands don't have broadcast-compatible shapes");
257  }
258 
259  auto rankedResults =
260  make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
261 
262  // If all of the results are unranked then no further verification.
263  if (rankedResults.empty())
264  return success();
265 
266  for (auto type : rankedResults) {
267  ArrayRef<int64_t> actualSuffix =
268  getShape(type).take_back(resultShape.size());
269  if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
270  return op->emitOpError()
271  << "result type " << getShapeString(getShape(type))
272  << " not broadcast compatible with broadcasted operands's shapes "
273  << getShapeString(resultShape);
274  }
275  return success();
276 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::string getShapeString(ArrayRef< int64_t > shape)
Definition: Traits.cpp:210
static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)
Definition: Traits.cpp:195
static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)
Returns a tuple corresponding to whether range has tensor or vector type.
Definition: Traits.cpp:190
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
LogicalResult verifyCompatibleOperandBroadcast(Operation *op)
Definition: Traits.cpp:229
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:25
Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)
Returns the result broadcast composition type from the two given types by following NumPy broadcast s...
Definition: Traits.cpp:133
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26