17 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
28 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
31 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
37 int64_t resultIdx = 0;
39 auto shapedType = dyn_cast<ShapedType>(result.getType());
42 if (!shapedType.hasRank()) {
48 assert(shapedType.getRank() ==
49 static_cast<int64_t
>(reifiedReturnShapes[resultIdx].size()) &&
50 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51 for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
55 assert(shapedType.isDynamicDim(dim) ==
56 reifiedReturnShapes[resultIdx][dim].is<
Value>() &&
57 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
62 assert(resultIdx ==
static_cast<int64_t
>(reifiedReturnShapes.size()) &&
63 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
71 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
72 return cast<ShapedType>(t).hasRank();
81 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
82 return cast<ShapedType>(t).getElementType();
90 if (
auto t = llvm::dyn_cast_if_present<Type>(val)) {
92 res.assign(vals.begin(), vals.end());
93 }
else if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
94 auto dattr = cast<DenseIntElementsAttr>(attr);
96 res.reserve(dattr.size());
97 for (
auto it : dattr.getValues<APInt>())
98 res.push_back(it.getSExtValue());
101 res.assign(vals.begin(), vals.end());
113 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
114 return cast<ShapedType>(t).getDimSize(index);
115 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
116 return cast<DenseIntElementsAttr>(attr)
117 .getValues<APInt>()[index]
125 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
126 return cast<ShapedType>(t).getRank();
127 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val))
128 return cast<DenseIntElementsAttr>(attr).size();
136 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
137 return cast<ShapedType>(t).hasStaticShape();
138 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
139 auto dattr = cast<DenseIntElementsAttr>(attr);
140 for (
auto index : dattr.getValues<APInt>())
141 if (ShapedType::isDynamic(index.getSExtValue()))
146 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
150 assert(
hasStaticShape() &&
"cannot get element count of dynamic shaped type");
152 if (
auto t = llvm::dyn_cast_if_present<Type>(val))
153 return cast<ShapedType>(t).getNumElements();
155 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
156 auto dattr = cast<DenseIntElementsAttr>(attr);
158 for (
auto index : dattr.getValues<APInt>()) {
159 num *= index.getZExtValue();
160 assert(num >= 0 &&
"integer overflow in element count computation");
167 for (int64_t dim : stc->getDims()) {
169 assert(num >= 0 &&
"integer overflow in element count computation");
176 llvm::errs() <<
"<<unranked>>\n";
182 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
183 if (ShapedType::isDynamic(dim))
185 return llvm::formatv(
"{0}", dim).str();
187 llvm::errs() <<
"rank = " <<
getRank() <<
" dims = [";
188 llvm::interleave(mapped, llvm::errs(),
"x");
189 llvm::errs() <<
"]\n";
193 Value val = operator[](index);
201 if (attr.getType().getRank() != 1)
214 if (index < 0 ||
static_cast<size_t>(index) >= size())
222 for (
const auto &shapeAndType : retComponents) {
223 Type elementTy = shapeAndType.getElementType();
224 assert(elementTy &&
"element type required to construct tensor");
226 Attribute attr = shapeAndType.getAttribute();
227 if (shapeAndType.hasRank()) {
228 inferredReturnTypes.push_back(
231 assert(attr ==
nullptr &&
"attribute not supported");
240 auto retTypeFn = cast<InferTypeOpInterface>(op);
241 auto result = retTypeFn.refineReturnTypes(
244 inferredReturnTypes);
246 op->
emitOpError() <<
"failed to infer returned types";
253 llvm::raw_string_ostream os(buffer);
254 os <<
"Failed to infer result type(s):\n";
255 os <<
"\"" << state.name <<
"\"(...) ";
256 os << state.attributes.getDictionary(state.location.getContext());
258 llvm::interleaveComma(state.operands, os,
259 [&](
Value val) { os << val.getType(); });
260 os <<
") -> ( ??? )";
262 llvm::report_fatal_error(llvm::StringRef(buffer));
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.