10 #include "TypeDetail.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/TypeSwitch.h"
30 #define GET_TYPEDEF_CLASSES
31 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
41 void BuiltinDialect::registerTypes() {
43 #define GET_TYPEDEF_LIST
44 #include "mlir/IR/BuiltinTypes.cpp.inc"
56 return emitError() <<
"invalid element type for complex";
67 SignednessSemantics signedness) {
68 if (width > IntegerType::kMaxWidth) {
69 return emitError() <<
"integer bitwidth is limited to "
70 << IntegerType::kMaxWidth <<
" bits";
75 unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
77 IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
78 return getImpl()->signedness;
81 IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
92 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
93 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
94 return APFloat::SEM(); \
114 #undef FLOAT_TYPE_SEMANTICS
116 FloatType Float16Type::scaleElementBitwidth(
unsigned scale)
const {
124 FloatType BFloat16Type::scaleElementBitwidth(
unsigned scale)
const {
132 FloatType Float32Type::scaleElementBitwidth(
unsigned scale)
const {
142 unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
145 return getImpl()->getInputs();
148 unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
151 return getImpl()->getResults();
160 FunctionType FunctionType::getWithArgsAndResults(
167 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
168 return clone(newArgTypes, newResultTypes);
173 FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
174 const BitVector &resultIndices) {
179 return clone(newArgTypes, newResultTypes);
188 StringAttr dialect, StringRef typeData) {
190 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
197 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
198 <<
"` type created with unregistered dialect. If this is "
199 "intended, please call allowUnregisteredDialects() on the "
200 "MLIRContext, or use -allow-unregistered-dialect with "
201 "the MLIR opt tool used";
211 bool VectorType::isValidElementType(
Type t) {
212 return isValidVectorTypeElementType(t);
218 if (!isValidElementType(elementType))
220 <<
"vector elements must be int/index/float type but got "
223 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
225 <<
"vector types must have positive constant sizes but got "
228 if (scalableDims.size() != shape.size())
229 return emitError() <<
"number of dims must match, got "
230 << scalableDims.size() <<
" and " << shape.size();
235 VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
239 if (
auto scaledEt = et.scaleElementBitwidth(scale))
242 if (
auto scaledEt = et.scaleElementBitwidth(scale))
248 Type elementType)
const {
259 .Case<RankedTensorType, UnrankedTensorType>(
260 [](
auto type) {
return type.getElementType(); });
264 return !llvm::isa<UnrankedTensorType>(*
this);
268 return llvm::cast<RankedTensorType>(*this).getShape();
272 Type elementType)
const {
273 if (llvm::dyn_cast<UnrankedTensorType>(*
this)) {
279 auto rankedTy = llvm::cast<RankedTensorType>(*
this);
282 rankedTy.getEncoding());
284 rankedTy.getEncoding());
288 Type elementType)
const {
289 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
293 return ::llvm::cast<RankedTensorType>(cloneWith(shape,
getElementType()));
301 return emitError() <<
"invalid tensor element type: " << elementType;
310 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
312 !llvm::isa<BuiltinDialect>(type.
getDialect());
323 for (int64_t s : shape)
324 if (s < 0 && ShapedType::isStatic(s))
325 return emitError() <<
"invalid tensor dimension size";
326 if (
auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
327 if (failed(v.verifyEncoding(shape, elementType,
emitError)))
349 [](
auto type) {
return type.getElementType(); });
353 return !llvm::isa<UnrankedMemRefType>(*
this);
357 return llvm::cast<MemRefType>(*this).getShape();
361 Type elementType)
const {
362 if (llvm::dyn_cast<UnrankedMemRefType>(*
this)) {
377 FailureOr<PtrLikeTypeInterface>
379 std::optional<Type> elementType)
const {
381 if (llvm::dyn_cast<UnrankedMemRefType>(*
this))
382 return cast<PtrLikeTypeInterface>(
388 return cast<PtrLikeTypeInterface>(
static_cast<MemRefType
>(builder));
392 Type elementType)
const {
393 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
397 return ::llvm::cast<MemRefType>(cloneWith(shape,
getElementType()));
401 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
402 return rankedMemRefTy.getMemorySpace();
403 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
407 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
408 return rankedMemRefTy.getMemorySpaceAsInt();
409 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
416 std::optional<llvm::SmallDenseSet<unsigned>>
420 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
421 llvm::SmallDenseSet<unsigned> unusedDims;
422 unsigned reducedIdx = 0;
423 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
425 int64_t origSize = originalShape[originalIdx];
427 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
428 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
429 ShapedType::isDynamic(origSize))) {
433 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
438 unusedDims.insert(originalIdx);
445 if (reducedIdx != reducedRank)
452 ShapedType candidateReducedType) {
453 if (originalType == candidateReducedType)
456 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
457 ShapedType candidateReducedShapedType =
458 llvm::cast<ShapedType>(candidateReducedType);
463 candidateReducedShapedType.getShape();
464 unsigned originalRank = originalShape.size(),
465 candidateReducedRank = candidateReducedShape.size();
466 if (candidateReducedRank > originalRank)
469 auto optionalUnusedDimsMask =
473 if (!optionalUnusedDimsMask)
476 if (originalShapedType.getElementType() !=
477 candidateReducedShapedType.getElementType())
489 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
493 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
501 if (memorySpace == 0)
508 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
509 if (intMemorySpace && intMemorySpace.getValue() == 0)
519 assert(llvm::isa<IntegerAttr>(memorySpace) &&
520 "Using `getMemorySpaceInteger` with non-Integer attribute");
522 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
530 MemRefLayoutAttrInterface layout,
544 MemRefType MemRefType::getChecked(
546 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
556 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
557 elementType, layout, memorySpace);
594 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
595 elementType, layout, memorySpace);
599 AffineMap map,
unsigned memorySpaceInd) {
620 unsigned memorySpaceInd) {
634 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
635 elementType, layout, memorySpace);
640 MemRefLayoutAttrInterface layout,
643 return emitError() <<
"invalid memref element type";
646 for (int64_t s : shape)
647 if (s < 0 && ShapedType::isStatic(s))
648 return emitError() <<
"invalid memref size";
650 assert(layout &&
"missing layout specification");
651 if (failed(layout.verifyLayout(shape,
emitError)))
655 return emitError() <<
"unsupported memory space Attribute";
660 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
661 assert(n <= getRank() &&
662 "number of dimensions to check must not exceed rank");
663 return n <= getNumContiguousTrailingDims();
666 int64_t MemRefType::getNumContiguousTrailingDims() {
667 const int64_t n = getRank();
670 if (getLayout().isIdentity())
688 int64_t dimProduct = 1;
689 for (int64_t i = n - 1; i >= 0; --i) {
692 if (strides[i] != dimProduct)
694 if (shape[i] == ShapedType::kDynamic)
696 dimProduct *= shape[i];
702 MemRefType MemRefType::canonicalizeStridedLayout() {
703 AffineMap m = getLayout().getAffineMap();
715 if (
auto cst = llvm::dyn_cast<AffineConstantExpr>(m.
getResult(0)))
716 if (cst.getValue() == 0)
731 auto simplifiedLayoutExpr =
733 if (expr != simplifiedLayoutExpr)
736 simplifiedLayoutExpr)));
741 int64_t &offset)
const {
742 return getLayout().getStridesAndOffset(
getShape(), strides, offset);
745 std::pair<SmallVector<int64_t>, int64_t>
751 assert(succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
752 return {strides, offset};
755 bool MemRefType::isStrided() {
759 return succeeded(res);
762 bool MemRefType::isLastDimUnitStride() {
766 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
781 return emitError() <<
"invalid memref element type";
784 return emitError() <<
"unsupported memory space Attribute";
794 ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
801 for (
Type type : getTypes()) {
802 if (
auto nestedTuple = llvm::dyn_cast<TupleType>(type))
803 nestedTuple.getFlattenedTypes(types);
805 types.push_back(type);
810 size_t TupleType::size()
const {
return getImpl()->size(); }
823 assert(!exprs.empty() &&
"expected exprs");
825 assert(!maps.empty() &&
"Expected one non-empty map");
826 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
829 bool dynamicPoisonBit =
false;
830 int64_t runningSize = 1;
831 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
832 int64_t size = std::get<1>(en);
837 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
840 assert(runningSize > 0 &&
"integer overflow in size computation");
842 dynamicPoisonBit =
true;
851 exprs.reserve(sizes.size());
852 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setElementType(Type newElementType)
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)