10 #include "TypeDetail.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
43 void BuiltinDialect::registerTypes() {
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
58 return emitError() <<
"invalid element type for complex";
69 SignednessSemantics signedness) {
70 if (width > IntegerType::kMaxWidth) {
71 return emitError() <<
"integer bitwidth is limited to "
72 << IntegerType::kMaxWidth <<
" bits";
77 unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
79 IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
80 return getImpl()->signedness;
83 IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
94 return APFloat::semanticsSizeInBits(getFloatSemantics());
99 if (llvm::isa<Float4E2M1FNType>(*
this))
100 return APFloat::Float4E2M1FN();
101 if (llvm::isa<Float6E2M3FNType>(*
this))
102 return APFloat::Float6E2M3FN();
103 if (llvm::isa<Float6E3M2FNType>(*
this))
104 return APFloat::Float6E3M2FN();
105 if (llvm::isa<Float8E5M2Type>(*
this))
106 return APFloat::Float8E5M2();
107 if (llvm::isa<Float8E4M3Type>(*
this))
108 return APFloat::Float8E4M3();
109 if (llvm::isa<Float8E4M3FNType>(*
this))
110 return APFloat::Float8E4M3FN();
111 if (llvm::isa<Float8E5M2FNUZType>(*
this))
112 return APFloat::Float8E5M2FNUZ();
113 if (llvm::isa<Float8E4M3FNUZType>(*
this))
114 return APFloat::Float8E4M3FNUZ();
115 if (llvm::isa<Float8E4M3B11FNUZType>(*
this))
116 return APFloat::Float8E4M3B11FNUZ();
117 if (llvm::isa<Float8E3M4Type>(*
this))
118 return APFloat::Float8E3M4();
119 if (llvm::isa<Float8E8M0FNUType>(*
this))
120 return APFloat::Float8E8M0FNU();
121 if (llvm::isa<BFloat16Type>(*
this))
122 return APFloat::BFloat();
123 if (llvm::isa<Float16Type>(*
this))
124 return APFloat::IEEEhalf();
125 if (llvm::isa<FloatTF32Type>(*
this))
126 return APFloat::FloatTF32();
127 if (llvm::isa<Float32Type>(*
this))
128 return APFloat::IEEEsingle();
129 if (llvm::isa<Float64Type>(*
this))
130 return APFloat::IEEEdouble();
131 if (llvm::isa<Float80Type>(*
this))
132 return APFloat::x87DoubleExtended();
133 if (llvm::isa<Float128Type>(*
this))
134 return APFloat::IEEEquad();
135 llvm_unreachable(
"non-floating point type used");
142 if (isF16() || isBF16()) {
155 return APFloat::semanticsPrecision(getFloatSemantics());
162 unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
165 return getImpl()->getInputs();
168 unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
171 return getImpl()->getResults();
180 FunctionType FunctionType::getWithArgsAndResults(
187 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
188 return clone(newArgTypes, newResultTypes);
193 FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
194 const BitVector &resultIndices) {
199 return clone(newArgTypes, newResultTypes);
208 StringAttr dialect, StringRef typeData) {
210 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
217 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
218 <<
"` type created with unregistered dialect. If this is "
219 "intended, please call allowUnregisteredDialects() on the "
220 "MLIRContext, or use -allow-unregistered-dialect with "
221 "the MLIR opt tool used";
231 bool VectorType::isValidElementType(
Type t) {
232 return isValidVectorTypeElementType(t);
238 if (!isValidElementType(elementType))
240 <<
"vector elements must be int/index/float type but got "
243 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
245 <<
"vector types must have positive constant sizes but got "
248 if (scalableDims.size() != shape.size())
249 return emitError() <<
"number of dims must match, got "
250 << scalableDims.size() <<
" and " << shape.size();
255 VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
259 if (
auto scaledEt = et.scaleElementBitwidth(scale))
262 if (
auto scaledEt = et.scaleElementBitwidth(scale))
268 Type elementType)
const {
279 .Case<RankedTensorType, UnrankedTensorType>(
280 [](
auto type) {
return type.getElementType(); });
284 return !llvm::isa<UnrankedTensorType>(*
this);
288 return llvm::cast<RankedTensorType>(*this).getShape();
292 Type elementType)
const {
293 if (llvm::dyn_cast<UnrankedTensorType>(*
this)) {
299 auto rankedTy = llvm::cast<RankedTensorType>(*
this);
302 rankedTy.getEncoding());
304 rankedTy.getEncoding());
308 Type elementType)
const {
309 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
313 return ::llvm::cast<RankedTensorType>(cloneWith(shape,
getElementType()));
321 return emitError() <<
"invalid tensor element type: " << elementType;
330 return llvm::isa<ComplexType,
FloatType, IntegerType, OpaqueType, VectorType,
332 !llvm::isa<BuiltinDialect>(type.
getDialect());
343 for (int64_t s : shape)
344 if (s < 0 && !ShapedType::isDynamic(s))
345 return emitError() <<
"invalid tensor dimension size";
346 if (
auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
347 if (failed(v.verifyEncoding(shape, elementType,
emitError)))
369 [](
auto type) {
return type.getElementType(); });
373 return !llvm::isa<UnrankedMemRefType>(*
this);
377 return llvm::cast<MemRefType>(*this).getShape();
381 Type elementType)
const {
382 if (llvm::dyn_cast<UnrankedMemRefType>(*
this)) {
398 Type elementType)
const {
399 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
403 return ::llvm::cast<MemRefType>(cloneWith(shape,
getElementType()));
407 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
408 return rankedMemRefTy.getMemorySpace();
409 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
413 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
414 return rankedMemRefTy.getMemorySpaceAsInt();
415 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
422 std::optional<llvm::SmallDenseSet<unsigned>>
426 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
427 llvm::SmallDenseSet<unsigned> unusedDims;
428 unsigned reducedIdx = 0;
429 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
431 int64_t origSize = originalShape[originalIdx];
433 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
434 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
435 ShapedType::isDynamic(origSize))) {
439 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
444 unusedDims.insert(originalIdx);
451 if (reducedIdx != reducedRank)
458 ShapedType candidateReducedType) {
459 if (originalType == candidateReducedType)
462 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
463 ShapedType candidateReducedShapedType =
464 llvm::cast<ShapedType>(candidateReducedType);
469 candidateReducedShapedType.getShape();
470 unsigned originalRank = originalShape.size(),
471 candidateReducedRank = candidateReducedShape.size();
472 if (candidateReducedRank > originalRank)
475 auto optionalUnusedDimsMask =
479 if (!optionalUnusedDimsMask)
482 if (originalShapedType.getElementType() !=
483 candidateReducedShapedType.getElementType())
495 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
499 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
507 if (memorySpace == 0)
514 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
515 if (intMemorySpace && intMemorySpace.getValue() == 0)
525 assert(llvm::isa<IntegerAttr>(memorySpace) &&
526 "Using `getMemorySpaceInteger` with non-Integer attribute");
528 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
536 MemRefLayoutAttrInterface layout,
550 MemRefType MemRefType::getChecked(
552 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
562 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
563 elementType, layout, memorySpace);
600 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
601 elementType, layout, memorySpace);
605 AffineMap map,
unsigned memorySpaceInd) {
626 unsigned memorySpaceInd) {
640 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
641 elementType, layout, memorySpace);
646 MemRefLayoutAttrInterface layout,
649 return emitError() <<
"invalid memref element type";
652 for (int64_t s : shape)
653 if (s < 0 && !ShapedType::isDynamic(s))
654 return emitError() <<
"invalid memref size";
656 assert(layout &&
"missing layout specification");
657 if (failed(layout.verifyLayout(shape,
emitError)))
661 return emitError() <<
"unsupported memory space Attribute";
678 return emitError() <<
"invalid memref element type";
681 return emitError() <<
"unsupported memory space Attribute";
692 if (
auto dim = dyn_cast<AffineDimExpr>(e))
693 strides[dim.getPosition()] =
694 strides[dim.getPosition()] + multiplicativeFactor;
696 offset = offset + e * multiplicativeFactor;
707 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
719 auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
721 strides[dim.getPosition()] =
722 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
729 if (bin.getLHS().isSymbolicOrConstant())
730 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
732 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
738 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
740 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
741 return success(succeeded(res1) && succeeded(res2));
744 llvm_unreachable(
"unexpected binary operation");
761 SmallVectorImpl<AffineExpr> &strides,
763 AffineMap m = t.getLayout().getAffineMap();
765 if (m.getNumResults() != 1 && !m.isIdentity())
771 strides.assign(t.getRank(), zero);
774 if (m.isIdentity()) {
776 if (t.getRank() == 0)
782 assert(
false &&
"unexpected failure: extract strides in canonical layout");
795 unsigned numDims = m.getNumDims();
796 unsigned numSymbols = m.getNumSymbols();
798 for (
auto &stride : strides)
808 if (
auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
809 llvm::append_range(strides, strided.getStrides());
810 offset = strided.getOffset();
820 if (
auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
821 offset = cst.getValue();
823 offset = ShapedType::kDynamic;
824 for (
auto e : strideExprs) {
825 if (
auto c = dyn_cast<AffineConstantExpr>(e))
826 strides.push_back(c.getValue());
828 strides.push_back(ShapedType::kDynamic);
833 std::pair<SmallVector<int64_t>, int64_t>
839 assert(succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
840 return {strides, offset};
848 ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
855 for (
Type type : getTypes()) {
856 if (
auto nestedTuple = llvm::dyn_cast<TupleType>(type))
857 nestedTuple.getFlattenedTypes(types);
859 types.push_back(type);
864 size_t TupleType::size()
const {
return getImpl()->size(); }
876 AffineMap m = t.getLayout().getAffineMap();
883 if (m.getNumResults() > 1)
887 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
888 if (
auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
889 if (cst.getValue() == 0)
897 if (t.getShape().empty())
905 auto simplifiedLayoutExpr =
907 if (expr != simplifiedLayoutExpr)
909 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
920 assert(!exprs.empty() &&
"expected exprs");
922 assert(!maps.empty() &&
"Expected one non-empty map");
923 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
926 bool dynamicPoisonBit =
false;
927 int64_t runningSize = 1;
928 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
929 int64_t size = std::get<1>(en);
934 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
937 assert(runningSize > 0 &&
"integer overflow in size computation");
939 dynamicPoisonBit =
true;
948 exprs.reserve(sizes.size());
949 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
958 return succeeded(res);
965 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
972 auto memrefShape = type.getShape().take_back(n);
973 if (ShapedType::isDynamicShape(memrefShape))
976 if (type.getLayout().isIdentity())
991 for (
auto dim : llvm::reverse(memrefShape.drop_front(1))) {
993 flattenedDims.push_back(dimProduct);
996 strides = strides.drop_back(1);
997 return llvm::equal(strides, llvm::reverse(flattenedDims));
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
static FloatType getF64(MLIRContext *ctx)
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setElementType(Type newElementType)
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
bool trailingNDimsContiguous(MemRefType type, int64_t n)
Return "true" if the last N dimensions of the given type are contiguous.