MLIR  14.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 /// Tablegen Type Definitions
31 //===----------------------------------------------------------------------===//
32 
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // BuiltinDialect
38 //===----------------------------------------------------------------------===//
39 
40 void BuiltinDialect::registerTypes() {
41  addTypes<
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
44  >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ComplexType
49 //===----------------------------------------------------------------------===//
50 
51 /// Verify the construction of an integer type.
53  Type elementType) {
54  if (!elementType.isIntOrFloat())
55  return emitError() << "invalid element type for complex";
56  return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // Integer Type
61 //===----------------------------------------------------------------------===//
62 
63 // static constexpr must have a definition (until in C++17 and inline variable).
64 constexpr unsigned IntegerType::kMaxWidth;
65 
66 /// Verify the construction of an integer type.
68  unsigned width,
69  SignednessSemantics signedness) {
70  if (width > IntegerType::kMaxWidth) {
71  return emitError() << "integer bitwidth is limited to "
72  << IntegerType::kMaxWidth << " bits";
73  }
74  return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80  return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84  if (!scale)
85  return IntegerType();
86  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92 
93 unsigned FloatType::getWidth() {
94  if (isa<Float16Type, BFloat16Type>())
95  return 16;
96  if (isa<Float32Type>())
97  return 32;
98  if (isa<Float64Type>())
99  return 64;
100  if (isa<Float80Type>())
101  return 80;
102  if (isa<Float128Type>())
103  return 128;
104  llvm_unreachable("unexpected float type");
105 }
106 
107 /// Returns the floating semantics for the given type.
108 const llvm::fltSemantics &FloatType::getFloatSemantics() {
109  if (isa<BFloat16Type>())
110  return APFloat::BFloat();
111  if (isa<Float16Type>())
112  return APFloat::IEEEhalf();
113  if (isa<Float32Type>())
114  return APFloat::IEEEsingle();
115  if (isa<Float64Type>())
116  return APFloat::IEEEdouble();
117  if (isa<Float80Type>())
118  return APFloat::x87DoubleExtended();
119  if (isa<Float128Type>())
120  return APFloat::IEEEquad();
121  llvm_unreachable("non-floating point type used");
122 }
123 
125  if (!scale)
126  return FloatType();
127  MLIRContext *ctx = getContext();
128  if (isF16() || isBF16()) {
129  if (scale == 2)
130  return FloatType::getF32(ctx);
131  if (scale == 4)
132  return FloatType::getF64(ctx);
133  }
134  if (isF32())
135  if (scale == 2)
136  return FloatType::getF64(ctx);
137  return FloatType();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // FunctionType
142 //===----------------------------------------------------------------------===//
143 
144 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
145 
146 ArrayRef<Type> FunctionType::getInputs() const {
147  return getImpl()->getInputs();
148 }
149 
150 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
151 
152 ArrayRef<Type> FunctionType::getResults() const {
153  return getImpl()->getResults();
154 }
155 
156 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
157  return get(getContext(), inputs, results);
158 }
159 
160 /// Returns a new function type with the specified arguments and results
161 /// inserted.
162 FunctionType FunctionType::getWithArgsAndResults(
163  ArrayRef<unsigned> argIndices, TypeRange argTypes,
164  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
165  SmallVector<Type> argStorage, resultStorage;
167  getInputs(), argIndices, argTypes, argStorage);
169  getResults(), resultIndices, resultTypes, resultStorage);
170  return clone(newArgTypes, newResultTypes);
171 }
172 
173 /// Returns a new function type without the specified arguments and results.
174 FunctionType
175 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
176  ArrayRef<unsigned> resultIndices) {
177  SmallVector<Type> argStorage, resultStorage;
179  getInputs(), argIndices, argStorage);
181  getResults(), resultIndices, resultStorage);
182  return clone(newArgTypes, newResultTypes);
183 }
184 
185 void FunctionType::walkImmediateSubElements(
186  function_ref<void(Attribute)> walkAttrsFn,
187  function_ref<void(Type)> walkTypesFn) const {
188  for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
189  walkTypesFn(type);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // OpaqueType
194 //===----------------------------------------------------------------------===//
195 
196 /// Verify the construction of an opaque type.
198  StringAttr dialect, StringRef typeData) {
199  if (!Dialect::isValidNamespace(dialect.strref()))
200  return emitError() << "invalid dialect namespace '" << dialect << "'";
201 
202  // Check that the dialect is actually registered.
203  MLIRContext *context = dialect.getContext();
204  if (!context->allowsUnregisteredDialects() &&
205  !context->getLoadedDialect(dialect.strref())) {
206  return emitError()
207  << "`!" << dialect << "<\"" << typeData << "\">"
208  << "` type created with unregistered dialect. If this is "
209  "intended, please call allowUnregisteredDialects() on the "
210  "MLIRContext, or use -allow-unregistered-dialect with "
211  "the MLIR opt tool used";
212  }
213 
214  return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // VectorType
219 //===----------------------------------------------------------------------===//
220 
222  ArrayRef<int64_t> shape, Type elementType,
223  unsigned numScalableDims) {
224  if (!isValidElementType(elementType))
225  return emitError()
226  << "vector elements must be int/index/float type but got "
227  << elementType;
228 
229  if (any_of(shape, [](int64_t i) { return i <= 0; }))
230  return emitError()
231  << "vector types must have positive constant sizes but got "
232  << shape;
233 
234  return success();
235 }
236 
237 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
238  if (!scale)
239  return VectorType();
240  if (auto et = getElementType().dyn_cast<IntegerType>())
241  if (auto scaledEt = et.scaleElementBitwidth(scale))
242  return VectorType::get(getShape(), scaledEt, getNumScalableDims());
243  if (auto et = getElementType().dyn_cast<FloatType>())
244  if (auto scaledEt = et.scaleElementBitwidth(scale))
245  return VectorType::get(getShape(), scaledEt, getNumScalableDims());
246  return VectorType();
247 }
248 
249 void VectorType::walkImmediateSubElements(
250  function_ref<void(Attribute)> walkAttrsFn,
251  function_ref<void(Type)> walkTypesFn) const {
252  walkTypesFn(getElementType());
253 }
254 
255 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
256  Type elementType) const {
257  return VectorType::get(shape.getValueOr(getShape()), elementType,
258  getNumScalableDims());
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // TensorType
263 //===----------------------------------------------------------------------===//
264 
267  .Case<RankedTensorType, UnrankedTensorType>(
268  [](auto type) { return type.getElementType(); });
269 }
270 
271 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
272 
274  return cast<RankedTensorType>().getShape();
275 }
276 
278  Type elementType) const {
279  if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
280  if (shape)
281  return RankedTensorType::get(*shape, elementType);
282  return UnrankedTensorType::get(elementType);
283  }
284 
285  auto rankedTy = cast<RankedTensorType>();
286  if (!shape)
287  return RankedTensorType::get(rankedTy.getShape(), elementType,
288  rankedTy.getEncoding());
289  return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()),
290  elementType, rankedTy.getEncoding());
291 }
292 
293 // Check if "elementType" can be an element type of a tensor.
294 static LogicalResult
296  Type elementType) {
297  if (!TensorType::isValidElementType(elementType))
298  return emitError() << "invalid tensor element type: " << elementType;
299  return success();
300 }
301 
302 /// Return true if the specified element type is ok in a tensor.
304  // Note: Non standard/builtin types are allowed to exist within tensor
305  // types. Dialects are expected to verify that tensor types have a valid
306  // element type within that dialect.
307  return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
308  IndexType>() ||
309  !llvm::isa<BuiltinDialect>(type.getDialect());
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // RankedTensorType
314 //===----------------------------------------------------------------------===//
315 
318  ArrayRef<int64_t> shape, Type elementType,
319  Attribute encoding) {
320  for (int64_t s : shape)
321  if (s < -1)
322  return emitError() << "invalid tensor dimension size";
323  if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
324  if (failed(v.verifyEncoding(shape, elementType, emitError)))
325  return failure();
326  return checkTensorElementType(emitError, elementType);
327 }
328 
329 void RankedTensorType::walkImmediateSubElements(
330  function_ref<void(Attribute)> walkAttrsFn,
331  function_ref<void(Type)> walkTypesFn) const {
332  walkTypesFn(getElementType());
333  if (Attribute encoding = getEncoding())
334  walkAttrsFn(encoding);
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // UnrankedTensorType
339 //===----------------------------------------------------------------------===//
340 
343  Type elementType) {
344  return checkTensorElementType(emitError, elementType);
345 }
346 
347 void UnrankedTensorType::walkImmediateSubElements(
348  function_ref<void(Attribute)> walkAttrsFn,
349  function_ref<void(Type)> walkTypesFn) const {
350  walkTypesFn(getElementType());
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // BaseMemRefType
355 //===----------------------------------------------------------------------===//
356 
359  .Case<MemRefType, UnrankedMemRefType>(
360  [](auto type) { return type.getElementType(); });
361 }
362 
363 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
364 
366  return cast<MemRefType>().getShape();
367 }
368 
370  Type elementType) const {
371  if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
372  if (!shape)
373  return UnrankedMemRefType::get(elementType, getMemorySpace());
374  MemRefType::Builder builder(*shape, elementType);
375  builder.setMemorySpace(getMemorySpace());
376  return builder;
377  }
378 
379  MemRefType::Builder builder(cast<MemRefType>());
380  if (shape)
381  builder.setShape(*shape);
382  builder.setElementType(elementType);
383  return builder;
384 }
385 
387  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
388  return rankedMemRefTy.getMemorySpace();
389  return cast<UnrankedMemRefType>().getMemorySpace();
390 }
391 
393  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
394  return rankedMemRefTy.getMemorySpaceAsInt();
395  return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // MemRefType
400 //===----------------------------------------------------------------------===//
401 
402 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
403 /// `originalShape` with some `1` entries erased, return the set of indices
404 /// that specifies which of the entries of `originalShape` are dropped to obtain
405 /// `reducedShape`. The returned mask can be applied as a projection to
406 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
407 /// which dimensions must be kept when e.g. compute MemRef strides under
408 /// rank-reducing operations. Return None if reducedShape cannot be obtained
409 /// by dropping only `1` entries in `originalShape`.
412  ArrayRef<int64_t> reducedShape) {
413  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
414  llvm::SmallDenseSet<unsigned> unusedDims;
415  unsigned reducedIdx = 0;
416  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
417  // Greedily insert `originalIdx` if match.
418  if (reducedIdx < reducedRank &&
419  originalShape[originalIdx] == reducedShape[reducedIdx]) {
420  reducedIdx++;
421  continue;
422  }
423 
424  unusedDims.insert(originalIdx);
425  // If no match on `originalIdx`, the `originalShape` at this dimension
426  // must be 1, otherwise we bail.
427  if (originalShape[originalIdx] != 1)
428  return llvm::None;
429  }
430  // The whole reducedShape must be scanned, otherwise we bail.
431  if (reducedIdx != reducedRank)
432  return llvm::None;
433  return unusedDims;
434 }
435 
437 mlir::isRankReducedType(ShapedType originalType,
438  ShapedType candidateReducedType) {
439  if (originalType == candidateReducedType)
441 
442  ShapedType originalShapedType = originalType.cast<ShapedType>();
443  ShapedType candidateReducedShapedType =
444  candidateReducedType.cast<ShapedType>();
445 
446  // Rank and size logic is valid for all ShapedTypes.
447  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
448  ArrayRef<int64_t> candidateReducedShape =
449  candidateReducedShapedType.getShape();
450  unsigned originalRank = originalShape.size(),
451  candidateReducedRank = candidateReducedShape.size();
452  if (candidateReducedRank > originalRank)
454 
455  auto optionalUnusedDimsMask =
456  computeRankReductionMask(originalShape, candidateReducedShape);
457 
458  // Sizes cannot be matched in case empty vector is returned.
459  if (!optionalUnusedDimsMask.hasValue())
461 
462  if (originalShapedType.getElementType() !=
463  candidateReducedShapedType.getElementType())
465 
467 }
468 
470  // Empty attribute is allowed as default memory space.
471  if (!memorySpace)
472  return true;
473 
474  // Supported built-in attributes.
475  if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
476  return true;
477 
478  // Allow custom dialect attributes.
479  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
480  return true;
481 
482  return false;
483 }
484 
486  MLIRContext *ctx) {
487  if (memorySpace == 0)
488  return nullptr;
489 
490  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
491 }
492 
494  IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
495  if (intMemorySpace && intMemorySpace.getValue() == 0)
496  return nullptr;
497 
498  return memorySpace;
499 }
500 
502  if (!memorySpace)
503  return 0;
504 
505  assert(memorySpace.isa<IntegerAttr>() &&
506  "Using `getMemorySpaceInteger` with non-Integer attribute");
507 
508  return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
509 }
510 
512 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
513  memorySpace =
514  wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
515  return *this;
516 }
517 
518 unsigned MemRefType::getMemorySpaceAsInt() const {
519  return detail::getMemorySpaceAsInt(getMemorySpace());
520 }
521 
522 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
523  MemRefLayoutAttrInterface layout,
524  Attribute memorySpace) {
525  // Use default layout for empty attribute.
526  if (!layout)
527  layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
528  shape.size(), elementType.getContext()));
529 
530  // Drop default memory space value and replace it with empty attribute.
531  memorySpace = skipDefaultMemorySpace(memorySpace);
532 
533  return Base::get(elementType.getContext(), shape, elementType, layout,
534  memorySpace);
535 }
536 
537 MemRefType MemRefType::getChecked(
538  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
539  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
540 
541  // Use default layout for empty attribute.
542  if (!layout)
543  layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
544  shape.size(), elementType.getContext()));
545 
546  // Drop default memory space value and replace it with empty attribute.
547  memorySpace = skipDefaultMemorySpace(memorySpace);
548 
549  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
550  elementType, layout, memorySpace);
551 }
552 
553 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
554  AffineMap map, Attribute memorySpace) {
555 
556  // Use default layout for empty map.
557  if (!map)
558  map = AffineMap::getMultiDimIdentityMap(shape.size(),
559  elementType.getContext());
560 
561  // Wrap AffineMap into Attribute.
562  Attribute layout = AffineMapAttr::get(map);
563 
564  // Drop default memory space value and replace it with empty attribute.
565  memorySpace = skipDefaultMemorySpace(memorySpace);
566 
567  return Base::get(elementType.getContext(), shape, elementType, layout,
568  memorySpace);
569 }
570 
571 MemRefType
572 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
573  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
574  Attribute memorySpace) {
575 
576  // Use default layout for empty map.
577  if (!map)
578  map = AffineMap::getMultiDimIdentityMap(shape.size(),
579  elementType.getContext());
580 
581  // Wrap AffineMap into Attribute.
582  Attribute layout = AffineMapAttr::get(map);
583 
584  // Drop default memory space value and replace it with empty attribute.
585  memorySpace = skipDefaultMemorySpace(memorySpace);
586 
587  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
588  elementType, layout, memorySpace);
589 }
590 
591 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
592  AffineMap map, unsigned memorySpaceInd) {
593 
594  // Use default layout for empty map.
595  if (!map)
596  map = AffineMap::getMultiDimIdentityMap(shape.size(),
597  elementType.getContext());
598 
599  // Wrap AffineMap into Attribute.
600  Attribute layout = AffineMapAttr::get(map);
601 
602  // Convert deprecated integer-like memory space to Attribute.
603  Attribute memorySpace =
604  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
605 
606  return Base::get(elementType.getContext(), shape, elementType, layout,
607  memorySpace);
608 }
609 
610 MemRefType
611 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
612  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
613  unsigned memorySpaceInd) {
614 
615  // Use default layout for empty map.
616  if (!map)
617  map = AffineMap::getMultiDimIdentityMap(shape.size(),
618  elementType.getContext());
619 
620  // Wrap AffineMap into Attribute.
621  Attribute layout = AffineMapAttr::get(map);
622 
623  // Convert deprecated integer-like memory space to Attribute.
624  Attribute memorySpace =
625  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
626 
627  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
628  elementType, layout, memorySpace);
629 }
630 
632  ArrayRef<int64_t> shape, Type elementType,
633  MemRefLayoutAttrInterface layout,
634  Attribute memorySpace) {
635  if (!BaseMemRefType::isValidElementType(elementType))
636  return emitError() << "invalid memref element type";
637 
638  // Negative sizes are not allowed except for `-1` that means dynamic size.
639  for (int64_t s : shape)
640  if (s < -1)
641  return emitError() << "invalid memref size";
642 
643  assert(layout && "missing layout specification");
644  if (failed(layout.verifyLayout(shape, emitError)))
645  return failure();
646 
647  if (!isSupportedMemorySpace(memorySpace))
648  return emitError() << "unsupported memory space Attribute";
649 
650  return success();
651 }
652 
653 void MemRefType::walkImmediateSubElements(
654  function_ref<void(Attribute)> walkAttrsFn,
655  function_ref<void(Type)> walkTypesFn) const {
656  walkTypesFn(getElementType());
657  if (!getLayout().isIdentity())
658  walkAttrsFn(getLayout());
659  walkAttrsFn(getMemorySpace());
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // UnrankedMemRefType
664 //===----------------------------------------------------------------------===//
665 
667  return detail::getMemorySpaceAsInt(getMemorySpace());
668 }
669 
672  Type elementType, Attribute memorySpace) {
673  if (!BaseMemRefType::isValidElementType(elementType))
674  return emitError() << "invalid memref element type";
675 
676  if (!isSupportedMemorySpace(memorySpace))
677  return emitError() << "unsupported memory space Attribute";
678 
679  return success();
680 }
681 
682 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
683 // i.e. single term). Accumulate the AffineExpr into the existing one.
685  AffineExpr multiplicativeFactor,
687  AffineExpr &offset) {
688  if (auto dim = e.dyn_cast<AffineDimExpr>())
689  strides[dim.getPosition()] =
690  strides[dim.getPosition()] + multiplicativeFactor;
691  else
692  offset = offset + e * multiplicativeFactor;
693 }
694 
695 /// Takes a single AffineExpr `e` and populates the `strides` array with the
696 /// strides expressions for each dim position.
697 /// The convention is that the strides for dimensions d0, .. dn appear in
698 /// order to make indexing intuitive into the result.
700  AffineExpr multiplicativeFactor,
702  AffineExpr &offset) {
703  auto bin = e.dyn_cast<AffineBinaryOpExpr>();
704  if (!bin) {
705  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
706  return success();
707  }
708 
709  if (bin.getKind() == AffineExprKind::CeilDiv ||
710  bin.getKind() == AffineExprKind::FloorDiv ||
711  bin.getKind() == AffineExprKind::Mod)
712  return failure();
713 
714  if (bin.getKind() == AffineExprKind::Mul) {
715  auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
716  if (dim) {
717  strides[dim.getPosition()] =
718  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
719  return success();
720  }
721  // LHS and RHS may both contain complex expressions of dims. Try one path
722  // and if it fails try the other. This is guaranteed to succeed because
723  // only one path may have a `dim`, otherwise this is not an AffineExpr in
724  // the first place.
725  if (bin.getLHS().isSymbolicOrConstant())
726  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
727  strides, offset);
728  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
729  strides, offset);
730  }
731 
732  if (bin.getKind() == AffineExprKind::Add) {
733  auto res1 =
734  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
735  auto res2 =
736  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
737  return success(succeeded(res1) && succeeded(res2));
738  }
739 
740  llvm_unreachable("unexpected binary operation");
741 }
742 
745  AffineExpr &offset) {
746  AffineMap m = t.getLayout().getAffineMap();
747 
748  if (m.getNumResults() != 1 && !m.isIdentity())
749  return failure();
750 
751  auto zero = getAffineConstantExpr(0, t.getContext());
752  auto one = getAffineConstantExpr(1, t.getContext());
753  offset = zero;
754  strides.assign(t.getRank(), zero);
755 
756  // Canonical case for empty map.
757  if (m.isIdentity()) {
758  // 0-D corner case, offset is already 0.
759  if (t.getRank() == 0)
760  return success();
761  auto stridedExpr =
762  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
763  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
764  return success();
765  assert(false && "unexpected failure: extract strides in canonical layout");
766  }
767 
768  // Non-canonical case requires more work.
769  auto stridedExpr =
771  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
772  offset = AffineExpr();
773  strides.clear();
774  return failure();
775  }
776 
777  // Simplify results to allow folding to constants and simple checks.
778  unsigned numDims = m.getNumDims();
779  unsigned numSymbols = m.getNumSymbols();
780  offset = simplifyAffineExpr(offset, numDims, numSymbols);
781  for (auto &stride : strides)
782  stride = simplifyAffineExpr(stride, numDims, numSymbols);
783 
784  /// In practice, a strided memref must be internally non-aliasing. Test
785  /// against 0 as a proxy.
786  /// TODO: static cases can have more advanced checks.
787  /// TODO: dynamic cases would require a way to compare symbolic
788  /// expressions and would probably need an affine set context propagated
789  /// everywhere.
790  if (llvm::any_of(strides, [](AffineExpr e) {
791  return e == getAffineConstantExpr(0, e.getContext());
792  })) {
793  offset = AffineExpr();
794  strides.clear();
795  return failure();
796  }
797 
798  return success();
799 }
800 
802  SmallVectorImpl<int64_t> &strides,
803  int64_t &offset) {
804  AffineExpr offsetExpr;
805  SmallVector<AffineExpr, 4> strideExprs;
806  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
807  return failure();
808  if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
809  offset = cst.getValue();
810  else
811  offset = ShapedType::kDynamicStrideOrOffset;
812  for (auto e : strideExprs) {
813  if (auto c = e.dyn_cast<AffineConstantExpr>())
814  strides.push_back(c.getValue());
815  else
816  strides.push_back(ShapedType::kDynamicStrideOrOffset);
817  }
818  return success();
819 }
820 
821 void UnrankedMemRefType::walkImmediateSubElements(
822  function_ref<void(Attribute)> walkAttrsFn,
823  function_ref<void(Type)> walkTypesFn) const {
824  walkTypesFn(getElementType());
825  walkAttrsFn(getMemorySpace());
826 }
827 
828 //===----------------------------------------------------------------------===//
829 /// TupleType
830 //===----------------------------------------------------------------------===//
831 
832 /// Return the elements types for this tuple.
833 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
834 
835 /// Accumulate the types contained in this tuple and tuples nested within it.
836 /// Note that this only flattens nested tuples, not any other container type,
837 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
838 /// (i32, tensor<i32>, f32, i64)
840  for (Type type : getTypes()) {
841  if (auto nestedTuple = type.dyn_cast<TupleType>())
842  nestedTuple.getFlattenedTypes(types);
843  else
844  types.push_back(type);
845  }
846 }
847 
848 /// Return the number of element types.
849 size_t TupleType::size() const { return getImpl()->size(); }
850 
851 void TupleType::walkImmediateSubElements(
852  function_ref<void(Attribute)> walkAttrsFn,
853  function_ref<void(Type)> walkTypesFn) const {
854  for (Type type : getTypes())
855  walkTypesFn(type);
856 }
857 
858 //===----------------------------------------------------------------------===//
859 // Type Utilities
860 //===----------------------------------------------------------------------===//
861 
863  int64_t offset,
864  MLIRContext *context) {
865  AffineExpr expr;
866  unsigned nSymbols = 0;
867 
868  // AffineExpr for offset.
869  // Static case.
870  if (offset != MemRefType::getDynamicStrideOrOffset()) {
871  auto cst = getAffineConstantExpr(offset, context);
872  expr = cst;
873  } else {
874  // Dynamic case, new symbol for the offset.
875  auto sym = getAffineSymbolExpr(nSymbols++, context);
876  expr = sym;
877  }
878 
879  // AffineExpr for strides.
880  for (const auto &en : llvm::enumerate(strides)) {
881  auto dim = en.index();
882  auto stride = en.value();
883  assert(stride != 0 && "Invalid stride specification");
884  auto d = getAffineDimExpr(dim, context);
885  AffineExpr mult;
886  // Static case.
887  if (stride != MemRefType::getDynamicStrideOrOffset())
888  mult = getAffineConstantExpr(stride, context);
889  else
890  // Dynamic case, new symbol for each new stride.
891  mult = getAffineSymbolExpr(nSymbols++, context);
892  expr = expr + d * mult;
893  }
894 
895  return AffineMap::get(strides.size(), nSymbols, expr);
896 }
897 
898 /// Return a version of `t` with identity layout if it can be determined
899 /// statically that the layout is the canonical contiguous strided layout.
900 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
901 /// `t` with simplified layout.
902 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
903 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
904  AffineMap m = t.getLayout().getAffineMap();
905 
906  // Already in canonical form.
907  if (m.isIdentity())
908  return t;
909 
910  // Can't reduce to canonical identity form, return in canonical form.
911  if (m.getNumResults() > 1)
912  return t;
913 
914  // Corner-case for 0-D affine maps.
915  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
916  if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
917  if (cst.getValue() == 0)
918  return MemRefType::Builder(t).setLayout({});
919  return t;
920  }
921 
922  // 0-D corner case for empty shape that still have an affine map. Example:
923  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
924  // offset needs to remain, just return t.
925  if (t.getShape().empty())
926  return t;
927 
928  // If the canonical strided layout for the sizes of `t` is equal to the
929  // simplified layout of `t` we can just return an empty layout. Otherwise,
930  // just simplify the existing layout.
931  AffineExpr expr =
932  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
933  auto simplifiedLayoutExpr =
935  if (expr != simplifiedLayoutExpr)
936  return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
937  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
938  return MemRefType::Builder(t).setLayout({});
939 }
940 
942  ArrayRef<AffineExpr> exprs,
943  MLIRContext *context) {
944  assert(!sizes.empty() && !exprs.empty() &&
945  "expected non-empty sizes and exprs");
946 
947  // Size 0 corner case is useful for canonicalizations.
948  if (llvm::is_contained(sizes, 0))
949  return getAffineConstantExpr(0, context);
950 
951  auto maps = AffineMap::inferFromExprList(exprs);
952  assert(!maps.empty() && "Expected one non-empty map");
953  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
954 
955  AffineExpr expr;
956  bool dynamicPoisonBit = false;
957  int64_t runningSize = 1;
958  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
959  int64_t size = std::get<1>(en);
960  // Degenerate case, no size =-> no stride
961  if (size == 0)
962  continue;
963  AffineExpr dimExpr = std::get<0>(en);
964  AffineExpr stride = dynamicPoisonBit
965  ? getAffineSymbolExpr(nSymbols++, context)
966  : getAffineConstantExpr(runningSize, context);
967  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
968  if (size > 0) {
969  runningSize *= size;
970  assert(runningSize > 0 && "integer overflow in size computation");
971  } else {
972  dynamicPoisonBit = true;
973  }
974  }
975  return simplifyAffineExpr(expr, numDims, nSymbols);
976 }
977 
978 /// Return a version of `t` with a layout that has all dynamic offset and
979 /// strides. This is used to erase the static layout.
980 MemRefType mlir::eraseStridedLayout(MemRefType t) {
981  auto val = ShapedType::kDynamicStrideOrOffset;
982  return MemRefType::Builder(t).setLayout(
983  AffineMapAttr::get(makeStridedLinearLayoutMap(
984  SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
985 }
986 
988  MLIRContext *context) {
990  exprs.reserve(sizes.size());
991  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
992  exprs.push_back(getAffineDimExpr(dim, context));
993  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
994 }
995 
996 /// Return true if the layout for `t` is compatible with strided semantics.
997 bool mlir::isStrided(MemRefType t) {
998  int64_t offset;
999  SmallVector<int64_t, 4> strides;
1000  auto res = getStridesAndOffset(t, strides, offset);
1001  return succeeded(res);
1002 }
1003 
1004 /// Return the layout map in strided linear layout AffineMap form.
1005 /// Return null if the layout is not compatible with a strided layout.
1007  int64_t offset;
1008  SmallVector<int64_t, 4> strides;
1009  if (failed(getStridesAndOffset(t, strides, offset)))
1010  return AffineMap();
1011  return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1012 }
1013 
1014 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1015 /// is a strided memref.
1016 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1017  SmallVector<AffineExpr> strides;
1018  AffineExpr offset;
1019  if (failed(getStridesAndOffset(memrefType, strides, offset)))
1020  assert(false && "expected strided memref");
1021  return offset;
1022 }
1023 
1024 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1025 /// `offset` AffineExpr.
1027  ArrayRef<int64_t> shape,
1028  Type elementType,
1029  AffineExpr offset) {
1030  AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1031  AffineExpr contiguousRowMajor = canonical + offset;
1032  AffineMap contiguousRowMajorMap =
1033  AffineMap::inferFromExprList({contiguousRowMajor})[0];
1034  return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1035 }
1036 
1037 /// Helper determining if a memref is static-shape and contiguous-row-major
1038 /// layout, while still allowing for an arbitrary offset (any static or
1039 /// dynamic value).
1040 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1041  if (!memrefType.hasStaticShape())
1042  return false;
1043  AffineExpr offset = getOffsetExpr(memrefType);
1044  MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1045  memrefType.getContext(), memrefType.getShape(),
1046  memrefType.getElementType(), offset);
1047  return canonicalizeStridedLayout(memrefType) ==
1048  canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1049 }
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Affine binary operation expression.
Definition: AffineExpr.h:207
Include the generated interface declarations.
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:114
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
U cast() const
Definition: Attributes.h:123
RHS of mod is always a constant or a symbolic expression with a positive value.
U dyn_cast_or_null() const
Definition: Attributes.h:120
MemRefType eraseStridedLayout(MemRefType t)
Return a version of t with a layout that has all dynamic offset and strides.
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:516
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:301
AffineMap getStridedLinearLayoutMap(MemRefType t)
Return the layout map in strided linear layout AffineMap form.
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
bool isa() const
Definition: Attributes.h:107
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static AffineExpr getOffsetExpr(MemRefType memrefType)
Return the AffineExpr representation of the offset, assuming memRefType is a strided memref...
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:378
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
TensorType cloneWith(Optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:315
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
TypeRange filterTypesOut(TypeRange types, ArrayRef< unsigned > indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
unsigned getWidth()
Return the bitwidth of this float type.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:70
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:244
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:501
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
BaseMemRefType cloneWith(Optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:334
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:23
RHS of mul is always a constant or a symbolic expression.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
unsigned getNumResults() const
Definition: AffineMap.cpp:302
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
RHS of floordiv is always a constant or a symbolic expression.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
RHS of ceildiv is always a constant or a symbolic expression.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, ArrayRef< int64_t > shape, Type elementType, AffineExpr offset)
Helper to construct a contiguous MemRefType of shape, elementType and offset AffineExpr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:109
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:359
Type getElementType() const
Returns the element type of this memref type.
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:177
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:161
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:255
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Type getElementType() const
Returns the element type of this tensor type.
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
AffineMap makeStridedLinearLayoutMap(ArrayRef< int64_t > strides, int64_t offset, MLIRContext *context)
Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() represents a dynamic value)...
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:187
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with &#39;numDims&#39; identity result dim exprs.
Definition: AffineMap.cpp:244
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:182
bool isa() const
Definition: Types.h:234
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:172
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:382
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType)
Helper determining if a memref is static-shape and contiguous-row-major layout, while still allowing ...
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:182