19 extents.emplace_back(shape1.begin(), shape1.end());
20 extents.emplace_back(shape2.begin(), shape2.end());
26 assert(!shapes.empty() &&
"Expected at least one shape");
27 size_t maxRank = shapes[0].size();
28 for (
size_t i = 1; i != shapes.size(); ++i)
29 maxRank =
std::max(maxRank, shapes[i].size());
32 for (
size_t i = 0; i != maxRank; ++i) {
33 bool seenDynamic =
false;
34 std::optional<int64_t> nonOneDim;
36 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
43 if (ShapedType::isDynamic(dim)) {
44 if (seenDynamic || nonOneDim)
50 if (nonOneDim && dim != *nonOneDim)
71 if (shape1.size() > shape2.size()) {
72 llvm::append_range(resultShape, shape1);
74 llvm::append_range(resultShape, shape2);
77 auto i1 = shape1.rbegin(), e1 = shape1.rend();
78 auto i2 = shape2.rbegin(), e2 = shape2.rend();
79 auto iR = resultShape.rbegin();
82 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
83 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
92 }
else if (*i1 == 1) {
94 }
else if (*i2 == 1) {
97 *iR = ShapedType::kDynamic;
100 if (*i1 == *i2 || *i2 == 1) {
102 }
else if (*i1 == 1) {
118 if (
auto sType = dyn_cast<ShapedType>(type))
119 return sType.getShape();
144 if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
145 if (isa<VectorType>(type1) || isa<VectorType>(type2))
152 auto getCompositeTypeKind = [](
Type type) -> std::optional<TypeID> {
153 if (isa<VectorType, RankedTensorType>(type))
154 return type.getTypeID();
159 std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
160 std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
161 std::optional<TypeID> resultCompositeKind;
163 if (compositeKind1 && compositeKind2) {
165 if (compositeKind1 != compositeKind2)
167 resultCompositeKind = compositeKind1;
168 }
else if (compositeKind1) {
169 resultCompositeKind = compositeKind1;
170 }
else if (compositeKind2) {
171 resultCompositeKind = compositeKind2;
180 if (resultCompositeKind == VectorType::getTypeID())
182 if (resultCompositeKind == RankedTensorType::getTypeID())
188 template <
typename iterator_range>
190 return {llvm::any_of(types, llvm::IsaPred<TensorType>),
191 llvm::any_of(types, llvm::IsaPred<VectorType>)};
197 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
198 return ShapedType::isDynamic(existingDim) ||
199 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
201 if (inferred.size() != existing.size())
203 for (
auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
204 if (!isCompatible(inferredDim, existingDim))
213 llvm::raw_string_ostream ss(ret);
218 if (ShapedType::isDynamic(dim))
230 auto operandsHasTensorVectorType =
233 if ((std::get<0>(operandsHasTensorVectorType) ||
234 std::get<0>(resultsHasTensorVectorType)) &&
235 (std::get<1>(operandsHasTensorVectorType) ||
236 std::get<1>(resultsHasTensorVectorType)))
237 return op->
emitError(
"cannot broadcast vector with tensor");
239 auto rankedOperands =
240 make_filter_range(op->
getOperandTypes(), llvm::IsaPred<RankedTensorType>);
243 if (rankedOperands.empty())
252 for (
auto other : make_early_inc_range(rankedOperands)) {
255 return op->
emitOpError(
"operands don't have broadcast-compatible shapes");
259 make_filter_range(op->
getResultTypes(), llvm::IsaPred<RankedTensorType>);
262 if (rankedResults.empty())
265 for (
auto type : rankedResults) {
267 getShape(type).take_back(resultShape.size());
271 <<
" not broadcast compatible with broadcasted operands's shapes "
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::string getShapeString(ArrayRef< int64_t > shape)
static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)
static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)
Returns a tuple corresponding to whether range has tensor or vector type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
LogicalResult verifyCompatibleOperandBroadcast(Operation *op)
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)
Returns the result broadcast composition type from the two given types by following NumPy broadcast s...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...