14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
28 class ShapedTypeComponents;
36 if (
auto st = t.
dyn_cast<ShapedType>())
67 return ShapedType::isDynamic(
getDimSize(index));
82 explicit operator bool()
const {
return !val.isNull(); }
109 : elementType(elementType), attr(nullptr), ranked(false) {}
111 ranked = shapedType.hasRank();
112 elementType = shapedType.getElementType();
114 dims = llvm::to_vector<4>(shapedType.getShape());
122 template <
typename Arg,
typename = std::enable_if_t<
123 std::is_constructible<ShapeStorageT, Arg>::value>>
126 : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
130 : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
136 assert(ranked &&
"requires ranked shape");
170 : RangeBaseT(values), operandShape(operandShape),
171 valueToShape(valueToShape) {}
239 LogicalResult(MLIRContext *, std::optional<Location> location,
240 ValueShapeRange operands, DictionaryAttr attributes,
242 SmallVectorImpl<ShapedTypeComponents> &retComponents)>
244 MLIRContext *context, std::optional<Location> location, ValueRange operands,
245 DictionaryAttr attributes, RegionRange regions,
246 SmallVectorImpl<Type> &inferredReturnTypes);
254 template <
typename ConcreteType>
255 class InferTensorType;
260 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
272 template <
typename ConcreteType>
277 ValueRange operands, DictionaryAttr attributes,
281 ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
282 "requires InferShapedTypeOpInterface to ensure succesful invocation");
284 ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
285 "requires InferTypeOpInterface to ensure succesful invocation");
287 ConcreteType::inferReturnTypeComponents, context, location, operands,
288 attributes, regions, inferredReturnTypes);
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
MLIRContext is the top-level object for a collection of MLIR operations.
Tensor type inference trait that constructs a tensor from the inferred shape and elemental types.
static LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
Helper class for implementing traits.
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
ShapeAdaptor(Attribute t)
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
ShapeAdaptor(ShapedTypeComponents &components)
ShapeAdaptor(ShapedTypeComponents *components)
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
ShapedTypeComponents()
Default construction is an unranked shape.
ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapedType shapedType)
ShapedTypeComponents(Type elementType)
ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapeAdaptor adaptor)
bool hasRank() const
Return whether the shape has a rank.
Type getElementType() const
Return the element type component.
Attribute getAttribute() const
Return the raw attribute component.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ValueShapeMapFn getOperandShapeMapping() const
ValueShapeMapFn getValueToShapeMapping() const
Returns the set Value to ShapeAdaptor mapping function.
ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)
Sets the Value to ShapeAdaptor mapping function and returns this.
ValueShapeRange(const ValueShapeRange &)=default
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
type_range getTypes() const
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)
ValueShapeRange(const std::initializer_list< Value > &values)
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
function_ref< ShapeAdaptor(Value)> ValueShapeMapFn
ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)
ValueRange getValues() const
Returns the Values in the ValueRange.
This class implements iteration on the types of a given range of values.
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult inferReturnTensorTypes(function_ref< LogicalResult(MLIRContext *, std::optional< Location > location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< ShapedTypeComponents > &retComponents)> componentTypeFn, MLIRContext *context, std::optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
This class represents an efficient way to signal success or failure.