MLIR  14.0.0git
InferTypeOpInterface.h
Go to the documentation of this file.
1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace mlir {
27 
29 
30 /// ShapedTypeComponents that represents the components of a ShapedType.
31 /// The components consist of
32 /// - A ranked or unranked shape with the dimension specification match those
33 /// of ShapeType's getShape() (e.g., dynamic dimension represented using
34 /// ShapedType::kDynamicSize)
35 /// - A element type, may be unset (nullptr)
36 /// - A attribute, may be unset (nullptr)
37 /// Used by ShapedType type inferences.
39  /// Internal storage type for shape.
41 
42 public:
43  /// Default construction is an unranked shape.
44  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
46  : elementType(elementType), attr(nullptr), ranked(false) {}
47  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
48  ranked = shapedType.hasRank();
49  elementType = shapedType.getElementType();
50  if (ranked)
51  dims = llvm::to_vector<4>(shapedType.getShape());
52  }
53  template <typename Arg, typename = typename std::enable_if_t<
55  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
56  Attribute attr = nullptr)
57  : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
58  ranked(true) {}
59  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
60  Attribute attr = nullptr)
61  : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
62  ranked(true) {}
63 
64  /// Return the dimensions of the shape.
65  /// Requires: shape is ranked.
67  assert(ranked && "requires ranked shape");
68  return dims;
69  }
70 
71  /// Return whether the shape has a rank.
72  bool hasRank() const { return ranked; };
73 
74  /// Return the element type component.
75  Type getElementType() const { return elementType; };
76 
77  /// Return the raw attribute component.
78  Attribute getAttribute() const { return attr; };
79 
80 private:
81  friend class ShapeAdaptor;
82 
83  ShapeStorageT dims;
84  Type elementType;
85  Attribute attr;
86  bool ranked{false};
87 };
88 
89 /// Adaptor class to abstract the differences between whether value is from
90 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
91 class ShapeAdaptor {
92 public:
94  if (auto st = t.dyn_cast<ShapedType>())
95  val = st;
96  }
98  if (auto da = t.dyn_cast<DenseIntElementsAttr>())
99  val = da;
100  }
101  ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
102  ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
103 
104  /// Returns whether the shape has a rank.
105  bool hasRank() const;
106 
107  /// Returns the element type.
108  Type getElementType() const;
109 
110  /// Populates the dimensions from shape referenced.
111  /// Requires: shape is ranked.
112  void getDims(SmallVectorImpl<int64_t> &res) const;
113 
114  /// Populates the dimensions of the ShapeTypeComponents.
115  /// Requires: shape is ranked.
116  void getDims(ShapedTypeComponents &res) const;
117 
118  /// Returns the size of the index'th dimension.
119  /// Requires: shape is ranked.
120  int64_t getDimSize(int index) const;
121 
122  /// Returns whether the index'th dimension is dynamic.
123  /// Requires: shape is ranked.
124  bool isDynamicDim(int index) const {
125  return ShapedType::isDynamic(getDimSize(index));
126  }
127 
128  /// Returns whether the shape is fully static.
129  bool hasStaticShape() const;
130 
131  /// Returns the rank of the shape.
132  /// Requires: shape is ranked.
133  int64_t getRank() const;
134 
135  /// Returns the number of elements in the shape.
136  /// Requires: hasStaticShape
137  int64_t getNumElements() const;
138 
139  /// Returns whether valid (non-null) shape.
140  operator bool() const { return !val.isNull(); }
141 
142  /// Dumps textual repesentation to stderr.
143  void dump() const;
144 
145 private:
146  // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
147  // casted), or DenseIntElementsAttribute (stored as Atrtribute).
149 };
150 
151 /// Range of values and shapes (corresponding effectively to Shapes dialect's
152 /// ValueShape type concept).
153 // Currently this exposes the Value (of operands) and Type of the Value. This is
154 // not ideal as then one can accidentally reference an out of date shape. This
155 // is done to both enable gradual switch and also as OpAdaptor doesn't currently
156 // allow returning anything other than Value.
157 class ValueShapeRange : public ValueRange::RangeBaseT {
158 public:
160 
161  ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
162  ValueShapeMapFn valueToShape = nullptr)
163  : RangeBaseT(values), operandShape(operandShape),
164  valueToShape(valueToShape) {}
165  ValueShapeRange(const std::initializer_list<Value> &values)
166  : ValueShapeRange(ValueRange(values)) {}
167 
168  ValueShapeRange(const ValueShapeRange &) = default;
169 
170  /// Sets the Value to ShapeAdaptor mapping function and returns this.
172  valueToShape = fn;
173  return *this;
174  }
175 
177  operandShape = fn;
178  return *this;
179  }
180 
181  /// Returns the set Value to ShapeAdaptor mapping function.
182  ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
183  ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
184 
185  // Accessors.
186 
187  /// Returns the types of the values within this range.
188  /// Note: This returns only the types of Values in the ValueRange and not a
189  /// more refined type.
192  type_range getTypes() const { return {begin(), end()}; }
193  auto getType() const { return getTypes(); }
194 
195  /// Returns the Values in the ValueRange.
196  /// To query the most up to date shape of a Value, query the shape
197  /// using getShape below rather than using the type of the Value.
198  ValueRange getValues() const { return ValueRange(begin(), end()); };
199 
200  /// Returns an argument as shape. If the argument is not constant or not a
201  /// shape, then the function returns a nullptr.
202  /// This will first query the valueToShape mapping (if set), before querying
203  /// the ValueRange.
204  ShapeAdaptor getValueAsShape(int index);
205 
206  /// Returns the shape of index'th operand.
207  // TODO: Update so that operator[] references these instead to avoid
208  // accidentally refering to less refined shape.
209  ShapeAdaptor getShape(int index) const;
210 
211  /// Returns the shape of the given Value.
212  ShapeAdaptor getShape(Value val) const;
213 
214 private:
215  // Mapping from Value to ShapedTypeComponents corresponding to shape of type
216  // of Value.
217  ValueShapeMapFn operandShape;
218 
219  // Mapping from Value to ShapedTypeComponents corresponding to constant Value
220  // if interpreted as shape.
221  ValueShapeMapFn valueToShape;
222 };
223 
224 namespace detail {
225 // Helper function to infer return tensor returns types given element and
226 // shape inference function.
227 //
228 // TODO: Consider generating typedefs for trait member functions if this usage
229 // becomes more common.
232  MLIRContext *, Optional<Location> location, ValueShapeRange operands,
233  DictionaryAttr attributes, RegionRange regions,
235  componentTypeFn,
236  MLIRContext *context, Optional<Location> location, ValueRange operands,
237  DictionaryAttr attributes, RegionRange regions,
238  SmallVectorImpl<Type> &inferredReturnTypes);
239 
240 /// Verifies that the inferred result types match the actual result types for
241 /// the op. Precondition: op implements InferTypeOpInterface.
243 } // namespace detail
244 
245 namespace OpTrait {
246 template <typename ConcreteType>
248 } // namespace OpTrait
249 } // namespace mlir
250 
251 /// Include the generated interface declarations.
252 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
253 
254 namespace mlir {
255 namespace OpTrait {
256 
257 /// Tensor type inference trait that constructs a tensor from the inferred
258 /// shape and elemental types.
259 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
260 /// Less strict is possible (e.g., implements inferReturnTypeComponents and
261 /// these always populates all element types and shapes or fails, but this\
262 /// trait is currently only used where the interfaces are, so keep it
263 /// restricted for now).
264 template <typename ConcreteType>
265 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
266 public:
267  static LogicalResult
269  ValueRange operands, DictionaryAttr attributes,
270  RegionRange regions,
271  SmallVectorImpl<Type> &inferredReturnTypes) {
272  static_assert(
273  ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
274  "requires InferShapedTypeOpInterface to ensure succesful invocation");
275  static_assert(
276  ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
277  "requires InferTypeOpInterface to ensure succesful invocation");
279  ConcreteType::inferReturnTypeComponents, context, location, operands,
280  attributes, regions, inferredReturnTypes);
281  }
282 };
283 
284 } // namespace OpTrait
285 } // namespace mlir
286 
287 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
Include the generated interface declarations.
Type getElementType() const
Return the element type component.
bool hasRank() const
Return whether the shape has a rank.
LogicalResult inferReturnTensorTypes(function_ref< LogicalResult(MLIRContext *, Optional< Location > location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< ShapedTypeComponents > &retComponents)> componentTypeFn, MLIRContext *context, Optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
ValueShapeMapFn getValueToShapeMapping() const
Returns the set Value to ShapeAdaptor mapping function.
ShapedTypeComponents(ShapedType shapedType)
ShapedTypeComponents that represents the components of a ShapedType.
ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents()
Default construction is an unranked shape.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Range of values and shapes (corresponding effectively to Shapes dialect&#39;s ValueShape type concept)...
static constexpr const bool value
ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)
Sets the Value to ShapeAdaptor mapping function and returns this.
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:126
static LogicalResult inferReturnTypes(MLIRContext *context, Optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
ValueShapeRange(const std::initializer_list< Value > &values)
U dyn_cast() const
Definition: Types.h:244
ShapeAdaptor(ShapedTypeComponents *components)
Attributes are known-constant values of operations.
Definition: Attributes.h:24
ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)
ShapeAdaptor(ShapedTypeComponents &components)
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
This class implements iteration on the types of a given range of values.
Definition: Block.h:26
bool isDynamicDim(int index) const
Returns whether the index&#39;th dimension is dynamic.
Helper class for implementing traits.
Definition: OpDefinition.h:291
Attribute getAttribute() const
Return the raw attribute component.
U dyn_cast() const
Definition: Attributes.h:117
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:638
ValueShapeMapFn getOperandShapeMapping() const
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:322
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)
type_range getTypes() const
ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)
Tensor type inference trait that constructs a tensor from the inferred shape and elemental types...
ValueRange getValues() const
Returns the Values in the ValueRange.
ShapedTypeComponents(Type elementType)
This class provides an abstraction over the different types of ranges over Values.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op. ...
An attribute that represents a reference to a dense integer vector or tensor object.