MLIR 22.0.0git
InferTypeOpInterface.h
Go to the documentation of this file.
1//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the definitions of the infer op interfaces defined in
10// `InferTypeOpInterface.td`.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15#define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
20#include "mlir/IR/Location.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/PointerUnion.h"
24#include "llvm/ADT/SmallVector.h"
25
26namespace mlir {
27
30
31/// Reify the shape of the result of an operation (typically in terms of the
32/// shape of its operands).
33LogicalResult
35 ReifiedRankedShapedTypeDims &reifiedReturnShapes);
36FailureOr<SmallVector<OpFoldResult>>
37reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
38FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
39 int resultIndex, int dim);
40
41/// Adaptor class to abstract the differences between whether value is from
42/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
44public:
46 if (auto st = dyn_cast<ShapedType>(t))
47 val = st;
48 }
50 if (auto da = dyn_cast<DenseIntElementsAttr>(t))
51 val = da;
52 }
53 ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
54 ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
55
56 /// Returns whether the shape has a rank.
57 bool hasRank() const;
58
59 /// Returns the element type.
60 Type getElementType() const;
61
62 /// Populates the dimensions from shape referenced.
63 /// Requires: shape is ranked.
64 void getDims(SmallVectorImpl<int64_t> &res) const;
65
66 /// Populates the dimensions of the ShapeTypeComponents.
67 /// Requires: shape is ranked.
68 void getDims(ShapedTypeComponents &res) const;
69
70 /// Returns the size of the index'th dimension.
71 /// Requires: shape is ranked.
72 int64_t getDimSize(int index) const;
73
74 /// Returns whether the index'th dimension is dynamic.
75 /// Requires: shape is ranked.
76 bool isDynamicDim(int index) const {
77 return ShapedType::isDynamic(getDimSize(index));
78 }
79
80 /// Returns whether the shape is fully static.
81 bool hasStaticShape() const;
82
83 /// Returns the rank of the shape.
84 /// Requires: shape is ranked.
85 int64_t getRank() const;
86
87 /// Returns the number of elements in the shape.
88 /// Requires: hasStaticShape
89 int64_t getNumElements() const;
90
91 /// Returns whether valid (non-null) shape.
92 explicit operator bool() const { return !val.isNull(); }
93
94 /// Dumps textual repesentation to stderr.
95 void dump() const;
96
97private:
98 // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
99 // casted), or DenseIntElementsAttribute (stored as Atrtribute).
101};
102
103/// ShapedTypeComponents that represents the components of a ShapedType.
104/// The components consist of
105/// - A ranked or unranked shape with the dimension specification match those
106/// of ShapeType's getShape() (e.g., dynamic dimension represented using
107/// ShapedType::kDynamic)
108/// - A element type, may be unset (nullptr)
109/// - A attribute, may be unset (nullptr)
110/// Used by ShapedType type inferences.
112 /// Internal storage type for shape.
113 using ShapeStorageT = SmallVector<int64_t, 3>;
114
115public:
116 /// Default construction is an unranked shape.
117 ShapedTypeComponents() : elementType(nullptr), attr(nullptr) {};
119 : elementType(elementType), attr(nullptr), ranked(false) {}
120 ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
121 ranked = shapedType.hasRank();
122 elementType = shapedType.getElementType();
123 if (ranked)
124 dims = llvm::to_vector<4>(shapedType.getShape());
125 }
127 ranked = adaptor.hasRank();
128 elementType = adaptor.getElementType();
129 if (ranked)
130 adaptor.getDims(*this);
131 }
132 template <typename Arg, typename = std::enable_if_t<
133 std::is_constructible<ShapeStorageT, Arg>::value>>
134 ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
135 Attribute attr = nullptr)
136 : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
137 ranked(true) {}
138 ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
139 Attribute attr = nullptr)
140 : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
141 ranked(true) {}
142
143 /// Return the dimensions of the shape.
144 /// Requires: shape is ranked.
146 assert(ranked && "requires ranked shape");
147 return dims;
148 }
149
150 /// Return whether the shape has a rank.
151 bool hasRank() const { return ranked; };
152
153 /// Return the element type component.
154 Type getElementType() const { return elementType; };
155
156 /// Return the raw attribute component.
157 Attribute getAttribute() const { return attr; };
158
159private:
160 friend class ShapeAdaptor;
161
162 ShapeStorageT dims;
163 Type elementType;
164 Attribute attr;
165 bool ranked{false};
166};
167
168/// Range of values and shapes (corresponding effectively to Shapes dialect's
169/// ValueShape type concept).
170// Currently this exposes the Value (of operands) and Type of the Value. This is
171// not ideal as then one can accidentally reference an out of date shape. This
172// is done to both enable gradual switch and also as OpAdaptor doesn't currently
173// allow returning anything other than Value.
174class ValueShapeRange : public ValueRange::RangeBaseT {
175public:
177
178 ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
179 ValueShapeMapFn valueToShape = nullptr)
180 : RangeBaseT(values), operandShape(operandShape),
181 valueToShape(valueToShape) {}
182 ValueShapeRange(const std::initializer_list<Value> &values)
183 : ValueShapeRange(ValueRange(values)) {}
184
186
187 /// Sets the Value to ShapeAdaptor mapping function and returns this.
189 valueToShape = fn;
190 return *this;
191 }
192
194 operandShape = fn;
195 return *this;
196 }
197
198 /// Returns the set Value to ShapeAdaptor mapping function.
199 ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
200 ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
201
202 // Accessors.
203
204 /// Returns the types of the values within this range.
205 /// Note: This returns only the types of Values in the ValueRange and not a
206 /// more refined type.
209 type_range getTypes() const { return {begin(), end()}; }
210 auto getType() const { return getTypes(); }
211
212 /// Returns the Values in the ValueRange.
213 /// To query the most up to date shape of a Value, query the shape
214 /// using getShape below rather than using the type of the Value.
215 ValueRange getValues() const { return ValueRange(begin(), end()); };
216
217 /// Returns an argument as shape. If the argument is not constant or not a
218 /// shape, then the function returns a nullptr.
219 /// This will first query the valueToShape mapping (if set), before querying
220 /// the ValueRange.
222
223 /// Returns the shape of index'th operand.
224 // TODO: Update so that operator[] references these instead to avoid
225 // accidentally refering to less refined shape.
226 ShapeAdaptor getShape(int index) const;
227
228 /// Returns the shape of the given Value.
229 ShapeAdaptor getShape(Value val) const;
230
231private:
232 // Mapping from Value to ShapedTypeComponents corresponding to shape of type
233 // of Value.
234 ValueShapeMapFn operandShape;
235
236 // Mapping from Value to ShapedTypeComponents corresponding to constant Value
237 // if interpreted as shape.
238 ValueShapeMapFn valueToShape;
239};
240
241namespace detail {
242// Helper function to infer return tensor returns types given element and
243// shape inference function.
244LogicalResult
245inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
246 SmallVectorImpl<Type> &inferredReturnTypes);
247
248/// Verifies that the inferred result types match the actual result types for
249/// the op. Precondition: op implements InferTypeOpInterface.
250LogicalResult verifyInferredResultTypes(Operation *op);
251
252/// Report a fatal error indicating that the result types could not be
253/// inferred.
254void reportFatalInferReturnTypesError(OperationState &state);
255} // namespace detail
256
257namespace OpTrait {
258template <typename ConcreteType>
259class InferTensorType;
260} // namespace OpTrait
261} // namespace mlir
262
263/// Include the generated interface declarations.
264#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
265
266namespace mlir {
267namespace OpTrait {
268
269template <typename ConcreteType>
270class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
271};
272
273template <typename ConcreteType>
275 : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
276
277/// Tensor type inference trait that constructs a tensor from the inferred
278/// shape and elemental types.
279/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
280/// Less strict is possible (e.g., implements inferReturnTypeComponents and
281/// these always populates all element types and shapes or fails, but this
282/// trait is currently only used where the interfaces are, so keep it
283/// restricted for now).
284template <typename ConcreteType>
285class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};
286
287} // namespace OpTrait
288} // namespace mlir
289
290#endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
true
Given two iterators into the same block, return "true" if a is before `b.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
false
Parses a map_entries map type from a string format back into its numeric value.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class helps build Operations.
Definition Builders.h:207
Tensor type inference trait that constructs a tensor from the inferred shape and elemental types.
Helper class for implementing traits.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
ShapeAdaptor(ShapedTypeComponents &components)
ShapeAdaptor(ShapedTypeComponents *components)
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
ShapedTypeComponents()
Default construction is an unranked shape.
ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapedType shapedType)
ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapeAdaptor adaptor)
bool hasRank() const
Return whether the shape has a rank.
Type getElementType() const
Return the element type component.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Attribute getAttribute() const
Return the raw attribute component.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
ValueTypeIterator< iterator > type_iterator
Returns the types of the values within this range.
ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)
ValueTypeRange< ValueRange > type_range
ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)
Sets the Value to ShapeAdaptor mapping function and returns this.
ValueShapeMapFn getOperandShapeMapping() const
ValueShapeMapFn getValueToShapeMapping() const
Returns the set Value to ShapeAdaptor mapping function.
ValueShapeRange(const ValueShapeRange &)=default
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
type_range getTypes() const
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)
ValueShapeRange(const std::initializer_list< Value > &values)
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
function_ref< ShapeAdaptor(Value)> ValueShapeMapFn
ValueRange getValues() const
Returns the Values in the ValueRange.
This class implements iteration on the types of a given range of values.
Definition TypeRange.h:122
This class implements iteration on the types of a given range of values.
Definition TypeRange.h:135
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
FailureOr< OpFoldResult > reifyDimOfResult(OpBuilder &b, Operation *op, int resultIndex, int dim)
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
FailureOr< SmallVector< OpFoldResult > > reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex)