MLIR  21.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Type Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_TYPEDEF_CLASSES
31 #include "mlir/IR/BuiltinTypes.cpp.inc"
32 
33 namespace mlir {
34 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
35 } // namespace mlir
36 
37 //===----------------------------------------------------------------------===//
38 // BuiltinDialect
39 //===----------------------------------------------------------------------===//
40 
41 void BuiltinDialect::registerTypes() {
42  addTypes<
43 #define GET_TYPEDEF_LIST
44 #include "mlir/IR/BuiltinTypes.cpp.inc"
45  >();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 /// ComplexType
50 //===----------------------------------------------------------------------===//
51 
52 /// Verify the construction of an integer type.
54  Type elementType) {
55  if (!elementType.isIntOrFloat())
56  return emitError() << "invalid element type for complex";
57  return success();
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Integer Type
62 //===----------------------------------------------------------------------===//
63 
64 /// Verify the construction of an integer type.
66  unsigned width,
67  SignednessSemantics signedness) {
68  if (width > IntegerType::kMaxWidth) {
69  return emitError() << "integer bitwidth is limited to "
70  << IntegerType::kMaxWidth << " bits";
71  }
72  return success();
73 }
74 
75 unsigned IntegerType::getWidth() const { return getImpl()->width; }
76 
77 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
78  return getImpl()->signedness;
79 }
80 
81 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
82  if (!scale)
83  return IntegerType();
84  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Float Types
89 //===----------------------------------------------------------------------===//
90 
91 // Mapping from MLIR FloatType to APFloat semantics.
92 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
93  const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
94  return APFloat::SEM(); \
95  }
96 FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
97 FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
98 FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
99 FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
100 FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
101 FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
102 FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
103 FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
104 FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
105 FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
106 FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
107 FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
108 FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
109 FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
110 FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
111 FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
112 FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
113 FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
114 #undef FLOAT_TYPE_SEMANTICS
115 
116 FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
117  if (scale == 2)
118  return Float32Type::get(getContext());
119  if (scale == 4)
120  return Float64Type::get(getContext());
121  return FloatType();
122 }
123 
124 FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
125  if (scale == 2)
126  return Float32Type::get(getContext());
127  if (scale == 4)
128  return Float64Type::get(getContext());
129  return FloatType();
130 }
131 
132 FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
133  if (scale == 2)
134  return Float64Type::get(getContext());
135  return FloatType();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // FunctionType
140 //===----------------------------------------------------------------------===//
141 
142 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
143 
144 ArrayRef<Type> FunctionType::getInputs() const {
145  return getImpl()->getInputs();
146 }
147 
148 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
149 
150 ArrayRef<Type> FunctionType::getResults() const {
151  return getImpl()->getResults();
152 }
153 
154 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
155  return get(getContext(), inputs, results);
156 }
157 
158 /// Returns a new function type with the specified arguments and results
159 /// inserted.
160 FunctionType FunctionType::getWithArgsAndResults(
161  ArrayRef<unsigned> argIndices, TypeRange argTypes,
162  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
163  SmallVector<Type> argStorage, resultStorage;
164  TypeRange newArgTypes =
165  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
166  TypeRange newResultTypes =
167  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
168  return clone(newArgTypes, newResultTypes);
169 }
170 
171 /// Returns a new function type without the specified arguments and results.
172 FunctionType
173 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
174  const BitVector &resultIndices) {
175  SmallVector<Type> argStorage, resultStorage;
176  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
177  TypeRange newResultTypes =
178  filterTypesOut(getResults(), resultIndices, resultStorage);
179  return clone(newArgTypes, newResultTypes);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // OpaqueType
184 //===----------------------------------------------------------------------===//
185 
186 /// Verify the construction of an opaque type.
188  StringAttr dialect, StringRef typeData) {
189  if (!Dialect::isValidNamespace(dialect.strref()))
190  return emitError() << "invalid dialect namespace '" << dialect << "'";
191 
192  // Check that the dialect is actually registered.
193  MLIRContext *context = dialect.getContext();
194  if (!context->allowsUnregisteredDialects() &&
195  !context->getLoadedDialect(dialect.strref())) {
196  return emitError()
197  << "`!" << dialect << "<\"" << typeData << "\">"
198  << "` type created with unregistered dialect. If this is "
199  "intended, please call allowUnregisteredDialects() on the "
200  "MLIRContext, or use -allow-unregistered-dialect with "
201  "the MLIR opt tool used";
202  }
203 
204  return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // VectorType
209 //===----------------------------------------------------------------------===//
210 
211 bool VectorType::isValidElementType(Type t) {
212  return isValidVectorTypeElementType(t);
213 }
214 
216  ArrayRef<int64_t> shape, Type elementType,
217  ArrayRef<bool> scalableDims) {
218  if (!isValidElementType(elementType))
219  return emitError()
220  << "vector elements must be int/index/float type but got "
221  << elementType;
222 
223  if (any_of(shape, [](int64_t i) { return i <= 0; }))
224  return emitError()
225  << "vector types must have positive constant sizes but got "
226  << shape;
227 
228  if (scalableDims.size() != shape.size())
229  return emitError() << "number of dims must match, got "
230  << scalableDims.size() << " and " << shape.size();
231 
232  return success();
233 }
234 
235 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
236  if (!scale)
237  return VectorType();
238  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
239  if (auto scaledEt = et.scaleElementBitwidth(scale))
240  return VectorType::get(getShape(), scaledEt, getScalableDims());
241  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
242  if (auto scaledEt = et.scaleElementBitwidth(scale))
243  return VectorType::get(getShape(), scaledEt, getScalableDims());
244  return VectorType();
245 }
246 
247 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
248  Type elementType) const {
249  return VectorType::get(shape.value_or(getShape()), elementType,
250  getScalableDims());
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // TensorType
255 //===----------------------------------------------------------------------===//
256 
259  .Case<RankedTensorType, UnrankedTensorType>(
260  [](auto type) { return type.getElementType(); });
261 }
262 
263 bool TensorType::hasRank() const {
264  return !llvm::isa<UnrankedTensorType>(*this);
265 }
266 
268  return llvm::cast<RankedTensorType>(*this).getShape();
269 }
270 
272  Type elementType) const {
273  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
274  if (shape)
275  return RankedTensorType::get(*shape, elementType);
276  return UnrankedTensorType::get(elementType);
277  }
278 
279  auto rankedTy = llvm::cast<RankedTensorType>(*this);
280  if (!shape)
281  return RankedTensorType::get(rankedTy.getShape(), elementType,
282  rankedTy.getEncoding());
283  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
284  rankedTy.getEncoding());
285 }
286 
287 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
288  Type elementType) const {
289  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
290 }
291 
292 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
293  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
294 }
295 
296 // Check if "elementType" can be an element type of a tensor.
297 static LogicalResult
299  Type elementType) {
300  if (!TensorType::isValidElementType(elementType))
301  return emitError() << "invalid tensor element type: " << elementType;
302  return success();
303 }
304 
305 /// Return true if the specified element type is ok in a tensor.
307  // Note: Non standard/builtin types are allowed to exist within tensor
308  // types. Dialects are expected to verify that tensor types have a valid
309  // element type within that dialect.
310  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
311  IndexType>(type) ||
312  !llvm::isa<BuiltinDialect>(type.getDialect());
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // RankedTensorType
317 //===----------------------------------------------------------------------===//
318 
319 LogicalResult
321  ArrayRef<int64_t> shape, Type elementType,
322  Attribute encoding) {
323  for (int64_t s : shape)
324  if (s < 0 && ShapedType::isStatic(s))
325  return emitError() << "invalid tensor dimension size";
326  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
327  if (failed(v.verifyEncoding(shape, elementType, emitError)))
328  return failure();
329  return checkTensorElementType(emitError, elementType);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // UnrankedTensorType
334 //===----------------------------------------------------------------------===//
335 
336 LogicalResult
338  Type elementType) {
339  return checkTensorElementType(emitError, elementType);
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // BaseMemRefType
344 //===----------------------------------------------------------------------===//
345 
348  .Case<MemRefType, UnrankedMemRefType>(
349  [](auto type) { return type.getElementType(); });
350 }
351 
353  return !llvm::isa<UnrankedMemRefType>(*this);
354 }
355 
357  return llvm::cast<MemRefType>(*this).getShape();
358 }
359 
361  Type elementType) const {
362  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
363  if (!shape)
364  return UnrankedMemRefType::get(elementType, getMemorySpace());
365  MemRefType::Builder builder(*shape, elementType);
366  builder.setMemorySpace(getMemorySpace());
367  return builder;
368  }
369 
370  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
371  if (shape)
372  builder.setShape(*shape);
373  builder.setElementType(elementType);
374  return builder;
375 }
376 
377 FailureOr<PtrLikeTypeInterface>
379  std::optional<Type> elementType) const {
380  Type eTy = elementType ? *elementType : getElementType();
381  if (llvm::dyn_cast<UnrankedMemRefType>(*this))
382  return cast<PtrLikeTypeInterface>(
383  UnrankedMemRefType::get(eTy, memorySpace));
384 
385  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
386  builder.setElementType(eTy);
387  builder.setMemorySpace(memorySpace);
388  return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
389 }
390 
392  Type elementType) const {
393  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
394 }
395 
396 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
397  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
398 }
399 
401  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
402  return rankedMemRefTy.getMemorySpace();
403  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
404 }
405 
407  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
408  return rankedMemRefTy.getMemorySpaceAsInt();
409  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // MemRefType
414 //===----------------------------------------------------------------------===//
415 
416 std::optional<llvm::SmallDenseSet<unsigned>>
418  ArrayRef<int64_t> reducedShape,
419  bool matchDynamic) {
420  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
421  llvm::SmallDenseSet<unsigned> unusedDims;
422  unsigned reducedIdx = 0;
423  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
424  // Greedily insert `originalIdx` if match.
425  int64_t origSize = originalShape[originalIdx];
426  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
427  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
428  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
429  ShapedType::isDynamic(origSize))) {
430  reducedIdx++;
431  continue;
432  }
433  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
434  reducedIdx++;
435  continue;
436  }
437 
438  unusedDims.insert(originalIdx);
439  // If no match on `originalIdx`, the `originalShape` at this dimension
440  // must be 1, otherwise we bail.
441  if (origSize != 1)
442  return std::nullopt;
443  }
444  // The whole reducedShape must be scanned, otherwise we bail.
445  if (reducedIdx != reducedRank)
446  return std::nullopt;
447  return unusedDims;
448 }
449 
451 mlir::isRankReducedType(ShapedType originalType,
452  ShapedType candidateReducedType) {
453  if (originalType == candidateReducedType)
455 
456  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
457  ShapedType candidateReducedShapedType =
458  llvm::cast<ShapedType>(candidateReducedType);
459 
460  // Rank and size logic is valid for all ShapedTypes.
461  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
462  ArrayRef<int64_t> candidateReducedShape =
463  candidateReducedShapedType.getShape();
464  unsigned originalRank = originalShape.size(),
465  candidateReducedRank = candidateReducedShape.size();
466  if (candidateReducedRank > originalRank)
468 
469  auto optionalUnusedDimsMask =
470  computeRankReductionMask(originalShape, candidateReducedShape);
471 
472  // Sizes cannot be matched in case empty vector is returned.
473  if (!optionalUnusedDimsMask)
475 
476  if (originalShapedType.getElementType() !=
477  candidateReducedShapedType.getElementType())
479 
481 }
482 
484  // Empty attribute is allowed as default memory space.
485  if (!memorySpace)
486  return true;
487 
488  // Supported built-in attributes.
489  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
490  return true;
491 
492  // Allow custom dialect attributes.
493  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
494  return true;
495 
496  return false;
497 }
498 
500  MLIRContext *ctx) {
501  if (memorySpace == 0)
502  return nullptr;
503 
504  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
505 }
506 
508  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
509  if (intMemorySpace && intMemorySpace.getValue() == 0)
510  return nullptr;
511 
512  return memorySpace;
513 }
514 
516  if (!memorySpace)
517  return 0;
518 
519  assert(llvm::isa<IntegerAttr>(memorySpace) &&
520  "Using `getMemorySpaceInteger` with non-Integer attribute");
521 
522  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
523 }
524 
525 unsigned MemRefType::getMemorySpaceAsInt() const {
526  return detail::getMemorySpaceAsInt(getMemorySpace());
527 }
528 
529 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
530  MemRefLayoutAttrInterface layout,
531  Attribute memorySpace) {
532  // Use default layout for empty attribute.
533  if (!layout)
535  shape.size(), elementType.getContext()));
536 
537  // Drop default memory space value and replace it with empty attribute.
538  memorySpace = skipDefaultMemorySpace(memorySpace);
539 
540  return Base::get(elementType.getContext(), shape, elementType, layout,
541  memorySpace);
542 }
543 
544 MemRefType MemRefType::getChecked(
545  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
546  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
547 
548  // Use default layout for empty attribute.
549  if (!layout)
551  shape.size(), elementType.getContext()));
552 
553  // Drop default memory space value and replace it with empty attribute.
554  memorySpace = skipDefaultMemorySpace(memorySpace);
555 
556  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
557  elementType, layout, memorySpace);
558 }
559 
560 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
561  AffineMap map, Attribute memorySpace) {
562 
563  // Use default layout for empty map.
564  if (!map)
565  map = AffineMap::getMultiDimIdentityMap(shape.size(),
566  elementType.getContext());
567 
568  // Wrap AffineMap into Attribute.
569  auto layout = AffineMapAttr::get(map);
570 
571  // Drop default memory space value and replace it with empty attribute.
572  memorySpace = skipDefaultMemorySpace(memorySpace);
573 
574  return Base::get(elementType.getContext(), shape, elementType, layout,
575  memorySpace);
576 }
577 
578 MemRefType
579 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
580  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
581  Attribute memorySpace) {
582 
583  // Use default layout for empty map.
584  if (!map)
585  map = AffineMap::getMultiDimIdentityMap(shape.size(),
586  elementType.getContext());
587 
588  // Wrap AffineMap into Attribute.
589  auto layout = AffineMapAttr::get(map);
590 
591  // Drop default memory space value and replace it with empty attribute.
592  memorySpace = skipDefaultMemorySpace(memorySpace);
593 
594  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
595  elementType, layout, memorySpace);
596 }
597 
598 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
599  AffineMap map, unsigned memorySpaceInd) {
600 
601  // Use default layout for empty map.
602  if (!map)
603  map = AffineMap::getMultiDimIdentityMap(shape.size(),
604  elementType.getContext());
605 
606  // Wrap AffineMap into Attribute.
607  auto layout = AffineMapAttr::get(map);
608 
609  // Convert deprecated integer-like memory space to Attribute.
610  Attribute memorySpace =
611  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
612 
613  return Base::get(elementType.getContext(), shape, elementType, layout,
614  memorySpace);
615 }
616 
617 MemRefType
618 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
619  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
620  unsigned memorySpaceInd) {
621 
622  // Use default layout for empty map.
623  if (!map)
624  map = AffineMap::getMultiDimIdentityMap(shape.size(),
625  elementType.getContext());
626 
627  // Wrap AffineMap into Attribute.
628  auto layout = AffineMapAttr::get(map);
629 
630  // Convert deprecated integer-like memory space to Attribute.
631  Attribute memorySpace =
632  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
633 
634  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
635  elementType, layout, memorySpace);
636 }
637 
639  ArrayRef<int64_t> shape, Type elementType,
640  MemRefLayoutAttrInterface layout,
641  Attribute memorySpace) {
642  if (!BaseMemRefType::isValidElementType(elementType))
643  return emitError() << "invalid memref element type";
644 
645  // Negative sizes are not allowed except for `kDynamic`.
646  for (int64_t s : shape)
647  if (s < 0 && ShapedType::isStatic(s))
648  return emitError() << "invalid memref size";
649 
650  assert(layout && "missing layout specification");
651  if (failed(layout.verifyLayout(shape, emitError)))
652  return failure();
653 
654  if (!isSupportedMemorySpace(memorySpace))
655  return emitError() << "unsupported memory space Attribute";
656 
657  return success();
658 }
659 
660 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
661  assert(n <= getRank() &&
662  "number of dimensions to check must not exceed rank");
663  return n <= getNumContiguousTrailingDims();
664 }
665 
666 int64_t MemRefType::getNumContiguousTrailingDims() {
667  const int64_t n = getRank();
668 
669  // memrefs with identity layout are entirely contiguous.
670  if (getLayout().isIdentity())
671  return n;
672 
673  // Get the strides (if any). Failing to do that, conservatively assume a
674  // non-contiguous layout.
675  int64_t offset;
676  SmallVector<int64_t> strides;
677  if (!succeeded(getStridesAndOffset(strides, offset)))
678  return 0;
679 
680  ArrayRef<int64_t> shape = getShape();
681 
682  // A memref with dimensions `d0, d1, ..., dn-1` and strides
683  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
684  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
685  // for `i` in `[k, n-1]`.
686  // Ignore stride elements if the corresponding dimension is 1, as they are
687  // of no consequence.
688  int64_t dimProduct = 1;
689  for (int64_t i = n - 1; i >= 0; --i) {
690  if (shape[i] == 1)
691  continue;
692  if (strides[i] != dimProduct)
693  return n - i - 1;
694  if (shape[i] == ShapedType::kDynamic)
695  return n - i;
696  dimProduct *= shape[i];
697  }
698 
699  return n;
700 }
701 
702 MemRefType MemRefType::canonicalizeStridedLayout() {
703  AffineMap m = getLayout().getAffineMap();
704 
705  // Already in canonical form.
706  if (m.isIdentity())
707  return *this;
708 
709  // Can't reduce to canonical identity form, return in canonical form.
710  if (m.getNumResults() > 1)
711  return *this;
712 
713  // Corner-case for 0-D affine maps.
714  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
715  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
716  if (cst.getValue() == 0)
717  return MemRefType::Builder(*this).setLayout({});
718  return *this;
719  }
720 
721  // 0-D corner case for empty shape that still have an affine map. Example:
722  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
723  // offset needs to remain, just return t.
724  if (getShape().empty())
725  return *this;
726 
727  // If the canonical strided layout for the sizes of `t` is equal to the
728  // simplified layout of `t` we can just return an empty layout. Otherwise,
729  // just simplify the existing layout.
731  auto simplifiedLayoutExpr =
733  if (expr != simplifiedLayoutExpr)
734  return MemRefType::Builder(*this).setLayout(
736  simplifiedLayoutExpr)));
737  return MemRefType::Builder(*this).setLayout({});
738 }
739 
740 LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
741  int64_t &offset) const {
742  return getLayout().getStridesAndOffset(getShape(), strides, offset);
743 }
744 
745 std::pair<SmallVector<int64_t>, int64_t>
747  SmallVector<int64_t> strides;
748  int64_t offset;
749  LogicalResult status = getStridesAndOffset(strides, offset);
750  (void)status;
751  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
752  return {strides, offset};
753 }
754 
755 bool MemRefType::isStrided() {
756  int64_t offset;
757  SmallVector<int64_t, 4> strides;
758  auto res = getStridesAndOffset(strides, offset);
759  return succeeded(res);
760 }
761 
762 bool MemRefType::isLastDimUnitStride() {
763  int64_t offset;
764  SmallVector<int64_t> strides;
765  auto successStrides = getStridesAndOffset(strides, offset);
766  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // UnrankedMemRefType
771 //===----------------------------------------------------------------------===//
772 
774  return detail::getMemorySpaceAsInt(getMemorySpace());
775 }
776 
777 LogicalResult
779  Type elementType, Attribute memorySpace) {
780  if (!BaseMemRefType::isValidElementType(elementType))
781  return emitError() << "invalid memref element type";
782 
783  if (!isSupportedMemorySpace(memorySpace))
784  return emitError() << "unsupported memory space Attribute";
785 
786  return success();
787 }
788 
789 //===----------------------------------------------------------------------===//
790 /// TupleType
791 //===----------------------------------------------------------------------===//
792 
793 /// Return the elements types for this tuple.
794 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
795 
796 /// Accumulate the types contained in this tuple and tuples nested within it.
797 /// Note that this only flattens nested tuples, not any other container type,
798 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
799 /// (i32, tensor<i32>, f32, i64)
800 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
801  for (Type type : getTypes()) {
802  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
803  nestedTuple.getFlattenedTypes(types);
804  else
805  types.push_back(type);
806  }
807 }
808 
809 /// Return the number of element types.
810 size_t TupleType::size() const { return getImpl()->size(); }
811 
812 //===----------------------------------------------------------------------===//
813 // Type Utilities
814 //===----------------------------------------------------------------------===//
815 
817  ArrayRef<AffineExpr> exprs,
818  MLIRContext *context) {
819  // Size 0 corner case is useful for canonicalizations.
820  if (sizes.empty())
821  return getAffineConstantExpr(0, context);
822 
823  assert(!exprs.empty() && "expected exprs");
824  auto maps = AffineMap::inferFromExprList(exprs, context);
825  assert(!maps.empty() && "Expected one non-empty map");
826  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
827 
828  AffineExpr expr;
829  bool dynamicPoisonBit = false;
830  int64_t runningSize = 1;
831  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
832  int64_t size = std::get<1>(en);
833  AffineExpr dimExpr = std::get<0>(en);
834  AffineExpr stride = dynamicPoisonBit
835  ? getAffineSymbolExpr(nSymbols++, context)
836  : getAffineConstantExpr(runningSize, context);
837  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
838  if (size > 0) {
839  runningSize *= size;
840  assert(runningSize > 0 && "integer overflow in size computation");
841  } else {
842  dynamicPoisonBit = true;
843  }
844  }
845  return simplifyAffineExpr(expr, numDims, nSymbols);
846 }
847 
849  MLIRContext *context) {
851  exprs.reserve(sizes.size());
852  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
853  exprs.push_back(getAffineDimExpr(dim, context));
854  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
855 }
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:58
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:413
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:95
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:198
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631