10#include "TypeDetail.h"
20#include "llvm/ADT/APFloat.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/CheckedArithmetic.h"
34#define GET_TYPEDEF_CLASSES
35#include "mlir/IR/BuiltinTypes.cpp.inc"
38#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
45void BuiltinDialect::registerTypes() {
47#define GET_TYPEDEF_LIST
48#include "mlir/IR/BuiltinTypes.cpp.inc"
60 return emitError() <<
"invalid element type for complex";
64size_t ComplexType::getDenseElementBitSize()
const {
66 return llvm::alignTo<8>(elemTy.getDenseElementBitSize()) * 2;
71 size_t singleElementBytes =
72 llvm::alignTo<8>(elemTy.getDenseElementBitSize()) / 8;
74 elemTy.convertToAttribute(rawData.take_front(singleElementBytes));
76 elemTy.convertToAttribute(rawData.take_back(singleElementBytes));
77 return ArrayAttr::get(
getContext(), {real, imag});
81ComplexType::convertFromAttribute(
Attribute attr,
83 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
84 if (!arrayAttr || arrayAttr.size() != 2)
88 if (
failed(elemTy.convertFromAttribute(arrayAttr[0], realData)))
90 if (
failed(elemTy.convertFromAttribute(arrayAttr[1], imagData)))
104 SignednessSemantics signedness) {
105 if (width > IntegerType::kMaxWidth) {
106 return emitError() <<
"integer bitwidth is limited to "
107 << IntegerType::kMaxWidth <<
" bits";
112unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
114IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
115 return getImpl()->signedness;
118IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
120 return IntegerType();
121 return IntegerType::get(
getContext(), scale * getWidth(), getSignedness());
124size_t IntegerType::getDenseElementBitSize()
const {
130 APInt value = detail::readBits(rawData.data(), 0, getWidth());
131 return IntegerAttr::get(*
this, value);
135 size_t byteSize = llvm::divideCeil(apInt.getBitWidth(), CHAR_BIT);
136 size_t bitPos =
result.size() * CHAR_BIT;
142IntegerType::convertFromAttribute(
Attribute attr,
144 auto intAttr = dyn_cast<IntegerAttr>(attr);
145 if (!intAttr || intAttr.getType() != *
this)
155size_t IndexType::getDenseElementBitSize()
const {
156 return kInternalStorageBitWidth;
161 detail::readBits(rawData.data(), 0, kInternalStorageBitWidth);
162 return IntegerAttr::get(*
this, value);
166IndexType::convertFromAttribute(
Attribute attr,
168 auto intAttr = dyn_cast<IntegerAttr>(attr);
169 if (!intAttr || intAttr.getType() != *
this)
180#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
181 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
182 return APFloat::SEM(); \
202#undef FLOAT_TYPE_SEMANTICS
204FloatType Float16Type::scaleElementBitwidth(
unsigned scale)
const {
212FloatType BFloat16Type::scaleElementBitwidth(
unsigned scale)
const {
220FloatType Float32Type::scaleElementBitwidth(
unsigned scale)
const {
230unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
233 return getImpl()->getInputs();
236unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
239 return getImpl()->getResults();
248FunctionType FunctionType::getWithArgsAndResults(
255 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
256 return clone(newArgTypes, newResultTypes);
261FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
262 const BitVector &resultIndices) {
267 return clone(newArgTypes, newResultTypes);
274unsigned GraphType::getNumInputs()
const {
return getImpl()->numInputs; }
276ArrayRef<Type> GraphType::getInputs()
const {
return getImpl()->getInputs(); }
278unsigned GraphType::getNumResults()
const {
return getImpl()->numResults; }
280ArrayRef<Type> GraphType::getResults()
const {
return getImpl()->getResults(); }
296 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
297 return clone(newArgTypes, newResultTypes);
301GraphType GraphType::getWithoutArgsAndResults(
const BitVector &argIndices,
302 const BitVector &resultIndices) {
307 return clone(newArgTypes, newResultTypes);
315 StringAttr dialect, StringRef typeData) {
317 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
324 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
325 <<
"` type created with unregistered dialect. If this is "
326 "intended, please call allowUnregisteredDialects() on the "
327 "MLIRContext, or use -allow-unregistered-dialect with "
328 "the MLIR opt tool used";
338bool VectorType::isValidElementType(
Type t) {
339 return isValidVectorTypeElementType(t);
345 if (!isValidElementType(elementType))
347 <<
"vector elements must be int/index/float type but got "
350 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
352 <<
"vector types must have positive constant sizes but got "
355 if (scalableDims.size() != shape.size())
356 return emitError() <<
"number of dims must match, got "
357 << scalableDims.size() <<
" and " << shape.size();
362VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
366 if (
auto scaledEt = et.scaleElementBitwidth(scale))
367 return VectorType::get(
getShape(), scaledEt, getScalableDims());
369 if (
auto scaledEt = et.scaleElementBitwidth(scale))
370 return VectorType::get(
getShape(), scaledEt, getScalableDims());
375 Type elementType)
const {
376 return VectorType::get(shape.value_or(
getShape()), elementType,
386 .Case<RankedTensorType, UnrankedTensorType>(
387 [](
auto type) {
return type.getElementType(); });
391 return !llvm::isa<UnrankedTensorType>(*
this);
395 return llvm::cast<RankedTensorType>(*this).getShape();
399 Type elementType)
const {
400 if (llvm::dyn_cast<UnrankedTensorType>(*
this)) {
402 return RankedTensorType::get(*
shape, elementType);
403 return UnrankedTensorType::get(elementType);
406 auto rankedTy = llvm::cast<RankedTensorType>(*
this);
408 return RankedTensorType::get(rankedTy.getShape(), elementType,
409 rankedTy.getEncoding());
410 return RankedTensorType::get(
shape.value_or(rankedTy.getShape()), elementType,
411 rankedTy.getEncoding());
415 Type elementType)
const {
416 return ::llvm::cast<RankedTensorType>(
cloneWith(
shape, elementType));
428 return emitError() <<
"invalid tensor element type: " << elementType;
437 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
439 !llvm::isa<BuiltinDialect>(type.
getDialect());
451 if (s < 0 && ShapedType::isStatic(s))
452 return emitError() <<
"invalid tensor dimension size";
453 if (
auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
476 [](
auto type) {
return type.getElementType(); });
480 return !llvm::isa<UnrankedMemRefType>(*
this);
484 return llvm::cast<MemRefType>(*this).getShape();
488 Type elementType)
const {
489 if (llvm::dyn_cast<UnrankedMemRefType>(*
this)) {
504FailureOr<PtrLikeTypeInterface>
506 std::optional<Type> elementType)
const {
508 if (llvm::dyn_cast<UnrankedMemRefType>(*
this))
509 return cast<PtrLikeTypeInterface>(
510 UnrankedMemRefType::get(eTy, memorySpace));
515 return cast<PtrLikeTypeInterface>(
static_cast<MemRefType
>(builder));
519 Type elementType)
const {
528 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
529 return rankedMemRefTy.getMemorySpace();
530 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
534 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
535 return rankedMemRefTy.getMemorySpaceAsInt();
536 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
543std::optional<llvm::SmallDenseSet<unsigned>>
547 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
548 llvm::SmallDenseSet<unsigned> unusedDims;
549 unsigned reducedIdx = 0;
550 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
552 int64_t origSize = originalShape[originalIdx];
554 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
555 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
556 ShapedType::isDynamic(origSize))) {
560 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
565 unusedDims.insert(originalIdx);
572 if (reducedIdx != reducedRank)
579 ShapedType candidateReducedType) {
580 if (originalType == candidateReducedType)
583 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
584 ShapedType candidateReducedShapedType =
585 llvm::cast<ShapedType>(candidateReducedType);
590 candidateReducedShapedType.getShape();
591 unsigned originalRank = originalShape.size(),
592 candidateReducedRank = candidateReducedShape.size();
593 if (candidateReducedRank > originalRank)
596 auto optionalUnusedDimsMask =
600 if (!optionalUnusedDimsMask)
603 if (originalShapedType.getElementType() !=
604 candidateReducedShapedType.getElementType())
616 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
620 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
628 if (memorySpace == 0)
631 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
635 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
636 if (intMemorySpace && intMemorySpace.getValue() == 0)
646 assert(llvm::isa<IntegerAttr>(memorySpace) &&
647 "Using `getMemorySpaceInteger` with non-Integer attribute");
649 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
652unsigned MemRefType::getMemorySpaceAsInt()
const {
657 MemRefLayoutAttrInterface layout,
671MemRefType MemRefType::getChecked(
673 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
684 elementType, layout, memorySpace);
696 auto layout = AffineMapAttr::get(map);
716 auto layout = AffineMapAttr::get(map);
722 elementType, layout, memorySpace);
726 AffineMap map,
unsigned memorySpaceInd) {
734 auto layout = AffineMapAttr::get(map);
747 unsigned memorySpaceInd) {
755 auto layout = AffineMapAttr::get(map);
762 elementType, layout, memorySpace);
767 MemRefLayoutAttrInterface layout,
770 return emitError() <<
"invalid memref element type";
773 for (int64_t s :
shape)
774 if (s < 0 && ShapedType::isStatic(s))
775 return emitError() <<
"invalid memref size";
777 assert(layout &&
"missing layout specification");
782 return emitError() <<
"unsupported memory space Attribute";
787bool MemRefType::areTrailingDimsContiguous(int64_t n) {
788 assert(n <= getRank() &&
789 "number of dimensions to check must not exceed rank");
790 return n <= getNumContiguousTrailingDims();
793int64_t MemRefType::getNumContiguousTrailingDims() {
794 const int64_t n = getRank();
797 if (getLayout().isIdentity())
815 int64_t dimProduct = 1;
816 for (int64_t i = n - 1; i >= 0; --i) {
819 if (strides[i] != dimProduct)
821 if (
shape[i] == ShapedType::kDynamic)
823 dimProduct *=
shape[i];
829MemRefType MemRefType::canonicalizeStridedLayout() {
830 AffineMap m = getLayout().getAffineMap();
842 if (
auto cst = llvm::dyn_cast<AffineConstantExpr>(m.
getResult(0)))
843 if (cst.getValue() == 0)
858 auto simplifiedLayoutExpr =
860 if (expr != simplifiedLayoutExpr)
863 simplifiedLayoutExpr)));
868 int64_t &offset)
const {
869 return getLayout().getStridesAndOffset(
getShape(), strides, offset);
872std::pair<SmallVector<int64_t>, int64_t>
873MemRefType::getStridesAndOffset()
const {
878 assert(succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
879 return {strides, offset};
882bool MemRefType::isStrided() {
886 return succeeded(res);
889bool MemRefType::isLastDimUnitStride() {
893 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
900unsigned UnrankedMemRefType::getMemorySpaceAsInt()
const {
908 return emitError() <<
"invalid memref element type";
911 return emitError() <<
"unsupported memory space Attribute";
921ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
928 for (
Type type : getTypes()) {
929 if (
auto nestedTuple = llvm::dyn_cast<TupleType>(type))
930 nestedTuple.getFlattenedTypes(types);
932 types.push_back(type);
937size_t TupleType::size()
const {
return getImpl()->size(); }
950 assert(!exprs.empty() &&
"expected exprs");
952 assert(!maps.empty() &&
"Expected one non-empty map");
953 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
956 bool dynamicPoisonBit =
false;
958 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
959 int64_t size = std::get<1>(en);
964 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
966 auto result = llvm::checkedMul(runningSize, size);
969 dynamicPoisonBit =
true;
974 dynamicPoisonBit =
true;
983 exprs.reserve(sizes.size());
984 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static void writeAPIntToVector(APInt apInt, SmallVectorImpl< char > &result)
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setElementType(Type newElementType)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
void writeBits(char *rawData, size_t bitPos, llvm::APInt value)
Write value to byte-aligned position bitPos in rawData.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)