17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/InterleavedRange.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
29 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
32 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
38 int64_t resultIdx = 0;
40 auto shapedType = dyn_cast<ShapedType>(result.getType());
43 if (!shapedType.hasRank()) {
49 assert(shapedType.getRank() ==
50 static_cast<int64_t
>(reifiedReturnShapes[resultIdx].size()) &&
51 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
55 assert(resultIdx ==
static_cast<int64_t
>(reifiedReturnShapes.size()) &&
56 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
64 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
65 return cast<ShapedType>(t).hasRank();
66 if (isa<Attribute>(val))
68 return cast<ShapedTypeComponents *>(val)->hasRank();
74 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
75 return cast<ShapedType>(t).getElementType();
76 if (isa<Attribute>(val))
78 return cast<ShapedTypeComponents *>(val)->getElementType();
83 if (
auto t = llvm::dyn_cast_if_present<Type>(val)) {
85 res.assign(vals.begin(), vals.end());
86 }
else if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
87 auto dattr = cast<DenseIntElementsAttr>(attr);
89 res.reserve(dattr.size());
90 for (
auto it : dattr.getValues<APInt>())
91 res.push_back(it.getSExtValue());
93 auto vals = cast<ShapedTypeComponents *>(val)->getDims();
94 res.assign(vals.begin(), vals.end());
106 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
107 return cast<ShapedType>(t).getDimSize(index);
108 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
109 return cast<DenseIntElementsAttr>(attr)
110 .getValues<APInt>()[index]
112 auto *stc = cast<ShapedTypeComponents *>(val);
113 return stc->getDims()[index];
118 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
119 return cast<ShapedType>(t).getRank();
120 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
121 return cast<DenseIntElementsAttr>(attr).size();
122 return cast<ShapedTypeComponents *>(val)->getDims().size();
129 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
130 return cast<ShapedType>(t).hasStaticShape();
131 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
132 auto dattr = cast<DenseIntElementsAttr>(attr);
133 for (
auto index : dattr.getValues<APInt>())
134 if (ShapedType::isDynamic(index.getSExtValue()))
138 auto *stc = cast<ShapedTypeComponents *>(val);
139 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
143 assert(
hasStaticShape() &&
"cannot get element count of dynamic shaped type");
145 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
146 return cast<ShapedType>(t).getNumElements();
148 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
149 auto dattr = cast<DenseIntElementsAttr>(attr);
151 for (
auto index : dattr.getValues<APInt>()) {
152 num *= index.getZExtValue();
153 assert(num >= 0 &&
"integer overflow in element count computation");
158 auto *stc = cast<ShapedTypeComponents *>(val);
160 for (int64_t dim : stc->getDims()) {
162 assert(num >= 0 &&
"integer overflow in element count computation");
169 llvm::errs() <<
"<<unranked>>\n";
175 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
176 if (ShapedType::isDynamic(dim))
178 return llvm::formatv(
"{0}", dim).str();
180 llvm::errs() <<
"rank = " <<
getRank()
181 <<
" dims = " << llvm::interleaved_array(mapped,
"x") <<
"\n";
185 Value val = operator[](index);
193 if (attr.getType().getRank() != 1)
206 if (index < 0 ||
static_cast<size_t>(index) >= size())
214 for (
const auto &shapeAndType : retComponents) {
215 Type elementTy = shapeAndType.getElementType();
216 assert(elementTy &&
"element type required to construct tensor");
218 Attribute attr = shapeAndType.getAttribute();
219 if (shapeAndType.hasRank()) {
220 inferredReturnTypes.push_back(
223 assert(attr ==
nullptr &&
"attribute not supported");
232 auto retTypeFn = cast<InferTypeOpInterface>(op);
233 auto result = retTypeFn.refineReturnTypes(
236 inferredReturnTypes);
238 op->
emitOpError() <<
"failed to infer returned types";
245 llvm::raw_string_ostream os(buffer);
246 os <<
"Failed to infer result type(s):\n"
247 <<
"\"" << state.name <<
"\"(...) "
248 << state.attributes.getDictionary(state.location.getContext()) <<
" : ("
249 << llvm::interleaved(llvm::map_range(
250 state.operands, [](
Value val) { return val.getType(); }))
253 llvm::report_fatal_error(llvm::StringRef(buffer));
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.