MLIR  17.0.0git
BuiltinTypes.h
Go to the documentation of this file.
1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11 
14 
15 namespace llvm {
16 class BitVector;
17 struct fltSemantics;
18 } // namespace llvm
19 
20 //===----------------------------------------------------------------------===//
21 // Tablegen Interface Declarations
22 //===----------------------------------------------------------------------===//
23 
24 namespace mlir {
25 class AffineExpr;
26 class AffineMap;
27 class FloatType;
28 class IndexType;
29 class IntegerType;
30 class StringAttr;
31 class TypeRange;
32 
33 //===----------------------------------------------------------------------===//
34 // FloatType
35 //===----------------------------------------------------------------------===//
36 
37 class FloatType : public Type {
38 public:
39  using Type::Type;
40 
41  // Convenience factories.
42  static FloatType getBF16(MLIRContext *ctx);
43  static FloatType getF16(MLIRContext *ctx);
44  static FloatType getF32(MLIRContext *ctx);
45  static FloatType getF64(MLIRContext *ctx);
46  static FloatType getF80(MLIRContext *ctx);
47  static FloatType getF128(MLIRContext *ctx);
48  static FloatType getFloat8E5M2(MLIRContext *ctx);
50 
51  /// Methods for support type inquiry through isa, cast, and dyn_cast.
52  static bool classof(Type type);
53 
54  /// Return the bitwidth of this float type.
55  unsigned getWidth();
56 
57  /// Return the width of the mantissa of this type.
58  unsigned getFPMantissaWidth();
59 
60  /// Get or create a new FloatType with bitwidth scaled by `scale`.
61  /// Return null if the scaled element type cannot be represented.
62  FloatType scaleElementBitwidth(unsigned scale);
63 
64  /// Return the floating semantics of this float type.
65  const llvm::fltSemantics &getFloatSemantics();
66 };
67 
68 //===----------------------------------------------------------------------===//
69 // TensorType
70 //===----------------------------------------------------------------------===//
71 
72 /// Tensor types represent multi-dimensional arrays, and have two variants:
73 /// RankedTensorType and UnrankedTensorType.
74 /// Note: This class attaches the ShapedType trait to act as a mixin to
75 /// provide many useful utility functions. This inheritance has no effect
76 /// on derived tensor types.
77 class TensorType : public Type, public ShapedType::Trait<TensorType> {
78 public:
79  using Type::Type;
80 
81  /// Returns the element type of this tensor type.
82  Type getElementType() const;
83 
84  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
85  bool hasRank() const;
86 
87  /// Returns the shape of this tensor type.
89 
90  /// Clone this type with the given shape and element type. If the
91  /// provided shape is `None`, the current shape of the type is used.
92  TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
93  Type elementType) const;
94 
95  /// Return true if the specified element type is ok in a tensor.
96  static bool isValidElementType(Type type);
97 
98  /// Methods for support type inquiry through isa, cast, and dyn_cast.
99  static bool classof(Type type);
100 
101  /// Allow implicit conversion to ShapedType.
102  operator ShapedType() const { return cast<ShapedType>(); }
103 };
104 
105 //===----------------------------------------------------------------------===//
106 // BaseMemRefType
107 //===----------------------------------------------------------------------===//
108 
109 /// This class provides a shared interface for ranked and unranked memref types.
110 /// Note: This class attaches the ShapedType trait to act as a mixin to
111 /// provide many useful utility functions. This inheritance has no effect
112 /// on derived memref types.
113 class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
114 public:
115  using Type::Type;
116 
117  /// Returns the element type of this memref type.
118  Type getElementType() const;
119 
120  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
121  bool hasRank() const;
122 
123  /// Returns the shape of this memref type.
124  ArrayRef<int64_t> getShape() const;
125 
126  /// Clone this type with the given shape and element type. If the
127  /// provided shape is `None`, the current shape of the type is used.
128  BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
129  Type elementType) const;
130 
131  /// Return true if the specified element type is ok in a memref.
132  static bool isValidElementType(Type type);
133 
134  /// Methods for support type inquiry through isa, cast, and dyn_cast.
135  static bool classof(Type type);
136 
137  /// Returns the memory space in which data referred to by this memref resides.
138  Attribute getMemorySpace() const;
139 
140  /// [deprecated] Returns the memory space in old raw integer representation.
141  /// New `Attribute getMemorySpace()` method should be used instead.
142  unsigned getMemorySpaceAsInt() const;
143 
144  /// Allow implicit conversion to ShapedType.
145  operator ShapedType() const { return cast<ShapedType>(); }
146 };
147 
148 } // namespace mlir
149 
150 //===----------------------------------------------------------------------===//
151 // Tablegen Type Declarations
152 //===----------------------------------------------------------------------===//
153 
154 #define GET_TYPEDEF_CLASSES
155 #include "mlir/IR/BuiltinTypes.h.inc"
156 
157 namespace mlir {
158 
159 //===----------------------------------------------------------------------===//
160 // MemRefType
161 //===----------------------------------------------------------------------===//
162 
163 /// This is a builder type that keeps local references to arguments. Arguments
164 /// that are passed into the builder must outlive the builder.
166 public:
167  // Build from another MemRefType.
168  explicit Builder(MemRefType other)
169  : shape(other.getShape()), elementType(other.getElementType()),
170  layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
171 
172  // Build from scratch.
173  Builder(ArrayRef<int64_t> shape, Type elementType)
174  : shape(shape), elementType(elementType) {}
175 
177  shape = newShape;
178  return *this;
179  }
180 
181  Builder &setElementType(Type newElementType) {
182  elementType = newElementType;
183  return *this;
184  }
185 
186  Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
187  layout = newLayout;
188  return *this;
189  }
190 
191  Builder &setMemorySpace(Attribute newMemorySpace) {
192  memorySpace = newMemorySpace;
193  return *this;
194  }
195 
196  operator MemRefType() {
197  return MemRefType::get(shape, elementType, layout, memorySpace);
198  }
199 
200 private:
201  ArrayRef<int64_t> shape;
202  Type elementType;
203  MemRefLayoutAttrInterface layout;
204  Attribute memorySpace;
205 };
206 
207 //===----------------------------------------------------------------------===//
208 // RankedTensorType
209 //===----------------------------------------------------------------------===//
210 
211 /// This is a builder type that keeps local references to arguments. Arguments
212 /// that are passed into the builder must outlive the builder.
214 public:
215  /// Build from another RankedTensorType.
216  explicit Builder(RankedTensorType other)
217  : shape(other.getShape()), elementType(other.getElementType()),
218  encoding(other.getEncoding()) {}
219 
220  /// Build from scratch.
221  Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
222  : shape(shape), elementType(elementType), encoding(encoding) {}
223 
225  shape = newShape;
226  return *this;
227  }
228 
229  Builder &setElementType(Type newElementType) {
230  elementType = newElementType;
231  return *this;
232  }
233 
234  Builder &setEncoding(Attribute newEncoding) {
235  encoding = newEncoding;
236  return *this;
237  }
238 
239  /// Erase a dim from shape @pos.
240  Builder &dropDim(unsigned pos) {
241  assert(pos < shape.size() && "overflow");
242  if (storage.empty())
243  storage.append(shape.begin(), shape.end());
244  storage.erase(storage.begin() + pos);
245  shape = {storage.data(), storage.size()};
246  return *this;
247  }
248 
249  /// Insert a val into shape @pos.
250  Builder &insertDim(int64_t val, unsigned pos) {
251  assert(pos <= shape.size() && "overflow");
252  if (storage.empty())
253  storage.append(shape.begin(), shape.end());
254  storage.insert(storage.begin() + pos, val);
255  shape = {storage.data(), storage.size()};
256  return *this;
257  }
258 
259  operator RankedTensorType() {
260  return RankedTensorType::get(shape, elementType, encoding);
261  }
262 
263 private:
264  ArrayRef<int64_t> shape;
265  // Owning shape data for copy-on-write operations.
266  SmallVector<int64_t> storage;
267  Type elementType;
268  Attribute encoding;
269 };
270 
271 //===----------------------------------------------------------------------===//
272 // VectorType
273 //===----------------------------------------------------------------------===//
274 
275 /// This is a builder type that keeps local references to arguments. Arguments
276 /// that are passed into the builder must outlive the builder.
278 public:
279  /// Build from another VectorType.
280  explicit Builder(VectorType other)
281  : shape(other.getShape()), elementType(other.getElementType()),
282  numScalableDims(other.getNumScalableDims()) {}
283 
284  /// Build from scratch.
285  Builder(ArrayRef<int64_t> shape, Type elementType,
286  unsigned numScalableDims = 0)
287  : shape(shape), elementType(elementType),
288  numScalableDims(numScalableDims) {}
289 
291  unsigned newNumScalableDims = 0) {
292  numScalableDims = newNumScalableDims;
293  shape = newShape;
294  return *this;
295  }
296 
297  Builder &setElementType(Type newElementType) {
298  elementType = newElementType;
299  return *this;
300  }
301 
302  /// Erase a dim from shape @pos.
303  Builder &dropDim(unsigned pos) {
304  assert(pos < shape.size() && "overflow");
305  if (pos >= shape.size() - numScalableDims)
306  numScalableDims--;
307  if (storage.empty())
308  storage.append(shape.begin(), shape.end());
309  storage.erase(storage.begin() + pos);
310  shape = {storage.data(), storage.size()};
311  return *this;
312  }
313 
314  /// In the particular case where the vector has a single dimension that we
315  /// drop, return the scalar element type.
316  // TODO: unify once we have a VectorType that supports 0-D.
317  operator Type() {
318  if (shape.empty())
319  return elementType;
320  return VectorType::get(shape, elementType, numScalableDims);
321  }
322 
323 private:
324  ArrayRef<int64_t> shape;
325  // Owning shape data for copy-on-write operations.
326  SmallVector<int64_t> storage;
327  Type elementType;
328  unsigned numScalableDims;
329 };
330 
331 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
332 /// `originalShape` with some `1` entries erased, return the set of indices
333 /// that specifies which of the entries of `originalShape` are dropped to obtain
334 /// `reducedShape`. The returned mask can be applied as a projection to
335 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
336 /// which dimensions must be kept when e.g. compute MemRef strides under
337 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
338 /// obtained by dropping only `1` entries in `originalShape`.
339 std::optional<llvm::SmallDenseSet<unsigned>>
341  ArrayRef<int64_t> reducedShape);
342 
343 /// Enum that captures information related to verifier error conditions on
344 /// slice insert/extract type of ops.
346  Success,
347  RankTooLarge,
348  SizeMismatch,
350  // Error codes to ops with a memory space and a layout annotation.
353 };
354 
355 /// Check if `originalType` can be rank reduced to `candidateReducedType` type
356 /// by dropping some dimensions with static size `1`.
357 /// Return `SliceVerificationResult::Success` on success or an appropriate error
358 /// code.
359 SliceVerificationResult isRankReducedType(ShapedType originalType,
360  ShapedType candidateReducedType);
361 
362 //===----------------------------------------------------------------------===//
363 // Deferred Method Definitions
364 //===----------------------------------------------------------------------===//
365 
366 inline bool BaseMemRefType::classof(Type type) {
367  return type.isa<MemRefType, UnrankedMemRefType>();
368 }
369 
371  return type.isIntOrIndexOrFloat() ||
372  type.isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>() ||
373  type.isa<MemRefElementTypeInterface>();
374 }
375 
376 inline bool FloatType::classof(Type type) {
377  return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
378  Float32Type, Float64Type, Float80Type, Float128Type>();
379 }
380 
382  return Float8E5M2Type::get(ctx);
383 }
384 
386  return Float8E4M3FNType::get(ctx);
387 }
388 
390  return BFloat16Type::get(ctx);
391 }
392 
394  return Float16Type::get(ctx);
395 }
396 
398  return Float32Type::get(ctx);
399 }
400 
402  return Float64Type::get(ctx);
403 }
404 
406  return Float80Type::get(ctx);
407 }
408 
410  return Float128Type::get(ctx);
411 }
412 
413 inline bool TensorType::classof(Type type) {
414  return type.isa<RankedTensorType, UnrankedTensorType>();
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // Type Utilities
419 //===----------------------------------------------------------------------===//
420 
421 /// Returns the strides of the MemRef if the layout map is in strided form.
422 /// MemRefs with a layout map in strided form include:
423 /// 1. empty or identity layout map, in which case the stride information is
424 /// the canonical form computed from sizes;
425 /// 2. a StridedLayoutAttr layout;
426 /// 3. any other layout that be converted into a single affine map layout of
427 /// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
428 /// symbols.
429 ///
430 /// A stride specification is a list of integer values that are either static
431 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
432 /// the distance in the number of elements between successive entries along a
433 /// particular dimension.
435  SmallVectorImpl<int64_t> &strides,
436  int64_t &offset);
437 
438 /// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
439 /// int64_t) that will assert if the logical result is not succeeded.
440 std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
441 
442 /// Return a version of `t` with identity layout if it can be determined
443 /// statically that the layout is the canonical contiguous strided layout.
444 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
445 /// `t` with simplified layout.
446 MemRefType canonicalizeStridedLayout(MemRefType t);
447 
448 /// Given MemRef `sizes` that are either static or dynamic, returns the
449 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
450 /// once a dynamic dimension is encountered, all canonical strides become
451 /// dynamic and need to be encoded with a different symbol.
452 /// For canonical strides expressions, the offset is always 0 and and fastest
453 /// varying stride is always `1`.
454 ///
455 /// Examples:
456 /// - memref<3x4x5xf32> has canonical stride expression
457 /// `20*exprs[0] + 5*exprs[1] + exprs[2]`.
458 /// - memref<3x?x5xf32> has canonical stride expression
459 /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
460 /// - memref<3x4x?xf32> has canonical stride expression
461 /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
463  ArrayRef<AffineExpr> exprs,
464  MLIRContext *context);
465 
466 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
467 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
469  MLIRContext *context);
470 
471 /// Return true if the layout for `t` is compatible with strided semantics.
472 bool isStrided(MemRefType t);
473 
474 } // namespace mlir
475 
476 #endif // MLIR_IR_BUILTINTYPES_H
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:113
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:370
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
static bool classof(Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: BuiltinTypes.h:366
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:401
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
static FloatType getFloat8E5M2(MLIRContext *ctx)
Definition: BuiltinTypes.h:381
static FloatType getF80(MLIRContext *ctx)
Definition: BuiltinTypes.h:405
static FloatType getFloat8E4M3FN(MLIRContext *ctx)
Definition: BuiltinTypes.h:385
static FloatType getF16(MLIRContext *ctx)
Definition: BuiltinTypes.h:393
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
static FloatType getBF16(MLIRContext *ctx)
Definition: BuiltinTypes.h:389
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
static FloatType getF128(MLIRContext *ctx)
Definition: BuiltinTypes.h:409
unsigned getWidth()
Return the bitwidth of this float type.
static bool classof(Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: BuiltinTypes.h:376
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:397
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:165
Builder(ArrayRef< int64_t > shape, Type elementType)
Definition: BuiltinTypes.h:173
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:186
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:181
Builder(MemRefType other)
Definition: BuiltinTypes.h:168
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:176
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:191
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:213
Builder(ArrayRef< int64_t > shape, Type elementType, Attribute encoding)
Build from scratch.
Definition: BuiltinTypes.h:221
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:240
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:224
Builder & insertDim(int64_t val, unsigned pos)
Insert a val into shape @pos.
Definition: BuiltinTypes.h:250
Builder(RankedTensorType other)
Build from another RankedTensorType.
Definition: BuiltinTypes.h:216
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:229
Builder & setEncoding(Attribute newEncoding)
Definition: BuiltinTypes.h:234
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:77
static bool classof(Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: BuiltinTypes.h:413
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
constexpr Type()=default
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:107
bool isa() const
Definition: Types.h:301
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:277
Builder(ArrayRef< int64_t > shape, Type elementType, unsigned numScalableDims=0)
Build from scratch.
Definition: BuiltinTypes.h:285
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:303
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:297
Builder(VectorType other)
Build from another VectorType.
Definition: BuiltinTypes.h:280
Builder & setShape(ArrayRef< int64_t > newShape, unsigned newNumScalableDims=0)
Definition: BuiltinTypes.h:290
Include the generated interface declarations.
Definition: CallGraph.h:229
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:345
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26