12 #include "llvm/Support/FormatVariadic.h"
20 extents.emplace_back(shape1.begin(), shape1.end());
21 extents.emplace_back(shape2.begin(), shape2.end());
27 assert(!shapes.empty() &&
"Expected at least one shape");
28 size_t maxRank = shapes[0].size();
29 for (
size_t i = 1; i != shapes.size(); ++i)
30 maxRank =
std::max(maxRank, shapes[i].size());
33 for (
size_t i = 0; i != maxRank; ++i) {
34 bool seenDynamic =
false;
35 std::optional<int64_t> nonOneDim;
37 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
44 if (ShapedType::isDynamic(dim)) {
45 if (seenDynamic || nonOneDim)
51 if (nonOneDim && dim != *nonOneDim)
72 if (shape1.size() > shape2.size()) {
73 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
75 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
78 auto i1 = shape1.rbegin(), e1 = shape1.rend();
79 auto i2 = shape2.rbegin(), e2 = shape2.rend();
80 auto iR = resultShape.rbegin();
83 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
84 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
93 }
else if (*i1 == 1) {
95 }
else if (*i2 == 1) {
98 *iR = ShapedType::kDynamic;
101 if (*i1 == *i2 || *i2 == 1) {
103 }
else if (*i1 == 1) {
119 if (
auto sType = dyn_cast<ShapedType>(type))
120 return sType.getShape();
145 if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
146 if (isa<VectorType>(type1) || isa<VectorType>(type2))
153 auto getCompositeTypeKind = [](
Type type) -> std::optional<TypeID> {
154 if (isa<VectorType, RankedTensorType>(type))
155 return type.getTypeID();
160 std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
161 std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
162 std::optional<TypeID> resultCompositeKind;
164 if (compositeKind1 && compositeKind2) {
166 if (compositeKind1 != compositeKind2)
168 resultCompositeKind = compositeKind1;
169 }
else if (compositeKind1) {
170 resultCompositeKind = compositeKind1;
171 }
else if (compositeKind2) {
172 resultCompositeKind = compositeKind2;
181 if (resultCompositeKind == VectorType::getTypeID())
183 if (resultCompositeKind == RankedTensorType::getTypeID())
189 template <
typename iterator_range>
191 return std::make_tuple(
192 llvm::any_of(types, [](
Type t) {
return isa<TensorType>(t); }),
193 llvm::any_of(types, [](
Type t) {
return isa<VectorType>(t); }));
199 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
200 return ShapedType::isDynamic(existingDim) ||
201 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
203 if (inferred.size() != existing.size())
205 for (
auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
206 if (!isCompatible(inferredDim, existingDim))
215 llvm::raw_string_ostream ss(ret);
220 if (ShapedType::isDynamic(dim))
232 auto operandsHasTensorVectorType =
235 if ((std::get<0>(operandsHasTensorVectorType) ||
236 std::get<0>(resultsHasTensorVectorType)) &&
237 (std::get<1>(operandsHasTensorVectorType) ||
238 std::get<1>(resultsHasTensorVectorType)))
239 return op->
emitError(
"cannot broadcast vector with tensor");
241 auto rankedOperands = make_filter_range(
245 if (rankedOperands.empty())
254 for (
auto other : make_early_inc_range(rankedOperands)) {
257 return op->
emitOpError(
"operands don't have broadcast-compatible shapes");
260 auto rankedResults = make_filter_range(
264 if (rankedResults.empty())
267 for (
auto type : rankedResults) {
269 getShape(type).take_back(resultShape.size());
273 <<
" not broadcast compatible with broadcasted operands's shapes "
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::string getShapeString(ArrayRef< int64_t > shape)
static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)
static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)
Returns a tuple corresponding to whether range has tensor or vector type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
LogicalResult verifyCompatibleOperandBroadcast(Operation *op)
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)
Returns the result broadcast composition type from the two given types by following NumPy broadcast s...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.