17 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
28 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
31 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
37 int64_t resultIdx = 0;
39 auto shapedType = dyn_cast<ShapedType>(result.getType());
42 if (!shapedType.hasRank()) {
48 assert(shapedType.getRank() ==
49 static_cast<int64_t
>(reifiedReturnShapes[resultIdx].size()) &&
50 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
54 assert(resultIdx ==
static_cast<int64_t
>(reifiedReturnShapes.size()) &&
55 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
63 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
64 return cast<ShapedType>(t).hasRank();
65 if (isa<Attribute>(val))
67 return cast<ShapedTypeComponents *>(val)->hasRank();
73 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
74 return cast<ShapedType>(t).getElementType();
75 if (isa<Attribute>(val))
77 return cast<ShapedTypeComponents *>(val)->getElementType();
82 if (
auto t = llvm::dyn_cast_if_present<Type>(val)) {
84 res.assign(vals.begin(), vals.end());
85 }
else if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
86 auto dattr = cast<DenseIntElementsAttr>(attr);
88 res.reserve(dattr.size());
89 for (
auto it : dattr.getValues<APInt>())
90 res.push_back(it.getSExtValue());
92 auto vals = cast<ShapedTypeComponents *>(val)->getDims();
93 res.assign(vals.begin(), vals.end());
105 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
106 return cast<ShapedType>(t).getDimSize(index);
107 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
108 return cast<DenseIntElementsAttr>(attr)
109 .getValues<APInt>()[index]
111 auto *stc = cast<ShapedTypeComponents *>(val);
112 return stc->getDims()[index];
117 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
118 return cast<ShapedType>(t).getRank();
119 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
120 return cast<DenseIntElementsAttr>(attr).size();
121 return cast<ShapedTypeComponents *>(val)->getDims().size();
128 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
129 return cast<ShapedType>(t).hasStaticShape();
130 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
131 auto dattr = cast<DenseIntElementsAttr>(attr);
132 for (
auto index : dattr.getValues<APInt>())
133 if (ShapedType::isDynamic(index.getSExtValue()))
137 auto *stc = cast<ShapedTypeComponents *>(val);
138 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
142 assert(
hasStaticShape() &&
"cannot get element count of dynamic shaped type");
144 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
145 return cast<ShapedType>(t).getNumElements();
147 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
148 auto dattr = cast<DenseIntElementsAttr>(attr);
150 for (
auto index : dattr.getValues<APInt>()) {
151 num *= index.getZExtValue();
152 assert(num >= 0 &&
"integer overflow in element count computation");
157 auto *stc = cast<ShapedTypeComponents *>(val);
159 for (int64_t dim : stc->getDims()) {
161 assert(num >= 0 &&
"integer overflow in element count computation");
168 llvm::errs() <<
"<<unranked>>\n";
174 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
175 if (ShapedType::isDynamic(dim))
177 return llvm::formatv(
"{0}", dim).str();
179 llvm::errs() <<
"rank = " <<
getRank() <<
" dims = [";
180 llvm::interleave(mapped, llvm::errs(),
"x");
181 llvm::errs() <<
"]\n";
185 Value val = operator[](index);
193 if (attr.getType().getRank() != 1)
206 if (index < 0 ||
static_cast<size_t>(index) >= size())
214 for (
const auto &shapeAndType : retComponents) {
215 Type elementTy = shapeAndType.getElementType();
216 assert(elementTy &&
"element type required to construct tensor");
218 Attribute attr = shapeAndType.getAttribute();
219 if (shapeAndType.hasRank()) {
220 inferredReturnTypes.push_back(
223 assert(attr ==
nullptr &&
"attribute not supported");
232 auto retTypeFn = cast<InferTypeOpInterface>(op);
233 auto result = retTypeFn.refineReturnTypes(
236 inferredReturnTypes);
238 op->
emitOpError() <<
"failed to infer returned types";
245 llvm::raw_string_ostream os(buffer);
246 os <<
"Failed to infer result type(s):\n";
247 os <<
"\"" << state.name <<
"\"(...) ";
248 os << state.attributes.getDictionary(state.location.getContext());
250 llvm::interleaveComma(state.operands, os,
251 [&](
Value val) { os << val.getType(); });
252 os <<
") -> ( ??? )";
254 llvm::report_fatal_error(llvm::StringRef(buffer));
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.