MLIR  19.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 // BuiltinDialect
37 //===----------------------------------------------------------------------===//
38 
39 void BuiltinDialect::registerTypes() {
40  addTypes<
41 #define GET_TYPEDEF_LIST
42 #include "mlir/IR/BuiltinTypes.cpp.inc"
43  >();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 /// ComplexType
48 //===----------------------------------------------------------------------===//
49 
50 /// Verify the construction of an integer type.
52  Type elementType) {
53  if (!elementType.isIntOrFloat())
54  return emitError() << "invalid element type for complex";
55  return success();
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Integer Type
60 //===----------------------------------------------------------------------===//
61 
62 /// Verify the construction of an integer type.
64  unsigned width,
65  SignednessSemantics signedness) {
66  if (width > IntegerType::kMaxWidth) {
67  return emitError() << "integer bitwidth is limited to "
68  << IntegerType::kMaxWidth << " bits";
69  }
70  return success();
71 }
72 
73 unsigned IntegerType::getWidth() const { return getImpl()->width; }
74 
75 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
76  return getImpl()->signedness;
77 }
78 
79 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
80  if (!scale)
81  return IntegerType();
82  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Float Type
87 //===----------------------------------------------------------------------===//
88 
89 unsigned FloatType::getWidth() {
90  if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
91  Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
92  return 8;
93  if (llvm::isa<Float16Type, BFloat16Type>(*this))
94  return 16;
95  if (llvm::isa<Float32Type, FloatTF32Type>(*this))
96  return 32;
97  if (llvm::isa<Float64Type>(*this))
98  return 64;
99  if (llvm::isa<Float80Type>(*this))
100  return 80;
101  if (llvm::isa<Float128Type>(*this))
102  return 128;
103  llvm_unreachable("unexpected float type");
104 }
105 
106 /// Returns the floating semantics for the given type.
107 const llvm::fltSemantics &FloatType::getFloatSemantics() {
108  if (llvm::isa<Float8E5M2Type>(*this))
109  return APFloat::Float8E5M2();
110  if (llvm::isa<Float8E4M3FNType>(*this))
111  return APFloat::Float8E4M3FN();
112  if (llvm::isa<Float8E5M2FNUZType>(*this))
113  return APFloat::Float8E5M2FNUZ();
114  if (llvm::isa<Float8E4M3FNUZType>(*this))
115  return APFloat::Float8E4M3FNUZ();
116  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
117  return APFloat::Float8E4M3B11FNUZ();
118  if (llvm::isa<BFloat16Type>(*this))
119  return APFloat::BFloat();
120  if (llvm::isa<Float16Type>(*this))
121  return APFloat::IEEEhalf();
122  if (llvm::isa<FloatTF32Type>(*this))
123  return APFloat::FloatTF32();
124  if (llvm::isa<Float32Type>(*this))
125  return APFloat::IEEEsingle();
126  if (llvm::isa<Float64Type>(*this))
127  return APFloat::IEEEdouble();
128  if (llvm::isa<Float80Type>(*this))
129  return APFloat::x87DoubleExtended();
130  if (llvm::isa<Float128Type>(*this))
131  return APFloat::IEEEquad();
132  llvm_unreachable("non-floating point type used");
133 }
134 
136  if (!scale)
137  return FloatType();
138  MLIRContext *ctx = getContext();
139  if (isF16() || isBF16()) {
140  if (scale == 2)
141  return FloatType::getF32(ctx);
142  if (scale == 4)
143  return FloatType::getF64(ctx);
144  }
145  if (isF32())
146  if (scale == 2)
147  return FloatType::getF64(ctx);
148  return FloatType();
149 }
150 
152  return APFloat::semanticsPrecision(getFloatSemantics());
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // FunctionType
157 //===----------------------------------------------------------------------===//
158 
159 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
160 
161 ArrayRef<Type> FunctionType::getInputs() const {
162  return getImpl()->getInputs();
163 }
164 
165 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
166 
167 ArrayRef<Type> FunctionType::getResults() const {
168  return getImpl()->getResults();
169 }
170 
171 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
172  return get(getContext(), inputs, results);
173 }
174 
175 /// Returns a new function type with the specified arguments and results
176 /// inserted.
177 FunctionType FunctionType::getWithArgsAndResults(
178  ArrayRef<unsigned> argIndices, TypeRange argTypes,
179  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
180  SmallVector<Type> argStorage, resultStorage;
181  TypeRange newArgTypes =
182  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
183  TypeRange newResultTypes =
184  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
185  return clone(newArgTypes, newResultTypes);
186 }
187 
188 /// Returns a new function type without the specified arguments and results.
189 FunctionType
190 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
191  const BitVector &resultIndices) {
192  SmallVector<Type> argStorage, resultStorage;
193  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
194  TypeRange newResultTypes =
195  filterTypesOut(getResults(), resultIndices, resultStorage);
196  return clone(newArgTypes, newResultTypes);
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // OpaqueType
201 //===----------------------------------------------------------------------===//
202 
203 /// Verify the construction of an opaque type.
205  StringAttr dialect, StringRef typeData) {
206  if (!Dialect::isValidNamespace(dialect.strref()))
207  return emitError() << "invalid dialect namespace '" << dialect << "'";
208 
209  // Check that the dialect is actually registered.
210  MLIRContext *context = dialect.getContext();
211  if (!context->allowsUnregisteredDialects() &&
212  !context->getLoadedDialect(dialect.strref())) {
213  return emitError()
214  << "`!" << dialect << "<\"" << typeData << "\">"
215  << "` type created with unregistered dialect. If this is "
216  "intended, please call allowUnregisteredDialects() on the "
217  "MLIRContext, or use -allow-unregistered-dialect with "
218  "the MLIR opt tool used";
219  }
220 
221  return success();
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // VectorType
226 //===----------------------------------------------------------------------===//
227 
229  ArrayRef<int64_t> shape, Type elementType,
230  ArrayRef<bool> scalableDims) {
231  if (!isValidElementType(elementType))
232  return emitError()
233  << "vector elements must be int/index/float type but got "
234  << elementType;
235 
236  if (any_of(shape, [](int64_t i) { return i <= 0; }))
237  return emitError()
238  << "vector types must have positive constant sizes but got "
239  << shape;
240 
241  if (scalableDims.size() != shape.size())
242  return emitError() << "number of dims must match, got "
243  << scalableDims.size() << " and " << shape.size();
244 
245  return success();
246 }
247 
248 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
249  if (!scale)
250  return VectorType();
251  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
252  if (auto scaledEt = et.scaleElementBitwidth(scale))
253  return VectorType::get(getShape(), scaledEt, getScalableDims());
254  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
255  if (auto scaledEt = et.scaleElementBitwidth(scale))
256  return VectorType::get(getShape(), scaledEt, getScalableDims());
257  return VectorType();
258 }
259 
260 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
261  Type elementType) const {
262  return VectorType::get(shape.value_or(getShape()), elementType,
263  getScalableDims());
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // TensorType
268 //===----------------------------------------------------------------------===//
269 
272  .Case<RankedTensorType, UnrankedTensorType>(
273  [](auto type) { return type.getElementType(); });
274 }
275 
276 bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
277 
279  return llvm::cast<RankedTensorType>(*this).getShape();
280 }
281 
283  Type elementType) const {
284  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
285  if (shape)
286  return RankedTensorType::get(*shape, elementType);
287  return UnrankedTensorType::get(elementType);
288  }
289 
290  auto rankedTy = llvm::cast<RankedTensorType>(*this);
291  if (!shape)
292  return RankedTensorType::get(rankedTy.getShape(), elementType,
293  rankedTy.getEncoding());
294  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
295  rankedTy.getEncoding());
296 }
297 
298 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
299  Type elementType) const {
300  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
301 }
302 
303 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
304  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
305 }
306 
307 // Check if "elementType" can be an element type of a tensor.
308 static LogicalResult
310  Type elementType) {
311  if (!TensorType::isValidElementType(elementType))
312  return emitError() << "invalid tensor element type: " << elementType;
313  return success();
314 }
315 
316 /// Return true if the specified element type is ok in a tensor.
318  // Note: Non standard/builtin types are allowed to exist within tensor
319  // types. Dialects are expected to verify that tensor types have a valid
320  // element type within that dialect.
321  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
322  IndexType>(type) ||
323  !llvm::isa<BuiltinDialect>(type.getDialect());
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // RankedTensorType
328 //===----------------------------------------------------------------------===//
329 
332  ArrayRef<int64_t> shape, Type elementType,
333  Attribute encoding) {
334  for (int64_t s : shape)
335  if (s < 0 && !ShapedType::isDynamic(s))
336  return emitError() << "invalid tensor dimension size";
337  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
338  if (failed(v.verifyEncoding(shape, elementType, emitError)))
339  return failure();
340  return checkTensorElementType(emitError, elementType);
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // UnrankedTensorType
345 //===----------------------------------------------------------------------===//
346 
349  Type elementType) {
350  return checkTensorElementType(emitError, elementType);
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // BaseMemRefType
355 //===----------------------------------------------------------------------===//
356 
359  .Case<MemRefType, UnrankedMemRefType>(
360  [](auto type) { return type.getElementType(); });
361 }
362 
363 bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
364 
366  return llvm::cast<MemRefType>(*this).getShape();
367 }
368 
370  Type elementType) const {
371  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
372  if (!shape)
373  return UnrankedMemRefType::get(elementType, getMemorySpace());
374  MemRefType::Builder builder(*shape, elementType);
375  builder.setMemorySpace(getMemorySpace());
376  return builder;
377  }
378 
379  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
380  if (shape)
381  builder.setShape(*shape);
382  builder.setElementType(elementType);
383  return builder;
384 }
385 
387  Type elementType) const {
388  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
389 }
390 
391 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
392  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
393 }
394 
396  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
397  return rankedMemRefTy.getMemorySpace();
398  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
399 }
400 
402  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
403  return rankedMemRefTy.getMemorySpaceAsInt();
404  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // MemRefType
409 //===----------------------------------------------------------------------===//
410 
411 std::optional<llvm::SmallDenseSet<unsigned>>
413  ArrayRef<int64_t> reducedShape,
414  bool matchDynamic) {
415  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
416  llvm::SmallDenseSet<unsigned> unusedDims;
417  unsigned reducedIdx = 0;
418  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
419  // Greedily insert `originalIdx` if match.
420  int64_t origSize = originalShape[originalIdx];
421  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
422  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
423  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
424  ShapedType::isDynamic(origSize))) {
425  reducedIdx++;
426  continue;
427  }
428  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
429  reducedIdx++;
430  continue;
431  }
432 
433  unusedDims.insert(originalIdx);
434  // If no match on `originalIdx`, the `originalShape` at this dimension
435  // must be 1, otherwise we bail.
436  if (origSize != 1)
437  return std::nullopt;
438  }
439  // The whole reducedShape must be scanned, otherwise we bail.
440  if (reducedIdx != reducedRank)
441  return std::nullopt;
442  return unusedDims;
443 }
444 
446 mlir::isRankReducedType(ShapedType originalType,
447  ShapedType candidateReducedType) {
448  if (originalType == candidateReducedType)
450 
451  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
452  ShapedType candidateReducedShapedType =
453  llvm::cast<ShapedType>(candidateReducedType);
454 
455  // Rank and size logic is valid for all ShapedTypes.
456  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
457  ArrayRef<int64_t> candidateReducedShape =
458  candidateReducedShapedType.getShape();
459  unsigned originalRank = originalShape.size(),
460  candidateReducedRank = candidateReducedShape.size();
461  if (candidateReducedRank > originalRank)
463 
464  auto optionalUnusedDimsMask =
465  computeRankReductionMask(originalShape, candidateReducedShape);
466 
467  // Sizes cannot be matched in case empty vector is returned.
468  if (!optionalUnusedDimsMask)
470 
471  if (originalShapedType.getElementType() !=
472  candidateReducedShapedType.getElementType())
474 
476 }
477 
479  // Empty attribute is allowed as default memory space.
480  if (!memorySpace)
481  return true;
482 
483  // Supported built-in attributes.
484  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
485  return true;
486 
487  // Allow custom dialect attributes.
488  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
489  return true;
490 
491  return false;
492 }
493 
495  MLIRContext *ctx) {
496  if (memorySpace == 0)
497  return nullptr;
498 
499  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
500 }
501 
503  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
504  if (intMemorySpace && intMemorySpace.getValue() == 0)
505  return nullptr;
506 
507  return memorySpace;
508 }
509 
511  if (!memorySpace)
512  return 0;
513 
514  assert(llvm::isa<IntegerAttr>(memorySpace) &&
515  "Using `getMemorySpaceInteger` with non-Integer attribute");
516 
517  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
518 }
519 
520 unsigned MemRefType::getMemorySpaceAsInt() const {
521  return detail::getMemorySpaceAsInt(getMemorySpace());
522 }
523 
524 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
525  MemRefLayoutAttrInterface layout,
526  Attribute memorySpace) {
527  // Use default layout for empty attribute.
528  if (!layout)
530  shape.size(), elementType.getContext()));
531 
532  // Drop default memory space value and replace it with empty attribute.
533  memorySpace = skipDefaultMemorySpace(memorySpace);
534 
535  return Base::get(elementType.getContext(), shape, elementType, layout,
536  memorySpace);
537 }
538 
539 MemRefType MemRefType::getChecked(
540  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
541  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
542 
543  // Use default layout for empty attribute.
544  if (!layout)
546  shape.size(), elementType.getContext()));
547 
548  // Drop default memory space value and replace it with empty attribute.
549  memorySpace = skipDefaultMemorySpace(memorySpace);
550 
551  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
552  elementType, layout, memorySpace);
553 }
554 
555 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
556  AffineMap map, Attribute memorySpace) {
557 
558  // Use default layout for empty map.
559  if (!map)
560  map = AffineMap::getMultiDimIdentityMap(shape.size(),
561  elementType.getContext());
562 
563  // Wrap AffineMap into Attribute.
564  auto layout = AffineMapAttr::get(map);
565 
566  // Drop default memory space value and replace it with empty attribute.
567  memorySpace = skipDefaultMemorySpace(memorySpace);
568 
569  return Base::get(elementType.getContext(), shape, elementType, layout,
570  memorySpace);
571 }
572 
573 MemRefType
574 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
575  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
576  Attribute memorySpace) {
577 
578  // Use default layout for empty map.
579  if (!map)
580  map = AffineMap::getMultiDimIdentityMap(shape.size(),
581  elementType.getContext());
582 
583  // Wrap AffineMap into Attribute.
584  auto layout = AffineMapAttr::get(map);
585 
586  // Drop default memory space value and replace it with empty attribute.
587  memorySpace = skipDefaultMemorySpace(memorySpace);
588 
589  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
590  elementType, layout, memorySpace);
591 }
592 
593 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
594  AffineMap map, unsigned memorySpaceInd) {
595 
596  // Use default layout for empty map.
597  if (!map)
598  map = AffineMap::getMultiDimIdentityMap(shape.size(),
599  elementType.getContext());
600 
601  // Wrap AffineMap into Attribute.
602  auto layout = AffineMapAttr::get(map);
603 
604  // Convert deprecated integer-like memory space to Attribute.
605  Attribute memorySpace =
606  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
607 
608  return Base::get(elementType.getContext(), shape, elementType, layout,
609  memorySpace);
610 }
611 
612 MemRefType
613 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
614  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
615  unsigned memorySpaceInd) {
616 
617  // Use default layout for empty map.
618  if (!map)
619  map = AffineMap::getMultiDimIdentityMap(shape.size(),
620  elementType.getContext());
621 
622  // Wrap AffineMap into Attribute.
623  auto layout = AffineMapAttr::get(map);
624 
625  // Convert deprecated integer-like memory space to Attribute.
626  Attribute memorySpace =
627  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
628 
629  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
630  elementType, layout, memorySpace);
631 }
632 
634  ArrayRef<int64_t> shape, Type elementType,
635  MemRefLayoutAttrInterface layout,
636  Attribute memorySpace) {
637  if (!BaseMemRefType::isValidElementType(elementType))
638  return emitError() << "invalid memref element type";
639 
640  // Negative sizes are not allowed except for `kDynamic`.
641  for (int64_t s : shape)
642  if (s < 0 && !ShapedType::isDynamic(s))
643  return emitError() << "invalid memref size";
644 
645  assert(layout && "missing layout specification");
646  if (failed(layout.verifyLayout(shape, emitError)))
647  return failure();
648 
649  if (!isSupportedMemorySpace(memorySpace))
650  return emitError() << "unsupported memory space Attribute";
651 
652  return success();
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // UnrankedMemRefType
657 //===----------------------------------------------------------------------===//
658 
660  return detail::getMemorySpaceAsInt(getMemorySpace());
661 }
662 
665  Type elementType, Attribute memorySpace) {
666  if (!BaseMemRefType::isValidElementType(elementType))
667  return emitError() << "invalid memref element type";
668 
669  if (!isSupportedMemorySpace(memorySpace))
670  return emitError() << "unsupported memory space Attribute";
671 
672  return success();
673 }
674 
675 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
676 // i.e. single term). Accumulate the AffineExpr into the existing one.
678  AffineExpr multiplicativeFactor,
680  AffineExpr &offset) {
681  if (auto dim = dyn_cast<AffineDimExpr>(e))
682  strides[dim.getPosition()] =
683  strides[dim.getPosition()] + multiplicativeFactor;
684  else
685  offset = offset + e * multiplicativeFactor;
686 }
687 
688 /// Takes a single AffineExpr `e` and populates the `strides` array with the
689 /// strides expressions for each dim position.
690 /// The convention is that the strides for dimensions d0, .. dn appear in
691 /// order to make indexing intuitive into the result.
693  AffineExpr multiplicativeFactor,
695  AffineExpr &offset) {
696  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
697  if (!bin) {
698  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
699  return success();
700  }
701 
702  if (bin.getKind() == AffineExprKind::CeilDiv ||
703  bin.getKind() == AffineExprKind::FloorDiv ||
704  bin.getKind() == AffineExprKind::Mod)
705  return failure();
706 
707  if (bin.getKind() == AffineExprKind::Mul) {
708  auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
709  if (dim) {
710  strides[dim.getPosition()] =
711  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
712  return success();
713  }
714  // LHS and RHS may both contain complex expressions of dims. Try one path
715  // and if it fails try the other. This is guaranteed to succeed because
716  // only one path may have a `dim`, otherwise this is not an AffineExpr in
717  // the first place.
718  if (bin.getLHS().isSymbolicOrConstant())
719  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
720  strides, offset);
721  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
722  strides, offset);
723  }
724 
725  if (bin.getKind() == AffineExprKind::Add) {
726  auto res1 =
727  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
728  auto res2 =
729  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
730  return success(succeeded(res1) && succeeded(res2));
731  }
732 
733  llvm_unreachable("unexpected binary operation");
734 }
735 
736 /// A stride specification is a list of integer values that are either static
737 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
738 /// the distance in the number of elements between successive entries along a
739 /// particular dimension.
740 ///
741 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
742 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
743 /// distance between two consecutive elements along the outer dimension is `1`
744 /// and the distance between two consecutive elements along the inner dimension
745 /// is `64`.
746 ///
747 /// The convention is that the strides for dimensions d0, .. dn appear in
748 /// order to make indexing intuitive into the result.
749 static LogicalResult getStridesAndOffset(MemRefType t,
750  SmallVectorImpl<AffineExpr> &strides,
751  AffineExpr &offset) {
752  AffineMap m = t.getLayout().getAffineMap();
753 
754  if (m.getNumResults() != 1 && !m.isIdentity())
755  return failure();
756 
757  auto zero = getAffineConstantExpr(0, t.getContext());
758  auto one = getAffineConstantExpr(1, t.getContext());
759  offset = zero;
760  strides.assign(t.getRank(), zero);
761 
762  // Canonical case for empty map.
763  if (m.isIdentity()) {
764  // 0-D corner case, offset is already 0.
765  if (t.getRank() == 0)
766  return success();
767  auto stridedExpr =
768  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
769  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
770  return success();
771  assert(false && "unexpected failure: extract strides in canonical layout");
772  }
773 
774  // Non-canonical case requires more work.
775  auto stridedExpr =
777  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
778  offset = AffineExpr();
779  strides.clear();
780  return failure();
781  }
782 
783  // Simplify results to allow folding to constants and simple checks.
784  unsigned numDims = m.getNumDims();
785  unsigned numSymbols = m.getNumSymbols();
786  offset = simplifyAffineExpr(offset, numDims, numSymbols);
787  for (auto &stride : strides)
788  stride = simplifyAffineExpr(stride, numDims, numSymbols);
789 
790  // In practice, a strided memref must be internally non-aliasing. Test
791  // against 0 as a proxy.
792  // TODO: static cases can have more advanced checks.
793  // TODO: dynamic cases would require a way to compare symbolic
794  // expressions and would probably need an affine set context propagated
795  // everywhere.
796  if (llvm::any_of(strides, [](AffineExpr e) {
797  return e == getAffineConstantExpr(0, e.getContext());
798  })) {
799  offset = AffineExpr();
800  strides.clear();
801  return failure();
802  }
803 
804  return success();
805 }
806 
808  SmallVectorImpl<int64_t> &strides,
809  int64_t &offset) {
810  // Happy path: the type uses the strided layout directly.
811  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
812  llvm::append_range(strides, strided.getStrides());
813  offset = strided.getOffset();
814  return success();
815  }
816 
817  // Otherwise, defer to the affine fallback as layouts are supposed to be
818  // convertible to affine maps.
819  AffineExpr offsetExpr;
820  SmallVector<AffineExpr, 4> strideExprs;
821  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
822  return failure();
823  if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
824  offset = cst.getValue();
825  else
826  offset = ShapedType::kDynamic;
827  for (auto e : strideExprs) {
828  if (auto c = dyn_cast<AffineConstantExpr>(e))
829  strides.push_back(c.getValue());
830  else
831  strides.push_back(ShapedType::kDynamic);
832  }
833  return success();
834 }
835 
836 std::pair<SmallVector<int64_t>, int64_t>
838  SmallVector<int64_t> strides;
839  int64_t offset;
840  LogicalResult status = getStridesAndOffset(t, strides, offset);
841  (void)status;
842  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
843  return {strides, offset};
844 }
845 
846 //===----------------------------------------------------------------------===//
847 /// TupleType
848 //===----------------------------------------------------------------------===//
849 
850 /// Return the elements types for this tuple.
851 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
852 
853 /// Accumulate the types contained in this tuple and tuples nested within it.
854 /// Note that this only flattens nested tuples, not any other container type,
855 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
856 /// (i32, tensor<i32>, f32, i64)
857 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
858  for (Type type : getTypes()) {
859  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
860  nestedTuple.getFlattenedTypes(types);
861  else
862  types.push_back(type);
863  }
864 }
865 
866 /// Return the number of element types.
867 size_t TupleType::size() const { return getImpl()->size(); }
868 
869 //===----------------------------------------------------------------------===//
870 // Type Utilities
871 //===----------------------------------------------------------------------===//
872 
873 /// Return a version of `t` with identity layout if it can be determined
874 /// statically that the layout is the canonical contiguous strided layout.
875 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
876 /// `t` with simplified layout.
877 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
878 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
879  AffineMap m = t.getLayout().getAffineMap();
880 
881  // Already in canonical form.
882  if (m.isIdentity())
883  return t;
884 
885  // Can't reduce to canonical identity form, return in canonical form.
886  if (m.getNumResults() > 1)
887  return t;
888 
889  // Corner-case for 0-D affine maps.
890  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
891  if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
892  if (cst.getValue() == 0)
893  return MemRefType::Builder(t).setLayout({});
894  return t;
895  }
896 
897  // 0-D corner case for empty shape that still have an affine map. Example:
898  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
899  // offset needs to remain, just return t.
900  if (t.getShape().empty())
901  return t;
902 
903  // If the canonical strided layout for the sizes of `t` is equal to the
904  // simplified layout of `t` we can just return an empty layout. Otherwise,
905  // just simplify the existing layout.
906  AffineExpr expr =
907  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
908  auto simplifiedLayoutExpr =
910  if (expr != simplifiedLayoutExpr)
912  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
913  return MemRefType::Builder(t).setLayout({});
914 }
915 
917  ArrayRef<AffineExpr> exprs,
918  MLIRContext *context) {
919  // Size 0 corner case is useful for canonicalizations.
920  if (sizes.empty())
921  return getAffineConstantExpr(0, context);
922 
923  assert(!exprs.empty() && "expected exprs");
924  auto maps = AffineMap::inferFromExprList(exprs, context);
925  assert(!maps.empty() && "Expected one non-empty map");
926  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
927 
928  AffineExpr expr;
929  bool dynamicPoisonBit = false;
930  int64_t runningSize = 1;
931  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
932  int64_t size = std::get<1>(en);
933  AffineExpr dimExpr = std::get<0>(en);
934  AffineExpr stride = dynamicPoisonBit
935  ? getAffineSymbolExpr(nSymbols++, context)
936  : getAffineConstantExpr(runningSize, context);
937  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
938  if (size > 0) {
939  runningSize *= size;
940  assert(runningSize > 0 && "integer overflow in size computation");
941  } else {
942  dynamicPoisonBit = true;
943  }
944  }
945  return simplifyAffineExpr(expr, numDims, nSymbols);
946 }
947 
949  MLIRContext *context) {
951  exprs.reserve(sizes.size());
952  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
953  exprs.push_back(getAffineDimExpr(dim, context));
954  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
955 }
956 
957 bool mlir::isStrided(MemRefType t) {
958  int64_t offset;
959  SmallVector<int64_t, 4> strides;
960  auto res = getStridesAndOffset(t, strides, offset);
961  return succeeded(res);
962 }
963 
964 bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
965  int64_t offset;
966  SmallVector<int64_t> strides;
967  auto successStrides = getStridesAndOffset(type, strides, offset);
968  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
969 }
970 
971 bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
972  if (!isLastMemrefDimUnitStride(type))
973  return false;
974 
975  auto memrefShape = type.getShape().take_back(n);
976  if (ShapedType::isDynamicShape(memrefShape))
977  return false;
978 
979  if (type.getLayout().isIdentity())
980  return true;
981 
982  int64_t offset;
983  SmallVector<int64_t> stridesFull;
984  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
985  return false;
986  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
987 
988  if (strides.empty())
989  return true;
990 
991  // Check whether strides match "flattened" dims.
992  SmallVector<int64_t> flattenedDims;
993  auto dimProduct = 1;
994  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
995  dimProduct *= dim;
996  flattenedDims.push_back(dimProduct);
997  }
998 
999  strides = strides.drop_back(1);
1000  return llvm::equal(strides, llvm::reverse(flattenedDims));
1001 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:69
MLIRContext * getContext() const
Definition: AffineExpr.cpp:29
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:322
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:386
unsigned getNumDims() const
Definition: AffineMap.cpp:382
unsigned getNumResults() const
Definition: AffineMap.cpp:390
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:399
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:333
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:300
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:400
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:92
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:450
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:446
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:222
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:217
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:212
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:227
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
Detect if any of the given parameter types has a sub-element handler.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:375
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:627
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:603
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:613
bool trailingNDimsContiguous(MemRefType type, int64_t n)
Return "true" if the last N dimensions of the given type are contiguous.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26