MLIR  17.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 /// Tablegen Type Definitions
31 //===----------------------------------------------------------------------===//
32 
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // BuiltinDialect
38 //===----------------------------------------------------------------------===//
39 
40 void BuiltinDialect::registerTypes() {
41  addTypes<
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
44  >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ComplexType
49 //===----------------------------------------------------------------------===//
50 
51 /// Verify the construction of an integer type.
53  Type elementType) {
54  if (!elementType.isIntOrFloat())
55  return emitError() << "invalid element type for complex";
56  return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // Integer Type
61 //===----------------------------------------------------------------------===//
62 
63 /// Verify the construction of an integer type.
65  unsigned width,
66  SignednessSemantics signedness) {
67  if (width > IntegerType::kMaxWidth) {
68  return emitError() << "integer bitwidth is limited to "
69  << IntegerType::kMaxWidth << " bits";
70  }
71  return success();
72 }
73 
74 unsigned IntegerType::getWidth() const { return getImpl()->width; }
75 
76 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
77  return getImpl()->signedness;
78 }
79 
80 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
81  if (!scale)
82  return IntegerType();
83  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Float Type
88 //===----------------------------------------------------------------------===//
89 
90 unsigned FloatType::getWidth() {
91  if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
92  Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
93  return 8;
94  if (isa<Float16Type, BFloat16Type>())
95  return 16;
96  if (isa<Float32Type>())
97  return 32;
98  if (isa<Float64Type>())
99  return 64;
100  if (isa<Float80Type>())
101  return 80;
102  if (isa<Float128Type>())
103  return 128;
104  llvm_unreachable("unexpected float type");
105 }
106 
107 /// Returns the floating semantics for the given type.
108 const llvm::fltSemantics &FloatType::getFloatSemantics() {
109  if (isa<Float8E5M2Type>())
110  return APFloat::Float8E5M2();
111  if (isa<Float8E4M3FNType>())
112  return APFloat::Float8E4M3FN();
113  if (isa<Float8E5M2FNUZType>())
114  return APFloat::Float8E5M2FNUZ();
115  if (isa<Float8E4M3FNUZType>())
116  return APFloat::Float8E4M3FNUZ();
117  if (isa<Float8E4M3B11FNUZType>())
118  return APFloat::Float8E4M3B11FNUZ();
119  if (isa<BFloat16Type>())
120  return APFloat::BFloat();
121  if (isa<Float16Type>())
122  return APFloat::IEEEhalf();
123  if (isa<Float32Type>())
124  return APFloat::IEEEsingle();
125  if (isa<Float64Type>())
126  return APFloat::IEEEdouble();
127  if (isa<Float80Type>())
128  return APFloat::x87DoubleExtended();
129  if (isa<Float128Type>())
130  return APFloat::IEEEquad();
131  llvm_unreachable("non-floating point type used");
132 }
133 
135  if (!scale)
136  return FloatType();
137  MLIRContext *ctx = getContext();
138  if (isF16() || isBF16()) {
139  if (scale == 2)
140  return FloatType::getF32(ctx);
141  if (scale == 4)
142  return FloatType::getF64(ctx);
143  }
144  if (isF32())
145  if (scale == 2)
146  return FloatType::getF64(ctx);
147  return FloatType();
148 }
149 
151  return APFloat::semanticsPrecision(getFloatSemantics());
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // FunctionType
156 //===----------------------------------------------------------------------===//
157 
158 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
159 
160 ArrayRef<Type> FunctionType::getInputs() const {
161  return getImpl()->getInputs();
162 }
163 
164 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
165 
166 ArrayRef<Type> FunctionType::getResults() const {
167  return getImpl()->getResults();
168 }
169 
170 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
171  return get(getContext(), inputs, results);
172 }
173 
174 /// Returns a new function type with the specified arguments and results
175 /// inserted.
176 FunctionType FunctionType::getWithArgsAndResults(
177  ArrayRef<unsigned> argIndices, TypeRange argTypes,
178  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
179  SmallVector<Type> argStorage, resultStorage;
181  getInputs(), argIndices, argTypes, argStorage);
183  getResults(), resultIndices, resultTypes, resultStorage);
184  return clone(newArgTypes, newResultTypes);
185 }
186 
187 /// Returns a new function type without the specified arguments and results.
188 FunctionType
189 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
190  const BitVector &resultIndices) {
191  SmallVector<Type> argStorage, resultStorage;
193  getInputs(), argIndices, argStorage);
195  getResults(), resultIndices, resultStorage);
196  return clone(newArgTypes, newResultTypes);
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // OpaqueType
201 //===----------------------------------------------------------------------===//
202 
203 /// Verify the construction of an opaque type.
205  StringAttr dialect, StringRef typeData) {
206  if (!Dialect::isValidNamespace(dialect.strref()))
207  return emitError() << "invalid dialect namespace '" << dialect << "'";
208 
209  // Check that the dialect is actually registered.
210  MLIRContext *context = dialect.getContext();
211  if (!context->allowsUnregisteredDialects() &&
212  !context->getLoadedDialect(dialect.strref())) {
213  return emitError()
214  << "`!" << dialect << "<\"" << typeData << "\">"
215  << "` type created with unregistered dialect. If this is "
216  "intended, please call allowUnregisteredDialects() on the "
217  "MLIRContext, or use -allow-unregistered-dialect with "
218  "the MLIR opt tool used";
219  }
220 
221  return success();
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // VectorType
226 //===----------------------------------------------------------------------===//
227 
229  ArrayRef<int64_t> shape, Type elementType,
230  unsigned numScalableDims) {
231  if (!isValidElementType(elementType))
232  return emitError()
233  << "vector elements must be int/index/float type but got "
234  << elementType;
235 
236  if (any_of(shape, [](int64_t i) { return i <= 0; }))
237  return emitError()
238  << "vector types must have positive constant sizes but got "
239  << shape;
240 
241  return success();
242 }
243 
244 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
245  if (!scale)
246  return VectorType();
247  if (auto et = getElementType().dyn_cast<IntegerType>())
248  if (auto scaledEt = et.scaleElementBitwidth(scale))
249  return VectorType::get(getShape(), scaledEt, getNumScalableDims());
250  if (auto et = getElementType().dyn_cast<FloatType>())
251  if (auto scaledEt = et.scaleElementBitwidth(scale))
252  return VectorType::get(getShape(), scaledEt, getNumScalableDims());
253  return VectorType();
254 }
255 
256 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
257  Type elementType) const {
258  return VectorType::get(shape.value_or(getShape()), elementType,
259  getNumScalableDims());
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // TensorType
264 //===----------------------------------------------------------------------===//
265 
268  .Case<RankedTensorType, UnrankedTensorType>(
269  [](auto type) { return type.getElementType(); });
270 }
271 
272 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
273 
275  return cast<RankedTensorType>().getShape();
276 }
277 
279  Type elementType) const {
280  if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
281  if (shape)
282  return RankedTensorType::get(*shape, elementType);
283  return UnrankedTensorType::get(elementType);
284  }
285 
286  auto rankedTy = cast<RankedTensorType>();
287  if (!shape)
288  return RankedTensorType::get(rankedTy.getShape(), elementType,
289  rankedTy.getEncoding());
290  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
291  rankedTy.getEncoding());
292 }
293 
294 // Check if "elementType" can be an element type of a tensor.
295 static LogicalResult
297  Type elementType) {
298  if (!TensorType::isValidElementType(elementType))
299  return emitError() << "invalid tensor element type: " << elementType;
300  return success();
301 }
302 
303 /// Return true if the specified element type is ok in a tensor.
305  // Note: Non standard/builtin types are allowed to exist within tensor
306  // types. Dialects are expected to verify that tensor types have a valid
307  // element type within that dialect.
308  return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
309  IndexType>() ||
310  !llvm::isa<BuiltinDialect>(type.getDialect());
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // RankedTensorType
315 //===----------------------------------------------------------------------===//
316 
319  ArrayRef<int64_t> shape, Type elementType,
320  Attribute encoding) {
321  for (int64_t s : shape)
322  if (s < 0 && !ShapedType::isDynamic(s))
323  return emitError() << "invalid tensor dimension size";
324  if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
325  if (failed(v.verifyEncoding(shape, elementType, emitError)))
326  return failure();
327  return checkTensorElementType(emitError, elementType);
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // UnrankedTensorType
332 //===----------------------------------------------------------------------===//
333 
336  Type elementType) {
337  return checkTensorElementType(emitError, elementType);
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // BaseMemRefType
342 //===----------------------------------------------------------------------===//
343 
346  .Case<MemRefType, UnrankedMemRefType>(
347  [](auto type) { return type.getElementType(); });
348 }
349 
350 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
351 
353  return cast<MemRefType>().getShape();
354 }
355 
357  Type elementType) const {
358  if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
359  if (!shape)
360  return UnrankedMemRefType::get(elementType, getMemorySpace());
361  MemRefType::Builder builder(*shape, elementType);
362  builder.setMemorySpace(getMemorySpace());
363  return builder;
364  }
365 
366  MemRefType::Builder builder(cast<MemRefType>());
367  if (shape)
368  builder.setShape(*shape);
369  builder.setElementType(elementType);
370  return builder;
371 }
372 
374  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
375  return rankedMemRefTy.getMemorySpace();
376  return cast<UnrankedMemRefType>().getMemorySpace();
377 }
378 
380  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
381  return rankedMemRefTy.getMemorySpaceAsInt();
382  return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // MemRefType
387 //===----------------------------------------------------------------------===//
388 
389 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
390 /// `originalShape` with some `1` entries erased, return the set of indices
391 /// that specifies which of the entries of `originalShape` are dropped to obtain
392 /// `reducedShape`. The returned mask can be applied as a projection to
393 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
394 /// which dimensions must be kept when e.g. compute MemRef strides under
395 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
396 /// obtained by dropping only `1` entries in `originalShape`.
397 std::optional<llvm::SmallDenseSet<unsigned>>
399  ArrayRef<int64_t> reducedShape) {
400  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
401  llvm::SmallDenseSet<unsigned> unusedDims;
402  unsigned reducedIdx = 0;
403  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
404  // Greedily insert `originalIdx` if match.
405  if (reducedIdx < reducedRank &&
406  originalShape[originalIdx] == reducedShape[reducedIdx]) {
407  reducedIdx++;
408  continue;
409  }
410 
411  unusedDims.insert(originalIdx);
412  // If no match on `originalIdx`, the `originalShape` at this dimension
413  // must be 1, otherwise we bail.
414  if (originalShape[originalIdx] != 1)
415  return std::nullopt;
416  }
417  // The whole reducedShape must be scanned, otherwise we bail.
418  if (reducedIdx != reducedRank)
419  return std::nullopt;
420  return unusedDims;
421 }
422 
424 mlir::isRankReducedType(ShapedType originalType,
425  ShapedType candidateReducedType) {
426  if (originalType == candidateReducedType)
428 
429  ShapedType originalShapedType = originalType.cast<ShapedType>();
430  ShapedType candidateReducedShapedType =
431  candidateReducedType.cast<ShapedType>();
432 
433  // Rank and size logic is valid for all ShapedTypes.
434  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
435  ArrayRef<int64_t> candidateReducedShape =
436  candidateReducedShapedType.getShape();
437  unsigned originalRank = originalShape.size(),
438  candidateReducedRank = candidateReducedShape.size();
439  if (candidateReducedRank > originalRank)
441 
442  auto optionalUnusedDimsMask =
443  computeRankReductionMask(originalShape, candidateReducedShape);
444 
445  // Sizes cannot be matched in case empty vector is returned.
446  if (!optionalUnusedDimsMask)
448 
449  if (originalShapedType.getElementType() !=
450  candidateReducedShapedType.getElementType())
452 
454 }
455 
457  // Empty attribute is allowed as default memory space.
458  if (!memorySpace)
459  return true;
460 
461  // Supported built-in attributes.
462  if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
463  return true;
464 
465  // Allow custom dialect attributes.
466  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
467  return true;
468 
469  return false;
470 }
471 
473  MLIRContext *ctx) {
474  if (memorySpace == 0)
475  return nullptr;
476 
477  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
478 }
479 
481  IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
482  if (intMemorySpace && intMemorySpace.getValue() == 0)
483  return nullptr;
484 
485  return memorySpace;
486 }
487 
489  if (!memorySpace)
490  return 0;
491 
492  assert(memorySpace.isa<IntegerAttr>() &&
493  "Using `getMemorySpaceInteger` with non-Integer attribute");
494 
495  return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
496 }
497 
498 unsigned MemRefType::getMemorySpaceAsInt() const {
499  return detail::getMemorySpaceAsInt(getMemorySpace());
500 }
501 
502 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
503  MemRefLayoutAttrInterface layout,
504  Attribute memorySpace) {
505  // Use default layout for empty attribute.
506  if (!layout)
507  layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
508  shape.size(), elementType.getContext()));
509 
510  // Drop default memory space value and replace it with empty attribute.
511  memorySpace = skipDefaultMemorySpace(memorySpace);
512 
513  return Base::get(elementType.getContext(), shape, elementType, layout,
514  memorySpace);
515 }
516 
517 MemRefType MemRefType::getChecked(
518  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
519  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
520 
521  // Use default layout for empty attribute.
522  if (!layout)
523  layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
524  shape.size(), elementType.getContext()));
525 
526  // Drop default memory space value and replace it with empty attribute.
527  memorySpace = skipDefaultMemorySpace(memorySpace);
528 
529  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
530  elementType, layout, memorySpace);
531 }
532 
533 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
534  AffineMap map, Attribute memorySpace) {
535 
536  // Use default layout for empty map.
537  if (!map)
538  map = AffineMap::getMultiDimIdentityMap(shape.size(),
539  elementType.getContext());
540 
541  // Wrap AffineMap into Attribute.
542  Attribute layout = AffineMapAttr::get(map);
543 
544  // Drop default memory space value and replace it with empty attribute.
545  memorySpace = skipDefaultMemorySpace(memorySpace);
546 
547  return Base::get(elementType.getContext(), shape, elementType, layout,
548  memorySpace);
549 }
550 
551 MemRefType
552 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
553  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
554  Attribute memorySpace) {
555 
556  // Use default layout for empty map.
557  if (!map)
558  map = AffineMap::getMultiDimIdentityMap(shape.size(),
559  elementType.getContext());
560 
561  // Wrap AffineMap into Attribute.
562  Attribute layout = AffineMapAttr::get(map);
563 
564  // Drop default memory space value and replace it with empty attribute.
565  memorySpace = skipDefaultMemorySpace(memorySpace);
566 
567  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
568  elementType, layout, memorySpace);
569 }
570 
571 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
572  AffineMap map, unsigned memorySpaceInd) {
573 
574  // Use default layout for empty map.
575  if (!map)
576  map = AffineMap::getMultiDimIdentityMap(shape.size(),
577  elementType.getContext());
578 
579  // Wrap AffineMap into Attribute.
580  Attribute layout = AffineMapAttr::get(map);
581 
582  // Convert deprecated integer-like memory space to Attribute.
583  Attribute memorySpace =
584  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
585 
586  return Base::get(elementType.getContext(), shape, elementType, layout,
587  memorySpace);
588 }
589 
590 MemRefType
591 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
592  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
593  unsigned memorySpaceInd) {
594 
595  // Use default layout for empty map.
596  if (!map)
597  map = AffineMap::getMultiDimIdentityMap(shape.size(),
598  elementType.getContext());
599 
600  // Wrap AffineMap into Attribute.
601  Attribute layout = AffineMapAttr::get(map);
602 
603  // Convert deprecated integer-like memory space to Attribute.
604  Attribute memorySpace =
605  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
606 
607  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
608  elementType, layout, memorySpace);
609 }
610 
612  ArrayRef<int64_t> shape, Type elementType,
613  MemRefLayoutAttrInterface layout,
614  Attribute memorySpace) {
615  if (!BaseMemRefType::isValidElementType(elementType))
616  return emitError() << "invalid memref element type";
617 
618  // Negative sizes are not allowed except for `kDynamic`.
619  for (int64_t s : shape)
620  if (s < 0 && !ShapedType::isDynamic(s))
621  return emitError() << "invalid memref size";
622 
623  assert(layout && "missing layout specification");
624  if (failed(layout.verifyLayout(shape, emitError)))
625  return failure();
626 
627  if (!isSupportedMemorySpace(memorySpace))
628  return emitError() << "unsupported memory space Attribute";
629 
630  return success();
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // UnrankedMemRefType
635 //===----------------------------------------------------------------------===//
636 
638  return detail::getMemorySpaceAsInt(getMemorySpace());
639 }
640 
643  Type elementType, Attribute memorySpace) {
644  if (!BaseMemRefType::isValidElementType(elementType))
645  return emitError() << "invalid memref element type";
646 
647  if (!isSupportedMemorySpace(memorySpace))
648  return emitError() << "unsupported memory space Attribute";
649 
650  return success();
651 }
652 
653 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
654 // i.e. single term). Accumulate the AffineExpr into the existing one.
656  AffineExpr multiplicativeFactor,
658  AffineExpr &offset) {
659  if (auto dim = e.dyn_cast<AffineDimExpr>())
660  strides[dim.getPosition()] =
661  strides[dim.getPosition()] + multiplicativeFactor;
662  else
663  offset = offset + e * multiplicativeFactor;
664 }
665 
666 /// Takes a single AffineExpr `e` and populates the `strides` array with the
667 /// strides expressions for each dim position.
668 /// The convention is that the strides for dimensions d0, .. dn appear in
669 /// order to make indexing intuitive into the result.
671  AffineExpr multiplicativeFactor,
673  AffineExpr &offset) {
674  auto bin = e.dyn_cast<AffineBinaryOpExpr>();
675  if (!bin) {
676  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
677  return success();
678  }
679 
680  if (bin.getKind() == AffineExprKind::CeilDiv ||
681  bin.getKind() == AffineExprKind::FloorDiv ||
682  bin.getKind() == AffineExprKind::Mod)
683  return failure();
684 
685  if (bin.getKind() == AffineExprKind::Mul) {
686  auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
687  if (dim) {
688  strides[dim.getPosition()] =
689  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
690  return success();
691  }
692  // LHS and RHS may both contain complex expressions of dims. Try one path
693  // and if it fails try the other. This is guaranteed to succeed because
694  // only one path may have a `dim`, otherwise this is not an AffineExpr in
695  // the first place.
696  if (bin.getLHS().isSymbolicOrConstant())
697  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
698  strides, offset);
699  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
700  strides, offset);
701  }
702 
703  if (bin.getKind() == AffineExprKind::Add) {
704  auto res1 =
705  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
706  auto res2 =
707  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
708  return success(succeeded(res1) && succeeded(res2));
709  }
710 
711  llvm_unreachable("unexpected binary operation");
712 }
713 
714 /// A stride specification is a list of integer values that are either static
715 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
716 /// the distance in the number of elements between successive entries along a
717 /// particular dimension.
718 ///
719 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
720 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
721 /// distance between two consecutive elements along the outer dimension is `1`
722 /// and the distance between two consecutive elements along the inner dimension
723 /// is `64`.
724 ///
725 /// The convention is that the strides for dimensions d0, .. dn appear in
726 /// order to make indexing intuitive into the result.
727 static LogicalResult getStridesAndOffset(MemRefType t,
728  SmallVectorImpl<AffineExpr> &strides,
729  AffineExpr &offset) {
730  AffineMap m = t.getLayout().getAffineMap();
731 
732  if (m.getNumResults() != 1 && !m.isIdentity())
733  return failure();
734 
735  auto zero = getAffineConstantExpr(0, t.getContext());
736  auto one = getAffineConstantExpr(1, t.getContext());
737  offset = zero;
738  strides.assign(t.getRank(), zero);
739 
740  // Canonical case for empty map.
741  if (m.isIdentity()) {
742  // 0-D corner case, offset is already 0.
743  if (t.getRank() == 0)
744  return success();
745  auto stridedExpr =
746  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
747  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
748  return success();
749  assert(false && "unexpected failure: extract strides in canonical layout");
750  }
751 
752  // Non-canonical case requires more work.
753  auto stridedExpr =
755  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
756  offset = AffineExpr();
757  strides.clear();
758  return failure();
759  }
760 
761  // Simplify results to allow folding to constants and simple checks.
762  unsigned numDims = m.getNumDims();
763  unsigned numSymbols = m.getNumSymbols();
764  offset = simplifyAffineExpr(offset, numDims, numSymbols);
765  for (auto &stride : strides)
766  stride = simplifyAffineExpr(stride, numDims, numSymbols);
767 
768  // In practice, a strided memref must be internally non-aliasing. Test
769  // against 0 as a proxy.
770  // TODO: static cases can have more advanced checks.
771  // TODO: dynamic cases would require a way to compare symbolic
772  // expressions and would probably need an affine set context propagated
773  // everywhere.
774  if (llvm::any_of(strides, [](AffineExpr e) {
775  return e == getAffineConstantExpr(0, e.getContext());
776  })) {
777  offset = AffineExpr();
778  strides.clear();
779  return failure();
780  }
781 
782  return success();
783 }
784 
786  SmallVectorImpl<int64_t> &strides,
787  int64_t &offset) {
788  // Happy path: the type uses the strided layout directly.
789  if (auto strided = t.getLayout().dyn_cast<StridedLayoutAttr>()) {
790  llvm::append_range(strides, strided.getStrides());
791  offset = strided.getOffset();
792  return success();
793  }
794 
795  // Otherwise, defer to the affine fallback as layouts are supposed to be
796  // convertible to affine maps.
797  AffineExpr offsetExpr;
798  SmallVector<AffineExpr, 4> strideExprs;
799  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
800  return failure();
801  if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
802  offset = cst.getValue();
803  else
804  offset = ShapedType::kDynamic;
805  for (auto e : strideExprs) {
806  if (auto c = e.dyn_cast<AffineConstantExpr>())
807  strides.push_back(c.getValue());
808  else
809  strides.push_back(ShapedType::kDynamic);
810  }
811  return success();
812 }
813 
814 std::pair<SmallVector<int64_t>, int64_t>
816  SmallVector<int64_t> strides;
817  int64_t offset;
818  LogicalResult status = getStridesAndOffset(t, strides, offset);
819  (void)status;
820  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
821  return {strides, offset};
822 }
823 
824 //===----------------------------------------------------------------------===//
825 /// TupleType
826 //===----------------------------------------------------------------------===//
827 
828 /// Return the elements types for this tuple.
829 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
830 
831 /// Accumulate the types contained in this tuple and tuples nested within it.
832 /// Note that this only flattens nested tuples, not any other container type,
833 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
834 /// (i32, tensor<i32>, f32, i64)
835 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
836  for (Type type : getTypes()) {
837  if (auto nestedTuple = type.dyn_cast<TupleType>())
838  nestedTuple.getFlattenedTypes(types);
839  else
840  types.push_back(type);
841  }
842 }
843 
844 /// Return the number of element types.
845 size_t TupleType::size() const { return getImpl()->size(); }
846 
847 //===----------------------------------------------------------------------===//
848 // Type Utilities
849 //===----------------------------------------------------------------------===//
850 
851 /// Return a version of `t` with identity layout if it can be determined
852 /// statically that the layout is the canonical contiguous strided layout.
853 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
854 /// `t` with simplified layout.
855 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
856 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
857  AffineMap m = t.getLayout().getAffineMap();
858 
859  // Already in canonical form.
860  if (m.isIdentity())
861  return t;
862 
863  // Can't reduce to canonical identity form, return in canonical form.
864  if (m.getNumResults() > 1)
865  return t;
866 
867  // Corner-case for 0-D affine maps.
868  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
869  if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
870  if (cst.getValue() == 0)
871  return MemRefType::Builder(t).setLayout({});
872  return t;
873  }
874 
875  // 0-D corner case for empty shape that still have an affine map. Example:
876  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
877  // offset needs to remain, just return t.
878  if (t.getShape().empty())
879  return t;
880 
881  // If the canonical strided layout for the sizes of `t` is equal to the
882  // simplified layout of `t` we can just return an empty layout. Otherwise,
883  // just simplify the existing layout.
884  AffineExpr expr =
885  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
886  auto simplifiedLayoutExpr =
888  if (expr != simplifiedLayoutExpr)
889  return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
890  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
891  return MemRefType::Builder(t).setLayout({});
892 }
893 
895  ArrayRef<AffineExpr> exprs,
896  MLIRContext *context) {
897  // Size 0 corner case is useful for canonicalizations.
898  if (sizes.empty())
899  return getAffineConstantExpr(0, context);
900 
901  assert(!exprs.empty() && "expected exprs");
902  auto maps = AffineMap::inferFromExprList(exprs);
903  assert(!maps.empty() && "Expected one non-empty map");
904  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
905 
906  AffineExpr expr;
907  bool dynamicPoisonBit = false;
908  int64_t runningSize = 1;
909  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
910  int64_t size = std::get<1>(en);
911  AffineExpr dimExpr = std::get<0>(en);
912  AffineExpr stride = dynamicPoisonBit
913  ? getAffineSymbolExpr(nSymbols++, context)
914  : getAffineConstantExpr(runningSize, context);
915  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
916  if (size > 0) {
917  runningSize *= size;
918  assert(runningSize > 0 && "integer overflow in size computation");
919  } else {
920  dynamicPoisonBit = true;
921  }
922  }
923  return simplifyAffineExpr(expr, numDims, nSymbols);
924 }
925 
927  MLIRContext *context) {
929  exprs.reserve(sizes.size());
930  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
931  exprs.push_back(getAffineDimExpr(dim, context));
932  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
933 }
934 
935 /// Return true if the layout for `t` is compatible with strided semantics.
936 bool mlir::isStrided(MemRefType t) {
937  int64_t offset;
938  SmallVector<int64_t, 4> strides;
939  auto res = getStridesAndOffset(t, strides, offset);
940  return succeeded(res);
941 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Affine binary operation expression.
Definition: AffineExpr.h:207
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
U dyn_cast() const
Definition: AffineExpr.h:281
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:262
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:328
unsigned getNumDims() const
Definition: AffineMap.cpp:324
unsigned getNumResults() const
Definition: AffineMap.cpp:332
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:242
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:341
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:273
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast_or_null() const
Definition: Attributes.h:171
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:71
U cast() const
Definition: Attributes.h:176
bool isa() const
Casting utility functions.
Definition: Attributes.h:156
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:116
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:373
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:92
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:418
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:414
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:168
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:189
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:184
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:179
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:194
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:80
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:108
bool isa() const
Definition: Types.h:301
Detect if any of the given parameter types has a sub-element handler.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:348
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:374
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:512
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26