MLIR  20.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 namespace mlir {
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37 } // namespace mlir
38 
39 //===----------------------------------------------------------------------===//
40 // BuiltinDialect
41 //===----------------------------------------------------------------------===//
42 
43 void BuiltinDialect::registerTypes() {
44  addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
47  >();
48 }
49 
50 //===----------------------------------------------------------------------===//
51 /// ComplexType
52 //===----------------------------------------------------------------------===//
53 
54 /// Verify the construction of an integer type.
56  Type elementType) {
57  if (!elementType.isIntOrFloat())
58  return emitError() << "invalid element type for complex";
59  return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Integer Type
64 //===----------------------------------------------------------------------===//
65 
66 /// Verify the construction of an integer type.
68  unsigned width,
69  SignednessSemantics signedness) {
70  if (width > IntegerType::kMaxWidth) {
71  return emitError() << "integer bitwidth is limited to "
72  << IntegerType::kMaxWidth << " bits";
73  }
74  return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80  return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84  if (!scale)
85  return IntegerType();
86  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92 
93 unsigned FloatType::getWidth() {
94  // The actual width of TF32 is 19 bits. However, since it is a truncated
95  // version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth
96  // for compatibility.
97  if (llvm::isa<FloatTF32Type>(*this))
98  return 32;
99  return APFloat::semanticsSizeInBits(getFloatSemantics());
100 }
101 
102 /// Returns the floating semantics for the given type.
103 const llvm::fltSemantics &FloatType::getFloatSemantics() {
104  if (llvm::isa<Float4E2M1FNType>(*this))
105  return APFloat::Float4E2M1FN();
106  if (llvm::isa<Float6E2M3FNType>(*this))
107  return APFloat::Float6E2M3FN();
108  if (llvm::isa<Float6E3M2FNType>(*this))
109  return APFloat::Float6E3M2FN();
110  if (llvm::isa<Float8E5M2Type>(*this))
111  return APFloat::Float8E5M2();
112  if (llvm::isa<Float8E4M3Type>(*this))
113  return APFloat::Float8E4M3();
114  if (llvm::isa<Float8E4M3FNType>(*this))
115  return APFloat::Float8E4M3FN();
116  if (llvm::isa<Float8E5M2FNUZType>(*this))
117  return APFloat::Float8E5M2FNUZ();
118  if (llvm::isa<Float8E4M3FNUZType>(*this))
119  return APFloat::Float8E4M3FNUZ();
120  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
121  return APFloat::Float8E4M3B11FNUZ();
122  if (llvm::isa<Float8E3M4Type>(*this))
123  return APFloat::Float8E3M4();
124  if (llvm::isa<Float8E8M0FNUType>(*this))
125  return APFloat::Float8E8M0FNU();
126  if (llvm::isa<BFloat16Type>(*this))
127  return APFloat::BFloat();
128  if (llvm::isa<Float16Type>(*this))
129  return APFloat::IEEEhalf();
130  if (llvm::isa<FloatTF32Type>(*this))
131  return APFloat::FloatTF32();
132  if (llvm::isa<Float32Type>(*this))
133  return APFloat::IEEEsingle();
134  if (llvm::isa<Float64Type>(*this))
135  return APFloat::IEEEdouble();
136  if (llvm::isa<Float80Type>(*this))
137  return APFloat::x87DoubleExtended();
138  if (llvm::isa<Float128Type>(*this))
139  return APFloat::IEEEquad();
140  llvm_unreachable("non-floating point type used");
141 }
142 
144  if (!scale)
145  return FloatType();
146  MLIRContext *ctx = getContext();
147  if (isF16() || isBF16()) {
148  if (scale == 2)
149  return FloatType::getF32(ctx);
150  if (scale == 4)
151  return FloatType::getF64(ctx);
152  }
153  if (isF32())
154  if (scale == 2)
155  return FloatType::getF64(ctx);
156  return FloatType();
157 }
158 
160  return APFloat::semanticsPrecision(getFloatSemantics());
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // FunctionType
165 //===----------------------------------------------------------------------===//
166 
167 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
168 
169 ArrayRef<Type> FunctionType::getInputs() const {
170  return getImpl()->getInputs();
171 }
172 
173 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
174 
175 ArrayRef<Type> FunctionType::getResults() const {
176  return getImpl()->getResults();
177 }
178 
179 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
180  return get(getContext(), inputs, results);
181 }
182 
183 /// Returns a new function type with the specified arguments and results
184 /// inserted.
185 FunctionType FunctionType::getWithArgsAndResults(
186  ArrayRef<unsigned> argIndices, TypeRange argTypes,
187  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
188  SmallVector<Type> argStorage, resultStorage;
189  TypeRange newArgTypes =
190  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
191  TypeRange newResultTypes =
192  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
193  return clone(newArgTypes, newResultTypes);
194 }
195 
196 /// Returns a new function type without the specified arguments and results.
197 FunctionType
198 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
199  const BitVector &resultIndices) {
200  SmallVector<Type> argStorage, resultStorage;
201  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
202  TypeRange newResultTypes =
203  filterTypesOut(getResults(), resultIndices, resultStorage);
204  return clone(newArgTypes, newResultTypes);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // OpaqueType
209 //===----------------------------------------------------------------------===//
210 
211 /// Verify the construction of an opaque type.
213  StringAttr dialect, StringRef typeData) {
214  if (!Dialect::isValidNamespace(dialect.strref()))
215  return emitError() << "invalid dialect namespace '" << dialect << "'";
216 
217  // Check that the dialect is actually registered.
218  MLIRContext *context = dialect.getContext();
219  if (!context->allowsUnregisteredDialects() &&
220  !context->getLoadedDialect(dialect.strref())) {
221  return emitError()
222  << "`!" << dialect << "<\"" << typeData << "\">"
223  << "` type created with unregistered dialect. If this is "
224  "intended, please call allowUnregisteredDialects() on the "
225  "MLIRContext, or use -allow-unregistered-dialect with "
226  "the MLIR opt tool used";
227  }
228 
229  return success();
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // VectorType
234 //===----------------------------------------------------------------------===//
235 
236 bool VectorType::isValidElementType(Type t) {
237  return isValidVectorTypeElementType(t);
238 }
239 
241  ArrayRef<int64_t> shape, Type elementType,
242  ArrayRef<bool> scalableDims) {
243  if (!isValidElementType(elementType))
244  return emitError()
245  << "vector elements must be int/index/float type but got "
246  << elementType;
247 
248  if (any_of(shape, [](int64_t i) { return i <= 0; }))
249  return emitError()
250  << "vector types must have positive constant sizes but got "
251  << shape;
252 
253  if (scalableDims.size() != shape.size())
254  return emitError() << "number of dims must match, got "
255  << scalableDims.size() << " and " << shape.size();
256 
257  return success();
258 }
259 
260 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
261  if (!scale)
262  return VectorType();
263  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
264  if (auto scaledEt = et.scaleElementBitwidth(scale))
265  return VectorType::get(getShape(), scaledEt, getScalableDims());
266  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
267  if (auto scaledEt = et.scaleElementBitwidth(scale))
268  return VectorType::get(getShape(), scaledEt, getScalableDims());
269  return VectorType();
270 }
271 
272 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
273  Type elementType) const {
274  return VectorType::get(shape.value_or(getShape()), elementType,
275  getScalableDims());
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // TensorType
280 //===----------------------------------------------------------------------===//
281 
284  .Case<RankedTensorType, UnrankedTensorType>(
285  [](auto type) { return type.getElementType(); });
286 }
287 
288 bool TensorType::hasRank() const {
289  return !llvm::isa<UnrankedTensorType>(*this);
290 }
291 
293  return llvm::cast<RankedTensorType>(*this).getShape();
294 }
295 
297  Type elementType) const {
298  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
299  if (shape)
300  return RankedTensorType::get(*shape, elementType);
301  return UnrankedTensorType::get(elementType);
302  }
303 
304  auto rankedTy = llvm::cast<RankedTensorType>(*this);
305  if (!shape)
306  return RankedTensorType::get(rankedTy.getShape(), elementType,
307  rankedTy.getEncoding());
308  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
309  rankedTy.getEncoding());
310 }
311 
312 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
313  Type elementType) const {
314  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
315 }
316 
317 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
318  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
319 }
320 
321 // Check if "elementType" can be an element type of a tensor.
322 static LogicalResult
324  Type elementType) {
325  if (!TensorType::isValidElementType(elementType))
326  return emitError() << "invalid tensor element type: " << elementType;
327  return success();
328 }
329 
330 /// Return true if the specified element type is ok in a tensor.
332  // Note: Non standard/builtin types are allowed to exist within tensor
333  // types. Dialects are expected to verify that tensor types have a valid
334  // element type within that dialect.
335  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
336  IndexType>(type) ||
337  !llvm::isa<BuiltinDialect>(type.getDialect());
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // RankedTensorType
342 //===----------------------------------------------------------------------===//
343 
344 LogicalResult
346  ArrayRef<int64_t> shape, Type elementType,
347  Attribute encoding) {
348  for (int64_t s : shape)
349  if (s < 0 && !ShapedType::isDynamic(s))
350  return emitError() << "invalid tensor dimension size";
351  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
352  if (failed(v.verifyEncoding(shape, elementType, emitError)))
353  return failure();
354  return checkTensorElementType(emitError, elementType);
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // UnrankedTensorType
359 //===----------------------------------------------------------------------===//
360 
361 LogicalResult
363  Type elementType) {
364  return checkTensorElementType(emitError, elementType);
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // BaseMemRefType
369 //===----------------------------------------------------------------------===//
370 
373  .Case<MemRefType, UnrankedMemRefType>(
374  [](auto type) { return type.getElementType(); });
375 }
376 
378  return !llvm::isa<UnrankedMemRefType>(*this);
379 }
380 
382  return llvm::cast<MemRefType>(*this).getShape();
383 }
384 
386  Type elementType) const {
387  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
388  if (!shape)
389  return UnrankedMemRefType::get(elementType, getMemorySpace());
390  MemRefType::Builder builder(*shape, elementType);
391  builder.setMemorySpace(getMemorySpace());
392  return builder;
393  }
394 
395  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
396  if (shape)
397  builder.setShape(*shape);
398  builder.setElementType(elementType);
399  return builder;
400 }
401 
403  Type elementType) const {
404  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
405 }
406 
407 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
408  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
409 }
410 
412  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
413  return rankedMemRefTy.getMemorySpace();
414  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
415 }
416 
418  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
419  return rankedMemRefTy.getMemorySpaceAsInt();
420  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // MemRefType
425 //===----------------------------------------------------------------------===//
426 
427 std::optional<llvm::SmallDenseSet<unsigned>>
429  ArrayRef<int64_t> reducedShape,
430  bool matchDynamic) {
431  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
432  llvm::SmallDenseSet<unsigned> unusedDims;
433  unsigned reducedIdx = 0;
434  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
435  // Greedily insert `originalIdx` if match.
436  int64_t origSize = originalShape[originalIdx];
437  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
438  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
439  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
440  ShapedType::isDynamic(origSize))) {
441  reducedIdx++;
442  continue;
443  }
444  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
445  reducedIdx++;
446  continue;
447  }
448 
449  unusedDims.insert(originalIdx);
450  // If no match on `originalIdx`, the `originalShape` at this dimension
451  // must be 1, otherwise we bail.
452  if (origSize != 1)
453  return std::nullopt;
454  }
455  // The whole reducedShape must be scanned, otherwise we bail.
456  if (reducedIdx != reducedRank)
457  return std::nullopt;
458  return unusedDims;
459 }
460 
462 mlir::isRankReducedType(ShapedType originalType,
463  ShapedType candidateReducedType) {
464  if (originalType == candidateReducedType)
466 
467  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
468  ShapedType candidateReducedShapedType =
469  llvm::cast<ShapedType>(candidateReducedType);
470 
471  // Rank and size logic is valid for all ShapedTypes.
472  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
473  ArrayRef<int64_t> candidateReducedShape =
474  candidateReducedShapedType.getShape();
475  unsigned originalRank = originalShape.size(),
476  candidateReducedRank = candidateReducedShape.size();
477  if (candidateReducedRank > originalRank)
479 
480  auto optionalUnusedDimsMask =
481  computeRankReductionMask(originalShape, candidateReducedShape);
482 
483  // Sizes cannot be matched in case empty vector is returned.
484  if (!optionalUnusedDimsMask)
486 
487  if (originalShapedType.getElementType() !=
488  candidateReducedShapedType.getElementType())
490 
492 }
493 
495  // Empty attribute is allowed as default memory space.
496  if (!memorySpace)
497  return true;
498 
499  // Supported built-in attributes.
500  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
501  return true;
502 
503  // Allow custom dialect attributes.
504  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
505  return true;
506 
507  return false;
508 }
509 
511  MLIRContext *ctx) {
512  if (memorySpace == 0)
513  return nullptr;
514 
515  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
516 }
517 
519  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
520  if (intMemorySpace && intMemorySpace.getValue() == 0)
521  return nullptr;
522 
523  return memorySpace;
524 }
525 
527  if (!memorySpace)
528  return 0;
529 
530  assert(llvm::isa<IntegerAttr>(memorySpace) &&
531  "Using `getMemorySpaceInteger` with non-Integer attribute");
532 
533  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
534 }
535 
536 unsigned MemRefType::getMemorySpaceAsInt() const {
537  return detail::getMemorySpaceAsInt(getMemorySpace());
538 }
539 
540 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
541  MemRefLayoutAttrInterface layout,
542  Attribute memorySpace) {
543  // Use default layout for empty attribute.
544  if (!layout)
546  shape.size(), elementType.getContext()));
547 
548  // Drop default memory space value and replace it with empty attribute.
549  memorySpace = skipDefaultMemorySpace(memorySpace);
550 
551  return Base::get(elementType.getContext(), shape, elementType, layout,
552  memorySpace);
553 }
554 
555 MemRefType MemRefType::getChecked(
556  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
557  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
558 
559  // Use default layout for empty attribute.
560  if (!layout)
562  shape.size(), elementType.getContext()));
563 
564  // Drop default memory space value and replace it with empty attribute.
565  memorySpace = skipDefaultMemorySpace(memorySpace);
566 
567  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
568  elementType, layout, memorySpace);
569 }
570 
571 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
572  AffineMap map, Attribute memorySpace) {
573 
574  // Use default layout for empty map.
575  if (!map)
576  map = AffineMap::getMultiDimIdentityMap(shape.size(),
577  elementType.getContext());
578 
579  // Wrap AffineMap into Attribute.
580  auto layout = AffineMapAttr::get(map);
581 
582  // Drop default memory space value and replace it with empty attribute.
583  memorySpace = skipDefaultMemorySpace(memorySpace);
584 
585  return Base::get(elementType.getContext(), shape, elementType, layout,
586  memorySpace);
587 }
588 
589 MemRefType
590 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
591  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
592  Attribute memorySpace) {
593 
594  // Use default layout for empty map.
595  if (!map)
596  map = AffineMap::getMultiDimIdentityMap(shape.size(),
597  elementType.getContext());
598 
599  // Wrap AffineMap into Attribute.
600  auto layout = AffineMapAttr::get(map);
601 
602  // Drop default memory space value and replace it with empty attribute.
603  memorySpace = skipDefaultMemorySpace(memorySpace);
604 
605  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
606  elementType, layout, memorySpace);
607 }
608 
609 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
610  AffineMap map, unsigned memorySpaceInd) {
611 
612  // Use default layout for empty map.
613  if (!map)
614  map = AffineMap::getMultiDimIdentityMap(shape.size(),
615  elementType.getContext());
616 
617  // Wrap AffineMap into Attribute.
618  auto layout = AffineMapAttr::get(map);
619 
620  // Convert deprecated integer-like memory space to Attribute.
621  Attribute memorySpace =
622  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
623 
624  return Base::get(elementType.getContext(), shape, elementType, layout,
625  memorySpace);
626 }
627 
628 MemRefType
629 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
630  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
631  unsigned memorySpaceInd) {
632 
633  // Use default layout for empty map.
634  if (!map)
635  map = AffineMap::getMultiDimIdentityMap(shape.size(),
636  elementType.getContext());
637 
638  // Wrap AffineMap into Attribute.
639  auto layout = AffineMapAttr::get(map);
640 
641  // Convert deprecated integer-like memory space to Attribute.
642  Attribute memorySpace =
643  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
644 
645  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
646  elementType, layout, memorySpace);
647 }
648 
650  ArrayRef<int64_t> shape, Type elementType,
651  MemRefLayoutAttrInterface layout,
652  Attribute memorySpace) {
653  if (!BaseMemRefType::isValidElementType(elementType))
654  return emitError() << "invalid memref element type";
655 
656  // Negative sizes are not allowed except for `kDynamic`.
657  for (int64_t s : shape)
658  if (s < 0 && !ShapedType::isDynamic(s))
659  return emitError() << "invalid memref size";
660 
661  assert(layout && "missing layout specification");
662  if (failed(layout.verifyLayout(shape, emitError)))
663  return failure();
664 
665  if (!isSupportedMemorySpace(memorySpace))
666  return emitError() << "unsupported memory space Attribute";
667 
668  return success();
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // UnrankedMemRefType
673 //===----------------------------------------------------------------------===//
674 
676  return detail::getMemorySpaceAsInt(getMemorySpace());
677 }
678 
679 LogicalResult
681  Type elementType, Attribute memorySpace) {
682  if (!BaseMemRefType::isValidElementType(elementType))
683  return emitError() << "invalid memref element type";
684 
685  if (!isSupportedMemorySpace(memorySpace))
686  return emitError() << "unsupported memory space Attribute";
687 
688  return success();
689 }
690 
691 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
692 // i.e. single term). Accumulate the AffineExpr into the existing one.
694  AffineExpr multiplicativeFactor,
696  AffineExpr &offset) {
697  if (auto dim = dyn_cast<AffineDimExpr>(e))
698  strides[dim.getPosition()] =
699  strides[dim.getPosition()] + multiplicativeFactor;
700  else
701  offset = offset + e * multiplicativeFactor;
702 }
703 
704 /// Takes a single AffineExpr `e` and populates the `strides` array with the
705 /// strides expressions for each dim position.
706 /// The convention is that the strides for dimensions d0, .. dn appear in
707 /// order to make indexing intuitive into the result.
708 static LogicalResult extractStrides(AffineExpr e,
709  AffineExpr multiplicativeFactor,
711  AffineExpr &offset) {
712  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
713  if (!bin) {
714  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
715  return success();
716  }
717 
718  if (bin.getKind() == AffineExprKind::CeilDiv ||
719  bin.getKind() == AffineExprKind::FloorDiv ||
720  bin.getKind() == AffineExprKind::Mod)
721  return failure();
722 
723  if (bin.getKind() == AffineExprKind::Mul) {
724  auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
725  if (dim) {
726  strides[dim.getPosition()] =
727  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
728  return success();
729  }
730  // LHS and RHS may both contain complex expressions of dims. Try one path
731  // and if it fails try the other. This is guaranteed to succeed because
732  // only one path may have a `dim`, otherwise this is not an AffineExpr in
733  // the first place.
734  if (bin.getLHS().isSymbolicOrConstant())
735  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
736  strides, offset);
737  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
738  strides, offset);
739  }
740 
741  if (bin.getKind() == AffineExprKind::Add) {
742  auto res1 =
743  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
744  auto res2 =
745  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
746  return success(succeeded(res1) && succeeded(res2));
747  }
748 
749  llvm_unreachable("unexpected binary operation");
750 }
751 
752 /// A stride specification is a list of integer values that are either static
753 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
754 /// the distance in the number of elements between successive entries along a
755 /// particular dimension.
756 ///
757 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
758 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
759 /// distance between two consecutive elements along the outer dimension is `1`
760 /// and the distance between two consecutive elements along the inner dimension
761 /// is `64`.
762 ///
763 /// The convention is that the strides for dimensions d0, .. dn appear in
764 /// order to make indexing intuitive into the result.
765 static LogicalResult getStridesAndOffset(MemRefType t,
766  SmallVectorImpl<AffineExpr> &strides,
767  AffineExpr &offset) {
768  AffineMap m = t.getLayout().getAffineMap();
769 
770  if (m.getNumResults() != 1 && !m.isIdentity())
771  return failure();
772 
773  auto zero = getAffineConstantExpr(0, t.getContext());
774  auto one = getAffineConstantExpr(1, t.getContext());
775  offset = zero;
776  strides.assign(t.getRank(), zero);
777 
778  // Canonical case for empty map.
779  if (m.isIdentity()) {
780  // 0-D corner case, offset is already 0.
781  if (t.getRank() == 0)
782  return success();
783  auto stridedExpr =
784  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
785  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
786  return success();
787  assert(false && "unexpected failure: extract strides in canonical layout");
788  }
789 
790  // Non-canonical case requires more work.
791  auto stridedExpr =
792  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
793  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
794  offset = AffineExpr();
795  strides.clear();
796  return failure();
797  }
798 
799  // Simplify results to allow folding to constants and simple checks.
800  unsigned numDims = m.getNumDims();
801  unsigned numSymbols = m.getNumSymbols();
802  offset = simplifyAffineExpr(offset, numDims, numSymbols);
803  for (auto &stride : strides)
804  stride = simplifyAffineExpr(stride, numDims, numSymbols);
805 
806  // In practice, a strided memref must be internally non-aliasing. Test
807  // against 0 as a proxy.
808  // TODO: static cases can have more advanced checks.
809  // TODO: dynamic cases would require a way to compare symbolic
810  // expressions and would probably need an affine set context propagated
811  // everywhere.
812  if (llvm::any_of(strides, [](AffineExpr e) {
813  return e == getAffineConstantExpr(0, e.getContext());
814  })) {
815  offset = AffineExpr();
816  strides.clear();
817  return failure();
818  }
819 
820  return success();
821 }
822 
823 LogicalResult mlir::getStridesAndOffset(MemRefType t,
824  SmallVectorImpl<int64_t> &strides,
825  int64_t &offset) {
826  // Happy path: the type uses the strided layout directly.
827  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
828  llvm::append_range(strides, strided.getStrides());
829  offset = strided.getOffset();
830  return success();
831  }
832 
833  // Otherwise, defer to the affine fallback as layouts are supposed to be
834  // convertible to affine maps.
835  AffineExpr offsetExpr;
836  SmallVector<AffineExpr, 4> strideExprs;
837  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
838  return failure();
839  if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
840  offset = cst.getValue();
841  else
842  offset = ShapedType::kDynamic;
843  for (auto e : strideExprs) {
844  if (auto c = dyn_cast<AffineConstantExpr>(e))
845  strides.push_back(c.getValue());
846  else
847  strides.push_back(ShapedType::kDynamic);
848  }
849  return success();
850 }
851 
852 std::pair<SmallVector<int64_t>, int64_t>
854  SmallVector<int64_t> strides;
855  int64_t offset;
856  LogicalResult status = getStridesAndOffset(t, strides, offset);
857  (void)status;
858  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
859  return {strides, offset};
860 }
861 
862 //===----------------------------------------------------------------------===//
863 /// TupleType
864 //===----------------------------------------------------------------------===//
865 
866 /// Return the elements types for this tuple.
867 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
868 
869 /// Accumulate the types contained in this tuple and tuples nested within it.
870 /// Note that this only flattens nested tuples, not any other container type,
871 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
872 /// (i32, tensor<i32>, f32, i64)
873 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
874  for (Type type : getTypes()) {
875  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
876  nestedTuple.getFlattenedTypes(types);
877  else
878  types.push_back(type);
879  }
880 }
881 
882 /// Return the number of element types.
883 size_t TupleType::size() const { return getImpl()->size(); }
884 
885 //===----------------------------------------------------------------------===//
886 // Type Utilities
887 //===----------------------------------------------------------------------===//
888 
889 /// Return a version of `t` with identity layout if it can be determined
890 /// statically that the layout is the canonical contiguous strided layout.
891 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
892 /// `t` with simplified layout.
893 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
894 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
895  AffineMap m = t.getLayout().getAffineMap();
896 
897  // Already in canonical form.
898  if (m.isIdentity())
899  return t;
900 
901  // Can't reduce to canonical identity form, return in canonical form.
902  if (m.getNumResults() > 1)
903  return t;
904 
905  // Corner-case for 0-D affine maps.
906  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
907  if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
908  if (cst.getValue() == 0)
909  return MemRefType::Builder(t).setLayout({});
910  return t;
911  }
912 
913  // 0-D corner case for empty shape that still have an affine map. Example:
914  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
915  // offset needs to remain, just return t.
916  if (t.getShape().empty())
917  return t;
918 
919  // If the canonical strided layout for the sizes of `t` is equal to the
920  // simplified layout of `t` we can just return an empty layout. Otherwise,
921  // just simplify the existing layout.
922  AffineExpr expr =
923  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
924  auto simplifiedLayoutExpr =
925  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
926  if (expr != simplifiedLayoutExpr)
928  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
929  return MemRefType::Builder(t).setLayout({});
930 }
931 
933  ArrayRef<AffineExpr> exprs,
934  MLIRContext *context) {
935  // Size 0 corner case is useful for canonicalizations.
936  if (sizes.empty())
937  return getAffineConstantExpr(0, context);
938 
939  assert(!exprs.empty() && "expected exprs");
940  auto maps = AffineMap::inferFromExprList(exprs, context);
941  assert(!maps.empty() && "Expected one non-empty map");
942  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
943 
944  AffineExpr expr;
945  bool dynamicPoisonBit = false;
946  int64_t runningSize = 1;
947  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
948  int64_t size = std::get<1>(en);
949  AffineExpr dimExpr = std::get<0>(en);
950  AffineExpr stride = dynamicPoisonBit
951  ? getAffineSymbolExpr(nSymbols++, context)
952  : getAffineConstantExpr(runningSize, context);
953  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
954  if (size > 0) {
955  runningSize *= size;
956  assert(runningSize > 0 && "integer overflow in size computation");
957  } else {
958  dynamicPoisonBit = true;
959  }
960  }
961  return simplifyAffineExpr(expr, numDims, nSymbols);
962 }
963 
965  MLIRContext *context) {
967  exprs.reserve(sizes.size());
968  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
969  exprs.push_back(getAffineDimExpr(dim, context));
970  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
971 }
972 
973 bool mlir::isStrided(MemRefType t) {
974  int64_t offset;
975  SmallVector<int64_t, 4> strides;
976  auto res = getStridesAndOffset(t, strides, offset);
977  return succeeded(res);
978 }
979 
980 bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
981  int64_t offset;
982  SmallVector<int64_t> strides;
983  auto successStrides = getStridesAndOffset(type, strides, offset);
984  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
985 }
986 
987 bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
988  if (!isLastMemrefDimUnitStride(type))
989  return false;
990 
991  auto memrefShape = type.getShape().take_back(n);
992  if (ShapedType::isDynamicShape(memrefShape))
993  return false;
994 
995  if (type.getLayout().isIdentity())
996  return true;
997 
998  int64_t offset;
999  SmallVector<int64_t> stridesFull;
1000  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
1001  return false;
1002  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
1003 
1004  if (strides.empty())
1005  return true;
1006 
1007  // Check whether strides match "flattened" dims.
1008  SmallVector<int64_t> flattenedDims;
1009  auto dimProduct = 1;
1010  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
1011  dimProduct *= dim;
1012  flattenedDims.push_back(dimProduct);
1013  }
1014 
1015  strides = strides.drop_back(1);
1016  return llvm::equal(strides, llvm::reverse(flattenedDims));
1017 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:412
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:488
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:484
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:213
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:234
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:229
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:224
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:239
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627
bool trailingNDimsContiguous(MemRefType type, int64_t n)
Return "true" if the last N dimensions of the given type are contiguous.