MLIR  15.0.0git
Functions
mlir::OpTrait::util Namespace Reference

Functions

bool getBroadcastedShape (ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
 Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broadcast compatible. More...
 
bool staticallyKnownBroadcastable (ArrayRef< SmallVector< int64_t, 6 >> shapes)
 Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an error. More...
 
bool staticallyKnownBroadcastable (ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
 
Type getBroadcastedType (Type type1, Type type2, Type elementType=nullptr)
 Returns the result broadcast composition type from the two given types by following NumPy broadcast semantics. More...
 

Function Documentation

◆ getBroadcastedShape()

bool mlir::OpTrait::util::getBroadcastedShape ( ArrayRef< int64_t >  shape1,
ArrayRef< int64_t >  shape2,
SmallVectorImpl< int64_t > &  resultShape 
)

Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broadcast compatible.

Returns false and clears resultShape otherwise.

The rules for determining the result shape are:

Zip together the dimensions in the two given shapes by prepending the shape with less dimensions with 1s. For each dimension pair, deduces the result dimension according to the following order:

  • If there are unknown dimensions, follows the TensorFlow behavior:
    • If either dimension is greater than 1, we assume that the program is correct, and the other dimension will be broadcast to match it.
    • If either dimension is 1, the other dimension is the result.
    • Otherwise, the result dimension is unknown dimension.
  • If one of the dimension is 1, the other dimension is the result.
  • If two dimensions are the same, that's the result.
  • Otherwise, incompatible shape.

Definition at line 59 of file Traits.cpp.

References copy().

Referenced by eachHasOnlyOneOfTypes(), getBroadcastedType(), and mlir::OpTrait::impl::verifyCompatibleOperandBroadcast().

◆ getBroadcastedType()

Type mlir::OpTrait::util::getBroadcastedType ( Type  type1,
Type  type2,
Type  elementType = nullptr 
)

Returns the result broadcast composition type from the two given types by following NumPy broadcast semantics.

Returned type may have dynamic shape if either of the input types has dynamic shape. Returns null type if the two given types are not broadcast-compatible.

elementType, if specified, will be used as the element type of the broadcasted result type. Otherwise it is required that the element type of type1 and type2 is the same and this element type will be used as the resultant element type.

Definition at line 132 of file Traits.cpp.

References getBroadcastedShape(), mlir::getElementTypeOrSelf(), getShape(), and mlir::Type::isa().

◆ staticallyKnownBroadcastable() [1/2]

bool mlir::OpTrait::util::staticallyKnownBroadcastable ( ArrayRef< SmallVector< int64_t, 6 >>  shapes)

Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an error.

False does not guarantee that the shapes are not broadcastable; it might guarantee that they are not broadcastable or it might mean that this function does not have enough information to know.

Conceptually, this returns true if getBroadcastedShape would have returned true and vice versa, with one exception. If a dimension is unknown in both shapes, getBroadcastedShape would return true and have a result with unknown dimension, while this function will return false because it's possible for both shapes to have a dimension greater than 1 and different which would fail to broadcast.

Definition at line 24 of file Traits.cpp.

References max().

Referenced by hasAtMostSingleNonScalar(), and staticallyKnownBroadcastable().

◆ staticallyKnownBroadcastable() [2/2]

bool mlir::OpTrait::util::staticallyKnownBroadcastable ( ArrayRef< int64_t >  shape1,
ArrayRef< int64_t >  shape2 
)

Definition at line 16 of file Traits.cpp.

References staticallyKnownBroadcastable().