MLIR  20.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 namespace mlir {
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37 } // namespace mlir
38 
39 //===----------------------------------------------------------------------===//
40 // BuiltinDialect
41 //===----------------------------------------------------------------------===//
42 
43 void BuiltinDialect::registerTypes() {
44  addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
47  >();
48 }
49 
50 //===----------------------------------------------------------------------===//
51 /// ComplexType
52 //===----------------------------------------------------------------------===//
53 
54 /// Verify the construction of an integer type.
56  Type elementType) {
57  if (!elementType.isIntOrFloat())
58  return emitError() << "invalid element type for complex";
59  return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Integer Type
64 //===----------------------------------------------------------------------===//
65 
66 /// Verify the construction of an integer type.
68  unsigned width,
69  SignednessSemantics signedness) {
70  if (width > IntegerType::kMaxWidth) {
71  return emitError() << "integer bitwidth is limited to "
72  << IntegerType::kMaxWidth << " bits";
73  }
74  return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80  return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84  if (!scale)
85  return IntegerType();
86  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92 
93 unsigned FloatType::getWidth() {
94  return APFloat::semanticsSizeInBits(getFloatSemantics());
95 }
96 
97 /// Returns the floating semantics for the given type.
98 const llvm::fltSemantics &FloatType::getFloatSemantics() {
99  if (llvm::isa<Float4E2M1FNType>(*this))
100  return APFloat::Float4E2M1FN();
101  if (llvm::isa<Float6E2M3FNType>(*this))
102  return APFloat::Float6E2M3FN();
103  if (llvm::isa<Float6E3M2FNType>(*this))
104  return APFloat::Float6E3M2FN();
105  if (llvm::isa<Float8E5M2Type>(*this))
106  return APFloat::Float8E5M2();
107  if (llvm::isa<Float8E4M3Type>(*this))
108  return APFloat::Float8E4M3();
109  if (llvm::isa<Float8E4M3FNType>(*this))
110  return APFloat::Float8E4M3FN();
111  if (llvm::isa<Float8E5M2FNUZType>(*this))
112  return APFloat::Float8E5M2FNUZ();
113  if (llvm::isa<Float8E4M3FNUZType>(*this))
114  return APFloat::Float8E4M3FNUZ();
115  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
116  return APFloat::Float8E4M3B11FNUZ();
117  if (llvm::isa<Float8E3M4Type>(*this))
118  return APFloat::Float8E3M4();
119  if (llvm::isa<Float8E8M0FNUType>(*this))
120  return APFloat::Float8E8M0FNU();
121  if (llvm::isa<BFloat16Type>(*this))
122  return APFloat::BFloat();
123  if (llvm::isa<Float16Type>(*this))
124  return APFloat::IEEEhalf();
125  if (llvm::isa<FloatTF32Type>(*this))
126  return APFloat::FloatTF32();
127  if (llvm::isa<Float32Type>(*this))
128  return APFloat::IEEEsingle();
129  if (llvm::isa<Float64Type>(*this))
130  return APFloat::IEEEdouble();
131  if (llvm::isa<Float80Type>(*this))
132  return APFloat::x87DoubleExtended();
133  if (llvm::isa<Float128Type>(*this))
134  return APFloat::IEEEquad();
135  llvm_unreachable("non-floating point type used");
136 }
137 
139  if (!scale)
140  return FloatType();
141  MLIRContext *ctx = getContext();
142  if (isF16() || isBF16()) {
143  if (scale == 2)
144  return FloatType::getF32(ctx);
145  if (scale == 4)
146  return FloatType::getF64(ctx);
147  }
148  if (isF32())
149  if (scale == 2)
150  return FloatType::getF64(ctx);
151  return FloatType();
152 }
153 
155  return APFloat::semanticsPrecision(getFloatSemantics());
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // FunctionType
160 //===----------------------------------------------------------------------===//
161 
162 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
163 
164 ArrayRef<Type> FunctionType::getInputs() const {
165  return getImpl()->getInputs();
166 }
167 
168 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
169 
170 ArrayRef<Type> FunctionType::getResults() const {
171  return getImpl()->getResults();
172 }
173 
174 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
175  return get(getContext(), inputs, results);
176 }
177 
178 /// Returns a new function type with the specified arguments and results
179 /// inserted.
180 FunctionType FunctionType::getWithArgsAndResults(
181  ArrayRef<unsigned> argIndices, TypeRange argTypes,
182  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
183  SmallVector<Type> argStorage, resultStorage;
184  TypeRange newArgTypes =
185  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
186  TypeRange newResultTypes =
187  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
188  return clone(newArgTypes, newResultTypes);
189 }
190 
191 /// Returns a new function type without the specified arguments and results.
192 FunctionType
193 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
194  const BitVector &resultIndices) {
195  SmallVector<Type> argStorage, resultStorage;
196  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
197  TypeRange newResultTypes =
198  filterTypesOut(getResults(), resultIndices, resultStorage);
199  return clone(newArgTypes, newResultTypes);
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // OpaqueType
204 //===----------------------------------------------------------------------===//
205 
206 /// Verify the construction of an opaque type.
208  StringAttr dialect, StringRef typeData) {
209  if (!Dialect::isValidNamespace(dialect.strref()))
210  return emitError() << "invalid dialect namespace '" << dialect << "'";
211 
212  // Check that the dialect is actually registered.
213  MLIRContext *context = dialect.getContext();
214  if (!context->allowsUnregisteredDialects() &&
215  !context->getLoadedDialect(dialect.strref())) {
216  return emitError()
217  << "`!" << dialect << "<\"" << typeData << "\">"
218  << "` type created with unregistered dialect. If this is "
219  "intended, please call allowUnregisteredDialects() on the "
220  "MLIRContext, or use -allow-unregistered-dialect with "
221  "the MLIR opt tool used";
222  }
223 
224  return success();
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // VectorType
229 //===----------------------------------------------------------------------===//
230 
231 bool VectorType::isValidElementType(Type t) {
232  return isValidVectorTypeElementType(t);
233 }
234 
236  ArrayRef<int64_t> shape, Type elementType,
237  ArrayRef<bool> scalableDims) {
238  if (!isValidElementType(elementType))
239  return emitError()
240  << "vector elements must be int/index/float type but got "
241  << elementType;
242 
243  if (any_of(shape, [](int64_t i) { return i <= 0; }))
244  return emitError()
245  << "vector types must have positive constant sizes but got "
246  << shape;
247 
248  if (scalableDims.size() != shape.size())
249  return emitError() << "number of dims must match, got "
250  << scalableDims.size() << " and " << shape.size();
251 
252  return success();
253 }
254 
255 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
256  if (!scale)
257  return VectorType();
258  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
259  if (auto scaledEt = et.scaleElementBitwidth(scale))
260  return VectorType::get(getShape(), scaledEt, getScalableDims());
261  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
262  if (auto scaledEt = et.scaleElementBitwidth(scale))
263  return VectorType::get(getShape(), scaledEt, getScalableDims());
264  return VectorType();
265 }
266 
267 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
268  Type elementType) const {
269  return VectorType::get(shape.value_or(getShape()), elementType,
270  getScalableDims());
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // TensorType
275 //===----------------------------------------------------------------------===//
276 
279  .Case<RankedTensorType, UnrankedTensorType>(
280  [](auto type) { return type.getElementType(); });
281 }
282 
283 bool TensorType::hasRank() const {
284  return !llvm::isa<UnrankedTensorType>(*this);
285 }
286 
288  return llvm::cast<RankedTensorType>(*this).getShape();
289 }
290 
292  Type elementType) const {
293  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
294  if (shape)
295  return RankedTensorType::get(*shape, elementType);
296  return UnrankedTensorType::get(elementType);
297  }
298 
299  auto rankedTy = llvm::cast<RankedTensorType>(*this);
300  if (!shape)
301  return RankedTensorType::get(rankedTy.getShape(), elementType,
302  rankedTy.getEncoding());
303  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
304  rankedTy.getEncoding());
305 }
306 
307 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
308  Type elementType) const {
309  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
310 }
311 
312 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
313  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
314 }
315 
316 // Check if "elementType" can be an element type of a tensor.
317 static LogicalResult
319  Type elementType) {
320  if (!TensorType::isValidElementType(elementType))
321  return emitError() << "invalid tensor element type: " << elementType;
322  return success();
323 }
324 
325 /// Return true if the specified element type is ok in a tensor.
327  // Note: Non standard/builtin types are allowed to exist within tensor
328  // types. Dialects are expected to verify that tensor types have a valid
329  // element type within that dialect.
330  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
331  IndexType>(type) ||
332  !llvm::isa<BuiltinDialect>(type.getDialect());
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // RankedTensorType
337 //===----------------------------------------------------------------------===//
338 
339 LogicalResult
341  ArrayRef<int64_t> shape, Type elementType,
342  Attribute encoding) {
343  for (int64_t s : shape)
344  if (s < 0 && !ShapedType::isDynamic(s))
345  return emitError() << "invalid tensor dimension size";
346  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
347  if (failed(v.verifyEncoding(shape, elementType, emitError)))
348  return failure();
349  return checkTensorElementType(emitError, elementType);
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // UnrankedTensorType
354 //===----------------------------------------------------------------------===//
355 
356 LogicalResult
358  Type elementType) {
359  return checkTensorElementType(emitError, elementType);
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // BaseMemRefType
364 //===----------------------------------------------------------------------===//
365 
368  .Case<MemRefType, UnrankedMemRefType>(
369  [](auto type) { return type.getElementType(); });
370 }
371 
373  return !llvm::isa<UnrankedMemRefType>(*this);
374 }
375 
377  return llvm::cast<MemRefType>(*this).getShape();
378 }
379 
381  Type elementType) const {
382  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
383  if (!shape)
384  return UnrankedMemRefType::get(elementType, getMemorySpace());
385  MemRefType::Builder builder(*shape, elementType);
386  builder.setMemorySpace(getMemorySpace());
387  return builder;
388  }
389 
390  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
391  if (shape)
392  builder.setShape(*shape);
393  builder.setElementType(elementType);
394  return builder;
395 }
396 
398  Type elementType) const {
399  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
400 }
401 
402 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
403  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
404 }
405 
407  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
408  return rankedMemRefTy.getMemorySpace();
409  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
410 }
411 
413  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
414  return rankedMemRefTy.getMemorySpaceAsInt();
415  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // MemRefType
420 //===----------------------------------------------------------------------===//
421 
422 std::optional<llvm::SmallDenseSet<unsigned>>
424  ArrayRef<int64_t> reducedShape,
425  bool matchDynamic) {
426  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
427  llvm::SmallDenseSet<unsigned> unusedDims;
428  unsigned reducedIdx = 0;
429  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
430  // Greedily insert `originalIdx` if match.
431  int64_t origSize = originalShape[originalIdx];
432  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
433  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
434  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
435  ShapedType::isDynamic(origSize))) {
436  reducedIdx++;
437  continue;
438  }
439  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
440  reducedIdx++;
441  continue;
442  }
443 
444  unusedDims.insert(originalIdx);
445  // If no match on `originalIdx`, the `originalShape` at this dimension
446  // must be 1, otherwise we bail.
447  if (origSize != 1)
448  return std::nullopt;
449  }
450  // The whole reducedShape must be scanned, otherwise we bail.
451  if (reducedIdx != reducedRank)
452  return std::nullopt;
453  return unusedDims;
454 }
455 
457 mlir::isRankReducedType(ShapedType originalType,
458  ShapedType candidateReducedType) {
459  if (originalType == candidateReducedType)
461 
462  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
463  ShapedType candidateReducedShapedType =
464  llvm::cast<ShapedType>(candidateReducedType);
465 
466  // Rank and size logic is valid for all ShapedTypes.
467  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
468  ArrayRef<int64_t> candidateReducedShape =
469  candidateReducedShapedType.getShape();
470  unsigned originalRank = originalShape.size(),
471  candidateReducedRank = candidateReducedShape.size();
472  if (candidateReducedRank > originalRank)
474 
475  auto optionalUnusedDimsMask =
476  computeRankReductionMask(originalShape, candidateReducedShape);
477 
478  // Sizes cannot be matched in case empty vector is returned.
479  if (!optionalUnusedDimsMask)
481 
482  if (originalShapedType.getElementType() !=
483  candidateReducedShapedType.getElementType())
485 
487 }
488 
490  // Empty attribute is allowed as default memory space.
491  if (!memorySpace)
492  return true;
493 
494  // Supported built-in attributes.
495  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
496  return true;
497 
498  // Allow custom dialect attributes.
499  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
500  return true;
501 
502  return false;
503 }
504 
506  MLIRContext *ctx) {
507  if (memorySpace == 0)
508  return nullptr;
509 
510  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
511 }
512 
514  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
515  if (intMemorySpace && intMemorySpace.getValue() == 0)
516  return nullptr;
517 
518  return memorySpace;
519 }
520 
522  if (!memorySpace)
523  return 0;
524 
525  assert(llvm::isa<IntegerAttr>(memorySpace) &&
526  "Using `getMemorySpaceInteger` with non-Integer attribute");
527 
528  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
529 }
530 
531 unsigned MemRefType::getMemorySpaceAsInt() const {
532  return detail::getMemorySpaceAsInt(getMemorySpace());
533 }
534 
535 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
536  MemRefLayoutAttrInterface layout,
537  Attribute memorySpace) {
538  // Use default layout for empty attribute.
539  if (!layout)
541  shape.size(), elementType.getContext()));
542 
543  // Drop default memory space value and replace it with empty attribute.
544  memorySpace = skipDefaultMemorySpace(memorySpace);
545 
546  return Base::get(elementType.getContext(), shape, elementType, layout,
547  memorySpace);
548 }
549 
550 MemRefType MemRefType::getChecked(
551  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
552  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
553 
554  // Use default layout for empty attribute.
555  if (!layout)
557  shape.size(), elementType.getContext()));
558 
559  // Drop default memory space value and replace it with empty attribute.
560  memorySpace = skipDefaultMemorySpace(memorySpace);
561 
562  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
563  elementType, layout, memorySpace);
564 }
565 
566 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
567  AffineMap map, Attribute memorySpace) {
568 
569  // Use default layout for empty map.
570  if (!map)
571  map = AffineMap::getMultiDimIdentityMap(shape.size(),
572  elementType.getContext());
573 
574  // Wrap AffineMap into Attribute.
575  auto layout = AffineMapAttr::get(map);
576 
577  // Drop default memory space value and replace it with empty attribute.
578  memorySpace = skipDefaultMemorySpace(memorySpace);
579 
580  return Base::get(elementType.getContext(), shape, elementType, layout,
581  memorySpace);
582 }
583 
584 MemRefType
585 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
586  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
587  Attribute memorySpace) {
588 
589  // Use default layout for empty map.
590  if (!map)
591  map = AffineMap::getMultiDimIdentityMap(shape.size(),
592  elementType.getContext());
593 
594  // Wrap AffineMap into Attribute.
595  auto layout = AffineMapAttr::get(map);
596 
597  // Drop default memory space value and replace it with empty attribute.
598  memorySpace = skipDefaultMemorySpace(memorySpace);
599 
600  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
601  elementType, layout, memorySpace);
602 }
603 
604 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
605  AffineMap map, unsigned memorySpaceInd) {
606 
607  // Use default layout for empty map.
608  if (!map)
609  map = AffineMap::getMultiDimIdentityMap(shape.size(),
610  elementType.getContext());
611 
612  // Wrap AffineMap into Attribute.
613  auto layout = AffineMapAttr::get(map);
614 
615  // Convert deprecated integer-like memory space to Attribute.
616  Attribute memorySpace =
617  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
618 
619  return Base::get(elementType.getContext(), shape, elementType, layout,
620  memorySpace);
621 }
622 
623 MemRefType
624 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
625  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
626  unsigned memorySpaceInd) {
627 
628  // Use default layout for empty map.
629  if (!map)
630  map = AffineMap::getMultiDimIdentityMap(shape.size(),
631  elementType.getContext());
632 
633  // Wrap AffineMap into Attribute.
634  auto layout = AffineMapAttr::get(map);
635 
636  // Convert deprecated integer-like memory space to Attribute.
637  Attribute memorySpace =
638  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
639 
640  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
641  elementType, layout, memorySpace);
642 }
643 
645  ArrayRef<int64_t> shape, Type elementType,
646  MemRefLayoutAttrInterface layout,
647  Attribute memorySpace) {
648  if (!BaseMemRefType::isValidElementType(elementType))
649  return emitError() << "invalid memref element type";
650 
651  // Negative sizes are not allowed except for `kDynamic`.
652  for (int64_t s : shape)
653  if (s < 0 && !ShapedType::isDynamic(s))
654  return emitError() << "invalid memref size";
655 
656  assert(layout && "missing layout specification");
657  if (failed(layout.verifyLayout(shape, emitError)))
658  return failure();
659 
660  if (!isSupportedMemorySpace(memorySpace))
661  return emitError() << "unsupported memory space Attribute";
662 
663  return success();
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // UnrankedMemRefType
668 //===----------------------------------------------------------------------===//
669 
671  return detail::getMemorySpaceAsInt(getMemorySpace());
672 }
673 
674 LogicalResult
676  Type elementType, Attribute memorySpace) {
677  if (!BaseMemRefType::isValidElementType(elementType))
678  return emitError() << "invalid memref element type";
679 
680  if (!isSupportedMemorySpace(memorySpace))
681  return emitError() << "unsupported memory space Attribute";
682 
683  return success();
684 }
685 
686 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
687 // i.e. single term). Accumulate the AffineExpr into the existing one.
689  AffineExpr multiplicativeFactor,
691  AffineExpr &offset) {
692  if (auto dim = dyn_cast<AffineDimExpr>(e))
693  strides[dim.getPosition()] =
694  strides[dim.getPosition()] + multiplicativeFactor;
695  else
696  offset = offset + e * multiplicativeFactor;
697 }
698 
699 /// Takes a single AffineExpr `e` and populates the `strides` array with the
700 /// strides expressions for each dim position.
701 /// The convention is that the strides for dimensions d0, .. dn appear in
702 /// order to make indexing intuitive into the result.
703 static LogicalResult extractStrides(AffineExpr e,
704  AffineExpr multiplicativeFactor,
706  AffineExpr &offset) {
707  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
708  if (!bin) {
709  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
710  return success();
711  }
712 
713  if (bin.getKind() == AffineExprKind::CeilDiv ||
714  bin.getKind() == AffineExprKind::FloorDiv ||
715  bin.getKind() == AffineExprKind::Mod)
716  return failure();
717 
718  if (bin.getKind() == AffineExprKind::Mul) {
719  auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
720  if (dim) {
721  strides[dim.getPosition()] =
722  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
723  return success();
724  }
725  // LHS and RHS may both contain complex expressions of dims. Try one path
726  // and if it fails try the other. This is guaranteed to succeed because
727  // only one path may have a `dim`, otherwise this is not an AffineExpr in
728  // the first place.
729  if (bin.getLHS().isSymbolicOrConstant())
730  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
731  strides, offset);
732  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
733  strides, offset);
734  }
735 
736  if (bin.getKind() == AffineExprKind::Add) {
737  auto res1 =
738  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
739  auto res2 =
740  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
741  return success(succeeded(res1) && succeeded(res2));
742  }
743 
744  llvm_unreachable("unexpected binary operation");
745 }
746 
747 /// A stride specification is a list of integer values that are either static
748 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
749 /// the distance in the number of elements between successive entries along a
750 /// particular dimension.
751 ///
752 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
753 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
754 /// distance between two consecutive elements along the outer dimension is `1`
755 /// and the distance between two consecutive elements along the inner dimension
756 /// is `64`.
757 ///
758 /// The convention is that the strides for dimensions d0, .. dn appear in
759 /// order to make indexing intuitive into the result.
760 static LogicalResult getStridesAndOffset(MemRefType t,
761  SmallVectorImpl<AffineExpr> &strides,
762  AffineExpr &offset) {
763  AffineMap m = t.getLayout().getAffineMap();
764 
765  if (m.getNumResults() != 1 && !m.isIdentity())
766  return failure();
767 
768  auto zero = getAffineConstantExpr(0, t.getContext());
769  auto one = getAffineConstantExpr(1, t.getContext());
770  offset = zero;
771  strides.assign(t.getRank(), zero);
772 
773  // Canonical case for empty map.
774  if (m.isIdentity()) {
775  // 0-D corner case, offset is already 0.
776  if (t.getRank() == 0)
777  return success();
778  auto stridedExpr =
779  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
780  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
781  return success();
782  assert(false && "unexpected failure: extract strides in canonical layout");
783  }
784 
785  // Non-canonical case requires more work.
786  auto stridedExpr =
787  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
788  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
789  offset = AffineExpr();
790  strides.clear();
791  return failure();
792  }
793 
794  // Simplify results to allow folding to constants and simple checks.
795  unsigned numDims = m.getNumDims();
796  unsigned numSymbols = m.getNumSymbols();
797  offset = simplifyAffineExpr(offset, numDims, numSymbols);
798  for (auto &stride : strides)
799  stride = simplifyAffineExpr(stride, numDims, numSymbols);
800 
801  return success();
802 }
803 
804 LogicalResult mlir::getStridesAndOffset(MemRefType t,
805  SmallVectorImpl<int64_t> &strides,
806  int64_t &offset) {
807  // Happy path: the type uses the strided layout directly.
808  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
809  llvm::append_range(strides, strided.getStrides());
810  offset = strided.getOffset();
811  return success();
812  }
813 
814  // Otherwise, defer to the affine fallback as layouts are supposed to be
815  // convertible to affine maps.
816  AffineExpr offsetExpr;
817  SmallVector<AffineExpr, 4> strideExprs;
818  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
819  return failure();
820  if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
821  offset = cst.getValue();
822  else
823  offset = ShapedType::kDynamic;
824  for (auto e : strideExprs) {
825  if (auto c = dyn_cast<AffineConstantExpr>(e))
826  strides.push_back(c.getValue());
827  else
828  strides.push_back(ShapedType::kDynamic);
829  }
830  return success();
831 }
832 
833 std::pair<SmallVector<int64_t>, int64_t>
835  SmallVector<int64_t> strides;
836  int64_t offset;
837  LogicalResult status = getStridesAndOffset(t, strides, offset);
838  (void)status;
839  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
840  return {strides, offset};
841 }
842 
843 //===----------------------------------------------------------------------===//
844 /// TupleType
845 //===----------------------------------------------------------------------===//
846 
847 /// Return the elements types for this tuple.
848 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
849 
850 /// Accumulate the types contained in this tuple and tuples nested within it.
851 /// Note that this only flattens nested tuples, not any other container type,
852 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
853 /// (i32, tensor<i32>, f32, i64)
854 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
855  for (Type type : getTypes()) {
856  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
857  nestedTuple.getFlattenedTypes(types);
858  else
859  types.push_back(type);
860  }
861 }
862 
863 /// Return the number of element types.
864 size_t TupleType::size() const { return getImpl()->size(); }
865 
866 //===----------------------------------------------------------------------===//
867 // Type Utilities
868 //===----------------------------------------------------------------------===//
869 
870 /// Return a version of `t` with identity layout if it can be determined
871 /// statically that the layout is the canonical contiguous strided layout.
872 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
873 /// `t` with simplified layout.
874 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
875 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
876  AffineMap m = t.getLayout().getAffineMap();
877 
878  // Already in canonical form.
879  if (m.isIdentity())
880  return t;
881 
882  // Can't reduce to canonical identity form, return in canonical form.
883  if (m.getNumResults() > 1)
884  return t;
885 
886  // Corner-case for 0-D affine maps.
887  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
888  if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
889  if (cst.getValue() == 0)
890  return MemRefType::Builder(t).setLayout({});
891  return t;
892  }
893 
894  // 0-D corner case for empty shape that still have an affine map. Example:
895  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
896  // offset needs to remain, just return t.
897  if (t.getShape().empty())
898  return t;
899 
900  // If the canonical strided layout for the sizes of `t` is equal to the
901  // simplified layout of `t` we can just return an empty layout. Otherwise,
902  // just simplify the existing layout.
903  AffineExpr expr =
904  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
905  auto simplifiedLayoutExpr =
906  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
907  if (expr != simplifiedLayoutExpr)
909  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
910  return MemRefType::Builder(t).setLayout({});
911 }
912 
914  ArrayRef<AffineExpr> exprs,
915  MLIRContext *context) {
916  // Size 0 corner case is useful for canonicalizations.
917  if (sizes.empty())
918  return getAffineConstantExpr(0, context);
919 
920  assert(!exprs.empty() && "expected exprs");
921  auto maps = AffineMap::inferFromExprList(exprs, context);
922  assert(!maps.empty() && "Expected one non-empty map");
923  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
924 
925  AffineExpr expr;
926  bool dynamicPoisonBit = false;
927  int64_t runningSize = 1;
928  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
929  int64_t size = std::get<1>(en);
930  AffineExpr dimExpr = std::get<0>(en);
931  AffineExpr stride = dynamicPoisonBit
932  ? getAffineSymbolExpr(nSymbols++, context)
933  : getAffineConstantExpr(runningSize, context);
934  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
935  if (size > 0) {
936  runningSize *= size;
937  assert(runningSize > 0 && "integer overflow in size computation");
938  } else {
939  dynamicPoisonBit = true;
940  }
941  }
942  return simplifyAffineExpr(expr, numDims, nSymbols);
943 }
944 
946  MLIRContext *context) {
948  exprs.reserve(sizes.size());
949  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
950  exprs.push_back(getAffineDimExpr(dim, context));
951  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
952 }
953 
954 bool mlir::isStrided(MemRefType t) {
955  int64_t offset;
956  SmallVector<int64_t, 4> strides;
957  auto res = getStridesAndOffset(t, strides, offset);
958  return succeeded(res);
959 }
960 
961 bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
962  int64_t offset;
963  SmallVector<int64_t> strides;
964  auto successStrides = getStridesAndOffset(type, strides, offset);
965  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
966 }
967 
968 bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
969  if (!isLastMemrefDimUnitStride(type))
970  return false;
971 
972  auto memrefShape = type.getShape().take_back(n);
973  if (ShapedType::isDynamicShape(memrefShape))
974  return false;
975 
976  if (type.getLayout().isIdentity())
977  return true;
978 
979  int64_t offset;
980  SmallVector<int64_t> stridesFull;
981  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
982  return false;
983  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
984 
985  if (strides.empty())
986  return true;
987 
988  // Check whether strides match "flattened" dims.
989  SmallVector<int64_t> flattenedDims;
990  auto dimProduct = 1;
991  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
992  dimProduct *= dim;
993  flattenedDims.push_back(dimProduct);
994  }
995 
996  strides = strides.drop_back(1);
997  return llvm::equal(strides, llvm::reverse(flattenedDims));
998 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:444
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:520
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:516
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:213
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:234
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:229
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:224
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:239
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627
bool trailingNDimsContiguous(MemRefType type, int64_t n)
Return "true" if the last N dimensions of the given type are contiguous.