MLIR  22.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Type Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_TYPEDEF_CLASSES
31 #include "mlir/IR/BuiltinTypes.cpp.inc"
32 
33 namespace mlir {
34 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
35 } // namespace mlir
36 
37 //===----------------------------------------------------------------------===//
38 // BuiltinDialect
39 //===----------------------------------------------------------------------===//
40 
41 void BuiltinDialect::registerTypes() {
42  addTypes<
43 #define GET_TYPEDEF_LIST
44 #include "mlir/IR/BuiltinTypes.cpp.inc"
45  >();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 /// ComplexType
50 //===----------------------------------------------------------------------===//
51 
52 /// Verify the construction of an integer type.
54  Type elementType) {
55  if (!elementType.isIntOrFloat())
56  return emitError() << "invalid element type for complex";
57  return success();
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Integer Type
62 //===----------------------------------------------------------------------===//
63 
64 /// Verify the construction of an integer type.
66  unsigned width,
67  SignednessSemantics signedness) {
68  if (width > IntegerType::kMaxWidth) {
69  return emitError() << "integer bitwidth is limited to "
70  << IntegerType::kMaxWidth << " bits";
71  }
72  return success();
73 }
74 
75 unsigned IntegerType::getWidth() const { return getImpl()->width; }
76 
77 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
78  return getImpl()->signedness;
79 }
80 
81 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
82  if (!scale)
83  return IntegerType();
84  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Float Types
89 //===----------------------------------------------------------------------===//
90 
91 // Mapping from MLIR FloatType to APFloat semantics.
92 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
93  const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
94  return APFloat::SEM(); \
95  }
96 FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
97 FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
98 FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
99 FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
100 FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
101 FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
102 FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
103 FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
104 FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
105 FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
106 FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
107 FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
108 FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
109 FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
110 FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
111 FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
112 FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
113 FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
114 #undef FLOAT_TYPE_SEMANTICS
115 
116 FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
117  if (scale == 2)
118  return Float32Type::get(getContext());
119  if (scale == 4)
120  return Float64Type::get(getContext());
121  return FloatType();
122 }
123 
124 FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
125  if (scale == 2)
126  return Float32Type::get(getContext());
127  if (scale == 4)
128  return Float64Type::get(getContext());
129  return FloatType();
130 }
131 
132 FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
133  if (scale == 2)
134  return Float64Type::get(getContext());
135  return FloatType();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // FunctionType
140 //===----------------------------------------------------------------------===//
141 
142 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
143 
144 ArrayRef<Type> FunctionType::getInputs() const {
145  return getImpl()->getInputs();
146 }
147 
148 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
149 
150 ArrayRef<Type> FunctionType::getResults() const {
151  return getImpl()->getResults();
152 }
153 
154 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
155  return get(getContext(), inputs, results);
156 }
157 
158 /// Returns a new function type with the specified arguments and results
159 /// inserted.
160 FunctionType FunctionType::getWithArgsAndResults(
161  ArrayRef<unsigned> argIndices, TypeRange argTypes,
162  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
163  SmallVector<Type> argStorage, resultStorage;
164  TypeRange newArgTypes =
165  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
166  TypeRange newResultTypes =
167  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
168  return clone(newArgTypes, newResultTypes);
169 }
170 
171 /// Returns a new function type without the specified arguments and results.
172 FunctionType
173 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
174  const BitVector &resultIndices) {
175  SmallVector<Type> argStorage, resultStorage;
176  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
177  TypeRange newResultTypes =
178  filterTypesOut(getResults(), resultIndices, resultStorage);
179  return clone(newArgTypes, newResultTypes);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // GraphType
184 //===----------------------------------------------------------------------===//
185 
186 unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
187 
188 ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
189 
190 unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
191 
192 ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
193 
194 GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
195  return get(getContext(), inputs, results);
196 }
197 
198 /// Returns a new function type with the specified arguments and results
199 /// inserted.
200 GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
201  TypeRange argTypes,
202  ArrayRef<unsigned> resultIndices,
203  TypeRange resultTypes) {
204  SmallVector<Type> argStorage, resultStorage;
205  TypeRange newArgTypes =
206  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
207  TypeRange newResultTypes =
208  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
209  return clone(newArgTypes, newResultTypes);
210 }
211 
212 /// Returns a new function type without the specified arguments and results.
213 GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
214  const BitVector &resultIndices) {
215  SmallVector<Type> argStorage, resultStorage;
216  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
217  TypeRange newResultTypes =
218  filterTypesOut(getResults(), resultIndices, resultStorage);
219  return clone(newArgTypes, newResultTypes);
220 }
221 //===----------------------------------------------------------------------===//
222 // OpaqueType
223 //===----------------------------------------------------------------------===//
224 
225 /// Verify the construction of an opaque type.
227  StringAttr dialect, StringRef typeData) {
228  if (!Dialect::isValidNamespace(dialect.strref()))
229  return emitError() << "invalid dialect namespace '" << dialect << "'";
230 
231  // Check that the dialect is actually registered.
232  MLIRContext *context = dialect.getContext();
233  if (!context->allowsUnregisteredDialects() &&
234  !context->getLoadedDialect(dialect.strref())) {
235  return emitError()
236  << "`!" << dialect << "<\"" << typeData << "\">"
237  << "` type created with unregistered dialect. If this is "
238  "intended, please call allowUnregisteredDialects() on the "
239  "MLIRContext, or use -allow-unregistered-dialect with "
240  "the MLIR opt tool used";
241  }
242 
243  return success();
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // VectorType
248 //===----------------------------------------------------------------------===//
249 
250 bool VectorType::isValidElementType(Type t) {
251  return isValidVectorTypeElementType(t);
252 }
253 
255  ArrayRef<int64_t> shape, Type elementType,
256  ArrayRef<bool> scalableDims) {
257  if (!isValidElementType(elementType))
258  return emitError()
259  << "vector elements must be int/index/float type but got "
260  << elementType;
261 
262  if (any_of(shape, [](int64_t i) { return i <= 0; }))
263  return emitError()
264  << "vector types must have positive constant sizes but got "
265  << shape;
266 
267  if (scalableDims.size() != shape.size())
268  return emitError() << "number of dims must match, got "
269  << scalableDims.size() << " and " << shape.size();
270 
271  return success();
272 }
273 
274 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
275  if (!scale)
276  return VectorType();
277  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
278  if (auto scaledEt = et.scaleElementBitwidth(scale))
279  return VectorType::get(getShape(), scaledEt, getScalableDims());
280  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
281  if (auto scaledEt = et.scaleElementBitwidth(scale))
282  return VectorType::get(getShape(), scaledEt, getScalableDims());
283  return VectorType();
284 }
285 
286 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
287  Type elementType) const {
288  return VectorType::get(shape.value_or(getShape()), elementType,
289  getScalableDims());
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // TensorType
294 //===----------------------------------------------------------------------===//
295 
298  .Case<RankedTensorType, UnrankedTensorType>(
299  [](auto type) { return type.getElementType(); });
300 }
301 
302 bool TensorType::hasRank() const {
303  return !llvm::isa<UnrankedTensorType>(*this);
304 }
305 
307  return llvm::cast<RankedTensorType>(*this).getShape();
308 }
309 
311  Type elementType) const {
312  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
313  if (shape)
314  return RankedTensorType::get(*shape, elementType);
315  return UnrankedTensorType::get(elementType);
316  }
317 
318  auto rankedTy = llvm::cast<RankedTensorType>(*this);
319  if (!shape)
320  return RankedTensorType::get(rankedTy.getShape(), elementType,
321  rankedTy.getEncoding());
322  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
323  rankedTy.getEncoding());
324 }
325 
326 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
327  Type elementType) const {
328  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
329 }
330 
331 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
332  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
333 }
334 
335 // Check if "elementType" can be an element type of a tensor.
336 static LogicalResult
338  Type elementType) {
339  if (!TensorType::isValidElementType(elementType))
340  return emitError() << "invalid tensor element type: " << elementType;
341  return success();
342 }
343 
344 /// Return true if the specified element type is ok in a tensor.
346  // Note: Non standard/builtin types are allowed to exist within tensor
347  // types. Dialects are expected to verify that tensor types have a valid
348  // element type within that dialect.
349  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
350  IndexType>(type) ||
351  !llvm::isa<BuiltinDialect>(type.getDialect());
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // RankedTensorType
356 //===----------------------------------------------------------------------===//
357 
358 LogicalResult
360  ArrayRef<int64_t> shape, Type elementType,
361  Attribute encoding) {
362  for (int64_t s : shape)
363  if (s < 0 && ShapedType::isStatic(s))
364  return emitError() << "invalid tensor dimension size";
365  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
366  if (failed(v.verifyEncoding(shape, elementType, emitError)))
367  return failure();
368  return checkTensorElementType(emitError, elementType);
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // UnrankedTensorType
373 //===----------------------------------------------------------------------===//
374 
375 LogicalResult
377  Type elementType) {
378  return checkTensorElementType(emitError, elementType);
379 }
380 
381 //===----------------------------------------------------------------------===//
382 // BaseMemRefType
383 //===----------------------------------------------------------------------===//
384 
387  .Case<MemRefType, UnrankedMemRefType>(
388  [](auto type) { return type.getElementType(); });
389 }
390 
392  return !llvm::isa<UnrankedMemRefType>(*this);
393 }
394 
396  return llvm::cast<MemRefType>(*this).getShape();
397 }
398 
400  Type elementType) const {
401  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
402  if (!shape)
403  return UnrankedMemRefType::get(elementType, getMemorySpace());
404  MemRefType::Builder builder(*shape, elementType);
405  builder.setMemorySpace(getMemorySpace());
406  return builder;
407  }
408 
409  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
410  if (shape)
411  builder.setShape(*shape);
412  builder.setElementType(elementType);
413  return builder;
414 }
415 
416 FailureOr<PtrLikeTypeInterface>
418  std::optional<Type> elementType) const {
419  Type eTy = elementType ? *elementType : getElementType();
420  if (llvm::dyn_cast<UnrankedMemRefType>(*this))
421  return cast<PtrLikeTypeInterface>(
422  UnrankedMemRefType::get(eTy, memorySpace));
423 
424  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
425  builder.setElementType(eTy);
426  builder.setMemorySpace(memorySpace);
427  return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
428 }
429 
431  Type elementType) const {
432  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
433 }
434 
435 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
436  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
437 }
438 
440  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
441  return rankedMemRefTy.getMemorySpace();
442  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
443 }
444 
446  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
447  return rankedMemRefTy.getMemorySpaceAsInt();
448  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // MemRefType
453 //===----------------------------------------------------------------------===//
454 
455 std::optional<llvm::SmallDenseSet<unsigned>>
457  ArrayRef<int64_t> reducedShape,
458  bool matchDynamic) {
459  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
460  llvm::SmallDenseSet<unsigned> unusedDims;
461  unsigned reducedIdx = 0;
462  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
463  // Greedily insert `originalIdx` if match.
464  int64_t origSize = originalShape[originalIdx];
465  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
466  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
467  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
468  ShapedType::isDynamic(origSize))) {
469  reducedIdx++;
470  continue;
471  }
472  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
473  reducedIdx++;
474  continue;
475  }
476 
477  unusedDims.insert(originalIdx);
478  // If no match on `originalIdx`, the `originalShape` at this dimension
479  // must be 1, otherwise we bail.
480  if (origSize != 1)
481  return std::nullopt;
482  }
483  // The whole reducedShape must be scanned, otherwise we bail.
484  if (reducedIdx != reducedRank)
485  return std::nullopt;
486  return unusedDims;
487 }
488 
490 mlir::isRankReducedType(ShapedType originalType,
491  ShapedType candidateReducedType) {
492  if (originalType == candidateReducedType)
494 
495  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
496  ShapedType candidateReducedShapedType =
497  llvm::cast<ShapedType>(candidateReducedType);
498 
499  // Rank and size logic is valid for all ShapedTypes.
500  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
501  ArrayRef<int64_t> candidateReducedShape =
502  candidateReducedShapedType.getShape();
503  unsigned originalRank = originalShape.size(),
504  candidateReducedRank = candidateReducedShape.size();
505  if (candidateReducedRank > originalRank)
507 
508  auto optionalUnusedDimsMask =
509  computeRankReductionMask(originalShape, candidateReducedShape);
510 
511  // Sizes cannot be matched in case empty vector is returned.
512  if (!optionalUnusedDimsMask)
514 
515  if (originalShapedType.getElementType() !=
516  candidateReducedShapedType.getElementType())
518 
520 }
521 
523  // Empty attribute is allowed as default memory space.
524  if (!memorySpace)
525  return true;
526 
527  // Supported built-in attributes.
528  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
529  return true;
530 
531  // Allow custom dialect attributes.
532  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
533  return true;
534 
535  return false;
536 }
537 
539  MLIRContext *ctx) {
540  if (memorySpace == 0)
541  return nullptr;
542 
543  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
544 }
545 
547  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
548  if (intMemorySpace && intMemorySpace.getValue() == 0)
549  return nullptr;
550 
551  return memorySpace;
552 }
553 
555  if (!memorySpace)
556  return 0;
557 
558  assert(llvm::isa<IntegerAttr>(memorySpace) &&
559  "Using `getMemorySpaceInteger` with non-Integer attribute");
560 
561  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
562 }
563 
564 unsigned MemRefType::getMemorySpaceAsInt() const {
565  return detail::getMemorySpaceAsInt(getMemorySpace());
566 }
567 
568 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
569  MemRefLayoutAttrInterface layout,
570  Attribute memorySpace) {
571  // Use default layout for empty attribute.
572  if (!layout)
574  shape.size(), elementType.getContext()));
575 
576  // Drop default memory space value and replace it with empty attribute.
577  memorySpace = skipDefaultMemorySpace(memorySpace);
578 
579  return Base::get(elementType.getContext(), shape, elementType, layout,
580  memorySpace);
581 }
582 
583 MemRefType MemRefType::getChecked(
584  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
585  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
586 
587  // Use default layout for empty attribute.
588  if (!layout)
590  shape.size(), elementType.getContext()));
591 
592  // Drop default memory space value and replace it with empty attribute.
593  memorySpace = skipDefaultMemorySpace(memorySpace);
594 
595  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
596  elementType, layout, memorySpace);
597 }
598 
599 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
600  AffineMap map, Attribute memorySpace) {
601 
602  // Use default layout for empty map.
603  if (!map)
604  map = AffineMap::getMultiDimIdentityMap(shape.size(),
605  elementType.getContext());
606 
607  // Wrap AffineMap into Attribute.
608  auto layout = AffineMapAttr::get(map);
609 
610  // Drop default memory space value and replace it with empty attribute.
611  memorySpace = skipDefaultMemorySpace(memorySpace);
612 
613  return Base::get(elementType.getContext(), shape, elementType, layout,
614  memorySpace);
615 }
616 
617 MemRefType
618 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
619  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
620  Attribute memorySpace) {
621 
622  // Use default layout for empty map.
623  if (!map)
624  map = AffineMap::getMultiDimIdentityMap(shape.size(),
625  elementType.getContext());
626 
627  // Wrap AffineMap into Attribute.
628  auto layout = AffineMapAttr::get(map);
629 
630  // Drop default memory space value and replace it with empty attribute.
631  memorySpace = skipDefaultMemorySpace(memorySpace);
632 
633  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
634  elementType, layout, memorySpace);
635 }
636 
637 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
638  AffineMap map, unsigned memorySpaceInd) {
639 
640  // Use default layout for empty map.
641  if (!map)
642  map = AffineMap::getMultiDimIdentityMap(shape.size(),
643  elementType.getContext());
644 
645  // Wrap AffineMap into Attribute.
646  auto layout = AffineMapAttr::get(map);
647 
648  // Convert deprecated integer-like memory space to Attribute.
649  Attribute memorySpace =
650  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
651 
652  return Base::get(elementType.getContext(), shape, elementType, layout,
653  memorySpace);
654 }
655 
656 MemRefType
657 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
658  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
659  unsigned memorySpaceInd) {
660 
661  // Use default layout for empty map.
662  if (!map)
663  map = AffineMap::getMultiDimIdentityMap(shape.size(),
664  elementType.getContext());
665 
666  // Wrap AffineMap into Attribute.
667  auto layout = AffineMapAttr::get(map);
668 
669  // Convert deprecated integer-like memory space to Attribute.
670  Attribute memorySpace =
671  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
672 
673  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
674  elementType, layout, memorySpace);
675 }
676 
678  ArrayRef<int64_t> shape, Type elementType,
679  MemRefLayoutAttrInterface layout,
680  Attribute memorySpace) {
681  if (!BaseMemRefType::isValidElementType(elementType))
682  return emitError() << "invalid memref element type";
683 
684  // Negative sizes are not allowed except for `kDynamic`.
685  for (int64_t s : shape)
686  if (s < 0 && ShapedType::isStatic(s))
687  return emitError() << "invalid memref size";
688 
689  assert(layout && "missing layout specification");
690  if (failed(layout.verifyLayout(shape, emitError)))
691  return failure();
692 
693  if (!isSupportedMemorySpace(memorySpace))
694  return emitError() << "unsupported memory space Attribute";
695 
696  return success();
697 }
698 
699 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
700  assert(n <= getRank() &&
701  "number of dimensions to check must not exceed rank");
702  return n <= getNumContiguousTrailingDims();
703 }
704 
705 int64_t MemRefType::getNumContiguousTrailingDims() {
706  const int64_t n = getRank();
707 
708  // memrefs with identity layout are entirely contiguous.
709  if (getLayout().isIdentity())
710  return n;
711 
712  // Get the strides (if any). Failing to do that, conservatively assume a
713  // non-contiguous layout.
714  int64_t offset;
715  SmallVector<int64_t> strides;
716  if (!succeeded(getStridesAndOffset(strides, offset)))
717  return 0;
718 
719  ArrayRef<int64_t> shape = getShape();
720 
721  // A memref with dimensions `d0, d1, ..., dn-1` and strides
722  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
723  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
724  // for `i` in `[k, n-1]`.
725  // Ignore stride elements if the corresponding dimension is 1, as they are
726  // of no consequence.
727  int64_t dimProduct = 1;
728  for (int64_t i = n - 1; i >= 0; --i) {
729  if (shape[i] == 1)
730  continue;
731  if (strides[i] != dimProduct)
732  return n - i - 1;
733  if (shape[i] == ShapedType::kDynamic)
734  return n - i;
735  dimProduct *= shape[i];
736  }
737 
738  return n;
739 }
740 
741 MemRefType MemRefType::canonicalizeStridedLayout() {
742  AffineMap m = getLayout().getAffineMap();
743 
744  // Already in canonical form.
745  if (m.isIdentity())
746  return *this;
747 
748  // Can't reduce to canonical identity form, return in canonical form.
749  if (m.getNumResults() > 1)
750  return *this;
751 
752  // Corner-case for 0-D affine maps.
753  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
754  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
755  if (cst.getValue() == 0)
756  return MemRefType::Builder(*this).setLayout({});
757  return *this;
758  }
759 
760  // 0-D corner case for empty shape that still have an affine map. Example:
761  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
762  // offset needs to remain, just return t.
763  if (getShape().empty())
764  return *this;
765 
766  // If the canonical strided layout for the sizes of `t` is equal to the
767  // simplified layout of `t` we can just return an empty layout. Otherwise,
768  // just simplify the existing layout.
770  auto simplifiedLayoutExpr =
772  if (expr != simplifiedLayoutExpr)
773  return MemRefType::Builder(*this).setLayout(
775  simplifiedLayoutExpr)));
776  return MemRefType::Builder(*this).setLayout({});
777 }
778 
779 LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
780  int64_t &offset) const {
781  return getLayout().getStridesAndOffset(getShape(), strides, offset);
782 }
783 
784 std::pair<SmallVector<int64_t>, int64_t>
786  SmallVector<int64_t> strides;
787  int64_t offset;
788  LogicalResult status = getStridesAndOffset(strides, offset);
789  (void)status;
790  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
791  return {strides, offset};
792 }
793 
794 bool MemRefType::isStrided() {
795  int64_t offset;
796  SmallVector<int64_t, 4> strides;
797  auto res = getStridesAndOffset(strides, offset);
798  return succeeded(res);
799 }
800 
801 bool MemRefType::isLastDimUnitStride() {
802  int64_t offset;
803  SmallVector<int64_t> strides;
804  auto successStrides = getStridesAndOffset(strides, offset);
805  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // UnrankedMemRefType
810 //===----------------------------------------------------------------------===//
811 
813  return detail::getMemorySpaceAsInt(getMemorySpace());
814 }
815 
816 LogicalResult
818  Type elementType, Attribute memorySpace) {
819  if (!BaseMemRefType::isValidElementType(elementType))
820  return emitError() << "invalid memref element type";
821 
822  if (!isSupportedMemorySpace(memorySpace))
823  return emitError() << "unsupported memory space Attribute";
824 
825  return success();
826 }
827 
828 //===----------------------------------------------------------------------===//
829 /// TupleType
830 //===----------------------------------------------------------------------===//
831 
832 /// Return the elements types for this tuple.
833 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
834 
835 /// Accumulate the types contained in this tuple and tuples nested within it.
836 /// Note that this only flattens nested tuples, not any other container type,
837 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
838 /// (i32, tensor<i32>, f32, i64)
839 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
840  for (Type type : getTypes()) {
841  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
842  nestedTuple.getFlattenedTypes(types);
843  else
844  types.push_back(type);
845  }
846 }
847 
848 /// Return the number of element types.
849 size_t TupleType::size() const { return getImpl()->size(); }
850 
851 //===----------------------------------------------------------------------===//
852 // Type Utilities
853 //===----------------------------------------------------------------------===//
854 
856  ArrayRef<AffineExpr> exprs,
857  MLIRContext *context) {
858  // Size 0 corner case is useful for canonicalizations.
859  if (sizes.empty())
860  return getAffineConstantExpr(0, context);
861 
862  assert(!exprs.empty() && "expected exprs");
863  auto maps = AffineMap::inferFromExprList(exprs, context);
864  assert(!maps.empty() && "Expected one non-empty map");
865  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
866 
867  AffineExpr expr;
868  bool dynamicPoisonBit = false;
869  int64_t runningSize = 1;
870  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
871  int64_t size = std::get<1>(en);
872  AffineExpr dimExpr = std::get<0>(en);
873  AffineExpr stride = dynamicPoisonBit
874  ? getAffineSymbolExpr(nSymbols++, context)
875  : getAffineConstantExpr(runningSize, context);
876  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
877  if (size > 0) {
878  runningSize *= size;
879  assert(runningSize > 0 && "integer overflow in size computation");
880  } else {
881  dynamicPoisonBit = true;
882  }
883  }
884  return simplifyAffineExpr(expr, numDims, nSymbols);
885 }
886 
888  MLIRContext *context) {
890  exprs.reserve(sizes.size());
891  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
892  exprs.push_back(getAffineDimExpr(dim, context));
893  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
894 }
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:58
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:413
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:95
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:203
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:198
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:193
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:629