MLIR 22.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
19#include "llvm/ADT/APFloat.h"
20#include "llvm/ADT/Sequence.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23using namespace mlir;
24using namespace mlir::detail;
25
26//===----------------------------------------------------------------------===//
27/// Tablegen Type Definitions
28//===----------------------------------------------------------------------===//
29
30#define GET_TYPEDEF_CLASSES
31#include "mlir/IR/BuiltinTypes.cpp.inc"
32
33namespace mlir {
34#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
35} // namespace mlir
36
37//===----------------------------------------------------------------------===//
38// BuiltinDialect
39//===----------------------------------------------------------------------===//
40
41void BuiltinDialect::registerTypes() {
42 addTypes<
43#define GET_TYPEDEF_LIST
44#include "mlir/IR/BuiltinTypes.cpp.inc"
45 >();
46}
47
48//===----------------------------------------------------------------------===//
49/// ComplexType
50//===----------------------------------------------------------------------===//
51
52/// Verify the construction of an integer type.
53LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
54 Type elementType) {
55 if (!elementType.isIntOrFloat())
56 return emitError() << "invalid element type for complex";
57 return success();
58}
59
60//===----------------------------------------------------------------------===//
61// Integer Type
62//===----------------------------------------------------------------------===//
63
64/// Verify the construction of an integer type.
65LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
66 unsigned width,
67 SignednessSemantics signedness) {
68 if (width > IntegerType::kMaxWidth) {
69 return emitError() << "integer bitwidth is limited to "
70 << IntegerType::kMaxWidth << " bits";
71 }
72 return success();
73}
74
75unsigned IntegerType::getWidth() const { return getImpl()->width; }
76
77IntegerType::SignednessSemantics IntegerType::getSignedness() const {
78 return getImpl()->signedness;
79}
80
81IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
82 if (!scale)
83 return IntegerType();
84 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
85}
86
87//===----------------------------------------------------------------------===//
88// Float Types
89//===----------------------------------------------------------------------===//
90
91// Mapping from MLIR FloatType to APFloat semantics.
92#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
93 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
94 return APFloat::SEM(); \
95 }
96FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
97FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
98FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
99FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
100FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
101FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
102FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
103FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
104FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
105FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
106FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
107FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
108FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
109FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
110FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
111FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
112FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
113FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
114#undef FLOAT_TYPE_SEMANTICS
115
116FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
117 if (scale == 2)
118 return Float32Type::get(getContext());
119 if (scale == 4)
120 return Float64Type::get(getContext());
121 return FloatType();
122}
123
124FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
125 if (scale == 2)
126 return Float32Type::get(getContext());
127 if (scale == 4)
128 return Float64Type::get(getContext());
129 return FloatType();
130}
131
132FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
133 if (scale == 2)
134 return Float64Type::get(getContext());
135 return FloatType();
136}
137
138//===----------------------------------------------------------------------===//
139// FunctionType
140//===----------------------------------------------------------------------===//
141
142unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
143
144ArrayRef<Type> FunctionType::getInputs() const {
145 return getImpl()->getInputs();
146}
147
148unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
149
150ArrayRef<Type> FunctionType::getResults() const {
151 return getImpl()->getResults();
152}
153
154FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
155 return get(getContext(), inputs, results);
156}
157
158/// Returns a new function type with the specified arguments and results
159/// inserted.
160FunctionType FunctionType::getWithArgsAndResults(
161 ArrayRef<unsigned> argIndices, TypeRange argTypes,
162 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
163 SmallVector<Type> argStorage, resultStorage;
164 TypeRange newArgTypes =
165 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
166 TypeRange newResultTypes =
167 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
168 return clone(newArgTypes, newResultTypes);
169}
170
171/// Returns a new function type without the specified arguments and results.
172FunctionType
173FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
174 const BitVector &resultIndices) {
175 SmallVector<Type> argStorage, resultStorage;
176 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
177 TypeRange newResultTypes =
178 filterTypesOut(getResults(), resultIndices, resultStorage);
179 return clone(newArgTypes, newResultTypes);
180}
181
182//===----------------------------------------------------------------------===//
183// GraphType
184//===----------------------------------------------------------------------===//
185
186unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
187
188ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
189
190unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
191
192ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
193
194GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
195 return get(getContext(), inputs, results);
196}
197
198/// Returns a new function type with the specified arguments and results
199/// inserted.
200GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
201 TypeRange argTypes,
202 ArrayRef<unsigned> resultIndices,
203 TypeRange resultTypes) {
204 SmallVector<Type> argStorage, resultStorage;
205 TypeRange newArgTypes =
206 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
207 TypeRange newResultTypes =
208 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
209 return clone(newArgTypes, newResultTypes);
210}
211
212/// Returns a new function type without the specified arguments and results.
213GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
214 const BitVector &resultIndices) {
215 SmallVector<Type> argStorage, resultStorage;
216 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
217 TypeRange newResultTypes =
218 filterTypesOut(getResults(), resultIndices, resultStorage);
219 return clone(newArgTypes, newResultTypes);
220}
221//===----------------------------------------------------------------------===//
222// OpaqueType
223//===----------------------------------------------------------------------===//
224
225/// Verify the construction of an opaque type.
226LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
227 StringAttr dialect, StringRef typeData) {
228 if (!Dialect::isValidNamespace(dialect.strref()))
229 return emitError() << "invalid dialect namespace '" << dialect << "'";
230
231 // Check that the dialect is actually registered.
232 MLIRContext *context = dialect.getContext();
233 if (!context->allowsUnregisteredDialects() &&
234 !context->getLoadedDialect(dialect.strref())) {
235 return emitError()
236 << "`!" << dialect << "<\"" << typeData << "\">"
237 << "` type created with unregistered dialect. If this is "
238 "intended, please call allowUnregisteredDialects() on the "
239 "MLIRContext, or use -allow-unregistered-dialect with "
240 "the MLIR opt tool used";
241 }
242
243 return success();
244}
245
246//===----------------------------------------------------------------------===//
247// VectorType
248//===----------------------------------------------------------------------===//
249
250bool VectorType::isValidElementType(Type t) {
252}
253
254LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
255 ArrayRef<int64_t> shape, Type elementType,
256 ArrayRef<bool> scalableDims) {
257 if (!isValidElementType(elementType))
258 return emitError()
259 << "vector elements must be int/index/float type but got "
260 << elementType;
261
262 if (any_of(shape, [](int64_t i) { return i <= 0; }))
263 return emitError()
264 << "vector types must have positive constant sizes but got "
265 << shape;
266
267 if (scalableDims.size() != shape.size())
268 return emitError() << "number of dims must match, got "
269 << scalableDims.size() << " and " << shape.size();
270
271 return success();
272}
273
274VectorType VectorType::scaleElementBitwidth(unsigned scale) {
275 if (!scale)
276 return VectorType();
277 if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
278 if (auto scaledEt = et.scaleElementBitwidth(scale))
279 return VectorType::get(getShape(), scaledEt, getScalableDims());
280 if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
281 if (auto scaledEt = et.scaleElementBitwidth(scale))
282 return VectorType::get(getShape(), scaledEt, getScalableDims());
283 return VectorType();
284}
285
286VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
287 Type elementType) const {
288 return VectorType::get(shape.value_or(getShape()), elementType,
289 getScalableDims());
290}
291
292//===----------------------------------------------------------------------===//
293// TensorType
294//===----------------------------------------------------------------------===//
295
298 .Case<RankedTensorType, UnrankedTensorType>(
299 [](auto type) { return type.getElementType(); });
300}
301
303 return !llvm::isa<UnrankedTensorType>(*this);
304}
305
307 return llvm::cast<RankedTensorType>(*this).getShape();
308}
309
311 Type elementType) const {
312 if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
313 if (shape)
314 return RankedTensorType::get(*shape, elementType);
315 return UnrankedTensorType::get(elementType);
316 }
317
318 auto rankedTy = llvm::cast<RankedTensorType>(*this);
319 if (!shape)
320 return RankedTensorType::get(rankedTy.getShape(), elementType,
321 rankedTy.getEncoding());
322 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
323 rankedTy.getEncoding());
324}
325
327 Type elementType) const {
328 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
329}
330
331RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
332 return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
333}
334
335// Check if "elementType" can be an element type of a tensor.
336static LogicalResult
338 Type elementType) {
339 if (!TensorType::isValidElementType(elementType))
340 return emitError() << "invalid tensor element type: " << elementType;
341 return success();
342}
343
344/// Return true if the specified element type is ok in a tensor.
346 // Note: Non standard/builtin types are allowed to exist within tensor
347 // types. Dialects are expected to verify that tensor types have a valid
348 // element type within that dialect.
349 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
350 IndexType>(type) ||
351 !llvm::isa<BuiltinDialect>(type.getDialect());
352}
353
354//===----------------------------------------------------------------------===//
355// RankedTensorType
356//===----------------------------------------------------------------------===//
357
358LogicalResult
359RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
360 ArrayRef<int64_t> shape, Type elementType,
361 Attribute encoding) {
362 for (int64_t s : shape)
363 if (s < 0 && ShapedType::isStatic(s))
364 return emitError() << "invalid tensor dimension size";
365 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
366 if (failed(v.verifyEncoding(shape, elementType, emitError)))
367 return failure();
368 return checkTensorElementType(emitError, elementType);
369}
370
371//===----------------------------------------------------------------------===//
372// UnrankedTensorType
373//===----------------------------------------------------------------------===//
374
375LogicalResult
376UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
377 Type elementType) {
378 return checkTensorElementType(emitError, elementType);
379}
380
381//===----------------------------------------------------------------------===//
382// BaseMemRefType
383//===----------------------------------------------------------------------===//
384
387 .Case<MemRefType, UnrankedMemRefType>(
388 [](auto type) { return type.getElementType(); });
389}
390
392 return !llvm::isa<UnrankedMemRefType>(*this);
393}
394
396 return llvm::cast<MemRefType>(*this).getShape();
397}
398
400 Type elementType) const {
401 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
402 if (!shape)
403 return UnrankedMemRefType::get(elementType, getMemorySpace());
404 MemRefType::Builder builder(*shape, elementType);
406 return builder;
407 }
408
409 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
410 if (shape)
411 builder.setShape(*shape);
412 builder.setElementType(elementType);
413 return builder;
414}
415
416FailureOr<PtrLikeTypeInterface>
418 std::optional<Type> elementType) const {
419 Type eTy = elementType ? *elementType : getElementType();
420 if (llvm::dyn_cast<UnrankedMemRefType>(*this))
421 return cast<PtrLikeTypeInterface>(
422 UnrankedMemRefType::get(eTy, memorySpace));
423
424 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
425 builder.setElementType(eTy);
426 builder.setMemorySpace(memorySpace);
427 return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
428}
429
431 Type elementType) const {
432 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
433}
434
436 return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
437}
438
440 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
441 return rankedMemRefTy.getMemorySpace();
442 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
443}
444
446 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
447 return rankedMemRefTy.getMemorySpaceAsInt();
448 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
449}
450
451//===----------------------------------------------------------------------===//
452// MemRefType
453//===----------------------------------------------------------------------===//
454
455std::optional<llvm::SmallDenseSet<unsigned>>
457 ArrayRef<int64_t> reducedShape,
458 bool matchDynamic) {
459 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
460 llvm::SmallDenseSet<unsigned> unusedDims;
461 unsigned reducedIdx = 0;
462 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
463 // Greedily insert `originalIdx` if match.
464 int64_t origSize = originalShape[originalIdx];
465 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
466 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
467 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
468 ShapedType::isDynamic(origSize))) {
469 reducedIdx++;
470 continue;
471 }
472 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
473 reducedIdx++;
474 continue;
475 }
476
477 unusedDims.insert(originalIdx);
478 // If no match on `originalIdx`, the `originalShape` at this dimension
479 // must be 1, otherwise we bail.
480 if (origSize != 1)
481 return std::nullopt;
482 }
483 // The whole reducedShape must be scanned, otherwise we bail.
484 if (reducedIdx != reducedRank)
485 return std::nullopt;
486 return unusedDims;
487}
488
490mlir::isRankReducedType(ShapedType originalType,
491 ShapedType candidateReducedType) {
492 if (originalType == candidateReducedType)
494
495 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
496 ShapedType candidateReducedShapedType =
497 llvm::cast<ShapedType>(candidateReducedType);
498
499 // Rank and size logic is valid for all ShapedTypes.
500 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
501 ArrayRef<int64_t> candidateReducedShape =
502 candidateReducedShapedType.getShape();
503 unsigned originalRank = originalShape.size(),
504 candidateReducedRank = candidateReducedShape.size();
505 if (candidateReducedRank > originalRank)
507
508 auto optionalUnusedDimsMask =
509 computeRankReductionMask(originalShape, candidateReducedShape);
510
511 // Sizes cannot be matched in case empty vector is returned.
512 if (!optionalUnusedDimsMask)
514
515 if (originalShapedType.getElementType() !=
516 candidateReducedShapedType.getElementType())
518
520}
521
523 // Empty attribute is allowed as default memory space.
524 if (!memorySpace)
525 return true;
526
527 // Supported built-in attributes.
528 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
529 return true;
530
531 // Allow custom dialect attributes.
532 if (!isa<BuiltinDialect>(memorySpace.getDialect()))
533 return true;
534
535 return false;
536}
537
539 MLIRContext *ctx) {
540 if (memorySpace == 0)
541 return nullptr;
542
543 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
544}
545
547 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
548 if (intMemorySpace && intMemorySpace.getValue() == 0)
549 return nullptr;
550
551 return memorySpace;
552}
553
555 if (!memorySpace)
556 return 0;
557
558 assert(llvm::isa<IntegerAttr>(memorySpace) &&
559 "Using `getMemorySpaceInteger` with non-Integer attribute");
560
561 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
562}
563
564unsigned MemRefType::getMemorySpaceAsInt() const {
565 return detail::getMemorySpaceAsInt(getMemorySpace());
566}
567
568MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
569 MemRefLayoutAttrInterface layout,
570 Attribute memorySpace) {
571 // Use default layout for empty attribute.
572 if (!layout)
573 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
574 shape.size(), elementType.getContext()));
575
576 // Drop default memory space value and replace it with empty attribute.
577 memorySpace = skipDefaultMemorySpace(memorySpace);
578
579 return Base::get(elementType.getContext(), shape, elementType, layout,
580 memorySpace);
581}
582
583MemRefType MemRefType::getChecked(
585 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
586
587 // Use default layout for empty attribute.
588 if (!layout)
589 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
590 shape.size(), elementType.getContext()));
591
592 // Drop default memory space value and replace it with empty attribute.
593 memorySpace = skipDefaultMemorySpace(memorySpace);
594
595 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
596 elementType, layout, memorySpace);
597}
598
599MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
600 AffineMap map, Attribute memorySpace) {
601
602 // Use default layout for empty map.
603 if (!map)
605 elementType.getContext());
606
607 // Wrap AffineMap into Attribute.
608 auto layout = AffineMapAttr::get(map);
609
610 // Drop default memory space value and replace it with empty attribute.
611 memorySpace = skipDefaultMemorySpace(memorySpace);
612
613 return Base::get(elementType.getContext(), shape, elementType, layout,
614 memorySpace);
615}
616
617MemRefType
618MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
619 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
620 Attribute memorySpace) {
621
622 // Use default layout for empty map.
623 if (!map)
625 elementType.getContext());
626
627 // Wrap AffineMap into Attribute.
628 auto layout = AffineMapAttr::get(map);
629
630 // Drop default memory space value and replace it with empty attribute.
631 memorySpace = skipDefaultMemorySpace(memorySpace);
632
633 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
634 elementType, layout, memorySpace);
635}
636
637MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
638 AffineMap map, unsigned memorySpaceInd) {
639
640 // Use default layout for empty map.
641 if (!map)
643 elementType.getContext());
644
645 // Wrap AffineMap into Attribute.
646 auto layout = AffineMapAttr::get(map);
647
648 // Convert deprecated integer-like memory space to Attribute.
649 Attribute memorySpace =
650 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
651
652 return Base::get(elementType.getContext(), shape, elementType, layout,
653 memorySpace);
654}
655
656MemRefType
657MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
658 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
659 unsigned memorySpaceInd) {
660
661 // Use default layout for empty map.
662 if (!map)
664 elementType.getContext());
665
666 // Wrap AffineMap into Attribute.
667 auto layout = AffineMapAttr::get(map);
668
669 // Convert deprecated integer-like memory space to Attribute.
670 Attribute memorySpace =
671 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
672
673 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
674 elementType, layout, memorySpace);
675}
676
677LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
678 ArrayRef<int64_t> shape, Type elementType,
679 MemRefLayoutAttrInterface layout,
680 Attribute memorySpace) {
681 if (!BaseMemRefType::isValidElementType(elementType))
682 return emitError() << "invalid memref element type";
683
684 // Negative sizes are not allowed except for `kDynamic`.
685 for (int64_t s : shape)
686 if (s < 0 && ShapedType::isStatic(s))
687 return emitError() << "invalid memref size";
688
689 assert(layout && "missing layout specification");
690 if (failed(layout.verifyLayout(shape, emitError)))
691 return failure();
692
693 if (!isSupportedMemorySpace(memorySpace))
694 return emitError() << "unsupported memory space Attribute";
695
696 return success();
697}
698
699bool MemRefType::areTrailingDimsContiguous(int64_t n) {
700 assert(n <= getRank() &&
701 "number of dimensions to check must not exceed rank");
702 return n <= getNumContiguousTrailingDims();
703}
704
705int64_t MemRefType::getNumContiguousTrailingDims() {
706 const int64_t n = getRank();
707
708 // memrefs with identity layout are entirely contiguous.
709 if (getLayout().isIdentity())
710 return n;
711
712 // Get the strides (if any). Failing to do that, conservatively assume a
713 // non-contiguous layout.
714 int64_t offset;
715 SmallVector<int64_t> strides;
716 if (!succeeded(getStridesAndOffset(strides, offset)))
717 return 0;
718
720
721 // A memref with dimensions `d0, d1, ..., dn-1` and strides
722 // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
723 // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
724 // for `i` in `[k, n-1]`.
725 // Ignore stride elements if the corresponding dimension is 1, as they are
726 // of no consequence.
727 int64_t dimProduct = 1;
728 for (int64_t i = n - 1; i >= 0; --i) {
729 if (shape[i] == 1)
730 continue;
731 if (strides[i] != dimProduct)
732 return n - i - 1;
733 if (shape[i] == ShapedType::kDynamic)
734 return n - i;
735 dimProduct *= shape[i];
736 }
737
738 return n;
739}
740
741MemRefType MemRefType::canonicalizeStridedLayout() {
742 AffineMap m = getLayout().getAffineMap();
743
744 // Already in canonical form.
745 if (m.isIdentity())
746 return *this;
747
748 // Can't reduce to canonical identity form, return in canonical form.
749 if (m.getNumResults() > 1)
750 return *this;
751
752 // Corner-case for 0-D affine maps.
753 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
754 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
755 if (cst.getValue() == 0)
756 return MemRefType::Builder(*this).setLayout({});
757 return *this;
758 }
759
760 // 0-D corner case for empty shape that still have an affine map. Example:
761 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
762 // offset needs to remain, just return t.
763 if (getShape().empty())
764 return *this;
765
766 // If the canonical strided layout for the sizes of `t` is equal to the
767 // simplified layout of `t` we can just return an empty layout. Otherwise,
768 // just simplify the existing layout.
770 auto simplifiedLayoutExpr =
772 if (expr != simplifiedLayoutExpr)
773 return MemRefType::Builder(*this).setLayout(
774 AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
775 simplifiedLayoutExpr)));
776 return MemRefType::Builder(*this).setLayout({});
777}
778
779LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
780 int64_t &offset) const {
781 return getLayout().getStridesAndOffset(getShape(), strides, offset);
782}
783
784std::pair<SmallVector<int64_t>, int64_t>
785MemRefType::getStridesAndOffset() const {
786 SmallVector<int64_t> strides;
787 int64_t offset;
788 LogicalResult status = getStridesAndOffset(strides, offset);
789 (void)status;
790 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
791 return {strides, offset};
792}
793
794bool MemRefType::isStrided() {
795 int64_t offset;
797 auto res = getStridesAndOffset(strides, offset);
798 return succeeded(res);
799}
800
801bool MemRefType::isLastDimUnitStride() {
802 int64_t offset;
803 SmallVector<int64_t> strides;
804 auto successStrides = getStridesAndOffset(strides, offset);
805 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
806}
807
808//===----------------------------------------------------------------------===//
809// UnrankedMemRefType
810//===----------------------------------------------------------------------===//
811
812unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
813 return detail::getMemorySpaceAsInt(getMemorySpace());
814}
815
816LogicalResult
817UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
818 Type elementType, Attribute memorySpace) {
819 if (!BaseMemRefType::isValidElementType(elementType))
820 return emitError() << "invalid memref element type";
821
822 if (!isSupportedMemorySpace(memorySpace))
823 return emitError() << "unsupported memory space Attribute";
824
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829/// TupleType
830//===----------------------------------------------------------------------===//
831
832/// Return the elements types for this tuple.
833ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
834
835/// Accumulate the types contained in this tuple and tuples nested within it.
836/// Note that this only flattens nested tuples, not any other container type,
837/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
838/// (i32, tensor<i32>, f32, i64)
839void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
840 for (Type type : getTypes()) {
841 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
842 nestedTuple.getFlattenedTypes(types);
843 else
844 types.push_back(type);
845 }
846}
847
848/// Return the number of element types.
849size_t TupleType::size() const { return getImpl()->size(); }
850
851//===----------------------------------------------------------------------===//
852// Type Utilities
853//===----------------------------------------------------------------------===//
854
857 MLIRContext *context) {
858 // Size 0 corner case is useful for canonicalizations.
859 if (sizes.empty())
860 return getAffineConstantExpr(0, context);
861
862 assert(!exprs.empty() && "expected exprs");
863 auto maps = AffineMap::inferFromExprList(exprs, context);
864 assert(!maps.empty() && "Expected one non-empty map");
865 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
866
867 AffineExpr expr;
868 bool dynamicPoisonBit = false;
869 int64_t runningSize = 1;
870 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
871 int64_t size = std::get<1>(en);
872 AffineExpr dimExpr = std::get<0>(en);
873 AffineExpr stride = dynamicPoisonBit
874 ? getAffineSymbolExpr(nSymbols++, context)
875 : getAffineConstantExpr(runningSize, context);
876 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
877 if (size > 0) {
878 runningSize *= size;
879 assert(runningSize > 0 && "integer overflow in size computation");
880 } else {
881 dynamicPoisonBit = true;
882 }
883 }
884 return simplifyAffineExpr(expr, numDims, nSymbols);
885}
886
888 MLIRContext *context) {
890 exprs.reserve(sizes.size());
891 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
892 exprs.push_back(getAffineDimExpr(dim, context));
893 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
894}
return success()
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
constexpr Type()=default
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setElementType(Type newElementType)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
constexpr Type()=default
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool isValidVectorTypeElementType(::mlir::Type type)
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)