10 #include "TypeDetail.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
43 void BuiltinDialect::registerTypes() {
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
58 return emitError() <<
"invalid element type for complex";
69 SignednessSemantics signedness) {
70 if (width > IntegerType::kMaxWidth) {
71 return emitError() <<
"integer bitwidth is limited to "
72 << IntegerType::kMaxWidth <<
" bits";
77 unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
79 IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
80 return getImpl()->signedness;
83 IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
97 if (llvm::isa<FloatTF32Type>(*
this))
99 return APFloat::semanticsSizeInBits(getFloatSemantics());
104 if (llvm::isa<Float4E2M1FNType>(*
this))
105 return APFloat::Float4E2M1FN();
106 if (llvm::isa<Float6E2M3FNType>(*
this))
107 return APFloat::Float6E2M3FN();
108 if (llvm::isa<Float6E3M2FNType>(*
this))
109 return APFloat::Float6E3M2FN();
110 if (llvm::isa<Float8E5M2Type>(*
this))
111 return APFloat::Float8E5M2();
112 if (llvm::isa<Float8E4M3Type>(*
this))
113 return APFloat::Float8E4M3();
114 if (llvm::isa<Float8E4M3FNType>(*
this))
115 return APFloat::Float8E4M3FN();
116 if (llvm::isa<Float8E5M2FNUZType>(*
this))
117 return APFloat::Float8E5M2FNUZ();
118 if (llvm::isa<Float8E4M3FNUZType>(*
this))
119 return APFloat::Float8E4M3FNUZ();
120 if (llvm::isa<Float8E4M3B11FNUZType>(*
this))
121 return APFloat::Float8E4M3B11FNUZ();
122 if (llvm::isa<Float8E3M4Type>(*
this))
123 return APFloat::Float8E3M4();
124 if (llvm::isa<Float8E8M0FNUType>(*
this))
125 return APFloat::Float8E8M0FNU();
126 if (llvm::isa<BFloat16Type>(*
this))
127 return APFloat::BFloat();
128 if (llvm::isa<Float16Type>(*
this))
129 return APFloat::IEEEhalf();
130 if (llvm::isa<FloatTF32Type>(*
this))
131 return APFloat::FloatTF32();
132 if (llvm::isa<Float32Type>(*
this))
133 return APFloat::IEEEsingle();
134 if (llvm::isa<Float64Type>(*
this))
135 return APFloat::IEEEdouble();
136 if (llvm::isa<Float80Type>(*
this))
137 return APFloat::x87DoubleExtended();
138 if (llvm::isa<Float128Type>(*
this))
139 return APFloat::IEEEquad();
140 llvm_unreachable(
"non-floating point type used");
147 if (isF16() || isBF16()) {
160 return APFloat::semanticsPrecision(getFloatSemantics());
167 unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
170 return getImpl()->getInputs();
173 unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
176 return getImpl()->getResults();
185 FunctionType FunctionType::getWithArgsAndResults(
192 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
193 return clone(newArgTypes, newResultTypes);
198 FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
199 const BitVector &resultIndices) {
204 return clone(newArgTypes, newResultTypes);
213 StringAttr dialect, StringRef typeData) {
215 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
222 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
223 <<
"` type created with unregistered dialect. If this is "
224 "intended, please call allowUnregisteredDialects() on the "
225 "MLIRContext, or use -allow-unregistered-dialect with "
226 "the MLIR opt tool used";
236 bool VectorType::isValidElementType(
Type t) {
237 return isValidVectorTypeElementType(t);
243 if (!isValidElementType(elementType))
245 <<
"vector elements must be int/index/float type but got "
248 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
250 <<
"vector types must have positive constant sizes but got "
253 if (scalableDims.size() != shape.size())
254 return emitError() <<
"number of dims must match, got "
255 << scalableDims.size() <<
" and " << shape.size();
260 VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
264 if (
auto scaledEt = et.scaleElementBitwidth(scale))
267 if (
auto scaledEt = et.scaleElementBitwidth(scale))
273 Type elementType)
const {
284 .Case<RankedTensorType, UnrankedTensorType>(
285 [](
auto type) {
return type.getElementType(); });
289 return !llvm::isa<UnrankedTensorType>(*
this);
293 return llvm::cast<RankedTensorType>(*this).getShape();
297 Type elementType)
const {
298 if (llvm::dyn_cast<UnrankedTensorType>(*
this)) {
304 auto rankedTy = llvm::cast<RankedTensorType>(*
this);
307 rankedTy.getEncoding());
309 rankedTy.getEncoding());
313 Type elementType)
const {
314 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
318 return ::llvm::cast<RankedTensorType>(cloneWith(shape,
getElementType()));
326 return emitError() <<
"invalid tensor element type: " << elementType;
335 return llvm::isa<ComplexType,
FloatType, IntegerType, OpaqueType, VectorType,
337 !llvm::isa<BuiltinDialect>(type.
getDialect());
348 for (int64_t s : shape)
349 if (s < 0 && !ShapedType::isDynamic(s))
350 return emitError() <<
"invalid tensor dimension size";
351 if (
auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
352 if (failed(v.verifyEncoding(shape, elementType,
emitError)))
374 [](
auto type) {
return type.getElementType(); });
378 return !llvm::isa<UnrankedMemRefType>(*
this);
382 return llvm::cast<MemRefType>(*this).getShape();
386 Type elementType)
const {
387 if (llvm::dyn_cast<UnrankedMemRefType>(*
this)) {
403 Type elementType)
const {
404 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
408 return ::llvm::cast<MemRefType>(cloneWith(shape,
getElementType()));
412 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
413 return rankedMemRefTy.getMemorySpace();
414 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
418 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
419 return rankedMemRefTy.getMemorySpaceAsInt();
420 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
427 std::optional<llvm::SmallDenseSet<unsigned>>
431 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
432 llvm::SmallDenseSet<unsigned> unusedDims;
433 unsigned reducedIdx = 0;
434 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
436 int64_t origSize = originalShape[originalIdx];
438 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
439 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
440 ShapedType::isDynamic(origSize))) {
444 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
449 unusedDims.insert(originalIdx);
456 if (reducedIdx != reducedRank)
463 ShapedType candidateReducedType) {
464 if (originalType == candidateReducedType)
467 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
468 ShapedType candidateReducedShapedType =
469 llvm::cast<ShapedType>(candidateReducedType);
474 candidateReducedShapedType.getShape();
475 unsigned originalRank = originalShape.size(),
476 candidateReducedRank = candidateReducedShape.size();
477 if (candidateReducedRank > originalRank)
480 auto optionalUnusedDimsMask =
484 if (!optionalUnusedDimsMask)
487 if (originalShapedType.getElementType() !=
488 candidateReducedShapedType.getElementType())
500 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
504 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
512 if (memorySpace == 0)
519 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
520 if (intMemorySpace && intMemorySpace.getValue() == 0)
530 assert(llvm::isa<IntegerAttr>(memorySpace) &&
531 "Using `getMemorySpaceInteger` with non-Integer attribute");
533 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
541 MemRefLayoutAttrInterface layout,
555 MemRefType MemRefType::getChecked(
557 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
567 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
568 elementType, layout, memorySpace);
605 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
606 elementType, layout, memorySpace);
610 AffineMap map,
unsigned memorySpaceInd) {
631 unsigned memorySpaceInd) {
645 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
646 elementType, layout, memorySpace);
651 MemRefLayoutAttrInterface layout,
654 return emitError() <<
"invalid memref element type";
657 for (int64_t s : shape)
658 if (s < 0 && !ShapedType::isDynamic(s))
659 return emitError() <<
"invalid memref size";
661 assert(layout &&
"missing layout specification");
662 if (failed(layout.verifyLayout(shape,
emitError)))
666 return emitError() <<
"unsupported memory space Attribute";
683 return emitError() <<
"invalid memref element type";
686 return emitError() <<
"unsupported memory space Attribute";
697 if (
auto dim = dyn_cast<AffineDimExpr>(e))
698 strides[dim.getPosition()] =
699 strides[dim.getPosition()] + multiplicativeFactor;
701 offset = offset + e * multiplicativeFactor;
712 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
724 auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
726 strides[dim.getPosition()] =
727 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
734 if (bin.getLHS().isSymbolicOrConstant())
735 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
737 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
743 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
745 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
746 return success(succeeded(res1) && succeeded(res2));
749 llvm_unreachable(
"unexpected binary operation");
766 SmallVectorImpl<AffineExpr> &strides,
768 AffineMap m = t.getLayout().getAffineMap();
770 if (m.getNumResults() != 1 && !m.isIdentity())
776 strides.assign(t.getRank(), zero);
779 if (m.isIdentity()) {
781 if (t.getRank() == 0)
787 assert(
false &&
"unexpected failure: extract strides in canonical layout");
800 unsigned numDims = m.getNumDims();
801 unsigned numSymbols = m.getNumSymbols();
803 for (
auto &stride : strides)
827 if (
auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
828 llvm::append_range(strides, strided.getStrides());
829 offset = strided.getOffset();
839 if (
auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
840 offset = cst.getValue();
842 offset = ShapedType::kDynamic;
843 for (
auto e : strideExprs) {
844 if (
auto c = dyn_cast<AffineConstantExpr>(e))
845 strides.push_back(c.getValue());
847 strides.push_back(ShapedType::kDynamic);
852 std::pair<SmallVector<int64_t>, int64_t>
858 assert(succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
859 return {strides, offset};
867 ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
874 for (
Type type : getTypes()) {
875 if (
auto nestedTuple = llvm::dyn_cast<TupleType>(type))
876 nestedTuple.getFlattenedTypes(types);
878 types.push_back(type);
883 size_t TupleType::size()
const {
return getImpl()->size(); }
895 AffineMap m = t.getLayout().getAffineMap();
902 if (m.getNumResults() > 1)
906 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
907 if (
auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
908 if (cst.getValue() == 0)
916 if (t.getShape().empty())
924 auto simplifiedLayoutExpr =
926 if (expr != simplifiedLayoutExpr)
928 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
939 assert(!exprs.empty() &&
"expected exprs");
941 assert(!maps.empty() &&
"Expected one non-empty map");
942 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
945 bool dynamicPoisonBit =
false;
946 int64_t runningSize = 1;
947 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
948 int64_t size = std::get<1>(en);
953 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
956 assert(runningSize > 0 &&
"integer overflow in size computation");
958 dynamicPoisonBit =
true;
967 exprs.reserve(sizes.size());
968 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
977 return succeeded(res);
984 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
991 auto memrefShape = type.getShape().take_back(n);
992 if (ShapedType::isDynamicShape(memrefShape))
995 if (type.getLayout().isIdentity())
1004 if (strides.empty())
1009 auto dimProduct = 1;
1010 for (
auto dim : llvm::reverse(memrefShape.drop_front(1))) {
1012 flattenedDims.push_back(dimProduct);
1015 strides = strides.drop_back(1);
1016 return llvm::equal(strides, llvm::reverse(flattenedDims));
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
static FloatType getF64(MLIRContext *ctx)
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setElementType(Type newElementType)
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
bool trailingNDimsContiguous(MemRefType type, int64_t n)
Return "true" if the last N dimensions of the given type are contiguous.