MLIR  20.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 namespace mlir {
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37 } // namespace mlir
38 
39 //===----------------------------------------------------------------------===//
40 // BuiltinDialect
41 //===----------------------------------------------------------------------===//
42 
43 void BuiltinDialect::registerTypes() {
44  addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
47  >();
48 }
49 
50 //===----------------------------------------------------------------------===//
51 /// ComplexType
52 //===----------------------------------------------------------------------===//
53 
54 /// Verify the construction of an integer type.
56  Type elementType) {
57  if (!elementType.isIntOrFloat())
58  return emitError() << "invalid element type for complex";
59  return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Integer Type
64 //===----------------------------------------------------------------------===//
65 
66 /// Verify the construction of an integer type.
68  unsigned width,
69  SignednessSemantics signedness) {
70  if (width > IntegerType::kMaxWidth) {
71  return emitError() << "integer bitwidth is limited to "
72  << IntegerType::kMaxWidth << " bits";
73  }
74  return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80  return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84  if (!scale)
85  return IntegerType();
86  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92 
93 unsigned FloatType::getWidth() {
94  // The actual width of TF32 is 19 bits. However, since it is a truncated
95  // version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth
96  // for compatibility.
97  if (llvm::isa<FloatTF32Type>(*this))
98  return 32;
99  return APFloat::semanticsSizeInBits(getFloatSemantics());
100 }
101 
102 /// Returns the floating semantics for the given type.
103 const llvm::fltSemantics &FloatType::getFloatSemantics() {
104  if (llvm::isa<Float6E3M2FNType>(*this))
105  return APFloat::Float6E3M2FN();
106  if (llvm::isa<Float8E5M2Type>(*this))
107  return APFloat::Float8E5M2();
108  if (llvm::isa<Float8E4M3Type>(*this))
109  return APFloat::Float8E4M3();
110  if (llvm::isa<Float8E4M3FNType>(*this))
111  return APFloat::Float8E4M3FN();
112  if (llvm::isa<Float8E5M2FNUZType>(*this))
113  return APFloat::Float8E5M2FNUZ();
114  if (llvm::isa<Float8E4M3FNUZType>(*this))
115  return APFloat::Float8E4M3FNUZ();
116  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
117  return APFloat::Float8E4M3B11FNUZ();
118  if (llvm::isa<Float8E3M4Type>(*this))
119  return APFloat::Float8E3M4();
120  if (llvm::isa<BFloat16Type>(*this))
121  return APFloat::BFloat();
122  if (llvm::isa<Float16Type>(*this))
123  return APFloat::IEEEhalf();
124  if (llvm::isa<FloatTF32Type>(*this))
125  return APFloat::FloatTF32();
126  if (llvm::isa<Float32Type>(*this))
127  return APFloat::IEEEsingle();
128  if (llvm::isa<Float64Type>(*this))
129  return APFloat::IEEEdouble();
130  if (llvm::isa<Float80Type>(*this))
131  return APFloat::x87DoubleExtended();
132  if (llvm::isa<Float128Type>(*this))
133  return APFloat::IEEEquad();
134  llvm_unreachable("non-floating point type used");
135 }
136 
138  if (!scale)
139  return FloatType();
140  MLIRContext *ctx = getContext();
141  if (isF16() || isBF16()) {
142  if (scale == 2)
143  return FloatType::getF32(ctx);
144  if (scale == 4)
145  return FloatType::getF64(ctx);
146  }
147  if (isF32())
148  if (scale == 2)
149  return FloatType::getF64(ctx);
150  return FloatType();
151 }
152 
154  return APFloat::semanticsPrecision(getFloatSemantics());
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // FunctionType
159 //===----------------------------------------------------------------------===//
160 
161 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
162 
163 ArrayRef<Type> FunctionType::getInputs() const {
164  return getImpl()->getInputs();
165 }
166 
167 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
168 
169 ArrayRef<Type> FunctionType::getResults() const {
170  return getImpl()->getResults();
171 }
172 
173 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
174  return get(getContext(), inputs, results);
175 }
176 
177 /// Returns a new function type with the specified arguments and results
178 /// inserted.
179 FunctionType FunctionType::getWithArgsAndResults(
180  ArrayRef<unsigned> argIndices, TypeRange argTypes,
181  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
182  SmallVector<Type> argStorage, resultStorage;
183  TypeRange newArgTypes =
184  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
185  TypeRange newResultTypes =
186  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
187  return clone(newArgTypes, newResultTypes);
188 }
189 
190 /// Returns a new function type without the specified arguments and results.
191 FunctionType
192 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
193  const BitVector &resultIndices) {
194  SmallVector<Type> argStorage, resultStorage;
195  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
196  TypeRange newResultTypes =
197  filterTypesOut(getResults(), resultIndices, resultStorage);
198  return clone(newArgTypes, newResultTypes);
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // OpaqueType
203 //===----------------------------------------------------------------------===//
204 
205 /// Verify the construction of an opaque type.
207  StringAttr dialect, StringRef typeData) {
208  if (!Dialect::isValidNamespace(dialect.strref()))
209  return emitError() << "invalid dialect namespace '" << dialect << "'";
210 
211  // Check that the dialect is actually registered.
212  MLIRContext *context = dialect.getContext();
213  if (!context->allowsUnregisteredDialects() &&
214  !context->getLoadedDialect(dialect.strref())) {
215  return emitError()
216  << "`!" << dialect << "<\"" << typeData << "\">"
217  << "` type created with unregistered dialect. If this is "
218  "intended, please call allowUnregisteredDialects() on the "
219  "MLIRContext, or use -allow-unregistered-dialect with "
220  "the MLIR opt tool used";
221  }
222 
223  return success();
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // VectorType
228 //===----------------------------------------------------------------------===//
229 
230 bool VectorType::isValidElementType(Type t) {
231  return isValidVectorTypeElementType(t);
232 }
233 
235  ArrayRef<int64_t> shape, Type elementType,
236  ArrayRef<bool> scalableDims) {
237  if (!isValidElementType(elementType))
238  return emitError()
239  << "vector elements must be int/index/float type but got "
240  << elementType;
241 
242  if (any_of(shape, [](int64_t i) { return i <= 0; }))
243  return emitError()
244  << "vector types must have positive constant sizes but got "
245  << shape;
246 
247  if (scalableDims.size() != shape.size())
248  return emitError() << "number of dims must match, got "
249  << scalableDims.size() << " and " << shape.size();
250 
251  return success();
252 }
253 
254 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
255  if (!scale)
256  return VectorType();
257  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
258  if (auto scaledEt = et.scaleElementBitwidth(scale))
259  return VectorType::get(getShape(), scaledEt, getScalableDims());
260  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
261  if (auto scaledEt = et.scaleElementBitwidth(scale))
262  return VectorType::get(getShape(), scaledEt, getScalableDims());
263  return VectorType();
264 }
265 
266 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
267  Type elementType) const {
268  return VectorType::get(shape.value_or(getShape()), elementType,
269  getScalableDims());
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // TensorType
274 //===----------------------------------------------------------------------===//
275 
278  .Case<RankedTensorType, UnrankedTensorType>(
279  [](auto type) { return type.getElementType(); });
280 }
281 
282 bool TensorType::hasRank() const {
283  return !llvm::isa<UnrankedTensorType>(*this);
284 }
285 
287  return llvm::cast<RankedTensorType>(*this).getShape();
288 }
289 
291  Type elementType) const {
292  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
293  if (shape)
294  return RankedTensorType::get(*shape, elementType);
295  return UnrankedTensorType::get(elementType);
296  }
297 
298  auto rankedTy = llvm::cast<RankedTensorType>(*this);
299  if (!shape)
300  return RankedTensorType::get(rankedTy.getShape(), elementType,
301  rankedTy.getEncoding());
302  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
303  rankedTy.getEncoding());
304 }
305 
306 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
307  Type elementType) const {
308  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
309 }
310 
311 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
312  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
313 }
314 
315 // Check if "elementType" can be an element type of a tensor.
316 static LogicalResult
318  Type elementType) {
319  if (!TensorType::isValidElementType(elementType))
320  return emitError() << "invalid tensor element type: " << elementType;
321  return success();
322 }
323 
324 /// Return true if the specified element type is ok in a tensor.
326  // Note: Non standard/builtin types are allowed to exist within tensor
327  // types. Dialects are expected to verify that tensor types have a valid
328  // element type within that dialect.
329  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
330  IndexType>(type) ||
331  !llvm::isa<BuiltinDialect>(type.getDialect());
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // RankedTensorType
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult
340  ArrayRef<int64_t> shape, Type elementType,
341  Attribute encoding) {
342  for (int64_t s : shape)
343  if (s < 0 && !ShapedType::isDynamic(s))
344  return emitError() << "invalid tensor dimension size";
345  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
346  if (failed(v.verifyEncoding(shape, elementType, emitError)))
347  return failure();
348  return checkTensorElementType(emitError, elementType);
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // UnrankedTensorType
353 //===----------------------------------------------------------------------===//
354 
355 LogicalResult
357  Type elementType) {
358  return checkTensorElementType(emitError, elementType);
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // BaseMemRefType
363 //===----------------------------------------------------------------------===//
364 
367  .Case<MemRefType, UnrankedMemRefType>(
368  [](auto type) { return type.getElementType(); });
369 }
370 
372  return !llvm::isa<UnrankedMemRefType>(*this);
373 }
374 
376  return llvm::cast<MemRefType>(*this).getShape();
377 }
378 
380  Type elementType) const {
381  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
382  if (!shape)
383  return UnrankedMemRefType::get(elementType, getMemorySpace());
384  MemRefType::Builder builder(*shape, elementType);
385  builder.setMemorySpace(getMemorySpace());
386  return builder;
387  }
388 
389  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
390  if (shape)
391  builder.setShape(*shape);
392  builder.setElementType(elementType);
393  return builder;
394 }
395 
397  Type elementType) const {
398  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
399 }
400 
401 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
402  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
403 }
404 
406  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
407  return rankedMemRefTy.getMemorySpace();
408  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
409 }
410 
412  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
413  return rankedMemRefTy.getMemorySpaceAsInt();
414  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // MemRefType
419 //===----------------------------------------------------------------------===//
420 
421 std::optional<llvm::SmallDenseSet<unsigned>>
423  ArrayRef<int64_t> reducedShape,
424  bool matchDynamic) {
425  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
426  llvm::SmallDenseSet<unsigned> unusedDims;
427  unsigned reducedIdx = 0;
428  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
429  // Greedily insert `originalIdx` if match.
430  int64_t origSize = originalShape[originalIdx];
431  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
432  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
433  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
434  ShapedType::isDynamic(origSize))) {
435  reducedIdx++;
436  continue;
437  }
438  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
439  reducedIdx++;
440  continue;
441  }
442 
443  unusedDims.insert(originalIdx);
444  // If no match on `originalIdx`, the `originalShape` at this dimension
445  // must be 1, otherwise we bail.
446  if (origSize != 1)
447  return std::nullopt;
448  }
449  // The whole reducedShape must be scanned, otherwise we bail.
450  if (reducedIdx != reducedRank)
451  return std::nullopt;
452  return unusedDims;
453 }
454 
456 mlir::isRankReducedType(ShapedType originalType,
457  ShapedType candidateReducedType) {
458  if (originalType == candidateReducedType)
460 
461  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
462  ShapedType candidateReducedShapedType =
463  llvm::cast<ShapedType>(candidateReducedType);
464 
465  // Rank and size logic is valid for all ShapedTypes.
466  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
467  ArrayRef<int64_t> candidateReducedShape =
468  candidateReducedShapedType.getShape();
469  unsigned originalRank = originalShape.size(),
470  candidateReducedRank = candidateReducedShape.size();
471  if (candidateReducedRank > originalRank)
473 
474  auto optionalUnusedDimsMask =
475  computeRankReductionMask(originalShape, candidateReducedShape);
476 
477  // Sizes cannot be matched in case empty vector is returned.
478  if (!optionalUnusedDimsMask)
480 
481  if (originalShapedType.getElementType() !=
482  candidateReducedShapedType.getElementType())
484 
486 }
487 
489  // Empty attribute is allowed as default memory space.
490  if (!memorySpace)
491  return true;
492 
493  // Supported built-in attributes.
494  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
495  return true;
496 
497  // Allow custom dialect attributes.
498  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
499  return true;
500 
501  return false;
502 }
503 
505  MLIRContext *ctx) {
506  if (memorySpace == 0)
507  return nullptr;
508 
509  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
510 }
511 
513  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
514  if (intMemorySpace && intMemorySpace.getValue() == 0)
515  return nullptr;
516 
517  return memorySpace;
518 }
519 
521  if (!memorySpace)
522  return 0;
523 
524  assert(llvm::isa<IntegerAttr>(memorySpace) &&
525  "Using `getMemorySpaceInteger` with non-Integer attribute");
526 
527  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
528 }
529 
530 unsigned MemRefType::getMemorySpaceAsInt() const {
531  return detail::getMemorySpaceAsInt(getMemorySpace());
532 }
533 
534 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
535  MemRefLayoutAttrInterface layout,
536  Attribute memorySpace) {
537  // Use default layout for empty attribute.
538  if (!layout)
540  shape.size(), elementType.getContext()));
541 
542  // Drop default memory space value and replace it with empty attribute.
543  memorySpace = skipDefaultMemorySpace(memorySpace);
544 
545  return Base::get(elementType.getContext(), shape, elementType, layout,
546  memorySpace);
547 }
548 
549 MemRefType MemRefType::getChecked(
550  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
551  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
552 
553  // Use default layout for empty attribute.
554  if (!layout)
556  shape.size(), elementType.getContext()));
557 
558  // Drop default memory space value and replace it with empty attribute.
559  memorySpace = skipDefaultMemorySpace(memorySpace);
560 
561  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
562  elementType, layout, memorySpace);
563 }
564 
565 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
566  AffineMap map, Attribute memorySpace) {
567 
568  // Use default layout for empty map.
569  if (!map)
570  map = AffineMap::getMultiDimIdentityMap(shape.size(),
571  elementType.getContext());
572 
573  // Wrap AffineMap into Attribute.
574  auto layout = AffineMapAttr::get(map);
575 
576  // Drop default memory space value and replace it with empty attribute.
577  memorySpace = skipDefaultMemorySpace(memorySpace);
578 
579  return Base::get(elementType.getContext(), shape, elementType, layout,
580  memorySpace);
581 }
582 
583 MemRefType
584 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
585  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
586  Attribute memorySpace) {
587 
588  // Use default layout for empty map.
589  if (!map)
590  map = AffineMap::getMultiDimIdentityMap(shape.size(),
591  elementType.getContext());
592 
593  // Wrap AffineMap into Attribute.
594  auto layout = AffineMapAttr::get(map);
595 
596  // Drop default memory space value and replace it with empty attribute.
597  memorySpace = skipDefaultMemorySpace(memorySpace);
598 
599  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
600  elementType, layout, memorySpace);
601 }
602 
603 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
604  AffineMap map, unsigned memorySpaceInd) {
605 
606  // Use default layout for empty map.
607  if (!map)
608  map = AffineMap::getMultiDimIdentityMap(shape.size(),
609  elementType.getContext());
610 
611  // Wrap AffineMap into Attribute.
612  auto layout = AffineMapAttr::get(map);
613 
614  // Convert deprecated integer-like memory space to Attribute.
615  Attribute memorySpace =
616  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
617 
618  return Base::get(elementType.getContext(), shape, elementType, layout,
619  memorySpace);
620 }
621 
622 MemRefType
623 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
624  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
625  unsigned memorySpaceInd) {
626 
627  // Use default layout for empty map.
628  if (!map)
629  map = AffineMap::getMultiDimIdentityMap(shape.size(),
630  elementType.getContext());
631 
632  // Wrap AffineMap into Attribute.
633  auto layout = AffineMapAttr::get(map);
634 
635  // Convert deprecated integer-like memory space to Attribute.
636  Attribute memorySpace =
637  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
638 
639  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
640  elementType, layout, memorySpace);
641 }
642 
644  ArrayRef<int64_t> shape, Type elementType,
645  MemRefLayoutAttrInterface layout,
646  Attribute memorySpace) {
647  if (!BaseMemRefType::isValidElementType(elementType))
648  return emitError() << "invalid memref element type";
649 
650  // Negative sizes are not allowed except for `kDynamic`.
651  for (int64_t s : shape)
652  if (s < 0 && !ShapedType::isDynamic(s))
653  return emitError() << "invalid memref size";
654 
655  assert(layout && "missing layout specification");
656  if (failed(layout.verifyLayout(shape, emitError)))
657  return failure();
658 
659  if (!isSupportedMemorySpace(memorySpace))
660  return emitError() << "unsupported memory space Attribute";
661 
662  return success();
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // UnrankedMemRefType
667 //===----------------------------------------------------------------------===//
668 
670  return detail::getMemorySpaceAsInt(getMemorySpace());
671 }
672 
673 LogicalResult
675  Type elementType, Attribute memorySpace) {
676  if (!BaseMemRefType::isValidElementType(elementType))
677  return emitError() << "invalid memref element type";
678 
679  if (!isSupportedMemorySpace(memorySpace))
680  return emitError() << "unsupported memory space Attribute";
681 
682  return success();
683 }
684 
685 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
686 // i.e. single term). Accumulate the AffineExpr into the existing one.
688  AffineExpr multiplicativeFactor,
690  AffineExpr &offset) {
691  if (auto dim = dyn_cast<AffineDimExpr>(e))
692  strides[dim.getPosition()] =
693  strides[dim.getPosition()] + multiplicativeFactor;
694  else
695  offset = offset + e * multiplicativeFactor;
696 }
697 
698 /// Takes a single AffineExpr `e` and populates the `strides` array with the
699 /// strides expressions for each dim position.
700 /// The convention is that the strides for dimensions d0, .. dn appear in
701 /// order to make indexing intuitive into the result.
702 static LogicalResult extractStrides(AffineExpr e,
703  AffineExpr multiplicativeFactor,
705  AffineExpr &offset) {
706  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
707  if (!bin) {
708  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
709  return success();
710  }
711 
712  if (bin.getKind() == AffineExprKind::CeilDiv ||
713  bin.getKind() == AffineExprKind::FloorDiv ||
714  bin.getKind() == AffineExprKind::Mod)
715  return failure();
716 
717  if (bin.getKind() == AffineExprKind::Mul) {
718  auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
719  if (dim) {
720  strides[dim.getPosition()] =
721  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
722  return success();
723  }
724  // LHS and RHS may both contain complex expressions of dims. Try one path
725  // and if it fails try the other. This is guaranteed to succeed because
726  // only one path may have a `dim`, otherwise this is not an AffineExpr in
727  // the first place.
728  if (bin.getLHS().isSymbolicOrConstant())
729  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
730  strides, offset);
731  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
732  strides, offset);
733  }
734 
735  if (bin.getKind() == AffineExprKind::Add) {
736  auto res1 =
737  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
738  auto res2 =
739  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
740  return success(succeeded(res1) && succeeded(res2));
741  }
742 
743  llvm_unreachable("unexpected binary operation");
744 }
745 
746 /// A stride specification is a list of integer values that are either static
747 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
748 /// the distance in the number of elements between successive entries along a
749 /// particular dimension.
750 ///
751 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
752 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
753 /// distance between two consecutive elements along the outer dimension is `1`
754 /// and the distance between two consecutive elements along the inner dimension
755 /// is `64`.
756 ///
757 /// The convention is that the strides for dimensions d0, .. dn appear in
758 /// order to make indexing intuitive into the result.
759 static LogicalResult getStridesAndOffset(MemRefType t,
760  SmallVectorImpl<AffineExpr> &strides,
761  AffineExpr &offset) {
762  AffineMap m = t.getLayout().getAffineMap();
763 
764  if (m.getNumResults() != 1 && !m.isIdentity())
765  return failure();
766 
767  auto zero = getAffineConstantExpr(0, t.getContext());
768  auto one = getAffineConstantExpr(1, t.getContext());
769  offset = zero;
770  strides.assign(t.getRank(), zero);
771 
772  // Canonical case for empty map.
773  if (m.isIdentity()) {
774  // 0-D corner case, offset is already 0.
775  if (t.getRank() == 0)
776  return success();
777  auto stridedExpr =
778  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
779  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
780  return success();
781  assert(false && "unexpected failure: extract strides in canonical layout");
782  }
783 
784  // Non-canonical case requires more work.
785  auto stridedExpr =
786  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
787  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
788  offset = AffineExpr();
789  strides.clear();
790  return failure();
791  }
792 
793  // Simplify results to allow folding to constants and simple checks.
794  unsigned numDims = m.getNumDims();
795  unsigned numSymbols = m.getNumSymbols();
796  offset = simplifyAffineExpr(offset, numDims, numSymbols);
797  for (auto &stride : strides)
798  stride = simplifyAffineExpr(stride, numDims, numSymbols);
799 
800  // In practice, a strided memref must be internally non-aliasing. Test
801  // against 0 as a proxy.
802  // TODO: static cases can have more advanced checks.
803  // TODO: dynamic cases would require a way to compare symbolic
804  // expressions and would probably need an affine set context propagated
805  // everywhere.
806  if (llvm::any_of(strides, [](AffineExpr e) {
807  return e == getAffineConstantExpr(0, e.getContext());
808  })) {
809  offset = AffineExpr();
810  strides.clear();
811  return failure();
812  }
813 
814  return success();
815 }
816 
817 LogicalResult mlir::getStridesAndOffset(MemRefType t,
818  SmallVectorImpl<int64_t> &strides,
819  int64_t &offset) {
820  // Happy path: the type uses the strided layout directly.
821  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
822  llvm::append_range(strides, strided.getStrides());
823  offset = strided.getOffset();
824  return success();
825  }
826 
827  // Otherwise, defer to the affine fallback as layouts are supposed to be
828  // convertible to affine maps.
829  AffineExpr offsetExpr;
830  SmallVector<AffineExpr, 4> strideExprs;
831  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
832  return failure();
833  if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
834  offset = cst.getValue();
835  else
836  offset = ShapedType::kDynamic;
837  for (auto e : strideExprs) {
838  if (auto c = dyn_cast<AffineConstantExpr>(e))
839  strides.push_back(c.getValue());
840  else
841  strides.push_back(ShapedType::kDynamic);
842  }
843  return success();
844 }
845 
846 std::pair<SmallVector<int64_t>, int64_t>
848  SmallVector<int64_t> strides;
849  int64_t offset;
850  LogicalResult status = getStridesAndOffset(t, strides, offset);
851  (void)status;
852  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
853  return {strides, offset};
854 }
855 
856 //===----------------------------------------------------------------------===//
857 /// TupleType
858 //===----------------------------------------------------------------------===//
859 
860 /// Return the elements types for this tuple.
861 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
862 
863 /// Accumulate the types contained in this tuple and tuples nested within it.
864 /// Note that this only flattens nested tuples, not any other container type,
865 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
866 /// (i32, tensor<i32>, f32, i64)
867 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
868  for (Type type : getTypes()) {
869  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
870  nestedTuple.getFlattenedTypes(types);
871  else
872  types.push_back(type);
873  }
874 }
875 
876 /// Return the number of element types.
877 size_t TupleType::size() const { return getImpl()->size(); }
878 
879 //===----------------------------------------------------------------------===//
880 // Type Utilities
881 //===----------------------------------------------------------------------===//
882 
883 /// Return a version of `t` with identity layout if it can be determined
884 /// statically that the layout is the canonical contiguous strided layout.
885 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
886 /// `t` with simplified layout.
887 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
888 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
889  AffineMap m = t.getLayout().getAffineMap();
890 
891  // Already in canonical form.
892  if (m.isIdentity())
893  return t;
894 
895  // Can't reduce to canonical identity form, return in canonical form.
896  if (m.getNumResults() > 1)
897  return t;
898 
899  // Corner-case for 0-D affine maps.
900  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
901  if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
902  if (cst.getValue() == 0)
903  return MemRefType::Builder(t).setLayout({});
904  return t;
905  }
906 
907  // 0-D corner case for empty shape that still have an affine map. Example:
908  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
909  // offset needs to remain, just return t.
910  if (t.getShape().empty())
911  return t;
912 
913  // If the canonical strided layout for the sizes of `t` is equal to the
914  // simplified layout of `t` we can just return an empty layout. Otherwise,
915  // just simplify the existing layout.
916  AffineExpr expr =
917  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
918  auto simplifiedLayoutExpr =
919  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
920  if (expr != simplifiedLayoutExpr)
922  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
923  return MemRefType::Builder(t).setLayout({});
924 }
925 
927  ArrayRef<AffineExpr> exprs,
928  MLIRContext *context) {
929  // Size 0 corner case is useful for canonicalizations.
930  if (sizes.empty())
931  return getAffineConstantExpr(0, context);
932 
933  assert(!exprs.empty() && "expected exprs");
934  auto maps = AffineMap::inferFromExprList(exprs, context);
935  assert(!maps.empty() && "Expected one non-empty map");
936  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
937 
938  AffineExpr expr;
939  bool dynamicPoisonBit = false;
940  int64_t runningSize = 1;
941  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
942  int64_t size = std::get<1>(en);
943  AffineExpr dimExpr = std::get<0>(en);
944  AffineExpr stride = dynamicPoisonBit
945  ? getAffineSymbolExpr(nSymbols++, context)
946  : getAffineConstantExpr(runningSize, context);
947  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
948  if (size > 0) {
949  runningSize *= size;
950  assert(runningSize > 0 && "integer overflow in size computation");
951  } else {
952  dynamicPoisonBit = true;
953  }
954  }
955  return simplifyAffineExpr(expr, numDims, nSymbols);
956 }
957 
959  MLIRContext *context) {
961  exprs.reserve(sizes.size());
962  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
963  exprs.push_back(getAffineDimExpr(dim, context));
964  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
965 }
966 
967 bool mlir::isStrided(MemRefType t) {
968  int64_t offset;
969  SmallVector<int64_t, 4> strides;
970  auto res = getStridesAndOffset(t, strides, offset);
971  return succeeded(res);
972 }
973 
974 bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
975  int64_t offset;
976  SmallVector<int64_t> strides;
977  auto successStrides = getStridesAndOffset(type, strides, offset);
978  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
979 }
980 
981 bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
982  if (!isLastMemrefDimUnitStride(type))
983  return false;
984 
985  auto memrefShape = type.getShape().take_back(n);
986  if (ShapedType::isDynamicShape(memrefShape))
987  return false;
988 
989  if (type.getLayout().isIdentity())
990  return true;
991 
992  int64_t offset;
993  SmallVector<int64_t> stridesFull;
994  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
995  return false;
996  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
997 
998  if (strides.empty())
999  return true;
1000 
1001  // Check whether strides match "flattened" dims.
1002  SmallVector<int64_t> flattenedDims;
1003  auto dimProduct = 1;
1004  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
1005  dimProduct *= dim;
1006  flattenedDims.push_back(dimProduct);
1007  }
1008 
1009  strides = strides.drop_back(1);
1010  return llvm::equal(strides, llvm::reverse(flattenedDims));
1011 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:146
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:409
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:472
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:468
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:210
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:231
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:226
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:221
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:236
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:99
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:122
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:384
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:617
bool trailingNDimsContiguous(MemRefType type, int64_t n)
Return "true" if the last N dimensions of the given type are contiguous.