MLIR 23.0.0git
BuiltinTypes.cpp
Go to the documentation of this file.
1//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/AffineMap.h"
16#include "mlir/IR/Diagnostics.h"
17#include "mlir/IR/Dialect.h"
20#include "llvm/ADT/APFloat.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/CheckedArithmetic.h"
25#include <cstring>
26
27using namespace mlir;
28using namespace mlir::detail;
29
30//===----------------------------------------------------------------------===//
31/// Tablegen Type Definitions
32//===----------------------------------------------------------------------===//
33
34#define GET_TYPEDEF_CLASSES
35#include "mlir/IR/BuiltinTypes.cpp.inc"
36
37namespace mlir {
38#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
39} // namespace mlir
40
41//===----------------------------------------------------------------------===//
42// BuiltinDialect
43//===----------------------------------------------------------------------===//
44
45void BuiltinDialect::registerTypes() {
46 addTypes<
47#define GET_TYPEDEF_LIST
48#include "mlir/IR/BuiltinTypes.cpp.inc"
49 >();
50}
51
52//===----------------------------------------------------------------------===//
53/// ComplexType
54//===----------------------------------------------------------------------===//
55
56/// Verify the construction of an integer type.
57LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
58 Type elementType) {
59 if (!elementType.isIntOrFloat())
60 return emitError() << "invalid element type for complex";
61 return success();
62}
63
64size_t ComplexType::getDenseElementBitSize() const {
65 auto elemTy = cast<DenseElementType>(getElementType());
66 return llvm::alignTo<8>(elemTy.getDenseElementBitSize()) * 2;
67}
68
69Attribute ComplexType::convertToAttribute(ArrayRef<char> rawData) const {
70 auto elemTy = cast<DenseElementType>(getElementType());
71 size_t singleElementBytes =
72 llvm::alignTo<8>(elemTy.getDenseElementBitSize()) / 8;
73 Attribute real =
74 elemTy.convertToAttribute(rawData.take_front(singleElementBytes));
75 Attribute imag =
76 elemTy.convertToAttribute(rawData.take_back(singleElementBytes));
77 return ArrayAttr::get(getContext(), {real, imag});
78}
79
80LogicalResult
81ComplexType::convertFromAttribute(Attribute attr,
83 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
84 if (!arrayAttr || arrayAttr.size() != 2)
85 return failure();
86 auto elemTy = cast<DenseElementType>(getElementType());
87 SmallVector<char> realData, imagData;
88 if (failed(elemTy.convertFromAttribute(arrayAttr[0], realData)))
89 return failure();
90 if (failed(elemTy.convertFromAttribute(arrayAttr[1], imagData)))
91 return failure();
92 result.append(realData);
93 result.append(imagData);
94 return success();
95}
96
97//===----------------------------------------------------------------------===//
98// Integer Type
99//===----------------------------------------------------------------------===//
100
101/// Verify the construction of an integer type.
102LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
103 unsigned width,
104 SignednessSemantics signedness) {
105 if (width > IntegerType::kMaxWidth) {
106 return emitError() << "integer bitwidth is limited to "
107 << IntegerType::kMaxWidth << " bits";
108 }
109 return success();
110}
111
112unsigned IntegerType::getWidth() const { return getImpl()->width; }
113
114IntegerType::SignednessSemantics IntegerType::getSignedness() const {
115 return getImpl()->signedness;
116}
117
118IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
119 if (!scale)
120 return IntegerType();
121 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
122}
123
124size_t IntegerType::getDenseElementBitSize() const {
125 // Return the actual bit width. Storage alignment is handled separately.
126 return getWidth();
127}
128
129Attribute IntegerType::convertToAttribute(ArrayRef<char> rawData) const {
130 APInt value = detail::readBits(rawData.data(), /*bitPos=*/0, getWidth());
131 return IntegerAttr::get(*this, value);
132}
133
135 size_t byteSize = llvm::divideCeil(apInt.getBitWidth(), CHAR_BIT);
136 size_t bitPos = result.size() * CHAR_BIT;
137 result.resize(result.size() + byteSize);
138 detail::writeBits(result.data(), bitPos, apInt);
139}
140
141LogicalResult
142IntegerType::convertFromAttribute(Attribute attr,
144 auto intAttr = dyn_cast<IntegerAttr>(attr);
145 if (!intAttr || intAttr.getType() != *this)
146 return failure();
147 writeAPIntToVector(intAttr.getValue(), result);
148 return success();
149}
150
151//===----------------------------------------------------------------------===//
152// Index Type
153//===----------------------------------------------------------------------===//
154
155size_t IndexType::getDenseElementBitSize() const {
156 return kInternalStorageBitWidth;
157}
158
159Attribute IndexType::convertToAttribute(ArrayRef<char> rawData) const {
160 APInt value =
161 detail::readBits(rawData.data(), /*bitPos=*/0, kInternalStorageBitWidth);
162 return IntegerAttr::get(*this, value);
163}
164
165LogicalResult
166IndexType::convertFromAttribute(Attribute attr,
168 auto intAttr = dyn_cast<IntegerAttr>(attr);
169 if (!intAttr || intAttr.getType() != *this)
170 return failure();
171 writeAPIntToVector(intAttr.getValue(), result);
172 return success();
173}
174
175//===----------------------------------------------------------------------===//
176// Float Types
177//===----------------------------------------------------------------------===//
178
179// Mapping from MLIR FloatType to APFloat semantics.
180#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
181 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
182 return APFloat::SEM(); \
183 }
184FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
185FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
186FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
187FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
188FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
189FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
190FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
191FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
192FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
193FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
194FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
195FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
196FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
197FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
198FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
199FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
200FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
201FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
202#undef FLOAT_TYPE_SEMANTICS
203
204FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
205 if (scale == 2)
206 return Float32Type::get(getContext());
207 if (scale == 4)
208 return Float64Type::get(getContext());
209 return FloatType();
210}
211
212FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
213 if (scale == 2)
214 return Float32Type::get(getContext());
215 if (scale == 4)
216 return Float64Type::get(getContext());
217 return FloatType();
218}
219
220FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
221 if (scale == 2)
222 return Float64Type::get(getContext());
223 return FloatType();
224}
225
226//===----------------------------------------------------------------------===//
227// FunctionType
228//===----------------------------------------------------------------------===//
229
230unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
231
232ArrayRef<Type> FunctionType::getInputs() const {
233 return getImpl()->getInputs();
234}
235
236unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
237
238ArrayRef<Type> FunctionType::getResults() const {
239 return getImpl()->getResults();
240}
241
242FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
243 return get(getContext(), inputs, results);
244}
245
246/// Returns a new function type with the specified arguments and results
247/// inserted.
248FunctionType FunctionType::getWithArgsAndResults(
249 ArrayRef<unsigned> argIndices, TypeRange argTypes,
250 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
251 SmallVector<Type> argStorage, resultStorage;
252 TypeRange newArgTypes =
253 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
254 TypeRange newResultTypes =
255 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
256 return clone(newArgTypes, newResultTypes);
257}
258
259/// Returns a new function type without the specified arguments and results.
260FunctionType
261FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
262 const BitVector &resultIndices) {
263 SmallVector<Type> argStorage, resultStorage;
264 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
265 TypeRange newResultTypes =
266 filterTypesOut(getResults(), resultIndices, resultStorage);
267 return clone(newArgTypes, newResultTypes);
268}
269
270//===----------------------------------------------------------------------===//
271// GraphType
272//===----------------------------------------------------------------------===//
273
274unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
275
276ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
277
278unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
279
280ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
281
282GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
283 return get(getContext(), inputs, results);
284}
285
286/// Returns a new function type with the specified arguments and results
287/// inserted.
288GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
289 TypeRange argTypes,
290 ArrayRef<unsigned> resultIndices,
291 TypeRange resultTypes) {
292 SmallVector<Type> argStorage, resultStorage;
293 TypeRange newArgTypes =
294 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
295 TypeRange newResultTypes =
296 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
297 return clone(newArgTypes, newResultTypes);
298}
299
300/// Returns a new function type without the specified arguments and results.
301GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
302 const BitVector &resultIndices) {
303 SmallVector<Type> argStorage, resultStorage;
304 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
305 TypeRange newResultTypes =
306 filterTypesOut(getResults(), resultIndices, resultStorage);
307 return clone(newArgTypes, newResultTypes);
308}
309//===----------------------------------------------------------------------===//
310// OpaqueType
311//===----------------------------------------------------------------------===//
312
313/// Verify the construction of an opaque type.
314LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
315 StringAttr dialect, StringRef typeData) {
316 if (!Dialect::isValidNamespace(dialect.strref()))
317 return emitError() << "invalid dialect namespace '" << dialect << "'";
318
319 // Check that the dialect is actually registered.
320 MLIRContext *context = dialect.getContext();
321 if (!context->allowsUnregisteredDialects() &&
322 !context->getLoadedDialect(dialect.strref())) {
323 return emitError()
324 << "`!" << dialect << "<\"" << typeData << "\">"
325 << "` type created with unregistered dialect. If this is "
326 "intended, please call allowUnregisteredDialects() on the "
327 "MLIRContext, or use -allow-unregistered-dialect with "
328 "the MLIR opt tool used";
329 }
330
331 return success();
332}
333
334//===----------------------------------------------------------------------===//
335// VectorType
336//===----------------------------------------------------------------------===//
337
338bool VectorType::isValidElementType(Type t) {
339 return isValidVectorTypeElementType(t);
340}
341
342LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
343 ArrayRef<int64_t> shape, Type elementType,
344 ArrayRef<bool> scalableDims) {
345 if (!isValidElementType(elementType))
346 return emitError()
347 << "vector elements must be int/index/float type but got "
348 << elementType;
349
350 if (any_of(shape, [](int64_t i) { return i <= 0; }))
351 return emitError()
352 << "vector types must have positive constant sizes but got "
353 << shape;
354
355 if (scalableDims.size() != shape.size())
356 return emitError() << "number of dims must match, got "
357 << scalableDims.size() << " and " << shape.size();
358
359 return success();
360}
361
362VectorType VectorType::scaleElementBitwidth(unsigned scale) {
363 if (!scale)
364 return VectorType();
365 if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
366 if (auto scaledEt = et.scaleElementBitwidth(scale))
367 return VectorType::get(getShape(), scaledEt, getScalableDims());
368 if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
369 if (auto scaledEt = et.scaleElementBitwidth(scale))
370 return VectorType::get(getShape(), scaledEt, getScalableDims());
371 return VectorType();
372}
373
374VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
375 Type elementType) const {
376 return VectorType::get(shape.value_or(getShape()), elementType,
377 getScalableDims());
378}
379
380//===----------------------------------------------------------------------===//
381// TensorType
382//===----------------------------------------------------------------------===//
383
386 .Case<RankedTensorType, UnrankedTensorType>(
387 [](auto type) { return type.getElementType(); });
388}
389
391 return !llvm::isa<UnrankedTensorType>(*this);
392}
393
395 return llvm::cast<RankedTensorType>(*this).getShape();
396}
397
399 Type elementType) const {
400 if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
401 if (shape)
402 return RankedTensorType::get(*shape, elementType);
403 return UnrankedTensorType::get(elementType);
404 }
405
406 auto rankedTy = llvm::cast<RankedTensorType>(*this);
407 if (!shape)
408 return RankedTensorType::get(rankedTy.getShape(), elementType,
409 rankedTy.getEncoding());
410 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
411 rankedTy.getEncoding());
412}
413
415 Type elementType) const {
416 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
417}
418
419RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
420 return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
421}
422
423// Check if "elementType" can be an element type of a tensor.
424static LogicalResult
426 Type elementType) {
427 if (!TensorType::isValidElementType(elementType))
428 return emitError() << "invalid tensor element type: " << elementType;
429 return success();
430}
431
432/// Return true if the specified element type is ok in a tensor.
434 // Note: Non standard/builtin types are allowed to exist within tensor
435 // types. Dialects are expected to verify that tensor types have a valid
436 // element type within that dialect.
437 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
438 IndexType>(type) ||
439 !llvm::isa<BuiltinDialect>(type.getDialect());
440}
441
442//===----------------------------------------------------------------------===//
443// RankedTensorType
444//===----------------------------------------------------------------------===//
445
446LogicalResult
447RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
448 ArrayRef<int64_t> shape, Type elementType,
449 Attribute encoding) {
450 for (int64_t s : shape)
451 if (s < 0 && ShapedType::isStatic(s))
452 return emitError() << "invalid tensor dimension size";
453 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
454 if (failed(v.verifyEncoding(shape, elementType, emitError)))
455 return failure();
456 return checkTensorElementType(emitError, elementType);
457}
458
459//===----------------------------------------------------------------------===//
460// UnrankedTensorType
461//===----------------------------------------------------------------------===//
462
463LogicalResult
464UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
465 Type elementType) {
466 return checkTensorElementType(emitError, elementType);
467}
468
469//===----------------------------------------------------------------------===//
470// BaseMemRefType
471//===----------------------------------------------------------------------===//
472
475 .Case<MemRefType, UnrankedMemRefType>(
476 [](auto type) { return type.getElementType(); });
477}
478
480 return !llvm::isa<UnrankedMemRefType>(*this);
481}
482
484 return llvm::cast<MemRefType>(*this).getShape();
485}
486
488 Type elementType) const {
489 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
490 if (!shape)
491 return UnrankedMemRefType::get(elementType, getMemorySpace());
492 MemRefType::Builder builder(*shape, elementType);
494 return builder;
495 }
496
497 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
498 if (shape)
499 builder.setShape(*shape);
500 builder.setElementType(elementType);
501 return builder;
502}
503
504FailureOr<PtrLikeTypeInterface>
506 std::optional<Type> elementType) const {
507 Type eTy = elementType ? *elementType : getElementType();
508 if (llvm::dyn_cast<UnrankedMemRefType>(*this))
509 return cast<PtrLikeTypeInterface>(
510 UnrankedMemRefType::get(eTy, memorySpace));
511
512 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
513 builder.setElementType(eTy);
514 builder.setMemorySpace(memorySpace);
515 return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
516}
517
519 Type elementType) const {
520 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
521}
522
524 return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
525}
526
528 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
529 return rankedMemRefTy.getMemorySpace();
530 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
531}
532
534 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
535 return rankedMemRefTy.getMemorySpaceAsInt();
536 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
537}
538
539//===----------------------------------------------------------------------===//
540// MemRefType
541//===----------------------------------------------------------------------===//
542
543std::optional<llvm::SmallDenseSet<unsigned>>
545 ArrayRef<int64_t> reducedShape,
546 bool matchDynamic) {
547 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
548 llvm::SmallDenseSet<unsigned> unusedDims;
549 unsigned reducedIdx = 0;
550 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
551 // Greedily insert `originalIdx` if match.
552 int64_t origSize = originalShape[originalIdx];
553 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
554 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
555 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
556 ShapedType::isDynamic(origSize))) {
557 reducedIdx++;
558 continue;
559 }
560 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
561 reducedIdx++;
562 continue;
563 }
564
565 unusedDims.insert(originalIdx);
566 // If no match on `originalIdx`, the `originalShape` at this dimension
567 // must be 1, otherwise we bail.
568 if (origSize != 1)
569 return std::nullopt;
570 }
571 // The whole reducedShape must be scanned, otherwise we bail.
572 if (reducedIdx != reducedRank)
573 return std::nullopt;
574 return unusedDims;
575}
576
578mlir::isRankReducedType(ShapedType originalType,
579 ShapedType candidateReducedType) {
580 if (originalType == candidateReducedType)
582
583 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
584 ShapedType candidateReducedShapedType =
585 llvm::cast<ShapedType>(candidateReducedType);
586
587 // Rank and size logic is valid for all ShapedTypes.
588 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
589 ArrayRef<int64_t> candidateReducedShape =
590 candidateReducedShapedType.getShape();
591 unsigned originalRank = originalShape.size(),
592 candidateReducedRank = candidateReducedShape.size();
593 if (candidateReducedRank > originalRank)
595
596 auto optionalUnusedDimsMask =
597 computeRankReductionMask(originalShape, candidateReducedShape);
598
599 // Sizes cannot be matched in case empty vector is returned.
600 if (!optionalUnusedDimsMask)
602
603 if (originalShapedType.getElementType() !=
604 candidateReducedShapedType.getElementType())
606
608}
609
611 // Empty attribute is allowed as default memory space.
612 if (!memorySpace)
613 return true;
614
615 // Supported built-in attributes.
616 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
617 return true;
618
619 // Allow custom dialect attributes.
620 if (!isa<BuiltinDialect>(memorySpace.getDialect()))
621 return true;
622
623 return false;
624}
625
627 MLIRContext *ctx) {
628 if (memorySpace == 0)
629 return nullptr;
630
631 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
632}
633
635 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
636 if (intMemorySpace && intMemorySpace.getValue() == 0)
637 return nullptr;
638
639 return memorySpace;
640}
641
643 if (!memorySpace)
644 return 0;
645
646 assert(llvm::isa<IntegerAttr>(memorySpace) &&
647 "Using `getMemorySpaceInteger` with non-Integer attribute");
648
649 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
650}
651
652unsigned MemRefType::getMemorySpaceAsInt() const {
653 return detail::getMemorySpaceAsInt(getMemorySpace());
654}
655
656MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
657 MemRefLayoutAttrInterface layout,
658 Attribute memorySpace) {
659 // Use default layout for empty attribute.
660 if (!layout)
661 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
662 shape.size(), elementType.getContext()));
663
664 // Drop default memory space value and replace it with empty attribute.
665 memorySpace = skipDefaultMemorySpace(memorySpace);
666
667 return Base::get(elementType.getContext(), shape, elementType, layout,
668 memorySpace);
669}
670
671MemRefType MemRefType::getChecked(
673 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
674
675 // Use default layout for empty attribute.
676 if (!layout)
677 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
678 shape.size(), elementType.getContext()));
679
680 // Drop default memory space value and replace it with empty attribute.
681 memorySpace = skipDefaultMemorySpace(memorySpace);
682
683 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
684 elementType, layout, memorySpace);
685}
686
687MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
688 AffineMap map, Attribute memorySpace) {
689
690 // Use default layout for empty map.
691 if (!map)
693 elementType.getContext());
694
695 // Wrap AffineMap into Attribute.
696 auto layout = AffineMapAttr::get(map);
697
698 // Drop default memory space value and replace it with empty attribute.
699 memorySpace = skipDefaultMemorySpace(memorySpace);
700
701 return Base::get(elementType.getContext(), shape, elementType, layout,
702 memorySpace);
703}
704
705MemRefType
706MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
707 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
708 Attribute memorySpace) {
709
710 // Use default layout for empty map.
711 if (!map)
713 elementType.getContext());
714
715 // Wrap AffineMap into Attribute.
716 auto layout = AffineMapAttr::get(map);
717
718 // Drop default memory space value and replace it with empty attribute.
719 memorySpace = skipDefaultMemorySpace(memorySpace);
720
721 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
722 elementType, layout, memorySpace);
723}
724
725MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
726 AffineMap map, unsigned memorySpaceInd) {
727
728 // Use default layout for empty map.
729 if (!map)
731 elementType.getContext());
732
733 // Wrap AffineMap into Attribute.
734 auto layout = AffineMapAttr::get(map);
735
736 // Convert deprecated integer-like memory space to Attribute.
737 Attribute memorySpace =
738 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
739
740 return Base::get(elementType.getContext(), shape, elementType, layout,
741 memorySpace);
742}
743
744MemRefType
745MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
746 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
747 unsigned memorySpaceInd) {
748
749 // Use default layout for empty map.
750 if (!map)
752 elementType.getContext());
753
754 // Wrap AffineMap into Attribute.
755 auto layout = AffineMapAttr::get(map);
756
757 // Convert deprecated integer-like memory space to Attribute.
758 Attribute memorySpace =
759 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
760
761 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
762 elementType, layout, memorySpace);
763}
764
765LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
766 ArrayRef<int64_t> shape, Type elementType,
767 MemRefLayoutAttrInterface layout,
768 Attribute memorySpace) {
769 if (!BaseMemRefType::isValidElementType(elementType))
770 return emitError() << "invalid memref element type";
771
772 // Negative sizes are not allowed except for `kDynamic`.
773 for (int64_t s : shape)
774 if (s < 0 && ShapedType::isStatic(s))
775 return emitError() << "invalid memref size";
776
777 assert(layout && "missing layout specification");
778 if (failed(layout.verifyLayout(shape, emitError)))
779 return failure();
780
781 if (!isSupportedMemorySpace(memorySpace))
782 return emitError() << "unsupported memory space Attribute";
783
784 return success();
785}
786
787bool MemRefType::areTrailingDimsContiguous(int64_t n) {
788 assert(n <= getRank() &&
789 "number of dimensions to check must not exceed rank");
790 return n <= getNumContiguousTrailingDims();
791}
792
793int64_t MemRefType::getNumContiguousTrailingDims() {
794 const int64_t n = getRank();
795
796 // memrefs with identity layout are entirely contiguous.
797 if (getLayout().isIdentity())
798 return n;
799
800 // Get the strides (if any). Failing to do that, conservatively assume a
801 // non-contiguous layout.
802 int64_t offset;
803 SmallVector<int64_t> strides;
804 if (!succeeded(getStridesAndOffset(strides, offset)))
805 return 0;
806
808
809 // A memref with dimensions `d0, d1, ..., dn-1` and strides
810 // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
811 // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
812 // for `i` in `[k, n-1]`.
813 // Ignore stride elements if the corresponding dimension is 1, as they are
814 // of no consequence.
815 int64_t dimProduct = 1;
816 for (int64_t i = n - 1; i >= 0; --i) {
817 if (shape[i] == 1)
818 continue;
819 if (strides[i] != dimProduct)
820 return n - i - 1;
821 if (shape[i] == ShapedType::kDynamic)
822 return n - i;
823 dimProduct *= shape[i];
824 }
825
826 return n;
827}
828
829MemRefType MemRefType::canonicalizeStridedLayout() {
830 AffineMap m = getLayout().getAffineMap();
831
832 // Already in canonical form.
833 if (m.isIdentity())
834 return *this;
835
836 // Can't reduce to canonical identity form, return in canonical form.
837 if (m.getNumResults() > 1)
838 return *this;
839
840 // Corner-case for 0-D affine maps.
841 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
842 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
843 if (cst.getValue() == 0)
844 return MemRefType::Builder(*this).setLayout({});
845 return *this;
846 }
847
848 // 0-D corner case for empty shape that still have an affine map. Example:
849 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
850 // offset needs to remain, just return t.
851 if (getShape().empty())
852 return *this;
853
854 // If the canonical strided layout for the sizes of `t` is equal to the
855 // simplified layout of `t` we can just return an empty layout. Otherwise,
856 // just simplify the existing layout.
858 auto simplifiedLayoutExpr =
860 if (expr != simplifiedLayoutExpr)
861 return MemRefType::Builder(*this).setLayout(
862 AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
863 simplifiedLayoutExpr)));
864 return MemRefType::Builder(*this).setLayout({});
865}
866
867LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
868 int64_t &offset) const {
869 return getLayout().getStridesAndOffset(getShape(), strides, offset);
870}
871
872std::pair<SmallVector<int64_t>, int64_t>
873MemRefType::getStridesAndOffset() const {
874 SmallVector<int64_t> strides;
875 int64_t offset;
876 LogicalResult status = getStridesAndOffset(strides, offset);
877 (void)status;
878 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
879 return {strides, offset};
880}
881
882bool MemRefType::isStrided() {
883 int64_t offset;
885 auto res = getStridesAndOffset(strides, offset);
886 return succeeded(res);
887}
888
889bool MemRefType::isLastDimUnitStride() {
890 int64_t offset;
891 SmallVector<int64_t> strides;
892 auto successStrides = getStridesAndOffset(strides, offset);
893 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
894}
895
896//===----------------------------------------------------------------------===//
897// UnrankedMemRefType
898//===----------------------------------------------------------------------===//
899
900unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
901 return detail::getMemorySpaceAsInt(getMemorySpace());
902}
903
904LogicalResult
905UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
906 Type elementType, Attribute memorySpace) {
907 if (!BaseMemRefType::isValidElementType(elementType))
908 return emitError() << "invalid memref element type";
909
910 if (!isSupportedMemorySpace(memorySpace))
911 return emitError() << "unsupported memory space Attribute";
912
913 return success();
914}
915
916//===----------------------------------------------------------------------===//
917/// TupleType
918//===----------------------------------------------------------------------===//
919
920/// Return the elements types for this tuple.
921ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
922
923/// Accumulate the types contained in this tuple and tuples nested within it.
924/// Note that this only flattens nested tuples, not any other container type,
925/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
926/// (i32, tensor<i32>, f32, i64)
927void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
928 for (Type type : getTypes()) {
929 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
930 nestedTuple.getFlattenedTypes(types);
931 else
932 types.push_back(type);
933 }
934}
935
936/// Return the number of element types.
937size_t TupleType::size() const { return getImpl()->size(); }
938
939//===----------------------------------------------------------------------===//
940// Type Utilities
941//===----------------------------------------------------------------------===//
942
945 MLIRContext *context) {
946 // Size 0 corner case is useful for canonicalizations.
947 if (sizes.empty())
948 return getAffineConstantExpr(0, context);
949
950 assert(!exprs.empty() && "expected exprs");
951 auto maps = AffineMap::inferFromExprList(exprs, context);
952 assert(!maps.empty() && "Expected one non-empty map");
953 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
954
955 AffineExpr expr;
956 bool dynamicPoisonBit = false;
957 int64_t runningSize = 1;
958 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
959 int64_t size = std::get<1>(en);
960 AffineExpr dimExpr = std::get<0>(en);
961 AffineExpr stride = dynamicPoisonBit
962 ? getAffineSymbolExpr(nSymbols++, context)
963 : getAffineConstantExpr(runningSize, context);
964 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
965 if (size > 0) {
966 auto result = llvm::checkedMul(runningSize, size);
967 if (!result) {
968 // Overflow occurred, treat as dynamic
969 dynamicPoisonBit = true;
970 } else {
971 runningSize = *result;
972 }
973 } else {
974 dynamicPoisonBit = true;
975 }
976 }
977 return simplifyAffineExpr(expr, numDims, nSymbols);
978}
979
981 MLIRContext *context) {
983 exprs.reserve(sizes.size());
984 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
985 exprs.push_back(getAffineDimExpr(dim, context));
986 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
987}
return success()
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static void writeAPIntToVector(APInt apInt, SmallVectorImpl< char > &result)
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
constexpr Type()=default
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setElementType(Type newElementType)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
constexpr Type()=default
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
AttrTypeReplacer.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
void writeBits(char *rawData, size_t bitPos, llvm::APInt value)
Write value to byte-aligned position bitPos in rawData.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)