MLIR  20.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 // BuiltinDialect
37 //===----------------------------------------------------------------------===//
38 
39 void BuiltinDialect::registerTypes() {
40  addTypes<
41 #define GET_TYPEDEF_LIST
42 #include "mlir/IR/BuiltinTypes.cpp.inc"
43  >();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 /// ComplexType
48 //===----------------------------------------------------------------------===//
49 
50 /// Verify the construction of an integer type.
52  Type elementType) {
53  if (!elementType.isIntOrFloat())
54  return emitError() << "invalid element type for complex";
55  return success();
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Integer Type
60 //===----------------------------------------------------------------------===//
61 
62 /// Verify the construction of an integer type.
64  unsigned width,
65  SignednessSemantics signedness) {
66  if (width > IntegerType::kMaxWidth) {
67  return emitError() << "integer bitwidth is limited to "
68  << IntegerType::kMaxWidth << " bits";
69  }
70  return success();
71 }
72 
73 unsigned IntegerType::getWidth() const { return getImpl()->width; }
74 
75 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
76  return getImpl()->signedness;
77 }
78 
79 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
80  if (!scale)
81  return IntegerType();
82  return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Float Type
87 //===----------------------------------------------------------------------===//
88 
89 unsigned FloatType::getWidth() {
90  if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
91  Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType>(
92  *this))
93  return 8;
94  if (llvm::isa<Float16Type, BFloat16Type>(*this))
95  return 16;
96  if (llvm::isa<Float32Type, FloatTF32Type>(*this))
97  return 32;
98  if (llvm::isa<Float64Type>(*this))
99  return 64;
100  if (llvm::isa<Float80Type>(*this))
101  return 80;
102  if (llvm::isa<Float128Type>(*this))
103  return 128;
104  llvm_unreachable("unexpected float type");
105 }
106 
107 /// Returns the floating semantics for the given type.
108 const llvm::fltSemantics &FloatType::getFloatSemantics() {
109  if (llvm::isa<Float8E5M2Type>(*this))
110  return APFloat::Float8E5M2();
111  if (llvm::isa<Float8E4M3Type>(*this))
112  return APFloat::Float8E4M3();
113  if (llvm::isa<Float8E4M3FNType>(*this))
114  return APFloat::Float8E4M3FN();
115  if (llvm::isa<Float8E5M2FNUZType>(*this))
116  return APFloat::Float8E5M2FNUZ();
117  if (llvm::isa<Float8E4M3FNUZType>(*this))
118  return APFloat::Float8E4M3FNUZ();
119  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
120  return APFloat::Float8E4M3B11FNUZ();
121  if (llvm::isa<BFloat16Type>(*this))
122  return APFloat::BFloat();
123  if (llvm::isa<Float16Type>(*this))
124  return APFloat::IEEEhalf();
125  if (llvm::isa<FloatTF32Type>(*this))
126  return APFloat::FloatTF32();
127  if (llvm::isa<Float32Type>(*this))
128  return APFloat::IEEEsingle();
129  if (llvm::isa<Float64Type>(*this))
130  return APFloat::IEEEdouble();
131  if (llvm::isa<Float80Type>(*this))
132  return APFloat::x87DoubleExtended();
133  if (llvm::isa<Float128Type>(*this))
134  return APFloat::IEEEquad();
135  llvm_unreachable("non-floating point type used");
136 }
137 
139  if (!scale)
140  return FloatType();
141  MLIRContext *ctx = getContext();
142  if (isF16() || isBF16()) {
143  if (scale == 2)
144  return FloatType::getF32(ctx);
145  if (scale == 4)
146  return FloatType::getF64(ctx);
147  }
148  if (isF32())
149  if (scale == 2)
150  return FloatType::getF64(ctx);
151  return FloatType();
152 }
153 
155  return APFloat::semanticsPrecision(getFloatSemantics());
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // FunctionType
160 //===----------------------------------------------------------------------===//
161 
162 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
163 
164 ArrayRef<Type> FunctionType::getInputs() const {
165  return getImpl()->getInputs();
166 }
167 
168 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
169 
170 ArrayRef<Type> FunctionType::getResults() const {
171  return getImpl()->getResults();
172 }
173 
174 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
175  return get(getContext(), inputs, results);
176 }
177 
178 /// Returns a new function type with the specified arguments and results
179 /// inserted.
180 FunctionType FunctionType::getWithArgsAndResults(
181  ArrayRef<unsigned> argIndices, TypeRange argTypes,
182  ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
183  SmallVector<Type> argStorage, resultStorage;
184  TypeRange newArgTypes =
185  insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
186  TypeRange newResultTypes =
187  insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
188  return clone(newArgTypes, newResultTypes);
189 }
190 
191 /// Returns a new function type without the specified arguments and results.
192 FunctionType
193 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
194  const BitVector &resultIndices) {
195  SmallVector<Type> argStorage, resultStorage;
196  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
197  TypeRange newResultTypes =
198  filterTypesOut(getResults(), resultIndices, resultStorage);
199  return clone(newArgTypes, newResultTypes);
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // OpaqueType
204 //===----------------------------------------------------------------------===//
205 
206 /// Verify the construction of an opaque type.
208  StringAttr dialect, StringRef typeData) {
209  if (!Dialect::isValidNamespace(dialect.strref()))
210  return emitError() << "invalid dialect namespace '" << dialect << "'";
211 
212  // Check that the dialect is actually registered.
213  MLIRContext *context = dialect.getContext();
214  if (!context->allowsUnregisteredDialects() &&
215  !context->getLoadedDialect(dialect.strref())) {
216  return emitError()
217  << "`!" << dialect << "<\"" << typeData << "\">"
218  << "` type created with unregistered dialect. If this is "
219  "intended, please call allowUnregisteredDialects() on the "
220  "MLIRContext, or use -allow-unregistered-dialect with "
221  "the MLIR opt tool used";
222  }
223 
224  return success();
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // VectorType
229 //===----------------------------------------------------------------------===//
230 
232  ArrayRef<int64_t> shape, Type elementType,
233  ArrayRef<bool> scalableDims) {
234  if (!isValidElementType(elementType))
235  return emitError()
236  << "vector elements must be int/index/float type but got "
237  << elementType;
238 
239  if (any_of(shape, [](int64_t i) { return i <= 0; }))
240  return emitError()
241  << "vector types must have positive constant sizes but got "
242  << shape;
243 
244  if (scalableDims.size() != shape.size())
245  return emitError() << "number of dims must match, got "
246  << scalableDims.size() << " and " << shape.size();
247 
248  return success();
249 }
250 
251 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
252  if (!scale)
253  return VectorType();
254  if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
255  if (auto scaledEt = et.scaleElementBitwidth(scale))
256  return VectorType::get(getShape(), scaledEt, getScalableDims());
257  if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
258  if (auto scaledEt = et.scaleElementBitwidth(scale))
259  return VectorType::get(getShape(), scaledEt, getScalableDims());
260  return VectorType();
261 }
262 
263 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
264  Type elementType) const {
265  return VectorType::get(shape.value_or(getShape()), elementType,
266  getScalableDims());
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // TensorType
271 //===----------------------------------------------------------------------===//
272 
275  .Case<RankedTensorType, UnrankedTensorType>(
276  [](auto type) { return type.getElementType(); });
277 }
278 
279 bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
280 
282  return llvm::cast<RankedTensorType>(*this).getShape();
283 }
284 
286  Type elementType) const {
287  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
288  if (shape)
289  return RankedTensorType::get(*shape, elementType);
290  return UnrankedTensorType::get(elementType);
291  }
292 
293  auto rankedTy = llvm::cast<RankedTensorType>(*this);
294  if (!shape)
295  return RankedTensorType::get(rankedTy.getShape(), elementType,
296  rankedTy.getEncoding());
297  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
298  rankedTy.getEncoding());
299 }
300 
301 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
302  Type elementType) const {
303  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
304 }
305 
306 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
307  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
308 }
309 
310 // Check if "elementType" can be an element type of a tensor.
311 static LogicalResult
313  Type elementType) {
314  if (!TensorType::isValidElementType(elementType))
315  return emitError() << "invalid tensor element type: " << elementType;
316  return success();
317 }
318 
319 /// Return true if the specified element type is ok in a tensor.
321  // Note: Non standard/builtin types are allowed to exist within tensor
322  // types. Dialects are expected to verify that tensor types have a valid
323  // element type within that dialect.
324  return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
325  IndexType>(type) ||
326  !llvm::isa<BuiltinDialect>(type.getDialect());
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // RankedTensorType
331 //===----------------------------------------------------------------------===//
332 
333 LogicalResult
335  ArrayRef<int64_t> shape, Type elementType,
336  Attribute encoding) {
337  for (int64_t s : shape)
338  if (s < 0 && !ShapedType::isDynamic(s))
339  return emitError() << "invalid tensor dimension size";
340  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
341  if (failed(v.verifyEncoding(shape, elementType, emitError)))
342  return failure();
343  return checkTensorElementType(emitError, elementType);
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // UnrankedTensorType
348 //===----------------------------------------------------------------------===//
349 
350 LogicalResult
352  Type elementType) {
353  return checkTensorElementType(emitError, elementType);
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // BaseMemRefType
358 //===----------------------------------------------------------------------===//
359 
362  .Case<MemRefType, UnrankedMemRefType>(
363  [](auto type) { return type.getElementType(); });
364 }
365 
366 bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
367 
369  return llvm::cast<MemRefType>(*this).getShape();
370 }
371 
373  Type elementType) const {
374  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
375  if (!shape)
376  return UnrankedMemRefType::get(elementType, getMemorySpace());
377  MemRefType::Builder builder(*shape, elementType);
378  builder.setMemorySpace(getMemorySpace());
379  return builder;
380  }
381 
382  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
383  if (shape)
384  builder.setShape(*shape);
385  builder.setElementType(elementType);
386  return builder;
387 }
388 
390  Type elementType) const {
391  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
392 }
393 
394 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
395  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
396 }
397 
399  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
400  return rankedMemRefTy.getMemorySpace();
401  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
402 }
403 
405  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
406  return rankedMemRefTy.getMemorySpaceAsInt();
407  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // MemRefType
412 //===----------------------------------------------------------------------===//
413 
414 std::optional<llvm::SmallDenseSet<unsigned>>
416  ArrayRef<int64_t> reducedShape,
417  bool matchDynamic) {
418  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
419  llvm::SmallDenseSet<unsigned> unusedDims;
420  unsigned reducedIdx = 0;
421  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
422  // Greedily insert `originalIdx` if match.
423  int64_t origSize = originalShape[originalIdx];
424  // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
425  if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
426  (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
427  ShapedType::isDynamic(origSize))) {
428  reducedIdx++;
429  continue;
430  }
431  if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
432  reducedIdx++;
433  continue;
434  }
435 
436  unusedDims.insert(originalIdx);
437  // If no match on `originalIdx`, the `originalShape` at this dimension
438  // must be 1, otherwise we bail.
439  if (origSize != 1)
440  return std::nullopt;
441  }
442  // The whole reducedShape must be scanned, otherwise we bail.
443  if (reducedIdx != reducedRank)
444  return std::nullopt;
445  return unusedDims;
446 }
447 
449 mlir::isRankReducedType(ShapedType originalType,
450  ShapedType candidateReducedType) {
451  if (originalType == candidateReducedType)
453 
454  ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
455  ShapedType candidateReducedShapedType =
456  llvm::cast<ShapedType>(candidateReducedType);
457 
458  // Rank and size logic is valid for all ShapedTypes.
459  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
460  ArrayRef<int64_t> candidateReducedShape =
461  candidateReducedShapedType.getShape();
462  unsigned originalRank = originalShape.size(),
463  candidateReducedRank = candidateReducedShape.size();
464  if (candidateReducedRank > originalRank)
466 
467  auto optionalUnusedDimsMask =
468  computeRankReductionMask(originalShape, candidateReducedShape);
469 
470  // Sizes cannot be matched in case empty vector is returned.
471  if (!optionalUnusedDimsMask)
473 
474  if (originalShapedType.getElementType() !=
475  candidateReducedShapedType.getElementType())
477 
479 }
480 
482  // Empty attribute is allowed as default memory space.
483  if (!memorySpace)
484  return true;
485 
486  // Supported built-in attributes.
487  if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
488  return true;
489 
490  // Allow custom dialect attributes.
491  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
492  return true;
493 
494  return false;
495 }
496 
498  MLIRContext *ctx) {
499  if (memorySpace == 0)
500  return nullptr;
501 
502  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
503 }
504 
506  IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
507  if (intMemorySpace && intMemorySpace.getValue() == 0)
508  return nullptr;
509 
510  return memorySpace;
511 }
512 
514  if (!memorySpace)
515  return 0;
516 
517  assert(llvm::isa<IntegerAttr>(memorySpace) &&
518  "Using `getMemorySpaceInteger` with non-Integer attribute");
519 
520  return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
521 }
522 
523 unsigned MemRefType::getMemorySpaceAsInt() const {
524  return detail::getMemorySpaceAsInt(getMemorySpace());
525 }
526 
527 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
528  MemRefLayoutAttrInterface layout,
529  Attribute memorySpace) {
530  // Use default layout for empty attribute.
531  if (!layout)
533  shape.size(), elementType.getContext()));
534 
535  // Drop default memory space value and replace it with empty attribute.
536  memorySpace = skipDefaultMemorySpace(memorySpace);
537 
538  return Base::get(elementType.getContext(), shape, elementType, layout,
539  memorySpace);
540 }
541 
542 MemRefType MemRefType::getChecked(
543  function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
544  Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
545 
546  // Use default layout for empty attribute.
547  if (!layout)
549  shape.size(), elementType.getContext()));
550 
551  // Drop default memory space value and replace it with empty attribute.
552  memorySpace = skipDefaultMemorySpace(memorySpace);
553 
554  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
555  elementType, layout, memorySpace);
556 }
557 
558 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
559  AffineMap map, Attribute memorySpace) {
560 
561  // Use default layout for empty map.
562  if (!map)
563  map = AffineMap::getMultiDimIdentityMap(shape.size(),
564  elementType.getContext());
565 
566  // Wrap AffineMap into Attribute.
567  auto layout = AffineMapAttr::get(map);
568 
569  // Drop default memory space value and replace it with empty attribute.
570  memorySpace = skipDefaultMemorySpace(memorySpace);
571 
572  return Base::get(elementType.getContext(), shape, elementType, layout,
573  memorySpace);
574 }
575 
576 MemRefType
577 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
578  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
579  Attribute memorySpace) {
580 
581  // Use default layout for empty map.
582  if (!map)
583  map = AffineMap::getMultiDimIdentityMap(shape.size(),
584  elementType.getContext());
585 
586  // Wrap AffineMap into Attribute.
587  auto layout = AffineMapAttr::get(map);
588 
589  // Drop default memory space value and replace it with empty attribute.
590  memorySpace = skipDefaultMemorySpace(memorySpace);
591 
592  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
593  elementType, layout, memorySpace);
594 }
595 
596 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
597  AffineMap map, unsigned memorySpaceInd) {
598 
599  // Use default layout for empty map.
600  if (!map)
601  map = AffineMap::getMultiDimIdentityMap(shape.size(),
602  elementType.getContext());
603 
604  // Wrap AffineMap into Attribute.
605  auto layout = AffineMapAttr::get(map);
606 
607  // Convert deprecated integer-like memory space to Attribute.
608  Attribute memorySpace =
609  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
610 
611  return Base::get(elementType.getContext(), shape, elementType, layout,
612  memorySpace);
613 }
614 
615 MemRefType
616 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
617  ArrayRef<int64_t> shape, Type elementType, AffineMap map,
618  unsigned memorySpaceInd) {
619 
620  // Use default layout for empty map.
621  if (!map)
622  map = AffineMap::getMultiDimIdentityMap(shape.size(),
623  elementType.getContext());
624 
625  // Wrap AffineMap into Attribute.
626  auto layout = AffineMapAttr::get(map);
627 
628  // Convert deprecated integer-like memory space to Attribute.
629  Attribute memorySpace =
630  wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
631 
632  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
633  elementType, layout, memorySpace);
634 }
635 
637  ArrayRef<int64_t> shape, Type elementType,
638  MemRefLayoutAttrInterface layout,
639  Attribute memorySpace) {
640  if (!BaseMemRefType::isValidElementType(elementType))
641  return emitError() << "invalid memref element type";
642 
643  // Negative sizes are not allowed except for `kDynamic`.
644  for (int64_t s : shape)
645  if (s < 0 && !ShapedType::isDynamic(s))
646  return emitError() << "invalid memref size";
647 
648  assert(layout && "missing layout specification");
649  if (failed(layout.verifyLayout(shape, emitError)))
650  return failure();
651 
652  if (!isSupportedMemorySpace(memorySpace))
653  return emitError() << "unsupported memory space Attribute";
654 
655  return success();
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // UnrankedMemRefType
660 //===----------------------------------------------------------------------===//
661 
663  return detail::getMemorySpaceAsInt(getMemorySpace());
664 }
665 
666 LogicalResult
668  Type elementType, Attribute memorySpace) {
669  if (!BaseMemRefType::isValidElementType(elementType))
670  return emitError() << "invalid memref element type";
671 
672  if (!isSupportedMemorySpace(memorySpace))
673  return emitError() << "unsupported memory space Attribute";
674 
675  return success();
676 }
677 
678 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
679 // i.e. single term). Accumulate the AffineExpr into the existing one.
681  AffineExpr multiplicativeFactor,
683  AffineExpr &offset) {
684  if (auto dim = dyn_cast<AffineDimExpr>(e))
685  strides[dim.getPosition()] =
686  strides[dim.getPosition()] + multiplicativeFactor;
687  else
688  offset = offset + e * multiplicativeFactor;
689 }
690 
691 /// Takes a single AffineExpr `e` and populates the `strides` array with the
692 /// strides expressions for each dim position.
693 /// The convention is that the strides for dimensions d0, .. dn appear in
694 /// order to make indexing intuitive into the result.
695 static LogicalResult extractStrides(AffineExpr e,
696  AffineExpr multiplicativeFactor,
698  AffineExpr &offset) {
699  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
700  if (!bin) {
701  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
702  return success();
703  }
704 
705  if (bin.getKind() == AffineExprKind::CeilDiv ||
706  bin.getKind() == AffineExprKind::FloorDiv ||
707  bin.getKind() == AffineExprKind::Mod)
708  return failure();
709 
710  if (bin.getKind() == AffineExprKind::Mul) {
711  auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
712  if (dim) {
713  strides[dim.getPosition()] =
714  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
715  return success();
716  }
717  // LHS and RHS may both contain complex expressions of dims. Try one path
718  // and if it fails try the other. This is guaranteed to succeed because
719  // only one path may have a `dim`, otherwise this is not an AffineExpr in
720  // the first place.
721  if (bin.getLHS().isSymbolicOrConstant())
722  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
723  strides, offset);
724  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
725  strides, offset);
726  }
727 
728  if (bin.getKind() == AffineExprKind::Add) {
729  auto res1 =
730  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
731  auto res2 =
732  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
733  return success(succeeded(res1) && succeeded(res2));
734  }
735 
736  llvm_unreachable("unexpected binary operation");
737 }
738 
739 /// A stride specification is a list of integer values that are either static
740 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
741 /// the distance in the number of elements between successive entries along a
742 /// particular dimension.
743 ///
744 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
745 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
746 /// distance between two consecutive elements along the outer dimension is `1`
747 /// and the distance between two consecutive elements along the inner dimension
748 /// is `64`.
749 ///
750 /// The convention is that the strides for dimensions d0, .. dn appear in
751 /// order to make indexing intuitive into the result.
752 static LogicalResult getStridesAndOffset(MemRefType t,
753  SmallVectorImpl<AffineExpr> &strides,
754  AffineExpr &offset) {
755  AffineMap m = t.getLayout().getAffineMap();
756 
757  if (m.getNumResults() != 1 && !m.isIdentity())
758  return failure();
759 
760  auto zero = getAffineConstantExpr(0, t.getContext());
761  auto one = getAffineConstantExpr(1, t.getContext());
762  offset = zero;
763  strides.assign(t.getRank(), zero);
764 
765  // Canonical case for empty map.
766  if (m.isIdentity()) {
767  // 0-D corner case, offset is already 0.
768  if (t.getRank() == 0)
769  return success();
770  auto stridedExpr =
771  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
772  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
773  return success();
774  assert(false && "unexpected failure: extract strides in canonical layout");
775  }
776 
777  // Non-canonical case requires more work.
778  auto stridedExpr =
779  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
780  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
781  offset = AffineExpr();
782  strides.clear();
783  return failure();
784  }
785 
786  // Simplify results to allow folding to constants and simple checks.
787  unsigned numDims = m.getNumDims();
788  unsigned numSymbols = m.getNumSymbols();
789  offset = simplifyAffineExpr(offset, numDims, numSymbols);
790  for (auto &stride : strides)
791  stride = simplifyAffineExpr(stride, numDims, numSymbols);
792 
793  // In practice, a strided memref must be internally non-aliasing. Test
794  // against 0 as a proxy.
795  // TODO: static cases can have more advanced checks.
796  // TODO: dynamic cases would require a way to compare symbolic
797  // expressions and would probably need an affine set context propagated
798  // everywhere.
799  if (llvm::any_of(strides, [](AffineExpr e) {
800  return e == getAffineConstantExpr(0, e.getContext());
801  })) {
802  offset = AffineExpr();
803  strides.clear();
804  return failure();
805  }
806 
807  return success();
808 }
809 
810 LogicalResult mlir::getStridesAndOffset(MemRefType t,
811  SmallVectorImpl<int64_t> &strides,
812  int64_t &offset) {
813  // Happy path: the type uses the strided layout directly.
814  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
815  llvm::append_range(strides, strided.getStrides());
816  offset = strided.getOffset();
817  return success();
818  }
819 
820  // Otherwise, defer to the affine fallback as layouts are supposed to be
821  // convertible to affine maps.
822  AffineExpr offsetExpr;
823  SmallVector<AffineExpr, 4> strideExprs;
824  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
825  return failure();
826  if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
827  offset = cst.getValue();
828  else
829  offset = ShapedType::kDynamic;
830  for (auto e : strideExprs) {
831  if (auto c = dyn_cast<AffineConstantExpr>(e))
832  strides.push_back(c.getValue());
833  else
834  strides.push_back(ShapedType::kDynamic);
835  }
836  return success();
837 }
838 
839 std::pair<SmallVector<int64_t>, int64_t>
841  SmallVector<int64_t> strides;
842  int64_t offset;
843  LogicalResult status = getStridesAndOffset(t, strides, offset);
844  (void)status;
845  assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
846  return {strides, offset};
847 }
848 
849 //===----------------------------------------------------------------------===//
850 /// TupleType
851 //===----------------------------------------------------------------------===//
852 
853 /// Return the elements types for this tuple.
854 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
855 
856 /// Accumulate the types contained in this tuple and tuples nested within it.
857 /// Note that this only flattens nested tuples, not any other container type,
858 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
859 /// (i32, tensor<i32>, f32, i64)
860 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
861  for (Type type : getTypes()) {
862  if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
863  nestedTuple.getFlattenedTypes(types);
864  else
865  types.push_back(type);
866  }
867 }
868 
869 /// Return the number of element types.
870 size_t TupleType::size() const { return getImpl()->size(); }
871 
872 //===----------------------------------------------------------------------===//
873 // Type Utilities
874 //===----------------------------------------------------------------------===//
875 
876 /// Return a version of `t` with identity layout if it can be determined
877 /// statically that the layout is the canonical contiguous strided layout.
878 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
879 /// `t` with simplified layout.
880 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
881 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
882  AffineMap m = t.getLayout().getAffineMap();
883 
884  // Already in canonical form.
885  if (m.isIdentity())
886  return t;
887 
888  // Can't reduce to canonical identity form, return in canonical form.
889  if (m.getNumResults() > 1)
890  return t;
891 
892  // Corner-case for 0-D affine maps.
893  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
894  if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
895  if (cst.getValue() == 0)
896  return MemRefType::Builder(t).setLayout({});
897  return t;
898  }
899 
900  // 0-D corner case for empty shape that still have an affine map. Example:
901  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
902  // offset needs to remain, just return t.
903  if (t.getShape().empty())
904  return t;
905 
906  // If the canonical strided layout for the sizes of `t` is equal to the
907  // simplified layout of `t` we can just return an empty layout. Otherwise,
908  // just simplify the existing layout.
909  AffineExpr expr =
910  makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
911  auto simplifiedLayoutExpr =
912  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
913  if (expr != simplifiedLayoutExpr)
915  m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
916  return MemRefType::Builder(t).setLayout({});
917 }
918 
920  ArrayRef<AffineExpr> exprs,
921  MLIRContext *context) {
922  // Size 0 corner case is useful for canonicalizations.
923  if (sizes.empty())
924  return getAffineConstantExpr(0, context);
925 
926  assert(!exprs.empty() && "expected exprs");
927  auto maps = AffineMap::inferFromExprList(exprs, context);
928  assert(!maps.empty() && "Expected one non-empty map");
929  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
930 
931  AffineExpr expr;
932  bool dynamicPoisonBit = false;
933  int64_t runningSize = 1;
934  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
935  int64_t size = std::get<1>(en);
936  AffineExpr dimExpr = std::get<0>(en);
937  AffineExpr stride = dynamicPoisonBit
938  ? getAffineSymbolExpr(nSymbols++, context)
939  : getAffineConstantExpr(runningSize, context);
940  expr = expr ? expr + dimExpr * stride : dimExpr * stride;
941  if (size > 0) {
942  runningSize *= size;
943  assert(runningSize > 0 && "integer overflow in size computation");
944  } else {
945  dynamicPoisonBit = true;
946  }
947  }
948  return simplifyAffineExpr(expr, numDims, nSymbols);
949 }
950 
952  MLIRContext *context) {
954  exprs.reserve(sizes.size());
955  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
956  exprs.push_back(getAffineDimExpr(dim, context));
957  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
958 }
959 
960 bool mlir::isStrided(MemRefType t) {
961  int64_t offset;
962  SmallVector<int64_t, 4> strides;
963  auto res = getStridesAndOffset(t, strides, offset);
964  return succeeded(res);
965 }
966 
967 bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
968  int64_t offset;
969  SmallVector<int64_t> strides;
970  auto successStrides = getStridesAndOffset(type, strides, offset);
971  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
972 }
973 
974 bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
975  if (!isLastMemrefDimUnitStride(type))
976  return false;
977 
978  auto memrefShape = type.getShape().take_back(n);
979  if (ShapedType::isDynamicShape(memrefShape))
980  return false;
981 
982  if (type.getLayout().isIdentity())
983  return true;
984 
985  int64_t offset;
986  SmallVector<int64_t> stridesFull;
987  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
988  return false;
989  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
990 
991  if (strides.empty())
992  return true;
993 
994  // Check whether strides match "flattened" dims.
995  SmallVector<int64_t> flattenedDims;
996  auto dimProduct = 1;
997  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
998  dimProduct *= dim;
999  flattenedDims.push_back(dimProduct);
1000  }
1001 
1002  strides = strides.drop_back(1);
1003  return llvm::equal(strides, llvm::reverse(flattenedDims));
1004 }
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:144
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Definition: BuiltinTypes.h:406
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:92
static FloatType getF64(MLIRContext *ctx)
Definition: BuiltinTypes.h:460
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
Definition: BuiltinTypes.h:456
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:207
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:228
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:223
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:218
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:233
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:97
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:120
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:381
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:617
bool trailingNDimsContiguous(MemRefType type, int64_t n)
Return "true" if the last N dimensions of the given type are contiguous.