MLIR  16.0.0git
InferTypeOpInterface.h
Go to the documentation of this file.
1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace mlir {
27 
28 class ShapedTypeComponents;
30 
31 /// Adaptor class to abstract the differences between whether value is from
32 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
33 class ShapeAdaptor {
34 public:
36  if (auto st = t.dyn_cast<ShapedType>())
37  val = st;
38  }
40  if (auto da = t.dyn_cast<DenseIntElementsAttr>())
41  val = da;
42  }
43  ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
44  ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
45 
46  /// Returns whether the shape has a rank.
47  bool hasRank() const;
48 
49  /// Returns the element type.
50  Type getElementType() const;
51 
52  /// Populates the dimensions from shape referenced.
53  /// Requires: shape is ranked.
54  void getDims(SmallVectorImpl<int64_t> &res) const;
55 
56  /// Populates the dimensions of the ShapeTypeComponents.
57  /// Requires: shape is ranked.
58  void getDims(ShapedTypeComponents &res) const;
59 
60  /// Returns the size of the index'th dimension.
61  /// Requires: shape is ranked.
62  int64_t getDimSize(int index) const;
63 
64  /// Returns whether the index'th dimension is dynamic.
65  /// Requires: shape is ranked.
66  bool isDynamicDim(int index) const {
67  return ShapedType::isDynamic(getDimSize(index));
68  }
69 
70  /// Returns whether the shape is fully static.
71  bool hasStaticShape() const;
72 
73  /// Returns the rank of the shape.
74  /// Requires: shape is ranked.
75  int64_t getRank() const;
76 
77  /// Returns the number of elements in the shape.
78  /// Requires: hasStaticShape
79  int64_t getNumElements() const;
80 
81  /// Returns whether valid (non-null) shape.
82  explicit operator bool() const { return !val.isNull(); }
83 
84  /// Dumps textual repesentation to stderr.
85  void dump() const;
86 
87 private:
88  // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
89  // casted), or DenseIntElementsAttribute (stored as Atrtribute).
91 };
92 
93 /// ShapedTypeComponents that represents the components of a ShapedType.
94 /// The components consist of
95 /// - A ranked or unranked shape with the dimension specification match those
96 /// of ShapeType's getShape() (e.g., dynamic dimension represented using
97 /// ShapedType::kDynamicSize)
98 /// - A element type, may be unset (nullptr)
99 /// - A attribute, may be unset (nullptr)
100 /// Used by ShapedType type inferences.
102  /// Internal storage type for shape.
104 
105 public:
106  /// Default construction is an unranked shape.
107  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
109  : elementType(elementType), attr(nullptr), ranked(false) {}
110  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
111  ranked = shapedType.hasRank();
112  elementType = shapedType.getElementType();
113  if (ranked)
114  dims = llvm::to_vector<4>(shapedType.getShape());
115  }
116  ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
117  ranked = adaptor.hasRank();
118  elementType = adaptor.getElementType();
119  if (ranked)
120  adaptor.getDims(*this);
121  }
122  template <typename Arg, typename = typename std::enable_if_t<
124  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
125  Attribute attr = nullptr)
126  : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
127  ranked(true) {}
128  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
129  Attribute attr = nullptr)
130  : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
131  ranked(true) {}
132 
133  /// Return the dimensions of the shape.
134  /// Requires: shape is ranked.
136  assert(ranked && "requires ranked shape");
137  return dims;
138  }
139 
140  /// Return whether the shape has a rank.
141  bool hasRank() const { return ranked; };
142 
143  /// Return the element type component.
144  Type getElementType() const { return elementType; };
145 
146  /// Return the raw attribute component.
147  Attribute getAttribute() const { return attr; };
148 
149 private:
150  friend class ShapeAdaptor;
151 
152  ShapeStorageT dims;
153  Type elementType;
154  Attribute attr;
155  bool ranked{false};
156 };
157 
158 /// Range of values and shapes (corresponding effectively to Shapes dialect's
159 /// ValueShape type concept).
160 // Currently this exposes the Value (of operands) and Type of the Value. This is
161 // not ideal as then one can accidentally reference an out of date shape. This
162 // is done to both enable gradual switch and also as OpAdaptor doesn't currently
163 // allow returning anything other than Value.
164 class ValueShapeRange : public ValueRange::RangeBaseT {
165 public:
167 
168  ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
169  ValueShapeMapFn valueToShape = nullptr)
170  : RangeBaseT(values), operandShape(operandShape),
171  valueToShape(valueToShape) {}
172  ValueShapeRange(const std::initializer_list<Value> &values)
173  : ValueShapeRange(ValueRange(values)) {}
174 
175  ValueShapeRange(const ValueShapeRange &) = default;
176 
177  /// Sets the Value to ShapeAdaptor mapping function and returns this.
179  valueToShape = fn;
180  return *this;
181  }
182 
184  operandShape = fn;
185  return *this;
186  }
187 
188  /// Returns the set Value to ShapeAdaptor mapping function.
189  ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
190  ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
191 
192  // Accessors.
193 
194  /// Returns the types of the values within this range.
195  /// Note: This returns only the types of Values in the ValueRange and not a
196  /// more refined type.
199  type_range getTypes() const { return {begin(), end()}; }
200  auto getType() const { return getTypes(); }
201 
202  /// Returns the Values in the ValueRange.
203  /// To query the most up to date shape of a Value, query the shape
204  /// using getShape below rather than using the type of the Value.
205  ValueRange getValues() const { return ValueRange(begin(), end()); };
206 
207  /// Returns an argument as shape. If the argument is not constant or not a
208  /// shape, then the function returns a nullptr.
209  /// This will first query the valueToShape mapping (if set), before querying
210  /// the ValueRange.
211  ShapeAdaptor getValueAsShape(int index);
212 
213  /// Returns the shape of index'th operand.
214  // TODO: Update so that operator[] references these instead to avoid
215  // accidentally refering to less refined shape.
216  ShapeAdaptor getShape(int index) const;
217 
218  /// Returns the shape of the given Value.
219  ShapeAdaptor getShape(Value val) const;
220 
221 private:
222  // Mapping from Value to ShapedTypeComponents corresponding to shape of type
223  // of Value.
224  ValueShapeMapFn operandShape;
225 
226  // Mapping from Value to ShapedTypeComponents corresponding to constant Value
227  // if interpreted as shape.
228  ValueShapeMapFn valueToShape;
229 };
230 
231 namespace detail {
232 // Helper function to infer return tensor returns types given element and
233 // shape inference function.
234 //
235 // TODO: Consider generating typedefs for trait member functions if this usage
236 // becomes more common.
239  MLIRContext *, Optional<Location> location, ValueShapeRange operands,
240  DictionaryAttr attributes, RegionRange regions,
242  componentTypeFn,
243  MLIRContext *context, Optional<Location> location, ValueRange operands,
244  DictionaryAttr attributes, RegionRange regions,
245  SmallVectorImpl<Type> &inferredReturnTypes);
246 
247 /// Verifies that the inferred result types match the actual result types for
248 /// the op. Precondition: op implements InferTypeOpInterface.
250 } // namespace detail
251 
252 namespace OpTrait {
253 template <typename ConcreteType>
255 } // namespace OpTrait
256 } // namespace mlir
257 
258 /// Include the generated interface declarations.
259 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
260 
261 namespace mlir {
262 namespace OpTrait {
263 
264 /// Tensor type inference trait that constructs a tensor from the inferred
265 /// shape and elemental types.
266 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
267 /// Less strict is possible (e.g., implements inferReturnTypeComponents and
268 /// these always populates all element types and shapes or fails, but this\
269 /// trait is currently only used where the interfaces are, so keep it
270 /// restricted for now).
271 template <typename ConcreteType>
272 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
273 public:
274  static LogicalResult
276  ValueRange operands, DictionaryAttr attributes,
277  RegionRange regions,
278  SmallVectorImpl<Type> &inferredReturnTypes) {
279  static_assert(
280  ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
281  "requires InferShapedTypeOpInterface to ensure succesful invocation");
282  static_assert(
283  ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
284  "requires InferTypeOpInterface to ensure succesful invocation");
286  ConcreteType::inferReturnTypeComponents, context, location, operands,
287  attributes, regions, inferredReturnTypes);
288  }
289 };
290 
291 } // namespace OpTrait
292 } // namespace mlir
293 
294 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
Include the generated interface declarations.
Type getElementType() const
Return the element type component.
bool hasRank() const
Return whether the shape has a rank.
LogicalResult inferReturnTensorTypes(function_ref< LogicalResult(MLIRContext *, Optional< Location > location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< ShapedTypeComponents > &retComponents)> componentTypeFn, MLIRContext *context, Optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
ValueShapeMapFn getValueToShapeMapping() const
Returns the set Value to ShapeAdaptor mapping function.
ShapedTypeComponents(ShapedType shapedType)
ShapedTypeComponents(ShapeAdaptor adaptor)
ShapedTypeComponents that represents the components of a ShapedType.
ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents()
Default construction is an unranked shape.
int64_t getNumElements() const
Returns the number of elements in the shape.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
Range of values and shapes (corresponding effectively to Shapes dialect&#39;s ValueShape type concept)...
static constexpr const bool value
ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)
Sets the Value to ShapeAdaptor mapping function and returns this.
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:117
static LogicalResult inferReturnTypes(MLIRContext *context, Optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
ValueShapeRange(const std::initializer_list< Value > &values)
U dyn_cast() const
Definition: Types.h:269
ShapeAdaptor(ShapedTypeComponents *components)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)
bool hasRank() const
Returns whether the shape has a rank.
Type getElementType() const
Returns the element type.
ShapeAdaptor(ShapedTypeComponents &components)
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
This class implements iteration on the types of a given range of values.
Definition: Block.h:26
bool isDynamicDim(int index) const
Returns whether the index&#39;th dimension is dynamic.
Helper class for implementing traits.
Definition: OpDefinition.h:316
void dump() const
Dumps textual repesentation to stderr.
U dyn_cast() const
Definition: Attributes.h:127
ValueShapeMapFn getOperandShapeMapping() const
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:327
int64_t getDimSize(int index) const
Returns the size of the index&#39;th dimension.
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)
type_range getTypes() const
ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)
Tensor type inference trait that constructs a tensor from the inferred shape and elemental types...
ValueRange getValues() const
Returns the Values in the ValueRange.
ShapedTypeComponents(Type elementType)
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op. ...
An attribute that represents a reference to a dense integer vector or tensor object.
int64_t getRank() const
Returns the rank of the shape.