MLIR  17.0.0git
InferTypeOpInterface.h
Go to the documentation of this file.
1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace mlir {
27 
28 class ShapedTypeComponents;
30 
31 /// Reify the shape of the result of an operation (typically in terms of the
32 /// shape of its operands).
35  ReifiedRankedShapedTypeDims &reifiedReturnShapes);
36 
37 /// Adaptor class to abstract the differences between whether value is from
38 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
39 class ShapeAdaptor {
40 public:
42  if (auto st = dyn_cast<ShapedType>(t))
43  val = st;
44  }
46  if (auto da = dyn_cast<DenseIntElementsAttr>(t))
47  val = da;
48  }
49  ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
50  ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
51 
52  /// Returns whether the shape has a rank.
53  bool hasRank() const;
54 
55  /// Returns the element type.
56  Type getElementType() const;
57 
58  /// Populates the dimensions from shape referenced.
59  /// Requires: shape is ranked.
60  void getDims(SmallVectorImpl<int64_t> &res) const;
61 
62  /// Populates the dimensions of the ShapeTypeComponents.
63  /// Requires: shape is ranked.
64  void getDims(ShapedTypeComponents &res) const;
65 
66  /// Returns the size of the index'th dimension.
67  /// Requires: shape is ranked.
68  int64_t getDimSize(int index) const;
69 
70  /// Returns whether the index'th dimension is dynamic.
71  /// Requires: shape is ranked.
72  bool isDynamicDim(int index) const {
73  return ShapedType::isDynamic(getDimSize(index));
74  }
75 
76  /// Returns whether the shape is fully static.
77  bool hasStaticShape() const;
78 
79  /// Returns the rank of the shape.
80  /// Requires: shape is ranked.
81  int64_t getRank() const;
82 
83  /// Returns the number of elements in the shape.
84  /// Requires: hasStaticShape
85  int64_t getNumElements() const;
86 
87  /// Returns whether valid (non-null) shape.
88  explicit operator bool() const { return !val.isNull(); }
89 
90  /// Dumps textual repesentation to stderr.
91  void dump() const;
92 
93 private:
94  // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
95  // casted), or DenseIntElementsAttribute (stored as Atrtribute).
97 };
98 
99 /// ShapedTypeComponents that represents the components of a ShapedType.
100 /// The components consist of
101 /// - A ranked or unranked shape with the dimension specification match those
102 /// of ShapeType's getShape() (e.g., dynamic dimension represented using
103 /// ShapedType::kDynamic)
104 /// - A element type, may be unset (nullptr)
105 /// - A attribute, may be unset (nullptr)
106 /// Used by ShapedType type inferences.
108  /// Internal storage type for shape.
110 
111 public:
112  /// Default construction is an unranked shape.
113  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
115  : elementType(elementType), attr(nullptr), ranked(false) {}
116  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
117  ranked = shapedType.hasRank();
118  elementType = shapedType.getElementType();
119  if (ranked)
120  dims = llvm::to_vector<4>(shapedType.getShape());
121  }
122  ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
123  ranked = adaptor.hasRank();
124  elementType = adaptor.getElementType();
125  if (ranked)
126  adaptor.getDims(*this);
127  }
128  template <typename Arg, typename = std::enable_if_t<
129  std::is_constructible<ShapeStorageT, Arg>::value>>
130  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
131  Attribute attr = nullptr)
132  : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
133  ranked(true) {}
134  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
135  Attribute attr = nullptr)
136  : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
137  ranked(true) {}
138 
139  /// Return the dimensions of the shape.
140  /// Requires: shape is ranked.
142  assert(ranked && "requires ranked shape");
143  return dims;
144  }
145 
146  /// Return whether the shape has a rank.
147  bool hasRank() const { return ranked; };
148 
149  /// Return the element type component.
150  Type getElementType() const { return elementType; };
151 
152  /// Return the raw attribute component.
153  Attribute getAttribute() const { return attr; };
154 
155 private:
156  friend class ShapeAdaptor;
157 
158  ShapeStorageT dims;
159  Type elementType;
160  Attribute attr;
161  bool ranked{false};
162 };
163 
164 /// Range of values and shapes (corresponding effectively to Shapes dialect's
165 /// ValueShape type concept).
166 // Currently this exposes the Value (of operands) and Type of the Value. This is
167 // not ideal as then one can accidentally reference an out of date shape. This
168 // is done to both enable gradual switch and also as OpAdaptor doesn't currently
169 // allow returning anything other than Value.
170 class ValueShapeRange : public ValueRange::RangeBaseT {
171 public:
173 
174  ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
175  ValueShapeMapFn valueToShape = nullptr)
176  : RangeBaseT(values), operandShape(operandShape),
177  valueToShape(valueToShape) {}
178  ValueShapeRange(const std::initializer_list<Value> &values)
179  : ValueShapeRange(ValueRange(values)) {}
180 
181  ValueShapeRange(const ValueShapeRange &) = default;
182 
183  /// Sets the Value to ShapeAdaptor mapping function and returns this.
185  valueToShape = fn;
186  return *this;
187  }
188 
190  operandShape = fn;
191  return *this;
192  }
193 
194  /// Returns the set Value to ShapeAdaptor mapping function.
195  ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
196  ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
197 
198  // Accessors.
199 
200  /// Returns the types of the values within this range.
201  /// Note: This returns only the types of Values in the ValueRange and not a
202  /// more refined type.
205  type_range getTypes() const { return {begin(), end()}; }
206  auto getType() const { return getTypes(); }
207 
208  /// Returns the Values in the ValueRange.
209  /// To query the most up to date shape of a Value, query the shape
210  /// using getShape below rather than using the type of the Value.
211  ValueRange getValues() const { return ValueRange(begin(), end()); };
212 
213  /// Returns an argument as shape. If the argument is not constant or not a
214  /// shape, then the function returns a nullptr.
215  /// This will first query the valueToShape mapping (if set), before querying
216  /// the ValueRange.
217  ShapeAdaptor getValueAsShape(int index);
218 
219  /// Returns the shape of index'th operand.
220  // TODO: Update so that operator[] references these instead to avoid
221  // accidentally refering to less refined shape.
222  ShapeAdaptor getShape(int index) const;
223 
224  /// Returns the shape of the given Value.
225  ShapeAdaptor getShape(Value val) const;
226 
227 private:
228  // Mapping from Value to ShapedTypeComponents corresponding to shape of type
229  // of Value.
230  ValueShapeMapFn operandShape;
231 
232  // Mapping from Value to ShapedTypeComponents corresponding to constant Value
233  // if interpreted as shape.
234  ValueShapeMapFn valueToShape;
235 };
236 
237 namespace detail {
238 // Helper function to infer return tensor returns types given element and
239 // shape inference function.
240 //
241 // TODO: Consider generating typedefs for trait member functions if this usage
242 // becomes more common.
243 LogicalResult inferReturnTensorTypes(
244  function_ref<
245  LogicalResult(MLIRContext *, std::optional<Location> location,
246  ValueShapeRange operands, DictionaryAttr attributes,
247  OpaqueProperties properties, RegionRange regions,
248  SmallVectorImpl<ShapedTypeComponents> &retComponents)>
249  componentTypeFn,
250  MLIRContext *context, std::optional<Location> location, ValueRange operands,
251  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
252  SmallVectorImpl<Type> &inferredReturnTypes);
253 
254 /// Verifies that the inferred result types match the actual result types for
255 /// the op. Precondition: op implements InferTypeOpInterface.
256 LogicalResult verifyInferredResultTypes(Operation *op);
257 } // namespace detail
258 
259 namespace OpTrait {
260 template <typename ConcreteType>
261 class InferTensorType;
262 } // namespace OpTrait
263 } // namespace mlir
264 
265 /// Include the generated interface declarations.
266 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
267 
268 namespace mlir {
269 namespace OpTrait {
270 
271 /// Tensor type inference trait that constructs a tensor from the inferred
272 /// shape and elemental types.
273 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
274 /// Less strict is possible (e.g., implements inferReturnTypeComponents and
275 /// these always populates all element types and shapes or fails, but this\
276 /// trait is currently only used where the interfaces are, so keep it
277 /// restricted for now).
278 template <typename ConcreteType>
279 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
280 public:
281  static LogicalResult
282  inferReturnTypes(MLIRContext *context, std::optional<Location> location,
283  ValueRange operands, DictionaryAttr attributes,
284  OpaqueProperties properties, RegionRange regions,
285  SmallVectorImpl<Type> &inferredReturnTypes) {
286  static_assert(
287  ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
288  "requires InferShapedTypeOpInterface to ensure succesful invocation");
289  static_assert(
290  ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
291  "requires InferTypeOpInterface to ensure succesful invocation");
293  ConcreteType::inferReturnTypeComponents, context, location, operands,
294  attributes, properties, regions, inferredReturnTypes);
295  }
296 };
297 
298 } // namespace OpTrait
299 } // namespace mlir
300 
301 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:202
Tensor type inference trait that constructs a tensor from the inferred shape and elemental types.
static LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
Helper class for implementing traits.
Definition: OpDefinition.h:362
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:334
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
ShapeAdaptor(ShapedTypeComponents &components)
ShapeAdaptor(ShapedTypeComponents *components)
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
ShapedTypeComponents()
Default construction is an unranked shape.
ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapedType shapedType)
ShapedTypeComponents(Type elementType)
ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapeAdaptor adaptor)
bool hasRank() const
Return whether the shape has a rank.
Type getElementType() const
Return the element type component.
Attribute getAttribute() const
Return the raw attribute component.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ValueShapeMapFn getOperandShapeMapping() const
ValueShapeMapFn getValueToShapeMapping() const
Returns the set Value to ShapeAdaptor mapping function.
ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)
Sets the Value to ShapeAdaptor mapping function and returns this.
ValueShapeRange(const ValueShapeRange &)=default
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
type_range getTypes() const
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)
ValueShapeRange(const std::initializer_list< Value > &values)
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
function_ref< ShapeAdaptor(Value)> ValueShapeMapFn
ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)
ValueRange getValues() const
Returns the Values in the ValueRange.
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:118
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
LogicalResult inferReturnTensorTypes(function_ref< LogicalResult(MLIRContext *, std::optional< Location > location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl< ShapedTypeComponents > &retComponents)> componentTypeFn, MLIRContext *context, std::optional< Location > location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
This header declares functions that assist transformations in the MemRef dialect.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:147
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26