MLIR 22.0.0git
InferTypeOpInterface.cpp
Go to the documentation of this file.
1//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the definitions of the infer op interfaces defined in
10// `InferTypeOpInterface.td`.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/IR/Matchers.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/InterleavedRange.h"
19
20using namespace mlir;
21
22namespace mlir {
23#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
24} // namespace mlir
25
26LogicalResult
28 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
29 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
30 if (!reifiableOp)
31 return failure();
32 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
33#ifndef NDEBUG
34 if (failed(status))
35 return failure();
36 // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
37 // a correct result.
38 int64_t resultIdx = 0;
39 for (OpResult result : op->getResults()) {
40 auto shapedType = dyn_cast<ShapedType>(result.getType());
41 if (!shapedType)
42 continue;
43 if (!shapedType.hasRank()) {
44 // Nothing to check for unranked shaped values.
45 ++resultIdx;
46 continue;
47 }
48 // Assert one OpFoldResult per dimension.
49 assert(shapedType.getRank() ==
50 static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
51 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
52 ++resultIdx;
53 }
54 // Assert that every shaped value result was reified.
55 assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
56 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
57#endif // NDEBUG
58 return status;
59}
60
61FailureOr<SmallVector<OpFoldResult>>
63 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
64 if (!reifiableOp)
65 return failure();
66 return reifiableOp.reifyShapeOfResult(b, resultIndex);
67}
68
69FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op,
70 int resultIndex, int dim) {
71 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
72 if (!reifiableOp)
73 return failure();
74 return reifiableOp.reifyDimOfResult(b, resultIndex, dim);
75}
76
78 if (val.isNull())
79 return false;
80 if (auto t = llvm::dyn_cast_if_present<Type>(val))
81 return cast<ShapedType>(t).hasRank();
82 if (isa<Attribute>(val))
83 return true;
84 return cast<ShapedTypeComponents *>(val)->hasRank();
85}
86
88 if (val.isNull())
89 return nullptr;
90 if (auto t = llvm::dyn_cast_if_present<Type>(val))
91 return cast<ShapedType>(t).getElementType();
92 if (isa<Attribute>(val))
93 return nullptr;
94 return cast<ShapedTypeComponents *>(val)->getElementType();
95}
96
98 assert(hasRank());
99 if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
100 ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
101 res.assign(vals.begin(), vals.end());
102 } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
103 auto dattr = cast<DenseIntElementsAttr>(attr);
104 res.clear();
105 res.reserve(dattr.size());
106 for (auto it : dattr.getValues<APInt>())
107 res.push_back(it.getSExtValue());
108 } else {
109 auto vals = cast<ShapedTypeComponents *>(val)->getDims();
110 res.assign(vals.begin(), vals.end());
111 }
112}
113
115 assert(hasRank());
116 res.ranked = true;
117 getDims(res.dims);
118}
119
121 assert(hasRank());
122 if (auto t = llvm::dyn_cast_if_present<Type>(val))
123 return cast<ShapedType>(t).getDimSize(index);
124 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
125 return cast<DenseIntElementsAttr>(attr)
126 .getValues<APInt>()[index]
127 .getSExtValue();
128 auto *stc = cast<ShapedTypeComponents *>(val);
129 return stc->getDims()[index];
130}
131
133 assert(hasRank());
134 if (auto t = llvm::dyn_cast_if_present<Type>(val))
135 return cast<ShapedType>(t).getRank();
136 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
137 return cast<DenseIntElementsAttr>(attr).size();
138 return cast<ShapedTypeComponents *>(val)->getDims().size();
139}
140
142 if (!hasRank())
143 return false;
144
145 if (auto t = llvm::dyn_cast_if_present<Type>(val))
146 return cast<ShapedType>(t).hasStaticShape();
147 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
148 auto dattr = cast<DenseIntElementsAttr>(attr);
149 for (auto index : dattr.getValues<APInt>())
150 if (ShapedType::isDynamic(index.getSExtValue()))
151 return false;
152 return true;
153 }
154 auto *stc = cast<ShapedTypeComponents *>(val);
155 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
156}
157
159 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
160
161 if (auto t = llvm::dyn_cast_if_present<Type>(val))
162 return cast<ShapedType>(t).getNumElements();
163
164 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
165 auto dattr = cast<DenseIntElementsAttr>(attr);
166 int64_t num = 1;
167 for (auto index : dattr.getValues<APInt>()) {
168 num *= index.getZExtValue();
169 assert(num >= 0 && "integer overflow in element count computation");
170 }
171 return num;
172 }
173
174 auto *stc = cast<ShapedTypeComponents *>(val);
175 int64_t num = 1;
176 for (int64_t dim : stc->getDims()) {
177 num *= dim;
178 assert(num >= 0 && "integer overflow in element count computation");
179 }
180 return num;
181}
182
183void ShapeAdaptor::dump() const {
184 if (!hasRank()) {
185 llvm::errs() << "<<unranked>>\n";
186 return;
187 }
188
190 getDims(dims);
191 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
192 if (ShapedType::isDynamic(dim))
193 return "?";
194 return llvm::formatv("{0}", dim).str();
195 });
196 llvm::errs() << "rank = " << getRank()
197 << " dims = " << llvm::interleaved_array(mapped, "x") << "\n";
198}
199
201 Value val = operator[](index);
202 if (valueToShape)
203 if (ShapeAdaptor ret = valueToShape(val))
204 return ret;
205
207 if (!matchPattern(val, m_Constant(&attr)))
208 return nullptr;
209 if (attr.getType().getRank() != 1)
210 return nullptr;
211 return attr;
212}
213
215 if (operandShape)
216 if (ShapeAdaptor ret = operandShape(val))
217 return ret;
218 return val.getType();
219}
220
222 if (index < 0 || static_cast<size_t>(index) >= size())
223 return nullptr;
224 return getShape(operator[](index));
225}
226
228 ArrayRef<ShapedTypeComponents> retComponents,
229 SmallVectorImpl<Type> &inferredReturnTypes) {
230 for (const auto &shapeAndType : retComponents) {
231 Type elementTy = shapeAndType.getElementType();
232 assert(elementTy && "element type required to construct tensor");
233
234 Attribute attr = shapeAndType.getAttribute();
235 if (shapeAndType.hasRank()) {
236 inferredReturnTypes.push_back(
237 RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
238 } else {
239 assert(attr == nullptr && "attribute not supported");
240 inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
241 }
242 }
243 return success();
244}
245
247 SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
248 auto retTypeFn = cast<InferTypeOpInterface>(op);
249 auto result = retTypeFn.refineReturnTypes(
250 op->getContext(), op->getLoc(), op->getOperands(),
252 inferredReturnTypes);
253 if (failed(result))
254 op->emitOpError() << "failed to infer returned types";
255
256 return result;
257}
258
260 std::string buffer;
261 llvm::raw_string_ostream os(buffer);
262 os << "Failed to infer result type(s):\n"
263 << "\"" << state.name << "\"(...) "
264 << state.attributes.getDictionary(state.location.getContext()) << " : ("
265 << llvm::interleaved(llvm::map_range(
266 state.operands, [](Value val) { return val.getType(); }))
267 << ") -> ( ??? )";
268 emitRemark(state.location, "location of op");
269 llvm::report_fatal_error(llvm::StringRef(buffer));
270}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Attributes are known-constant values of operations.
Definition Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
This class helps build Operations.
Definition Builders.h:207
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition Operation.h:509
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
FailureOr< OpFoldResult > reifyDimOfResult(OpBuilder &b, Operation *op, int resultIndex, int dim)
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
FailureOr< SmallVector< OpFoldResult > > reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands