MLIR 23.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
19#include "llvm/ADT/APFloat.h"
20#include "llvm/ADT/Sequence.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/CheckedArithmetic.h"
23
24using namespace mlir;
25using namespace mlir::detail;
26
27//===----------------------------------------------------------------------===//
28/// Tablegen Type Definitions
29//===----------------------------------------------------------------------===//
30
31#define GET_TYPEDEF_CLASSES
32#include "mlir/IR/BuiltinTypes.cpp.inc"
33
34namespace mlir {
35#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
36} // namespace mlir
37
38//===----------------------------------------------------------------------===//
39// BuiltinDialect
40//===----------------------------------------------------------------------===//
41
42void BuiltinDialect::registerTypes() {
43 addTypes<
44#define GET_TYPEDEF_LIST
45#include "mlir/IR/BuiltinTypes.cpp.inc"
46 >();
47}
48
49//===----------------------------------------------------------------------===//
50/// ComplexType
51//===----------------------------------------------------------------------===//
52
53/// Verify the construction of an integer type.
54LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
55 Type elementType) {
56 if (!elementType.isIntOrFloat())
57 return emitError() << "invalid element type for complex";
58 return success();
59}
60
61//===----------------------------------------------------------------------===//
62// Integer Type
63//===----------------------------------------------------------------------===//
64
65/// Verify the construction of an integer type.
66LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
67 unsigned width,
68 SignednessSemantics signedness) {
69 if (width > IntegerType::kMaxWidth) {
70 return emitError() << "integer bitwidth is limited to "
71 << IntegerType::kMaxWidth << " bits";
72 }
73 return success();
74}
75
76unsigned IntegerType::getWidth() const { return getImpl()->width; }
77
78IntegerType::SignednessSemantics IntegerType::getSignedness() const {
79 return getImpl()->signedness;
80}
81
82IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
83 if (!scale)
84 return IntegerType();
85 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
86}
87
88//===----------------------------------------------------------------------===//
89// Float Types
90//===----------------------------------------------------------------------===//
91
92// Mapping from MLIR FloatType to APFloat semantics.
93#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
94 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
95 return APFloat::SEM(); \
96 }
97FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
98FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
99FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
100FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
101FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
102FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
103FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
104FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
105FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
106FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
107FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
108FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
109FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
110FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
111FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
112FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
113FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
114FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
115#undef FLOAT_TYPE_SEMANTICS
116
117FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
118 if (scale == 2)
119 return Float32Type::get(getContext());
120 if (scale == 4)
121 return Float64Type::get(getContext());
122 return FloatType();
123}
124
125FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
126 if (scale == 2)
127 return Float32Type::get(getContext());
128 if (scale == 4)
129 return Float64Type::get(getContext());
130 return FloatType();
131}
132
133FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
134 if (scale == 2)
135 return Float64Type::get(getContext());
136 return FloatType();
137}
138
139//===----------------------------------------------------------------------===//
140// FunctionType
141//===----------------------------------------------------------------------===//
142
143unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
144
145ArrayRef<Type> FunctionType::getInputs() const {
146 return getImpl()->getInputs();
147}
148
149unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
150
151ArrayRef<Type> FunctionType::getResults() const {
152 return getImpl()->getResults();
153}
154
155FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
156 return get(getContext(), inputs, results);
157}
158
159/// Returns a new function type with the specified arguments and results
160/// inserted.
161FunctionType FunctionType::getWithArgsAndResults(
162 ArrayRef<unsigned> argIndices, TypeRange argTypes,
163 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
164 SmallVector<Type> argStorage, resultStorage;
165 TypeRange newArgTypes =
166 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
167 TypeRange newResultTypes =
168 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
169 return clone(newArgTypes, newResultTypes);
170}
171
172/// Returns a new function type without the specified arguments and results.
173FunctionType
174FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
175 const BitVector &resultIndices) {
176 SmallVector<Type> argStorage, resultStorage;
177 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
178 TypeRange newResultTypes =
179 filterTypesOut(getResults(), resultIndices, resultStorage);
180 return clone(newArgTypes, newResultTypes);
181}
182
183//===----------------------------------------------------------------------===//
184// GraphType
185//===----------------------------------------------------------------------===//
186
187unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
188
189ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
190
191unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
192
193ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
194
195GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
196 return get(getContext(), inputs, results);
197}
198
199/// Returns a new function type with the specified arguments and results
200/// inserted.
201GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
202 TypeRange argTypes,
203 ArrayRef<unsigned> resultIndices,
204 TypeRange resultTypes) {
205 SmallVector<Type> argStorage, resultStorage;
206 TypeRange newArgTypes =
207 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
208 TypeRange newResultTypes =
209 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
210 return clone(newArgTypes, newResultTypes);
211}
212
213/// Returns a new function type without the specified arguments and results.
214GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
215 const BitVector &resultIndices) {
216 SmallVector<Type> argStorage, resultStorage;
217 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
218 TypeRange newResultTypes =
219 filterTypesOut(getResults(), resultIndices, resultStorage);
220 return clone(newArgTypes, newResultTypes);
221}
222//===----------------------------------------------------------------------===//
223// OpaqueType
224//===----------------------------------------------------------------------===//
225
226/// Verify the construction of an opaque type.
227LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
228 StringAttr dialect, StringRef typeData) {
229 if (!Dialect::isValidNamespace(dialect.strref()))
230 return emitError() << "invalid dialect namespace '" << dialect << "'";
231
232 // Check that the dialect is actually registered.
233 MLIRContext *context = dialect.getContext();
234 if (!context->allowsUnregisteredDialects() &&
235 !context->getLoadedDialect(dialect.strref())) {
236 return emitError()
237 << "`!" << dialect << "<\"" << typeData << "\">"
238 << "` type created with unregistered dialect. If this is "
239 "intended, please call allowUnregisteredDialects() on the "
240 "MLIRContext, or use -allow-unregistered-dialect with "
241 "the MLIR opt tool used";
242 }
243
244 return success();
245}
246
247//===----------------------------------------------------------------------===//
248// VectorType
249//===----------------------------------------------------------------------===//
250
251bool VectorType::isValidElementType(Type t) {
253}
254
255LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
256 ArrayRef<int64_t> shape, Type elementType,
257 ArrayRef<bool> scalableDims) {
258 if (!isValidElementType(elementType))
259 return emitError()
260 << "vector elements must be int/index/float type but got "
261 << elementType;
262
263 if (any_of(shape, [](int64_t i) { return i <= 0; }))
264 return emitError()
265 << "vector types must have positive constant sizes but got "
266 << shape;
267
268 if (scalableDims.size() != shape.size())
269 return emitError() << "number of dims must match, got "
270 << scalableDims.size() << " and " << shape.size();
271
272 return success();
273}
274
275VectorType VectorType::scaleElementBitwidth(unsigned scale) {
276 if (!scale)
277 return VectorType();
278 if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
279 if (auto scaledEt = et.scaleElementBitwidth(scale))
280 return VectorType::get(getShape(), scaledEt, getScalableDims());
281 if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
282 if (auto scaledEt = et.scaleElementBitwidth(scale))
283 return VectorType::get(getShape(), scaledEt, getScalableDims());
284 return VectorType();
285}
286
287VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
288 Type elementType) const {
289 return VectorType::get(shape.value_or(getShape()), elementType,
290 getScalableDims());
291}
292
293//===----------------------------------------------------------------------===//
294// TensorType
295//===----------------------------------------------------------------------===//
296
299 .Case<RankedTensorType, UnrankedTensorType>(
300 [](auto type) { return type.getElementType(); });
301}
302
304 return !llvm::isa<UnrankedTensorType>(*this);
305}
306
308 return llvm::cast<RankedTensorType>(*this).getShape();
309}
310
312 Type elementType) const {
313 if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
314 if (shape)
315 return RankedTensorType::get(*shape, elementType);
316 return UnrankedTensorType::get(elementType);
317 }
318
319 auto rankedTy = llvm::cast<RankedTensorType>(*this);
320 if (!shape)
321 return RankedTensorType::get(rankedTy.getShape(), elementType,
322 rankedTy.getEncoding());
323 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
324 rankedTy.getEncoding());
325}
326
328 Type elementType) const {
329 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
330}
331
332RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
333 return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
334}
335
336// Check if "elementType" can be an element type of a tensor.
337static LogicalResult
339 Type elementType) {
340 if (!TensorType::isValidElementType(elementType))
341 return emitError() << "invalid tensor element type: " << elementType;
342 return success();
343}
344
345/// Return true if the specified element type is ok in a tensor.
347 // Note: Non standard/builtin types are allowed to exist within tensor
348 // types. Dialects are expected to verify that tensor types have a valid
349 // element type within that dialect.
350 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
351 IndexType>(type) ||
352 !llvm::isa<BuiltinDialect>(type.getDialect());
353}
354
355//===----------------------------------------------------------------------===//
356// RankedTensorType
357//===----------------------------------------------------------------------===//
358
359LogicalResult
360RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
361 ArrayRef<int64_t> shape, Type elementType,
362 Attribute encoding) {
363 for (int64_t s : shape)
364 if (s < 0 && ShapedType::isStatic(s))
365 return emitError() << "invalid tensor dimension size";
366 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
367 if (failed(v.verifyEncoding(shape, elementType, emitError)))
368 return failure();
369 return checkTensorElementType(emitError, elementType);
370}
371
372//===----------------------------------------------------------------------===//
373// UnrankedTensorType
374//===----------------------------------------------------------------------===//
375
376LogicalResult
377UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
378 Type elementType) {
379 return checkTensorElementType(emitError, elementType);
380}
381
382//===----------------------------------------------------------------------===//
383// BaseMemRefType
384//===----------------------------------------------------------------------===//
385
388 .Case<MemRefType, UnrankedMemRefType>(
389 [](auto type) { return type.getElementType(); });
390}
391
393 return !llvm::isa<UnrankedMemRefType>(*this);
394}
395
397 return llvm::cast<MemRefType>(*this).getShape();
398}
399
401 Type elementType) const {
402 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
403 if (!shape)
404 return UnrankedMemRefType::get(elementType, getMemorySpace());
405 MemRefType::Builder builder(*shape, elementType);
407 return builder;
408 }
409
410 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
411 if (shape)
412 builder.setShape(*shape);
413 builder.setElementType(elementType);
414 return builder;
415}
416
417FailureOr<PtrLikeTypeInterface>
419 std::optional<Type> elementType) const {
420 Type eTy = elementType ? *elementType : getElementType();
421 if (llvm::dyn_cast<UnrankedMemRefType>(*this))
422 return cast<PtrLikeTypeInterface>(
423 UnrankedMemRefType::get(eTy, memorySpace));
424
425 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
426 builder.setElementType(eTy);
427 builder.setMemorySpace(memorySpace);
428 return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
429}
430
432 Type elementType) const {
433 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
434}
435
437 return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
438}
439
441 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
442 return rankedMemRefTy.getMemorySpace();
443 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
444}
445
447 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
448 return rankedMemRefTy.getMemorySpaceAsInt();
449 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
450}
451
452//===----------------------------------------------------------------------===//
453// MemRefType
454//===----------------------------------------------------------------------===//
455
456std::optional<llvm::SmallDenseSet<unsigned>>
458 ArrayRef<int64_t> reducedShape,
459 bool matchDynamic) {
460 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
461 llvm::SmallDenseSet<unsigned> unusedDims;
462 unsigned reducedIdx = 0;
463 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
464 // Greedily insert `originalIdx` if match.
465 int64_t origSize = originalShape[originalIdx];
466 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
467 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
468 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
469 ShapedType::isDynamic(origSize))) {
470 reducedIdx++;
471 continue;
472 }
473 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
474 reducedIdx++;
475 continue;
476 }
477
478 unusedDims.insert(originalIdx);
479 // If no match on `originalIdx`, the `originalShape` at this dimension
480 // must be 1, otherwise we bail.
481 if (origSize != 1)
482 return std::nullopt;
483 }
484 // The whole reducedShape must be scanned, otherwise we bail.
485 if (reducedIdx != reducedRank)
486 return std::nullopt;
487 return unusedDims;
488}
489
491mlir::isRankReducedType(ShapedType originalType,
492 ShapedType candidateReducedType) {
493 if (originalType == candidateReducedType)
495
496 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
497 ShapedType candidateReducedShapedType =
498 llvm::cast<ShapedType>(candidateReducedType);
499
500 // Rank and size logic is valid for all ShapedTypes.
501 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
502 ArrayRef<int64_t> candidateReducedShape =
503 candidateReducedShapedType.getShape();
504 unsigned originalRank = originalShape.size(),
505 candidateReducedRank = candidateReducedShape.size();
506 if (candidateReducedRank > originalRank)
508
509 auto optionalUnusedDimsMask =
510 computeRankReductionMask(originalShape, candidateReducedShape);
511
512 // Sizes cannot be matched in case empty vector is returned.
513 if (!optionalUnusedDimsMask)
515
516 if (originalShapedType.getElementType() !=
517 candidateReducedShapedType.getElementType())
519
521}
522
524 // Empty attribute is allowed as default memory space.
525 if (!memorySpace)
526 return true;
527
528 // Supported built-in attributes.
529 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
530 return true;
531
532 // Allow custom dialect attributes.
533 if (!isa<BuiltinDialect>(memorySpace.getDialect()))
534 return true;
535
536 return false;
537}
538
540 MLIRContext *ctx) {
541 if (memorySpace == 0)
542 return nullptr;
543
544 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
545}
546
548 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
549 if (intMemorySpace && intMemorySpace.getValue() == 0)
550 return nullptr;
551
552 return memorySpace;
553}
554
556 if (!memorySpace)
557 return 0;
558
559 assert(llvm::isa<IntegerAttr>(memorySpace) &&
560 "Using `getMemorySpaceInteger` with non-Integer attribute");
561
562 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
563}
564
565unsigned MemRefType::getMemorySpaceAsInt() const {
566 return detail::getMemorySpaceAsInt(getMemorySpace());
567}
568
569MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
570 MemRefLayoutAttrInterface layout,
571 Attribute memorySpace) {
572 // Use default layout for empty attribute.
573 if (!layout)
574 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
575 shape.size(), elementType.getContext()));
576
577 // Drop default memory space value and replace it with empty attribute.
578 memorySpace = skipDefaultMemorySpace(memorySpace);
579
580 return Base::get(elementType.getContext(), shape, elementType, layout,
581 memorySpace);
582}
583
584MemRefType MemRefType::getChecked(
586 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
587
588 // Use default layout for empty attribute.
589 if (!layout)
590 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
591 shape.size(), elementType.getContext()));
592
593 // Drop default memory space value and replace it with empty attribute.
594 memorySpace = skipDefaultMemorySpace(memorySpace);
595
596 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
597 elementType, layout, memorySpace);
598}
599
600MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
601 AffineMap map, Attribute memorySpace) {
602
603 // Use default layout for empty map.
604 if (!map)
606 elementType.getContext());
607
608 // Wrap AffineMap into Attribute.
609 auto layout = AffineMapAttr::get(map);
610
611 // Drop default memory space value and replace it with empty attribute.
612 memorySpace = skipDefaultMemorySpace(memorySpace);
613
614 return Base::get(elementType.getContext(), shape, elementType, layout,
615 memorySpace);
616}
617
618MemRefType
619MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
620 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
621 Attribute memorySpace) {
622
623 // Use default layout for empty map.
624 if (!map)
626 elementType.getContext());
627
628 // Wrap AffineMap into Attribute.
629 auto layout = AffineMapAttr::get(map);
630
631 // Drop default memory space value and replace it with empty attribute.
632 memorySpace = skipDefaultMemorySpace(memorySpace);
633
634 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
635 elementType, layout, memorySpace);
636}
637
638MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
639 AffineMap map, unsigned memorySpaceInd) {
640
641 // Use default layout for empty map.
642 if (!map)
644 elementType.getContext());
645
646 // Wrap AffineMap into Attribute.
647 auto layout = AffineMapAttr::get(map);
648
649 // Convert deprecated integer-like memory space to Attribute.
650 Attribute memorySpace =
651 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
652
653 return Base::get(elementType.getContext(), shape, elementType, layout,
654 memorySpace);
655}
656
657MemRefType
658MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
659 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
660 unsigned memorySpaceInd) {
661
662 // Use default layout for empty map.
663 if (!map)
665 elementType.getContext());
666
667 // Wrap AffineMap into Attribute.
668 auto layout = AffineMapAttr::get(map);
669
670 // Convert deprecated integer-like memory space to Attribute.
671 Attribute memorySpace =
672 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
673
674 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
675 elementType, layout, memorySpace);
676}
677
678LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
679 ArrayRef<int64_t> shape, Type elementType,
680 MemRefLayoutAttrInterface layout,
681 Attribute memorySpace) {
682 if (!BaseMemRefType::isValidElementType(elementType))
683 return emitError() << "invalid memref element type";
684
685 // Negative sizes are not allowed except for `kDynamic`.
686 for (int64_t s : shape)
687 if (s < 0 && ShapedType::isStatic(s))
688 return emitError() << "invalid memref size";
689
690 assert(layout && "missing layout specification");
691 if (failed(layout.verifyLayout(shape, emitError)))
692 return failure();
693
694 if (!isSupportedMemorySpace(memorySpace))
695 return emitError() << "unsupported memory space Attribute";
696
697 return success();
698}
699
700bool MemRefType::areTrailingDimsContiguous(int64_t n) {
701 assert(n <= getRank() &&
702 "number of dimensions to check must not exceed rank");
703 return n <= getNumContiguousTrailingDims();
704}
705
706int64_t MemRefType::getNumContiguousTrailingDims() {
707 const int64_t n = getRank();
708
709 // memrefs with identity layout are entirely contiguous.
710 if (getLayout().isIdentity())
711 return n;
712
713 // Get the strides (if any). Failing to do that, conservatively assume a
714 // non-contiguous layout.
715 int64_t offset;
716 SmallVector<int64_t> strides;
717 if (!succeeded(getStridesAndOffset(strides, offset)))
718 return 0;
719
721
722 // A memref with dimensions `d0, d1, ..., dn-1` and strides
723 // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
724 // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
725 // for `i` in `[k, n-1]`.
726 // Ignore stride elements if the corresponding dimension is 1, as they are
727 // of no consequence.
728 int64_t dimProduct = 1;
729 for (int64_t i = n - 1; i >= 0; --i) {
730 if (shape[i] == 1)
731 continue;
732 if (strides[i] != dimProduct)
733 return n - i - 1;
734 if (shape[i] == ShapedType::kDynamic)
735 return n - i;
736 dimProduct *= shape[i];
737 }
738
739 return n;
740}
741
742MemRefType MemRefType::canonicalizeStridedLayout() {
743 AffineMap m = getLayout().getAffineMap();
744
745 // Already in canonical form.
746 if (m.isIdentity())
747 return *this;
748
749 // Can't reduce to canonical identity form, return in canonical form.
750 if (m.getNumResults() > 1)
751 return *this;
752
753 // Corner-case for 0-D affine maps.
754 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
755 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
756 if (cst.getValue() == 0)
757 return MemRefType::Builder(*this).setLayout({});
758 return *this;
759 }
760
761 // 0-D corner case for empty shape that still have an affine map. Example:
762 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
763 // offset needs to remain, just return t.
764 if (getShape().empty())
765 return *this;
766
767 // If the canonical strided layout for the sizes of `t` is equal to the
768 // simplified layout of `t` we can just return an empty layout. Otherwise,
769 // just simplify the existing layout.
771 auto simplifiedLayoutExpr =
773 if (expr != simplifiedLayoutExpr)
774 return MemRefType::Builder(*this).setLayout(
775 AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
776 simplifiedLayoutExpr)));
777 return MemRefType::Builder(*this).setLayout({});
778}
779
780LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
781 int64_t &offset) const {
782 return getLayout().getStridesAndOffset(getShape(), strides, offset);
783}
784
785std::pair<SmallVector<int64_t>, int64_t>
786MemRefType::getStridesAndOffset() const {
787 SmallVector<int64_t> strides;
788 int64_t offset;
789 LogicalResult status = getStridesAndOffset(strides, offset);
790 (void)status;
791 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
792 return {strides, offset};
793}
794
795bool MemRefType::isStrided() {
796 int64_t offset;
798 auto res = getStridesAndOffset(strides, offset);
799 return succeeded(res);
800}
801
802bool MemRefType::isLastDimUnitStride() {
803 int64_t offset;
804 SmallVector<int64_t> strides;
805 auto successStrides = getStridesAndOffset(strides, offset);
806 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
807}
808
809//===----------------------------------------------------------------------===//
810// UnrankedMemRefType
811//===----------------------------------------------------------------------===//
812
813unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
814 return detail::getMemorySpaceAsInt(getMemorySpace());
815}
816
817LogicalResult
818UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
819 Type elementType, Attribute memorySpace) {
820 if (!BaseMemRefType::isValidElementType(elementType))
821 return emitError() << "invalid memref element type";
822
823 if (!isSupportedMemorySpace(memorySpace))
824 return emitError() << "unsupported memory space Attribute";
825
826 return success();
827}
828
829//===----------------------------------------------------------------------===//
830/// TupleType
831//===----------------------------------------------------------------------===//
832
833/// Return the elements types for this tuple.
834ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
835
836/// Accumulate the types contained in this tuple and tuples nested within it.
837/// Note that this only flattens nested tuples, not any other container type,
838/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
839/// (i32, tensor<i32>, f32, i64)
840void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
841 for (Type type : getTypes()) {
842 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
843 nestedTuple.getFlattenedTypes(types);
844 else
845 types.push_back(type);
846 }
847}
848
849/// Return the number of element types.
850size_t TupleType::size() const { return getImpl()->size(); }
851
852//===----------------------------------------------------------------------===//
853// Type Utilities
854//===----------------------------------------------------------------------===//
855
858 MLIRContext *context) {
859 // Size 0 corner case is useful for canonicalizations.
860 if (sizes.empty())
861 return getAffineConstantExpr(0, context);
862
863 assert(!exprs.empty() && "expected exprs");
864 auto maps = AffineMap::inferFromExprList(exprs, context);
865 assert(!maps.empty() && "Expected one non-empty map");
866 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
867
868 AffineExpr expr;
869 bool dynamicPoisonBit = false;
870 int64_t runningSize = 1;
871 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
872 int64_t size = std::get<1>(en);
873 AffineExpr dimExpr = std::get<0>(en);
874 AffineExpr stride = dynamicPoisonBit
875 ? getAffineSymbolExpr(nSymbols++, context)
876 : getAffineConstantExpr(runningSize, context);
877 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
878 if (size > 0) {
879 auto result = llvm::checkedMul(runningSize, size);
880 if (!result) {
881 // Overflow occurred, treat as dynamic
882 dynamicPoisonBit = true;
883 } else {
884 runningSize = *result;
885 }
886 } else {
887 dynamicPoisonBit = true;
888 }
889 }
890 return simplifyAffineExpr(expr, numDims, nSymbols);
891}
892
894 MLIRContext *context) {
896 exprs.reserve(sizes.size());
897 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
898 exprs.push_back(getAffineDimExpr(dim, context));
899 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
900}
return success()
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
constexpr Type()=default
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setElementType(Type newElementType)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
constexpr Type()=default
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool isValidVectorTypeElementType(::mlir::Type type)
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)