12 #include "llvm/Support/FormatVariadic.h"
20 extents.emplace_back(shape1.begin(), shape1.end());
21 extents.emplace_back(shape2.begin(), shape2.end());
27 assert(!shapes.empty() &&
"Expected at least one shape");
28 size_t maxRank = shapes[0].size();
29 for (
size_t i = 1; i != shapes.size(); ++i)
30 maxRank =
std::max(maxRank, shapes[i].size());
33 for (
size_t i = 0; i != maxRank; ++i) {
34 bool seenDynamic =
false;
35 std::optional<int64_t> nonOneDim;
37 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
44 if (ShapedType::isDynamic(dim)) {
45 if (seenDynamic || nonOneDim)
51 if (nonOneDim && dim != *nonOneDim)
72 if (shape1.size() > shape2.size()) {
73 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
75 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
78 auto i1 = shape1.rbegin(), e1 = shape1.rend();
79 auto i2 = shape2.rbegin(), e2 = shape2.rend();
80 auto iR = resultShape.rbegin();
83 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
84 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
93 }
else if (*i1 == 1) {
95 }
else if (*i2 == 1) {
98 *iR = ShapedType::kDynamic;
101 if (*i1 == *i2 || *i2 == 1) {
103 }
else if (*i1 == 1) {
119 if (
auto sType = dyn_cast<ShapedType>(type))
120 return sType.getShape();
145 if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
146 if (isa<VectorType>(type1) || isa<VectorType>(type2))
153 auto getCompositeTypeKind = [](
Type type) -> std::optional<TypeID> {
154 if (isa<VectorType, RankedTensorType>(type))
155 return type.getTypeID();
160 std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
161 std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
162 std::optional<TypeID> resultCompositeKind;
164 if (compositeKind1 && compositeKind2) {
166 if (compositeKind1 != compositeKind2)
168 resultCompositeKind = compositeKind1;
169 }
else if (compositeKind1) {
170 resultCompositeKind = compositeKind1;
171 }
else if (compositeKind2) {
172 resultCompositeKind = compositeKind2;
181 if (resultCompositeKind == VectorType::getTypeID())
183 if (resultCompositeKind == RankedTensorType::getTypeID())
189 template <
typename iterator_range>
191 return {llvm::any_of(types, llvm::IsaPred<TensorType>),
192 llvm::any_of(types, llvm::IsaPred<VectorType>)};
198 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
199 return ShapedType::isDynamic(existingDim) ||
200 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
202 if (inferred.size() != existing.size())
204 for (
auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
205 if (!isCompatible(inferredDim, existingDim))
214 llvm::raw_string_ostream ss(ret);
219 if (ShapedType::isDynamic(dim))
231 auto operandsHasTensorVectorType =
234 if ((std::get<0>(operandsHasTensorVectorType) ||
235 std::get<0>(resultsHasTensorVectorType)) &&
236 (std::get<1>(operandsHasTensorVectorType) ||
237 std::get<1>(resultsHasTensorVectorType)))
238 return op->
emitError(
"cannot broadcast vector with tensor");
240 auto rankedOperands =
241 make_filter_range(op->
getOperandTypes(), llvm::IsaPred<RankedTensorType>);
244 if (rankedOperands.empty())
253 for (
auto other : make_early_inc_range(rankedOperands)) {
256 return op->
emitOpError(
"operands don't have broadcast-compatible shapes");
260 make_filter_range(op->
getResultTypes(), llvm::IsaPred<RankedTensorType>);
263 if (rankedResults.empty())
266 for (
auto type : rankedResults) {
268 getShape(type).take_back(resultShape.size());
272 <<
" not broadcast compatible with broadcasted operands's shapes "
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::string getShapeString(ArrayRef< int64_t > shape)
static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)
static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)
Returns a tuple corresponding to whether range has tensor or vector type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
LogicalResult verifyCompatibleOperandBroadcast(Operation *op)
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)
Returns the result broadcast composition type from the two given types by following NumPy broadcast s...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...