10#include "TypeDetail.h"
19#include "llvm/ADT/APFloat.h"
20#include "llvm/ADT/Sequence.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/CheckedArithmetic.h"
31#define GET_TYPEDEF_CLASSES
32#include "mlir/IR/BuiltinTypes.cpp.inc"
35#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
42void BuiltinDialect::registerTypes() {
44#define GET_TYPEDEF_LIST
45#include "mlir/IR/BuiltinTypes.cpp.inc"
57 return emitError() <<
"invalid element type for complex";
68 SignednessSemantics signedness) {
69 if (width > IntegerType::kMaxWidth) {
70 return emitError() <<
"integer bitwidth is limited to "
71 << IntegerType::kMaxWidth <<
" bits";
76unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
78IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
79 return getImpl()->signedness;
82IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
85 return IntegerType::get(
getContext(), scale * getWidth(), getSignedness());
93#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
94 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
95 return APFloat::SEM(); \
115#undef FLOAT_TYPE_SEMANTICS
117FloatType Float16Type::scaleElementBitwidth(
unsigned scale)
const {
125FloatType BFloat16Type::scaleElementBitwidth(
unsigned scale)
const {
133FloatType Float32Type::scaleElementBitwidth(
unsigned scale)
const {
143unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
146 return getImpl()->getInputs();
149unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
152 return getImpl()->getResults();
161FunctionType FunctionType::getWithArgsAndResults(
168 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
169 return clone(newArgTypes, newResultTypes);
174FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
175 const BitVector &resultIndices) {
180 return clone(newArgTypes, newResultTypes);
187unsigned GraphType::getNumInputs()
const {
return getImpl()->numInputs; }
189ArrayRef<Type> GraphType::getInputs()
const {
return getImpl()->getInputs(); }
191unsigned GraphType::getNumResults()
const {
return getImpl()->numResults; }
193ArrayRef<Type> GraphType::getResults()
const {
return getImpl()->getResults(); }
209 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
210 return clone(newArgTypes, newResultTypes);
214GraphType GraphType::getWithoutArgsAndResults(
const BitVector &argIndices,
215 const BitVector &resultIndices) {
220 return clone(newArgTypes, newResultTypes);
228 StringAttr dialect, StringRef typeData) {
230 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
237 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
238 <<
"` type created with unregistered dialect. If this is "
239 "intended, please call allowUnregisteredDialects() on the "
240 "MLIRContext, or use -allow-unregistered-dialect with "
241 "the MLIR opt tool used";
251bool VectorType::isValidElementType(
Type t) {
258 if (!isValidElementType(elementType))
260 <<
"vector elements must be int/index/float type but got "
263 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
265 <<
"vector types must have positive constant sizes but got "
268 if (scalableDims.size() != shape.size())
269 return emitError() <<
"number of dims must match, got "
270 << scalableDims.size() <<
" and " << shape.size();
275VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
279 if (
auto scaledEt = et.scaleElementBitwidth(scale))
280 return VectorType::get(
getShape(), scaledEt, getScalableDims());
282 if (
auto scaledEt = et.scaleElementBitwidth(scale))
283 return VectorType::get(
getShape(), scaledEt, getScalableDims());
288 Type elementType)
const {
289 return VectorType::get(shape.value_or(
getShape()), elementType,
299 .Case<RankedTensorType, UnrankedTensorType>(
300 [](
auto type) {
return type.getElementType(); });
304 return !llvm::isa<UnrankedTensorType>(*
this);
308 return llvm::cast<RankedTensorType>(*this).getShape();
312 Type elementType)
const {
313 if (llvm::dyn_cast<UnrankedTensorType>(*
this)) {
315 return RankedTensorType::get(*
shape, elementType);
316 return UnrankedTensorType::get(elementType);
319 auto rankedTy = llvm::cast<RankedTensorType>(*
this);
321 return RankedTensorType::get(rankedTy.getShape(), elementType,
322 rankedTy.getEncoding());
323 return RankedTensorType::get(
shape.value_or(rankedTy.getShape()), elementType,
324 rankedTy.getEncoding());
328 Type elementType)
const {
329 return ::llvm::cast<RankedTensorType>(
cloneWith(
shape, elementType));
341 return emitError() <<
"invalid tensor element type: " << elementType;
350 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
352 !llvm::isa<BuiltinDialect>(type.
getDialect());
364 if (s < 0 && ShapedType::isStatic(s))
365 return emitError() <<
"invalid tensor dimension size";
366 if (
auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
389 [](
auto type) {
return type.getElementType(); });
393 return !llvm::isa<UnrankedMemRefType>(*
this);
397 return llvm::cast<MemRefType>(*this).getShape();
401 Type elementType)
const {
402 if (llvm::dyn_cast<UnrankedMemRefType>(*
this)) {
417FailureOr<PtrLikeTypeInterface>
419 std::optional<Type> elementType)
const {
421 if (llvm::dyn_cast<UnrankedMemRefType>(*
this))
422 return cast<PtrLikeTypeInterface>(
423 UnrankedMemRefType::get(eTy, memorySpace));
428 return cast<PtrLikeTypeInterface>(
static_cast<MemRefType
>(builder));
432 Type elementType)
const {
441 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
442 return rankedMemRefTy.getMemorySpace();
443 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
447 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
448 return rankedMemRefTy.getMemorySpaceAsInt();
449 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
456std::optional<llvm::SmallDenseSet<unsigned>>
460 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
461 llvm::SmallDenseSet<unsigned> unusedDims;
462 unsigned reducedIdx = 0;
463 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
465 int64_t origSize = originalShape[originalIdx];
467 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
468 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
469 ShapedType::isDynamic(origSize))) {
473 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
478 unusedDims.insert(originalIdx);
485 if (reducedIdx != reducedRank)
492 ShapedType candidateReducedType) {
493 if (originalType == candidateReducedType)
496 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
497 ShapedType candidateReducedShapedType =
498 llvm::cast<ShapedType>(candidateReducedType);
503 candidateReducedShapedType.getShape();
504 unsigned originalRank = originalShape.size(),
505 candidateReducedRank = candidateReducedShape.size();
506 if (candidateReducedRank > originalRank)
509 auto optionalUnusedDimsMask =
513 if (!optionalUnusedDimsMask)
516 if (originalShapedType.getElementType() !=
517 candidateReducedShapedType.getElementType())
529 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
533 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
541 if (memorySpace == 0)
544 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
548 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
549 if (intMemorySpace && intMemorySpace.getValue() == 0)
559 assert(llvm::isa<IntegerAttr>(memorySpace) &&
560 "Using `getMemorySpaceInteger` with non-Integer attribute");
562 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
565unsigned MemRefType::getMemorySpaceAsInt()
const {
570 MemRefLayoutAttrInterface layout,
584MemRefType MemRefType::getChecked(
586 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
597 elementType, layout, memorySpace);
609 auto layout = AffineMapAttr::get(map);
629 auto layout = AffineMapAttr::get(map);
635 elementType, layout, memorySpace);
639 AffineMap map,
unsigned memorySpaceInd) {
647 auto layout = AffineMapAttr::get(map);
660 unsigned memorySpaceInd) {
668 auto layout = AffineMapAttr::get(map);
675 elementType, layout, memorySpace);
680 MemRefLayoutAttrInterface layout,
683 return emitError() <<
"invalid memref element type";
686 for (int64_t s :
shape)
687 if (s < 0 && ShapedType::isStatic(s))
688 return emitError() <<
"invalid memref size";
690 assert(layout &&
"missing layout specification");
695 return emitError() <<
"unsupported memory space Attribute";
700bool MemRefType::areTrailingDimsContiguous(int64_t n) {
701 assert(n <= getRank() &&
702 "number of dimensions to check must not exceed rank");
703 return n <= getNumContiguousTrailingDims();
706int64_t MemRefType::getNumContiguousTrailingDims() {
707 const int64_t n = getRank();
710 if (getLayout().isIdentity())
728 int64_t dimProduct = 1;
729 for (int64_t i = n - 1; i >= 0; --i) {
732 if (strides[i] != dimProduct)
734 if (
shape[i] == ShapedType::kDynamic)
736 dimProduct *=
shape[i];
742MemRefType MemRefType::canonicalizeStridedLayout() {
743 AffineMap m = getLayout().getAffineMap();
755 if (
auto cst = llvm::dyn_cast<AffineConstantExpr>(m.
getResult(0)))
756 if (cst.getValue() == 0)
771 auto simplifiedLayoutExpr =
773 if (expr != simplifiedLayoutExpr)
776 simplifiedLayoutExpr)));
781 int64_t &offset)
const {
782 return getLayout().getStridesAndOffset(
getShape(), strides, offset);
785std::pair<SmallVector<int64_t>, int64_t>
786MemRefType::getStridesAndOffset()
const {
791 assert(succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
792 return {strides, offset};
795bool MemRefType::isStrided() {
799 return succeeded(res);
802bool MemRefType::isLastDimUnitStride() {
806 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
813unsigned UnrankedMemRefType::getMemorySpaceAsInt()
const {
821 return emitError() <<
"invalid memref element type";
824 return emitError() <<
"unsupported memory space Attribute";
834ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
841 for (
Type type : getTypes()) {
842 if (
auto nestedTuple = llvm::dyn_cast<TupleType>(type))
843 nestedTuple.getFlattenedTypes(types);
845 types.push_back(type);
850size_t TupleType::size()
const {
return getImpl()->size(); }
863 assert(!exprs.empty() &&
"expected exprs");
865 assert(!maps.empty() &&
"Expected one non-empty map");
866 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
869 bool dynamicPoisonBit =
false;
871 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
872 int64_t size = std::get<1>(en);
877 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
879 auto result = llvm::checkedMul(runningSize, size);
882 dynamicPoisonBit =
true;
887 dynamicPoisonBit =
true;
896 exprs.reserve(sizes.size());
897 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setElementType(Type newElementType)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
bool isValidVectorTypeElementType(::mlir::Type type)
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)