MLIR  21.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 namespace mlir {
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37 } // namespace mlir
38 
39 //===----------------------------------------------------------------------===//
40 // BuiltinDialect
41 //===----------------------------------------------------------------------===//
42 
43 void BuiltinDialect::registerTypes() {
44  addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
47  >();
48 }
49 
50 //===----------------------------------------------------------------------===//
51 /// ComplexType
52 //===----------------------------------------------------------------------===//
53 
54 /// Verify the construction of an integer type.
56  Type elementType) {
57  if (!elementType.isIntOrFloat())
58  return emitError() << "invalid element type for complex";
59  return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Integer Type
64 //===----------------------------------------------------------------------===//
65 
66 /// Verify the construction of an integer type.
68  unsigned width,
69  SignednessSemantics signedness) {
70  if (width > IntegerType::kMaxWidth) {
71  return emitError() << "integer bitwidth is limited to "
72  << IntegerType::kMaxWidth << " bits";
73  }
74  return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80  return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84  if (!scale)
85  return IntegerType();
86  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Types
91 //===----------------------------------------------------------------------===//
92 
93 // Mapping from MLIR FloatType to APFloat semantics.
94 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
95  const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
96  return APFloat::SEM(); \
97  }
98 FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
99 FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
100 FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
101 FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
102 FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
103 FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
104 FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
105 FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
106 FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
107 FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
108 FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
109 FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
110 FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
111 FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
112 FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
113 FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
114 FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
115 FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
116 #undef FLOAT_TYPE_SEMANTICS
117 
118 FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
119  if (scale == 2)
120  return Float32Type::get(getContext());
121  if (scale == 4)
122  return Float64Type::get(getContext());
123  return FloatType();
124 }
125 
126 FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
127  if (scale == 2)
128  return Float32Type::get(getContext());
129  if (scale == 4)
130  return Float64Type::get(getContext());
131  return FloatType();
132 }
133 
134 FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
135  if (scale == 2)
136  return Float64Type::get(getContext());
137  return FloatType();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // FunctionType
142 //===----------------------------------------------------------------------===//
143 
144 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
145 
146 ArrayRef<Type> FunctionType::getInputs() const {
147  return getImpl()->getInputs();
148 }
149 
150 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
151 
152 ArrayRef<Type> FunctionType::getResults() const {
153  return getImpl()->getResults();
154 }
155 
156 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
157  return get(getContext(), inputs, results);
158 }
159 
160 /// Returns a new function type with the specified arguments and results
161 /// inserted.
162 FunctionType FunctionType::getWithArgsAndResults(
163  ArrayRef<unsigned> argIndices, TypeRange argTypes,
164  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
165  SmallVector<Type> argStorage, resultStorage;
166  TypeRange newArgTypes =
167  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
168  TypeRange newResultTypes =
169  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
170  return clone(newArgTypes, newResultTypes);
171 }
172 
173 /// Returns a new function type without the specified arguments and results.
174 FunctionType
175 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
176  const BitVector &resultIndices) {
177  SmallVector<Type> argStorage, resultStorage;
178  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
179  TypeRange newResultTypes =
180  filterTypesOut(getResults(), resultIndices, resultStorage);
181  return clone(newArgTypes, newResultTypes);
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // OpaqueType
186 //===----------------------------------------------------------------------===//
187 
188 /// Verify the construction of an opaque type.
190  StringAttr dialect, StringRef typeData) {
191  if (!Dialect::isValidNamespace(dialect.strref()))
192  return emitError() << "invalid dialect namespace '" << dialect << "'";
193 
194  // Check that the dialect is actually registered.
195  MLIRContext *context = dialect.getContext();
196  if (!context->allowsUnregisteredDialects() &&
197  !context->getLoadedDialect(dialect.strref())) {
198  return emitError()
199  << "`!" << dialect << "<\"" << typeData << "\">"
200  << "` type created with unregistered dialect. If this is "
201  "intended, please call allowUnregisteredDialects() on the "
202  "MLIRContext, or use -allow-unregistered-dialect with "
203  "the MLIR opt tool used";
204  }
205 
206  return success();
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // VectorType
211 //===----------------------------------------------------------------------===//
212 
213 bool VectorType::isValidElementType(Type t) {
214  return isValidVectorTypeElementType(t);
215 }
216 
218  ArrayRef<int64_t> shape, Type elementType,
219  ArrayRef<bool> scalableDims) {
220  if (!isValidElementType(elementType))
221  return emitError()
222  << "vector elements must be int/index/float type but got "
223  << elementType;
224 
225  if (any_of(shape, [](int64_t i) { return i <= 0; }))
226  return emitError()
227  << "vector types must have positive constant sizes but got "
228  << shape;
229 
230  if (scalableDims.size() != shape.size())
231  return emitError() << "number of dims must match, got "
232  << scalableDims.size() << " and " << shape.size();
233 
234  return success();
235 }
236 
237 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
238  if (!scale)
239  return VectorType();
240  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
241  if (auto scaledEt = et.scaleElementBitwidth(scale))
242  return VectorType::get(getShape(), scaledEt, getScalableDims());
243  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
244  if (auto scaledEt = et.scaleElementBitwidth(scale))
245  return VectorType::get(getShape(), scaledEt, getScalableDims());
246  return VectorType();
247 }
248 
249 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
250  Type elementType) const {
251  return VectorType::get(shape.value_or(getShape()), elementType,
252  getScalableDims());
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // TensorType
257 //===----------------------------------------------------------------------===//
258 
261  .Case<RankedTensorType, UnrankedTensorType>(
262  [](auto type) { return type.getElementType(); });
263 }
264 
265 bool TensorType::hasRank() const {
266  return !llvm::isa<UnrankedTensorType>(*this);
267 }
268 
270  return llvm::cast<RankedTensorType>(*this).getShape();
271 }
272 
274  Type elementType) const {
275  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
276  if (shape)
277  return RankedTensorType::get(*shape, elementType);
278  return UnrankedTensorType::get(elementType);
279  }
280 
281  auto rankedTy = llvm::cast<RankedTensorType>(*this);
282  if (!shape)
283  return RankedTensorType::get(rankedTy.getShape(), elementType,
284  rankedTy.getEncoding());
285  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
286  rankedTy.getEncoding());
287 }
288 
289 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
290  Type elementType) const {
291  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
292 }
293 
294 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
295  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
296 }
297 
298 // Check if "elementType" can be an element type of a tensor.
299 static LogicalResult
301  Type elementType) {
302  if (!TensorType::isValidElementType(elementType))
303  return emitError() << "invalid tensor element type: " << elementType;
304  return success();
305 }
306 
307 /// Return true if the specified element type is ok in a tensor.
309  // Note: Non standard/builtin types are allowed to exist within tensor
310  // types. Dialects are expected to verify that tensor types have a valid
311  // element type within that dialect.
312  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
313  IndexType>(type) ||
314  !llvm::isa<BuiltinDialect>(type.getDialect());
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // RankedTensorType
319 //===----------------------------------------------------------------------===//
320 
321 LogicalResult
323  ArrayRef<int64_t> shape, Type elementType,
324  Attribute encoding) {
325  for (int64_t s : shape)
326  if (s < 0 && !ShapedType::isDynamic(s))
327  return emitError() << "invalid tensor dimension size";
328  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
329  if (failed(v.verifyEncoding(shape, elementType, emitError)))
330  return failure();
331  return checkTensorElementType(emitError, elementType);
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // UnrankedTensorType
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult
340  Type elementType) {
341  return checkTensorElementType(emitError, elementType);
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // BaseMemRefType
346 //===----------------------------------------------------------------------===//
347 
350  .Case<MemRefType, UnrankedMemRefType>(
351  [](auto type) { return type.getElementType(); });
352 }
353 
355  return !llvm::isa<UnrankedMemRefType>(*this);
356 }
357 
359  return llvm::cast<MemRefType>(*this).getShape();
360 }
361 
363  Type elementType) const {
364  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
365  if (!shape)
366  return UnrankedMemRefType::get(elementType, getMemorySpace());
367  MemRefType::Builder builder(*shape, elementType);
368  builder.setMemorySpace(getMemorySpace());
369  return builder;
370  }
371 
372  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
373  if (shape)
374  builder.setShape(*shape);
375  builder.setElementType(elementType);
376  return builder;
377 }
378 
380  Type elementType) const {
381  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
382 }
383 
384 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
385  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
386 }
387 
389  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
390  return rankedMemRefTy.getMemorySpace();
391  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
392 }
393 
395  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
396  return rankedMemRefTy.getMemorySpaceAsInt();
397  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // MemRefType
402 //===----------------------------------------------------------------------===//
403 
404 std::optional<llvm::SmallDenseSet<unsigned>>
406  ArrayRef<int64_t> reducedShape,
407  bool matchDynamic) {
408  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
409  llvm::SmallDenseSet<unsigned> unusedDims;
410  unsigned reducedIdx = 0;
411  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
412  // Greedily insert `originalIdx` if match.
413  int64_t origSize = originalShape[originalIdx];
414  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
415  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
416  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
417  ShapedType::isDynamic(origSize))) {
418  reducedIdx++;
419  continue;
420  }
421  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
422  reducedIdx++;
423  continue;
424  }
425 
426  unusedDims.insert(originalIdx);
427  // If no match on `originalIdx`, the `originalShape` at this dimension
428  // must be 1, otherwise we bail.
429  if (origSize != 1)
430  return std::nullopt;
431  }
432  // The whole reducedShape must be scanned, otherwise we bail.
433  if (reducedIdx != reducedRank)
434  return std::nullopt;
435  return unusedDims;
436 }
437 
439 mlir::isRankReducedType(ShapedType originalType,
440  ShapedType candidateReducedType) {
441  if (originalType == candidateReducedType)
443 
444  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
445  ShapedType candidateReducedShapedType =
446  llvm::cast<ShapedType>(candidateReducedType);
447 
448  // Rank and size logic is valid for all ShapedTypes.
449  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
450  ArrayRef<int64_t> candidateReducedShape =
451  candidateReducedShapedType.getShape();
452  unsigned originalRank = originalShape.size(),
453  candidateReducedRank = candidateReducedShape.size();
454  if (candidateReducedRank > originalRank)
456 
457  auto optionalUnusedDimsMask =
458  computeRankReductionMask(originalShape, candidateReducedShape);
459 
460  // Sizes cannot be matched in case empty vector is returned.
461  if (!optionalUnusedDimsMask)
463 
464  if (originalShapedType.getElementType() !=
465  candidateReducedShapedType.getElementType())
467 
469 }
470 
472  // Empty attribute is allowed as default memory space.
473  if (!memorySpace)
474  return true;
475 
476  // Supported built-in attributes.
477  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
478  return true;
479 
480  // Allow custom dialect attributes.
481  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
482  return true;
483 
484  return false;
485 }
486 
488  MLIRContext *ctx) {
489  if (memorySpace == 0)
490  return nullptr;
491 
492  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
493 }
494 
496  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
497  if (intMemorySpace && intMemorySpace.getValue() == 0)
498  return nullptr;
499 
500  return memorySpace;
501 }
502 
504  if (!memorySpace)
505  return 0;
506 
507  assert(llvm::isa<IntegerAttr>(memorySpace) &&
508  "Using `getMemorySpaceInteger` with non-Integer attribute");
509 
510  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
511 }
512 
513 unsigned MemRefType::getMemorySpaceAsInt() const {
514  return detail::getMemorySpaceAsInt(getMemorySpace());
515 }
516 
517 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
518  MemRefLayoutAttrInterface layout,
519  Attribute memorySpace) {
520  // Use default layout for empty attribute.
521  if (!layout)
523  shape.size(), elementType.getContext()));
524 
525  // Drop default memory space value and replace it with empty attribute.
526  memorySpace = skipDefaultMemorySpace(memorySpace);
527 
528  return Base::get(elementType.getContext(), shape, elementType, layout,
529  memorySpace);
530 }
531 
532 MemRefType MemRefType::getChecked(
533  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
534  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
535 
536  // Use default layout for empty attribute.
537  if (!layout)
539  shape.size(), elementType.getContext()));
540 
541  // Drop default memory space value and replace it with empty attribute.
542  memorySpace = skipDefaultMemorySpace(memorySpace);
543 
544  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
545  elementType, layout, memorySpace);
546 }
547 
548 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
549  AffineMap map, Attribute memorySpace) {
550 
551  // Use default layout for empty map.
552  if (!map)
553  map = AffineMap::getMultiDimIdentityMap(shape.size(),
554  elementType.getContext());
555 
556  // Wrap AffineMap into Attribute.
557  auto layout = AffineMapAttr::get(map);
558 
559  // Drop default memory space value and replace it with empty attribute.
560  memorySpace = skipDefaultMemorySpace(memorySpace);
561 
562  return Base::get(elementType.getContext(), shape, elementType, layout,
563  memorySpace);
564 }
565 
566 MemRefType
567 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
568  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
569  Attribute memorySpace) {
570 
571  // Use default layout for empty map.
572  if (!map)
573  map = AffineMap::getMultiDimIdentityMap(shape.size(),
574  elementType.getContext());
575 
576  // Wrap AffineMap into Attribute.
577  auto layout = AffineMapAttr::get(map);
578 
579  // Drop default memory space value and replace it with empty attribute.
580  memorySpace = skipDefaultMemorySpace(memorySpace);
581 
582  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
583  elementType, layout, memorySpace);
584 }
585 
586 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
587  AffineMap map, unsigned memorySpaceInd) {
588 
589  // Use default layout for empty map.
590  if (!map)
591  map = AffineMap::getMultiDimIdentityMap(shape.size(),
592  elementType.getContext());
593 
594  // Wrap AffineMap into Attribute.
595  auto layout = AffineMapAttr::get(map);
596 
597  // Convert deprecated integer-like memory space to Attribute.
598  Attribute memorySpace =
599  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
600 
601  return Base::get(elementType.getContext(), shape, elementType, layout,
602  memorySpace);
603 }
604 
605 MemRefType
606 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
607  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
608  unsigned memorySpaceInd) {
609 
610  // Use default layout for empty map.
611  if (!map)
612  map = AffineMap::getMultiDimIdentityMap(shape.size(),
613  elementType.getContext());
614 
615  // Wrap AffineMap into Attribute.
616  auto layout = AffineMapAttr::get(map);
617 
618  // Convert deprecated integer-like memory space to Attribute.
619  Attribute memorySpace =
620  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
621 
622  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
623  elementType, layout, memorySpace);
624 }
625 
627  ArrayRef<int64_t> shape, Type elementType,
628  MemRefLayoutAttrInterface layout,
629  Attribute memorySpace) {
630  if (!BaseMemRefType::isValidElementType(elementType))
631  return emitError() << "invalid memref element type";
632 
633  // Negative sizes are not allowed except for `kDynamic`.
634  for (int64_t s : shape)
635  if (s < 0 && !ShapedType::isDynamic(s))
636  return emitError() << "invalid memref size";
637 
638  assert(layout && "missing layout specification");
639  if (failed(layout.verifyLayout(shape, emitError)))
640  return failure();
641 
642  if (!isSupportedMemorySpace(memorySpace))
643  return emitError() << "unsupported memory space Attribute";
644 
645  return success();
646 }
647 
648 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649  if (!isLastDimUnitStride())
650  return false;
651 
652  auto memrefShape = getShape().take_back(n);
653  if (ShapedType::isDynamicShape(memrefShape))
654  return false;
655 
656  if (getLayout().isIdentity())
657  return true;
658 
659  int64_t offset;
660  SmallVector<int64_t> stridesFull;
661  if (!succeeded(getStridesAndOffset(stridesFull, offset)))
662  return false;
663  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
664 
665  if (strides.empty())
666  return true;
667 
668  // Check whether strides match "flattened" dims.
669  SmallVector<int64_t> flattenedDims;
670  auto dimProduct = 1;
671  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
672  dimProduct *= dim;
673  flattenedDims.push_back(dimProduct);
674  }
675 
676  strides = strides.drop_back(1);
677  return llvm::equal(strides, llvm::reverse(flattenedDims));
678 }
679 
680 MemRefType MemRefType::canonicalizeStridedLayout() {
681  AffineMap m = getLayout().getAffineMap();
682 
683  // Already in canonical form.
684  if (m.isIdentity())
685  return *this;
686 
687  // Can't reduce to canonical identity form, return in canonical form.
688  if (m.getNumResults() > 1)
689  return *this;
690 
691  // Corner-case for 0-D affine maps.
692  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
693  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
694  if (cst.getValue() == 0)
695  return MemRefType::Builder(*this).setLayout({});
696  return *this;
697  }
698 
699  // 0-D corner case for empty shape that still have an affine map. Example:
700  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
701  // offset needs to remain, just return t.
702  if (getShape().empty())
703  return *this;
704 
705  // If the canonical strided layout for the sizes of `t` is equal to the
706  // simplified layout of `t` we can just return an empty layout. Otherwise,
707  // just simplify the existing layout.
709  auto simplifiedLayoutExpr =
710  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
711  if (expr != simplifiedLayoutExpr)
712  return MemRefType::Builder(*this).setLayout(
713  AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
714  simplifiedLayoutExpr)));
715  return MemRefType::Builder(*this).setLayout({});
716 }
717 
718 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
719 // i.e. single term). Accumulate the AffineExpr into the existing one.
721  AffineExpr multiplicativeFactor,
723  AffineExpr &offset) {
724  if (auto dim = dyn_cast<AffineDimExpr>(e))
725  strides[dim.getPosition()] =
726  strides[dim.getPosition()] + multiplicativeFactor;
727  else
728  offset = offset + e * multiplicativeFactor;
729 }
730 
731 /// Takes a single AffineExpr `e` and populates the `strides` array with the
732 /// strides expressions for each dim position.
733 /// The convention is that the strides for dimensions d0, .. dn appear in
734 /// order to make indexing intuitive into the result.
735 static LogicalResult extractStrides(AffineExpr e,
736  AffineExpr multiplicativeFactor,
738  AffineExpr &offset) {
739  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
740  if (!bin) {
741  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
742  return success();
743  }
744 
745  if (bin.getKind() == AffineExprKind::CeilDiv ||
746  bin.getKind() == AffineExprKind::FloorDiv ||
747  bin.getKind() == AffineExprKind::Mod)
748  return failure();
749 
750  if (bin.getKind() == AffineExprKind::Mul) {
751  auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
752  if (dim) {
753  strides[dim.getPosition()] =
754  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
755  return success();
756  }
757  // LHS and RHS may both contain complex expressions of dims. Try one path
758  // and if it fails try the other. This is guaranteed to succeed because
759  // only one path may have a `dim`, otherwise this is not an AffineExpr in
760  // the first place.
761  if (bin.getLHS().isSymbolicOrConstant())
762  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
763  strides, offset);
764  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
765  strides, offset);
766  }
767 
768  if (bin.getKind() == AffineExprKind::Add) {
769  auto res1 =
770  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
771  auto res2 =
772  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
773  return success(succeeded(res1) && succeeded(res2));
774  }
775 
776  llvm_unreachable("unexpected binary operation");
777 }
778 
779 /// A stride specification is a list of integer values that are either static
780 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
781 /// the distance in the number of elements between successive entries along a
782 /// particular dimension.
783 ///
784 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
785 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
786 /// distance between two consecutive elements along the outer dimension is `1`
787 /// and the distance between two consecutive elements along the inner dimension
788 /// is `64`.
789 ///
790 /// The convention is that the strides for dimensions d0, .. dn appear in
791 /// order to make indexing intuitive into the result.
792 static LogicalResult getStridesAndOffset(MemRefType t,
793  SmallVectorImpl<AffineExpr> &strides,
794  AffineExpr &offset) {
795  AffineMap m = t.getLayout().getAffineMap();
796 
797  if (m.getNumResults() != 1 && !m.isIdentity())
798  return failure();
799 
800  auto zero = getAffineConstantExpr(0, t.getContext());
801  auto one = getAffineConstantExpr(1, t.getContext());
802  offset = zero;
803  strides.assign(t.getRank(), zero);
804 
805  // Canonical case for empty map.
806  if (m.isIdentity()) {
807  // 0-D corner case, offset is already 0.
808  if (t.getRank() == 0)
809  return success();
810  auto stridedExpr =
811  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
812  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
813  return success();
814  assert(false && "unexpected failure: extract strides in canonical layout");
815  }
816 
817  // Non-canonical case requires more work.
818  auto stridedExpr =
819  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
820  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
821  offset = AffineExpr();
822  strides.clear();
823  return failure();
824  }
825 
826  // Simplify results to allow folding to constants and simple checks.
827  unsigned numDims = m.getNumDims();
828  unsigned numSymbols = m.getNumSymbols();
829  offset = simplifyAffineExpr(offset, numDims, numSymbols);
830  for (auto &stride : strides)
831  stride = simplifyAffineExpr(stride, numDims, numSymbols);
832 
833  return success();
834 }
835 
836 LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
837  int64_t &offset) {
838  // Happy path: the type uses the strided layout directly.
839  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
840  llvm::append_range(strides, strided.getStrides());
841  offset = strided.getOffset();
842  return success();
843  }
844 
845  // Otherwise, defer to the affine fallback as layouts are supposed to be
846  // convertible to affine maps.
847  AffineExpr offsetExpr;
848  SmallVector<AffineExpr, 4> strideExprs;
849  if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
850  return failure();
851  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
852  offset = cst.getValue();
853  else
854  offset = ShapedType::kDynamic;
855  for (auto e : strideExprs) {
856  if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
857  strides.push_back(c.getValue());
858  else
859  strides.push_back(ShapedType::kDynamic);
860  }
861  return success();
862 }
863 
864 std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {
865  SmallVector<int64_t> strides;
866  int64_t offset;
867  LogicalResult status = getStridesAndOffset(strides, offset);
868  (void)status;
869  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
870  return {strides, offset};
871 }
872 
873 bool MemRefType::isStrided() {
874  int64_t offset;
875  SmallVector<int64_t, 4> strides;
876  auto res = getStridesAndOffset(strides, offset);
877  return succeeded(res);
878 }
879 
880 bool MemRefType::isLastDimUnitStride() {
881  int64_t offset;
882  SmallVector<int64_t> strides;
883  auto successStrides = getStridesAndOffset(strides, offset);
884  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // UnrankedMemRefType
889 //===----------------------------------------------------------------------===//
890 
892  return detail::getMemorySpaceAsInt(getMemorySpace());
893 }
894 
895 LogicalResult
897  Type elementType, Attribute memorySpace) {
898  if (!BaseMemRefType::isValidElementType(elementType))
899  return emitError() << "invalid memref element type";
900 
901  if (!isSupportedMemorySpace(memorySpace))
902  return emitError() << "unsupported memory space Attribute";
903 
904  return success();
905 }
906 
907 //===----------------------------------------------------------------------===//
908 /// TupleType
909 //===----------------------------------------------------------------------===//
910 
911 /// Return the elements types for this tuple.
912 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
913 
914 /// Accumulate the types contained in this tuple and tuples nested within it.
915 /// Note that this only flattens nested tuples, not any other container type,
916 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
917 /// (i32, tensor<i32>, f32, i64)
918 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
919  for (Type type : getTypes()) {
920  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
921  nestedTuple.getFlattenedTypes(types);
922  else
923  types.push_back(type);
924  }
925 }
926 
927 /// Return the number of element types.
928 size_t TupleType::size() const { return getImpl()->size(); }
929 
930 //===----------------------------------------------------------------------===//
931 // Type Utilities
932 //===----------------------------------------------------------------------===//
933 
935  ArrayRef<AffineExpr> exprs,
936  MLIRContext *context) {
937  // Size 0 corner case is useful for canonicalizations.
938  if (sizes.empty())
939  return getAffineConstantExpr(0, context);
940 
941  assert(!exprs.empty() && "expected exprs");
942  auto maps = AffineMap::inferFromExprList(exprs, context);
943  assert(!maps.empty() && "Expected one non-empty map");
944  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
945 
946  AffineExpr expr;
947  bool dynamicPoisonBit = false;
948  int64_t runningSize = 1;
949  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
950  int64_t size = std::get<1>(en);
951  AffineExpr dimExpr = std::get<0>(en);
952  AffineExpr stride = dynamicPoisonBit
953  ? getAffineSymbolExpr(nSymbols++, context)
954  : getAffineConstantExpr(runningSize, context);
955  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
956  if (size > 0) {
957  runningSize *= size;
958  assert(runningSize > 0 && "integer overflow in size computation");
959  } else {
960  dynamicPoisonBit = true;
961  }
962  }
963  return simplifyAffineExpr(expr, numDims, nSymbols);
964 }
965 
967  MLIRContext *context) {
969  exprs.reserve(sizes.size());
970  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
971  exprs.push_back(getAffineDimExpr(dim, context));
972  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
973 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:397
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:182
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:177
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:192
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:108
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627