MLIR  16.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 /// Tablegen Type Definitions
31 //===----------------------------------------------------------------------===//
32 
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // BuiltinDialect
38 //===----------------------------------------------------------------------===//
39 
40 void BuiltinDialect::registerTypes() {
41  addTypes<
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
44  >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ComplexType
49 //===----------------------------------------------------------------------===//
50 
51 /// Verify the construction of an integer type.
53  Type elementType) {
54  if (!elementType.isIntOrFloat())
55  return emitError() << "invalid element type for complex";
56  return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // Integer Type
61 //===----------------------------------------------------------------------===//
62 
63 /// Verify the construction of an integer type.
65  unsigned width,
66  SignednessSemantics signedness) {
67  if (width > IntegerType::kMaxWidth) {
68  return emitError() << "integer bitwidth is limited to "
69  << IntegerType::kMaxWidth << " bits";
70  }
71  return success();
72 }
73 
74 unsigned IntegerType::getWidth() const { return getImpl()->width; }
75 
76 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
77  return getImpl()->signedness;
78 }
79 
80 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
81  if (!scale)
82  return IntegerType();
83  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Float Type
88 //===----------------------------------------------------------------------===//
89 
90 unsigned FloatType::getWidth() {
91  if (isa<Float8E5M2Type, Float8E4M3FNType>())
92  return 8;
93  if (isa<Float16Type, BFloat16Type>())
94  return 16;
95  if (isa<Float32Type>())
96  return 32;
97  if (isa<Float64Type>())
98  return 64;
99  if (isa<Float80Type>())
100  return 80;
101  if (isa<Float128Type>())
102  return 128;
103  llvm_unreachable("unexpected float type");
104 }
105 
106 /// Returns the floating semantics for the given type.
107 const llvm::fltSemantics &FloatType::getFloatSemantics() {
108  if (isa<Float8E5M2Type>())
109  return APFloat::Float8E5M2();
110  if (isa<Float8E4M3FNType>())
111  return APFloat::Float8E4M3FN();
112  if (isa<BFloat16Type>())
113  return APFloat::BFloat();
114  if (isa<Float16Type>())
115  return APFloat::IEEEhalf();
116  if (isa<Float32Type>())
117  return APFloat::IEEEsingle();
118  if (isa<Float64Type>())
119  return APFloat::IEEEdouble();
120  if (isa<Float80Type>())
121  return APFloat::x87DoubleExtended();
122  if (isa<Float128Type>())
123  return APFloat::IEEEquad();
124  llvm_unreachable("non-floating point type used");
125 }
126 
128  if (!scale)
129  return FloatType();
130  MLIRContext *ctx = getContext();
131  if (isF16() || isBF16()) {
132  if (scale == 2)
133  return FloatType::getF32(ctx);
134  if (scale == 4)
135  return FloatType::getF64(ctx);
136  }
137  if (isF32())
138  if (scale == 2)
139  return FloatType::getF64(ctx);
140  return FloatType();
141 }
142 
144  return APFloat::semanticsPrecision(getFloatSemantics());
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // FunctionType
149 //===----------------------------------------------------------------------===//
150 
151 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
152 
153 ArrayRef<Type> FunctionType::getInputs() const {
154  return getImpl()->getInputs();
155 }
156 
157 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
158 
159 ArrayRef<Type> FunctionType::getResults() const {
160  return getImpl()->getResults();
161 }
162 
163 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
164  return get(getContext(), inputs, results);
165 }
166 
167 /// Returns a new function type with the specified arguments and results
168 /// inserted.
169 FunctionType FunctionType::getWithArgsAndResults(
170  ArrayRef<unsigned> argIndices, TypeRange argTypes,
171  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
172  SmallVector<Type> argStorage, resultStorage;
174  getInputs(), argIndices, argTypes, argStorage);
176  getResults(), resultIndices, resultTypes, resultStorage);
177  return clone(newArgTypes, newResultTypes);
178 }
179 
180 /// Returns a new function type without the specified arguments and results.
181 FunctionType
182 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
183  const BitVector &resultIndices) {
184  SmallVector<Type> argStorage, resultStorage;
186  getInputs(), argIndices, argStorage);
188  getResults(), resultIndices, resultStorage);
189  return clone(newArgTypes, newResultTypes);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // OpaqueType
194 //===----------------------------------------------------------------------===//
195 
196 /// Verify the construction of an opaque type.
198  StringAttr dialect, StringRef typeData) {
199  if (!Dialect::isValidNamespace(dialect.strref()))
200  return emitError() << "invalid dialect namespace '" << dialect << "'";
201 
202  // Check that the dialect is actually registered.
203  MLIRContext *context = dialect.getContext();
204  if (!context->allowsUnregisteredDialects() &&
205  !context->getLoadedDialect(dialect.strref())) {
206  return emitError()
207  << "`!" << dialect << "<\"" << typeData << "\">"
208  << "` type created with unregistered dialect. If this is "
209  "intended, please call allowUnregisteredDialects() on the "
210  "MLIRContext, or use -allow-unregistered-dialect with "
211  "the MLIR opt tool used";
212  }
213 
214  return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // VectorType
219 //===----------------------------------------------------------------------===//
220 
222  ArrayRef<int64_t> shape, Type elementType,
223  unsigned numScalableDims) {
224  if (!isValidElementType(elementType))
225  return emitError()
226  << "vector elements must be int/index/float type but got "
227  << elementType;
228 
229  if (any_of(shape, [](int64_t i) { return i <= 0; }))
230  return emitError()
231  << "vector types must have positive constant sizes but got "
232  << shape;
233 
234  return success();
235 }
236 
237 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
238  if (!scale)
239  return VectorType();
240  if (auto et = getElementType().dyn_cast<IntegerType>())
241  if (auto scaledEt = et.scaleElementBitwidth(scale))
242  return VectorType::get(getShape(), scaledEt, getNumScalableDims());
243  if (auto et = getElementType().dyn_cast<FloatType>())
244  if (auto scaledEt = et.scaleElementBitwidth(scale))
245  return VectorType::get(getShape(), scaledEt, getNumScalableDims());
246  return VectorType();
247 }
248 
249 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
250  Type elementType) const {
251  return VectorType::get(shape.value_or(getShape()), elementType,
252  getNumScalableDims());
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // TensorType
257 //===----------------------------------------------------------------------===//
258 
261  .Case<RankedTensorType, UnrankedTensorType>(
262  [](auto type) { return type.getElementType(); });
263 }
264 
265 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
266 
268  return cast<RankedTensorType>().getShape();
269 }
270 
272  Type elementType) const {
273  if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
274  if (shape)
275  return RankedTensorType::get(*shape, elementType);
276  return UnrankedTensorType::get(elementType);
277  }
278 
279  auto rankedTy = cast<RankedTensorType>();
280  if (!shape)
281  return RankedTensorType::get(rankedTy.getShape(), elementType,
282  rankedTy.getEncoding());
283  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
284  rankedTy.getEncoding());
285 }
286 
287 // Check if "elementType" can be an element type of a tensor.
288 static LogicalResult
290  Type elementType) {
291  if (!TensorType::isValidElementType(elementType))
292  return emitError() << "invalid tensor element type: " << elementType;
293  return success();
294 }
295 
296 /// Return true if the specified element type is ok in a tensor.
298  // Note: Non standard/builtin types are allowed to exist within tensor
299  // types. Dialects are expected to verify that tensor types have a valid
300  // element type within that dialect.
301  return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
302  IndexType>() ||
303  !llvm::isa<BuiltinDialect>(type.getDialect());
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // RankedTensorType
308 //===----------------------------------------------------------------------===//
309 
312  ArrayRef<int64_t> shape, Type elementType,
313  Attribute encoding) {
314  for (int64_t s : shape)
315  if (s < 0 && !ShapedType::isDynamic(s))
316  return emitError() << "invalid tensor dimension size";
317  if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
318  if (failed(v.verifyEncoding(shape, elementType, emitError)))
319  return failure();
320  return checkTensorElementType(emitError, elementType);
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // UnrankedTensorType
325 //===----------------------------------------------------------------------===//
326 
329  Type elementType) {
330  return checkTensorElementType(emitError, elementType);
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // BaseMemRefType
335 //===----------------------------------------------------------------------===//
336 
339  .Case<MemRefType, UnrankedMemRefType>(
340  [](auto type) { return type.getElementType(); });
341 }
342 
343 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
344 
346  return cast<MemRefType>().getShape();
347 }
348 
350  Type elementType) const {
351  if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
352  if (!shape)
353  return UnrankedMemRefType::get(elementType, getMemorySpace());
354  MemRefType::Builder builder(*shape, elementType);
355  builder.setMemorySpace(getMemorySpace());
356  return builder;
357  }
358 
359  MemRefType::Builder builder(cast<MemRefType>());
360  if (shape)
361  builder.setShape(*shape);
362  builder.setElementType(elementType);
363  return builder;
364 }
365 
367  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
368  return rankedMemRefTy.getMemorySpace();
369  return cast<UnrankedMemRefType>().getMemorySpace();
370 }
371 
373  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
374  return rankedMemRefTy.getMemorySpaceAsInt();
375  return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // MemRefType
380 //===----------------------------------------------------------------------===//
381 
382 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
383 /// `originalShape` with some `1` entries erased, return the set of indices
384 /// that specifies which of the entries of `originalShape` are dropped to obtain
385 /// `reducedShape`. The returned mask can be applied as a projection to
386 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
387 /// which dimensions must be kept when e.g. compute MemRef strides under
388 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
389 /// obtained by dropping only `1` entries in `originalShape`.
392  ArrayRef<int64_t> reducedShape) {
393  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
394  llvm::SmallDenseSet<unsigned> unusedDims;
395  unsigned reducedIdx = 0;
396  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
397  // Greedily insert `originalIdx` if match.
398  if (reducedIdx < reducedRank &&
399  originalShape[originalIdx] == reducedShape[reducedIdx]) {
400  reducedIdx++;
401  continue;
402  }
403 
404  unusedDims.insert(originalIdx);
405  // If no match on `originalIdx`, the `originalShape` at this dimension
406  // must be 1, otherwise we bail.
407  if (originalShape[originalIdx] != 1)
408  return std::nullopt;
409  }
410  // The whole reducedShape must be scanned, otherwise we bail.
411  if (reducedIdx != reducedRank)
412  return std::nullopt;
413  return unusedDims;
414 }
415 
417 mlir::isRankReducedType(ShapedType originalType,
418  ShapedType candidateReducedType) {
419  if (originalType == candidateReducedType)
421 
422  ShapedType originalShapedType = originalType.cast<ShapedType>();
423  ShapedType candidateReducedShapedType =
424  candidateReducedType.cast<ShapedType>();
425 
426  // Rank and size logic is valid for all ShapedTypes.
427  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
428  ArrayRef<int64_t> candidateReducedShape =
429  candidateReducedShapedType.getShape();
430  unsigned originalRank = originalShape.size(),
431  candidateReducedRank = candidateReducedShape.size();
432  if (candidateReducedRank > originalRank)
434 
435  auto optionalUnusedDimsMask =
436  computeRankReductionMask(originalShape, candidateReducedShape);
437 
438  // Sizes cannot be matched in case empty vector is returned.
439  if (!optionalUnusedDimsMask)
441 
442  if (originalShapedType.getElementType() !=
443  candidateReducedShapedType.getElementType())
445 
447 }
448 
450  // Empty attribute is allowed as default memory space.
451  if (!memorySpace)
452  return true;
453 
454  // Supported built-in attributes.
455  if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
456  return true;
457 
458  // Allow custom dialect attributes.
459  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
460  return true;
461 
462  return false;
463 }
464 
466  MLIRContext *ctx) {
467  if (memorySpace == 0)
468  return nullptr;
469 
470  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
471 }
472 
474  IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
475  if (intMemorySpace && intMemorySpace.getValue() == 0)
476  return nullptr;
477 
478  return memorySpace;
479 }
480 
482  if (!memorySpace)
483  return 0;
484 
485  assert(memorySpace.isa<IntegerAttr>() &&
486  "Using `getMemorySpaceInteger` with non-Integer attribute");
487 
488  return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
489 }
490 
491 unsigned MemRefType::getMemorySpaceAsInt() const {
492  return detail::getMemorySpaceAsInt(getMemorySpace());
493 }
494 
495 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
496  MemRefLayoutAttrInterface layout,
497  Attribute memorySpace) {
498  // Use default layout for empty attribute.
499  if (!layout)
500  layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
501  shape.size(), elementType.getContext()));
502 
503  // Drop default memory space value and replace it with empty attribute.
504  memorySpace = skipDefaultMemorySpace(memorySpace);
505 
506  return Base::get(elementType.getContext(), shape, elementType, layout,
507  memorySpace);
508 }
509 
510 MemRefType MemRefType::getChecked(
511  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
512  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
513 
514  // Use default layout for empty attribute.
515  if (!layout)
516  layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
517  shape.size(), elementType.getContext()));
518 
519  // Drop default memory space value and replace it with empty attribute.
520  memorySpace = skipDefaultMemorySpace(memorySpace);
521 
522  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
523  elementType, layout, memorySpace);
524 }
525 
526 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
527  AffineMap map, Attribute memorySpace) {
528 
529  // Use default layout for empty map.
530  if (!map)
531  map = AffineMap::getMultiDimIdentityMap(shape.size(),
532  elementType.getContext());
533 
534  // Wrap AffineMap into Attribute.
535  Attribute layout = AffineMapAttr::get(map);
536 
537  // Drop default memory space value and replace it with empty attribute.
538  memorySpace = skipDefaultMemorySpace(memorySpace);
539 
540  return Base::get(elementType.getContext(), shape, elementType, layout,
541  memorySpace);
542 }
543 
544 MemRefType
545 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
546  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
547  Attribute memorySpace) {
548 
549  // Use default layout for empty map.
550  if (!map)
551  map = AffineMap::getMultiDimIdentityMap(shape.size(),
552  elementType.getContext());
553 
554  // Wrap AffineMap into Attribute.
555  Attribute layout = AffineMapAttr::get(map);
556 
557  // Drop default memory space value and replace it with empty attribute.
558  memorySpace = skipDefaultMemorySpace(memorySpace);
559 
560  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
561  elementType, layout, memorySpace);
562 }
563 
564 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
565  AffineMap map, unsigned memorySpaceInd) {
566 
567  // Use default layout for empty map.
568  if (!map)
569  map = AffineMap::getMultiDimIdentityMap(shape.size(),
570  elementType.getContext());
571 
572  // Wrap AffineMap into Attribute.
573  Attribute layout = AffineMapAttr::get(map);
574 
575  // Convert deprecated integer-like memory space to Attribute.
576  Attribute memorySpace =
577  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
578 
579  return Base::get(elementType.getContext(), shape, elementType, layout,
580  memorySpace);
581 }
582 
583 MemRefType
584 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
585  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
586  unsigned memorySpaceInd) {
587 
588  // Use default layout for empty map.
589  if (!map)
590  map = AffineMap::getMultiDimIdentityMap(shape.size(),
591  elementType.getContext());
592 
593  // Wrap AffineMap into Attribute.
594  Attribute layout = AffineMapAttr::get(map);
595 
596  // Convert deprecated integer-like memory space to Attribute.
597  Attribute memorySpace =
598  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
599 
600  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
601  elementType, layout, memorySpace);
602 }
603 
605  ArrayRef<int64_t> shape, Type elementType,
606  MemRefLayoutAttrInterface layout,
607  Attribute memorySpace) {
608  if (!BaseMemRefType::isValidElementType(elementType))
609  return emitError() << "invalid memref element type";
610 
611  // Negative sizes are not allowed except for `kDynamic`.
612  for (int64_t s : shape)
613  if (s < 0 && !ShapedType::isDynamic(s))
614  return emitError() << "invalid memref size";
615 
616  assert(layout && "missing layout specification");
617  if (failed(layout.verifyLayout(shape, emitError)))
618  return failure();
619 
620  if (!isSupportedMemorySpace(memorySpace))
621  return emitError() << "unsupported memory space Attribute";
622 
623  return success();
624 }
625 
626 //===----------------------------------------------------------------------===//
627 // UnrankedMemRefType
628 //===----------------------------------------------------------------------===//
629 
631  return detail::getMemorySpaceAsInt(getMemorySpace());
632 }
633 
636  Type elementType, Attribute memorySpace) {
637  if (!BaseMemRefType::isValidElementType(elementType))
638  return emitError() << "invalid memref element type";
639 
640  if (!isSupportedMemorySpace(memorySpace))
641  return emitError() << "unsupported memory space Attribute";
642 
643  return success();
644 }
645 
646 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
647 // i.e. single term). Accumulate the AffineExpr into the existing one.
649  AffineExpr multiplicativeFactor,
651  AffineExpr &offset) {
652  if (auto dim = e.dyn_cast<AffineDimExpr>())
653  strides[dim.getPosition()] =
654  strides[dim.getPosition()] + multiplicativeFactor;
655  else
656  offset = offset + e * multiplicativeFactor;
657 }
658 
659 /// Takes a single AffineExpr `e` and populates the `strides` array with the
660 /// strides expressions for each dim position.
661 /// The convention is that the strides for dimensions d0, .. dn appear in
662 /// order to make indexing intuitive into the result.
664  AffineExpr multiplicativeFactor,
666  AffineExpr &offset) {
667  auto bin = e.dyn_cast<AffineBinaryOpExpr>();
668  if (!bin) {
669  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
670  return success();
671  }
672 
673  if (bin.getKind() == AffineExprKind::CeilDiv ||
674  bin.getKind() == AffineExprKind::FloorDiv ||
675  bin.getKind() == AffineExprKind::Mod)
676  return failure();
677 
678  if (bin.getKind() == AffineExprKind::Mul) {
679  auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
680  if (dim) {
681  strides[dim.getPosition()] =
682  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
683  return success();
684  }
685  // LHS and RHS may both contain complex expressions of dims. Try one path
686  // and if it fails try the other. This is guaranteed to succeed because
687  // only one path may have a `dim`, otherwise this is not an AffineExpr in
688  // the first place.
689  if (bin.getLHS().isSymbolicOrConstant())
690  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
691  strides, offset);
692  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
693  strides, offset);
694  }
695 
696  if (bin.getKind() == AffineExprKind::Add) {
697  auto res1 =
698  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
699  auto res2 =
700  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
701  return success(succeeded(res1) && succeeded(res2));
702  }
703 
704  llvm_unreachable("unexpected binary operation");
705 }
706 
707 /// A stride specification is a list of integer values that are either static
708 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
709 /// the distance in the number of elements between successive entries along a
710 /// particular dimension.
711 ///
712 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
713 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
714 /// distance between two consecutive elements along the outer dimension is `1`
715 /// and the distance between two consecutive elements along the inner dimension
716 /// is `64`.
717 ///
718 /// The convention is that the strides for dimensions d0, .. dn appear in
719 /// order to make indexing intuitive into the result.
720 static LogicalResult getStridesAndOffset(MemRefType t,
721  SmallVectorImpl<AffineExpr> &strides,
722  AffineExpr &offset) {
723  AffineMap m = t.getLayout().getAffineMap();
724 
725  if (m.getNumResults() != 1 && !m.isIdentity())
726  return failure();
727 
728  auto zero = getAffineConstantExpr(0, t.getContext());
729  auto one = getAffineConstantExpr(1, t.getContext());
730  offset = zero;
731  strides.assign(t.getRank(), zero);
732 
733  // Canonical case for empty map.
734  if (m.isIdentity()) {
735  // 0-D corner case, offset is already 0.
736  if (t.getRank() == 0)
737  return success();
738  auto stridedExpr =
739  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
740  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
741  return success();
742  assert(false && "unexpected failure: extract strides in canonical layout");
743  }
744 
745  // Non-canonical case requires more work.
746  auto stridedExpr =
748  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
749  offset = AffineExpr();
750  strides.clear();
751  return failure();
752  }
753 
754  // Simplify results to allow folding to constants and simple checks.
755  unsigned numDims = m.getNumDims();
756  unsigned numSymbols = m.getNumSymbols();
757  offset = simplifyAffineExpr(offset, numDims, numSymbols);
758  for (auto &stride : strides)
759  stride = simplifyAffineExpr(stride, numDims, numSymbols);
760 
761  // In practice, a strided memref must be internally non-aliasing. Test
762  // against 0 as a proxy.
763  // TODO: static cases can have more advanced checks.
764  // TODO: dynamic cases would require a way to compare symbolic
765  // expressions and would probably need an affine set context propagated
766  // everywhere.
767  if (llvm::any_of(strides, [](AffineExpr e) {
768  return e == getAffineConstantExpr(0, e.getContext());
769  })) {
770  offset = AffineExpr();
771  strides.clear();
772  return failure();
773  }
774 
775  return success();
776 }
777 
779  SmallVectorImpl<int64_t> &strides,
780  int64_t &offset) {
781  // Happy path: the type uses the strided layout directly.
782  if (auto strided = t.getLayout().dyn_cast<StridedLayoutAttr>()) {
783  llvm::append_range(strides, strided.getStrides());
784  offset = strided.getOffset();
785  return success();
786  }
787 
788  // Otherwise, defer to the affine fallback as layouts are supposed to be
789  // convertible to affine maps.
790  AffineExpr offsetExpr;
791  SmallVector<AffineExpr, 4> strideExprs;
792  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
793  return failure();
794  if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
795  offset = cst.getValue();
796  else
797  offset = ShapedType::kDynamic;
798  for (auto e : strideExprs) {
799  if (auto c = e.dyn_cast<AffineConstantExpr>())
800  strides.push_back(c.getValue());
801  else
802  strides.push_back(ShapedType::kDynamic);
803  }
804  return success();
805 }
806 
807 std::pair<SmallVector<int64_t>, int64_t>
809  SmallVector<int64_t> strides;
810  int64_t offset;
811  LogicalResult status = getStridesAndOffset(t, strides, offset);
812  (void)status;
813  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
814  return {strides, offset};
815 }
816 
817 //===----------------------------------------------------------------------===//
818 /// TupleType
819 //===----------------------------------------------------------------------===//
820 
821 /// Return the elements types for this tuple.
822 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
823 
824 /// Accumulate the types contained in this tuple and tuples nested within it.
825 /// Note that this only flattens nested tuples, not any other container type,
826 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
827 /// (i32, tensor<i32>, f32, i64)
828 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
829  for (Type type : getTypes()) {
830  if (auto nestedTuple = type.dyn_cast<TupleType>())
831  nestedTuple.getFlattenedTypes(types);
832  else
833  types.push_back(type);
834  }
835 }
836 
837 /// Return the number of element types.
838 size_t TupleType::size() const { return getImpl()->size(); }
839 
840 //===----------------------------------------------------------------------===//
841 // Type Utilities
842 //===----------------------------------------------------------------------===//
843 
844 /// Return a version of `t` with identity layout if it can be determined
845 /// statically that the layout is the canonical contiguous strided layout.
846 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
847 /// `t` with simplified layout.
848 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
849 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
850  AffineMap m = t.getLayout().getAffineMap();
851 
852  // Already in canonical form.
853  if (m.isIdentity())
854  return t;
855 
856  // Can't reduce to canonical identity form, return in canonical form.
857  if (m.getNumResults() > 1)
858  return t;
859 
860  // Corner-case for 0-D affine maps.
861  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
862  if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
863  if (cst.getValue() == 0)
864  return MemRefType::Builder(t).setLayout({});
865  return t;
866  }
867 
868  // 0-D corner case for empty shape that still have an affine map. Example:
869  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
870  // offset needs to remain, just return t.
871  if (t.getShape().empty())
872  return t;
873 
874  // If the canonical strided layout for the sizes of `t` is equal to the
875  // simplified layout of `t` we can just return an empty layout. Otherwise,
876  // just simplify the existing layout.
877  AffineExpr expr =
878  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
879  auto simplifiedLayoutExpr =
881  if (expr != simplifiedLayoutExpr)
882  return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
883  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
884  return MemRefType::Builder(t).setLayout({});
885 }
886 
888  ArrayRef<AffineExpr> exprs,
889  MLIRContext *context) {
890  // Size 0 corner case is useful for canonicalizations.
891  if (sizes.empty())
892  return getAffineConstantExpr(0, context);
893 
894  assert(!exprs.empty() && "expected exprs");
895  auto maps = AffineMap::inferFromExprList(exprs);
896  assert(!maps.empty() && "Expected one non-empty map");
897  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
898 
899  AffineExpr expr;
900  bool dynamicPoisonBit = false;
901  int64_t runningSize = 1;
902  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
903  int64_t size = std::get<1>(en);
904  AffineExpr dimExpr = std::get<0>(en);
905  AffineExpr stride = dynamicPoisonBit
906  ? getAffineSymbolExpr(nSymbols++, context)
907  : getAffineConstantExpr(runningSize, context);
908  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
909  if (size > 0) {
910  runningSize *= size;
911  assert(runningSize > 0 && "integer overflow in size computation");
912  } else {
913  dynamicPoisonBit = true;
914  }
915  }
916  return simplifyAffineExpr(expr, numDims, nSymbols);
917 }
918 
920  MLIRContext *context) {
922  exprs.reserve(sizes.size());
923  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
924  exprs.push_back(getAffineDimExpr(dim, context));
925  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
926 }
927 
928 /// Return true if the layout for `t` is compatible with strided semantics.
929 bool mlir::isStrided(MemRefType t) {
930  int64_t offset;
931  SmallVector<int64_t, 4> strides;
932  auto res = getStridesAndOffset(t, strides, offset);
933  return succeeded(res);
934 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Affine binary operation expression.
Definition: AffineExpr.h:207
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:24
U dyn_cast() const
Definition: AffineExpr.h:281
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:256
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:310
unsigned getNumDims() const
Definition: AffineMap.cpp:306
unsigned getNumResults() const
Definition: AffineMap.cpp:314
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:236
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:323
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:267
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast_or_null() const
Definition: Attributes.h:132
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:74
U cast() const
Definition: Attributes.h:137
bool isa() const
Casting utility functions.
Definition: Attributes.h:117
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:114
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:371
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
BaseMemRefType cloneWith(Optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
Type getElementType() const
Returns the element type of this memref type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:92
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:402
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:398
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:182
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:177
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:192
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
TensorType cloneWith(Optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:121
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:89
bool isa() const
Definition: Types.h:260
Detect if any of the given parameter types has a sub-element handler.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:346
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:513
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:498
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26