17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/InterleavedRange.h"
23#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
29 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
32 LogicalResult status = reifiableOp.reifyResultShapes(
b, reifiedReturnShapes);
40 auto shapedType = dyn_cast<ShapedType>(
result.getType());
43 if (!shapedType.hasRank()) {
49 assert(shapedType.getRank() ==
50 static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
51 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
55 assert(resultIdx ==
static_cast<int64_t>(reifiedReturnShapes.size()) &&
56 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
61FailureOr<SmallVector<OpFoldResult>>
63 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
66 return reifiableOp.reifyShapeOfResult(
b, resultIndex);
70 int resultIndex,
int dim) {
71 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
74 return reifiableOp.reifyDimOfResult(
b, resultIndex, dim);
80 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
81 return cast<ShapedType>(t).hasRank();
82 if (isa<Attribute>(val))
84 return cast<ShapedTypeComponents *>(val)->hasRank();
90 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
91 return cast<ShapedType>(t).getElementType();
92 if (isa<Attribute>(val))
94 return cast<ShapedTypeComponents *>(val)->getElementType();
99 if (
auto t = llvm::dyn_cast_if_present<Type>(val)) {
101 res.assign(vals.begin(), vals.end());
102 }
else if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
103 auto dattr = cast<DenseIntElementsAttr>(attr);
105 res.reserve(dattr.size());
106 for (
auto it : dattr.getValues<APInt>())
107 res.push_back(it.getSExtValue());
109 auto vals = cast<ShapedTypeComponents *>(val)->getDims();
110 res.assign(vals.begin(), vals.end());
122 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
123 return cast<ShapedType>(t).getDimSize(
index);
124 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
125 return cast<DenseIntElementsAttr>(attr)
126 .getValues<APInt>()[
index]
128 auto *stc = cast<ShapedTypeComponents *>(val);
129 return stc->getDims()[
index];
134 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
135 return cast<ShapedType>(t).getRank();
136 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
137 return cast<DenseIntElementsAttr>(attr).size();
138 return cast<ShapedTypeComponents *>(val)->getDims().size();
145 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
146 return cast<ShapedType>(t).hasStaticShape();
147 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
148 auto dattr = cast<DenseIntElementsAttr>(attr);
149 for (
auto index : dattr.getValues<APInt>())
150 if (ShapedType::isDynamic(
index.getSExtValue()))
154 auto *stc = cast<ShapedTypeComponents *>(val);
155 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
159 assert(
hasStaticShape() &&
"cannot get element count of dynamic shaped type");
161 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
162 return cast<ShapedType>(t).getNumElements();
164 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
165 auto dattr = cast<DenseIntElementsAttr>(attr);
167 for (
auto index : dattr.getValues<APInt>()) {
168 num *=
index.getZExtValue();
169 assert(num >= 0 &&
"integer overflow in element count computation");
174 auto *stc = cast<ShapedTypeComponents *>(val);
176 for (
int64_t dim : stc->getDims()) {
178 assert(num >= 0 &&
"integer overflow in element count computation");
185 llvm::errs() <<
"<<unranked>>\n";
191 auto mapped = llvm::map_range(dims, [](
int64_t dim) -> std::string {
192 if (ShapedType::isDynamic(dim))
194 return llvm::formatv(
"{0}", dim).str();
196 llvm::errs() <<
"rank = " <<
getRank()
197 <<
" dims = " << llvm::interleaved_array(mapped,
"x") <<
"\n";
209 if (attr.getType().getRank() != 1)
222 if (
index < 0 ||
static_cast<size_t>(
index) >= size())
230 for (
const auto &shapeAndType : retComponents) {
231 Type elementTy = shapeAndType.getElementType();
232 assert(elementTy &&
"element type required to construct tensor");
234 Attribute attr = shapeAndType.getAttribute();
235 if (shapeAndType.hasRank()) {
236 inferredReturnTypes.push_back(
237 RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
239 assert(attr ==
nullptr &&
"attribute not supported");
240 inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
248 auto retTypeFn = cast<InferTypeOpInterface>(op);
249 auto result = retTypeFn.refineReturnTypes(
252 inferredReturnTypes);
254 op->
emitOpError() <<
"failed to infer returned types";
261 llvm::raw_string_ostream os(buffer);
262 os <<
"Failed to infer result type(s):\n"
263 <<
"\"" << state.
name <<
"\"(...) "
265 << llvm::interleaved(llvm::map_range(
269 llvm::report_fatal_error(llvm::StringRef(buffer));
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
MLIRContext * getContext() const
Return the context this location is uniqued in.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
FailureOr< OpFoldResult > reifyDimOfResult(OpBuilder &b, Operation *op, int resultIndex, int dim)
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
FailureOr< SmallVector< OpFoldResult > > reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands