MLIR  17.0.0git
InferTypeOpInterface.h
Go to the documentation of this file.
1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace mlir {
27 
28 class ShapedTypeComponents;
30 
31 /// Adaptor class to abstract the differences between whether value is from
32 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
33 class ShapeAdaptor {
34 public:
36  if (auto st = t.dyn_cast<ShapedType>())
37  val = st;
38  }
40  if (auto da = t.dyn_cast<DenseIntElementsAttr>())
41  val = da;
42  }
43  ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
44  ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
45 
46  /// Returns whether the shape has a rank.
47  bool hasRank() const;
48 
49  /// Returns the element type.
50  Type getElementType() const;
51 
52  /// Populates the dimensions from shape referenced.
53  /// Requires: shape is ranked.
54  void getDims(SmallVectorImpl<int64_t> &res) const;
55 
56  /// Populates the dimensions of the ShapeTypeComponents.
57  /// Requires: shape is ranked.
58  void getDims(ShapedTypeComponents &res) const;
59 
60  /// Returns the size of the index'th dimension.
61  /// Requires: shape is ranked.
62  int64_t getDimSize(int index) const;
63 
64  /// Returns whether the index'th dimension is dynamic.
65  /// Requires: shape is ranked.
66  bool isDynamicDim(int index) const {
67  return ShapedType::isDynamic(getDimSize(index));
68  }
69 
70  /// Returns whether the shape is fully static.
71  bool hasStaticShape() const;
72 
73  /// Returns the rank of the shape.
74  /// Requires: shape is ranked.
75  int64_t getRank() const;
76 
77  /// Returns the number of elements in the shape.
78  /// Requires: hasStaticShape
79  int64_t getNumElements() const;
80 
81  /// Returns whether valid (non-null) shape.
82  explicit operator bool() const { return !val.isNull(); }
83 
84  /// Dumps textual repesentation to stderr.
85  void dump() const;
86 
87 private:
88  // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
89  // casted), or DenseIntElementsAttribute (stored as Atrtribute).
91 };
92 
93 /// ShapedTypeComponents that represents the components of a ShapedType.
94 /// The components consist of
95 /// - A ranked or unranked shape with the dimension specification match those
96 /// of ShapeType's getShape() (e.g., dynamic dimension represented using
97 /// ShapedType::kDynamic)
98 /// - A element type, may be unset (nullptr)
99 /// - A attribute, may be unset (nullptr)
100 /// Used by ShapedType type inferences.
102  /// Internal storage type for shape.
104 
105 public:
106  /// Default construction is an unranked shape.
107  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
109  : elementType(elementType), attr(nullptr), ranked(false) {}
110  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
111  ranked = shapedType.hasRank();
112  elementType = shapedType.getElementType();
113  if (ranked)
114  dims = llvm::to_vector<4>(shapedType.getShape());
115  }
116  ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
117  ranked = adaptor.hasRank();
118  elementType = adaptor.getElementType();
119  if (ranked)
120  adaptor.getDims(*this);
121  }
122  template <typename Arg, typename = std::enable_if_t<
123  std::is_constructible<ShapeStorageT, Arg>::value>>
124  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
125  Attribute attr = nullptr)
126  : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
127  ranked(true) {}
128  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
129  Attribute attr = nullptr)
130  : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
131  ranked(true) {}
132 
133  /// Return the dimensions of the shape.
134  /// Requires: shape is ranked.
136  assert(ranked && "requires ranked shape");
137  return dims;
138  }
139 
140  /// Return whether the shape has a rank.
141  bool hasRank() const { return ranked; };
142 
143  /// Return the element type component.
144  Type getElementType() const { return elementType; };
145 
146  /// Return the raw attribute component.
147  Attribute getAttribute() const { return attr; };
148 
149 private:
150  friend class ShapeAdaptor;
151 
152  ShapeStorageT dims;
153  Type elementType;
154  Attribute attr;
155  bool ranked{false};
156 };
157 
158 /// Range of values and shapes (corresponding effectively to Shapes dialect's
159 /// ValueShape type concept).
160 // Currently this exposes the Value (of operands) and Type of the Value. This is
161 // not ideal as then one can accidentally reference an out of date shape. This
162 // is done to both enable gradual switch and also as OpAdaptor doesn't currently
163 // allow returning anything other than Value.
164 class ValueShapeRange : public ValueRange::RangeBaseT {
165 public:
167 
168  ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
169  ValueShapeMapFn valueToShape = nullptr)
170  : RangeBaseT(values), operandShape(operandShape),
171  valueToShape(valueToShape) {}
172  ValueShapeRange(const std::initializer_list<Value> &values)
173  : ValueShapeRange(ValueRange(values)) {}
174 
175  ValueShapeRange(const ValueShapeRange &) = default;
176 
177  /// Sets the Value to ShapeAdaptor mapping function and returns this.
179  valueToShape = fn;
180  return *this;
181  }
182 
184  operandShape = fn;
185  return *this;
186  }
187 
188  /// Returns the set Value to ShapeAdaptor mapping function.
189  ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
190  ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
191 
192  // Accessors.
193 
194  /// Returns the types of the values within this range.
195  /// Note: This returns only the types of Values in the ValueRange and not a
196  /// more refined type.
199  type_range getTypes() const { return {begin(), end()}; }
200  auto getType() const { return getTypes(); }
201 
202  /// Returns the Values in the ValueRange.
203  /// To query the most up to date shape of a Value, query the shape
204  /// using getShape below rather than using the type of the Value.
205  ValueRange getValues() const { return ValueRange(begin(), end()); };
206 
207  /// Returns an argument as shape. If the argument is not constant or not a
208  /// shape, then the function returns a nullptr.
209  /// This will first query the valueToShape mapping (if set), before querying
210  /// the ValueRange.
211  ShapeAdaptor getValueAsShape(int index);
212 
213  /// Returns the shape of index'th operand.
214  // TODO: Update so that operator[] references these instead to avoid
215  // accidentally refering to less refined shape.
216  ShapeAdaptor getShape(int index) const;
217 
218  /// Returns the shape of the given Value.
219  ShapeAdaptor getShape(Value val) const;
220 
221 private:
222  // Mapping from Value to ShapedTypeComponents corresponding to shape of type
223  // of Value.
224  ValueShapeMapFn operandShape;
225 
226  // Mapping from Value to ShapedTypeComponents corresponding to constant Value
227  // if interpreted as shape.
228  ValueShapeMapFn valueToShape;
229 };
230 
231 namespace detail {
232 // Helper function to infer return tensor returns types given element and
233 // shape inference function.
234 //
235 // TODO: Consider generating typedefs for trait member functions if this usage
236 // becomes more common.
237 LogicalResult inferReturnTensorTypes(
238  function_ref<
239  LogicalResult(MLIRContext *, std::optional<Location> location,
240  ValueShapeRange operands, DictionaryAttr attributes,
241  RegionRange regions,
242  SmallVectorImpl<ShapedTypeComponents> &retComponents)>
243  componentTypeFn,
244  MLIRContext *context, std::optional<Location> location, ValueRange operands,
245  DictionaryAttr attributes, RegionRange regions,
246  SmallVectorImpl<Type> &inferredReturnTypes);
247 
248 /// Verifies that the inferred result types match the actual result types for
249 /// the op. Precondition: op implements InferTypeOpInterface.
250 LogicalResult verifyInferredResultTypes(Operation *op);
251 } // namespace detail
252 
253 namespace OpTrait {
254 template <typename ConcreteType>
255 class InferTensorType;
256 } // namespace OpTrait
257 } // namespace mlir
258 
259 /// Include the generated interface declarations.
260 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
261 
262 namespace mlir {
263 namespace OpTrait {
264 
265 /// Tensor type inference trait that constructs a tensor from the inferred
266 /// shape and elemental types.
267 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
268 /// Less strict is possible (e.g., implements inferReturnTypeComponents and
269 /// these always populates all element types and shapes or fails, but this\
270 /// trait is currently only used where the interfaces are, so keep it
271 /// restricted for now).
272 template <typename ConcreteType>
273 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
274 public:
275  static LogicalResult
276  inferReturnTypes(MLIRContext *context, std::optional<Location> location,
277  ValueRange operands, DictionaryAttr attributes,
278  RegionRange regions,
279  SmallVectorImpl<Type> &inferredReturnTypes) {
280  static_assert(
281  ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
282  "requires InferShapedTypeOpInterface to ensure succesful invocation");
283  static_assert(
284  ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
285  "requires InferTypeOpInterface to ensure succesful invocation");
287  ConcreteType::inferReturnTypeComponents, context, location, operands,
288  attributes, regions, inferredReturnTypes);
289  }
290 };
291 
292 } // namespace OpTrait
293 } // namespace mlir
294 
295 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:166
An attribute that represents a reference to a dense integer vector or tensor object.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Tensor type inference trait that constructs a tensor from the inferred shape and elemental types.
static LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
Helper class for implementing traits.
Definition: OpDefinition.h:310
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:330
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
ShapeAdaptor(ShapedTypeComponents &components)
ShapeAdaptor(ShapedTypeComponents *components)
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
ShapedTypeComponents()
Default construction is an unranked shape.
ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapedType shapedType)
ShapedTypeComponents(Type elementType)
ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapeAdaptor adaptor)
bool hasRank() const
Return whether the shape has a rank.
Type getElementType() const
Return the element type component.
Attribute getAttribute() const
Return the raw attribute component.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U dyn_cast() const
Definition: Types.h:308
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:350
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ValueShapeMapFn getOperandShapeMapping() const
ValueShapeMapFn getValueToShapeMapping() const
Returns the set Value to ShapeAdaptor mapping function.
ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)
Sets the Value to ShapeAdaptor mapping function and returns this.
ValueShapeRange(const ValueShapeRange &)=default
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
type_range getTypes() const
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)
ValueShapeRange(const std::initializer_list< Value > &values)
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
function_ref< ShapeAdaptor(Value)> ValueShapeMapFn
ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)
ValueRange getValues() const
Returns the Values in the ValueRange.
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:118
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
LogicalResult inferReturnTensorTypes(function_ref< LogicalResult(MLIRContext *, std::optional< Location > location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< ShapedTypeComponents > &retComponents)> componentTypeFn, MLIRContext *context, std::optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:148
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26