10 #include "TypeDetail.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
43 void BuiltinDialect::registerTypes() {
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
58 return emitError() <<
"invalid element type for complex";
69 SignednessSemantics signedness) {
70 if (width > IntegerType::kMaxWidth) {
71 return emitError() <<
"integer bitwidth is limited to "
72 << IntegerType::kMaxWidth <<
" bits";
77 unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
79 IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
80 return getImpl()->signedness;
83 IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
94 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
95 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
96 return APFloat::SEM(); \
116 #undef FLOAT_TYPE_SEMANTICS
118 FloatType Float16Type::scaleElementBitwidth(
unsigned scale)
const {
126 FloatType BFloat16Type::scaleElementBitwidth(
unsigned scale)
const {
134 FloatType Float32Type::scaleElementBitwidth(
unsigned scale)
const {
144 unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
147 return getImpl()->getInputs();
150 unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
153 return getImpl()->getResults();
162 FunctionType FunctionType::getWithArgsAndResults(
169 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
170 return clone(newArgTypes, newResultTypes);
175 FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
176 const BitVector &resultIndices) {
181 return clone(newArgTypes, newResultTypes);
190 StringAttr dialect, StringRef typeData) {
192 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
199 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
200 <<
"` type created with unregistered dialect. If this is "
201 "intended, please call allowUnregisteredDialects() on the "
202 "MLIRContext, or use -allow-unregistered-dialect with "
203 "the MLIR opt tool used";
213 bool VectorType::isValidElementType(
Type t) {
214 return isValidVectorTypeElementType(t);
220 if (!isValidElementType(elementType))
222 <<
"vector elements must be int/index/float type but got "
225 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
227 <<
"vector types must have positive constant sizes but got "
230 if (scalableDims.size() != shape.size())
231 return emitError() <<
"number of dims must match, got "
232 << scalableDims.size() <<
" and " << shape.size();
237 VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
241 if (
auto scaledEt = et.scaleElementBitwidth(scale))
244 if (
auto scaledEt = et.scaleElementBitwidth(scale))
250 Type elementType)
const {
261 .Case<RankedTensorType, UnrankedTensorType>(
262 [](
auto type) {
return type.getElementType(); });
266 return !llvm::isa<UnrankedTensorType>(*
this);
270 return llvm::cast<RankedTensorType>(*this).getShape();
274 Type elementType)
const {
275 if (llvm::dyn_cast<UnrankedTensorType>(*
this)) {
281 auto rankedTy = llvm::cast<RankedTensorType>(*
this);
284 rankedTy.getEncoding());
286 rankedTy.getEncoding());
290 Type elementType)
const {
291 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
295 return ::llvm::cast<RankedTensorType>(cloneWith(shape,
getElementType()));
303 return emitError() <<
"invalid tensor element type: " << elementType;
312 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
314 !llvm::isa<BuiltinDialect>(type.
getDialect());
325 for (int64_t s : shape)
326 if (s < 0 && !ShapedType::isDynamic(s))
327 return emitError() <<
"invalid tensor dimension size";
328 if (
auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
329 if (failed(v.verifyEncoding(shape, elementType,
emitError)))
351 [](
auto type) {
return type.getElementType(); });
355 return !llvm::isa<UnrankedMemRefType>(*
this);
359 return llvm::cast<MemRefType>(*this).getShape();
363 Type elementType)
const {
364 if (llvm::dyn_cast<UnrankedMemRefType>(*
this)) {
380 Type elementType)
const {
381 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
385 return ::llvm::cast<MemRefType>(cloneWith(shape,
getElementType()));
389 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
390 return rankedMemRefTy.getMemorySpace();
391 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
395 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
396 return rankedMemRefTy.getMemorySpaceAsInt();
397 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
404 std::optional<llvm::SmallDenseSet<unsigned>>
408 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
409 llvm::SmallDenseSet<unsigned> unusedDims;
410 unsigned reducedIdx = 0;
411 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
413 int64_t origSize = originalShape[originalIdx];
415 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
416 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
417 ShapedType::isDynamic(origSize))) {
421 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
426 unusedDims.insert(originalIdx);
433 if (reducedIdx != reducedRank)
440 ShapedType candidateReducedType) {
441 if (originalType == candidateReducedType)
444 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
445 ShapedType candidateReducedShapedType =
446 llvm::cast<ShapedType>(candidateReducedType);
451 candidateReducedShapedType.getShape();
452 unsigned originalRank = originalShape.size(),
453 candidateReducedRank = candidateReducedShape.size();
454 if (candidateReducedRank > originalRank)
457 auto optionalUnusedDimsMask =
461 if (!optionalUnusedDimsMask)
464 if (originalShapedType.getElementType() !=
465 candidateReducedShapedType.getElementType())
477 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
481 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
489 if (memorySpace == 0)
496 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
497 if (intMemorySpace && intMemorySpace.getValue() == 0)
507 assert(llvm::isa<IntegerAttr>(memorySpace) &&
508 "Using `getMemorySpaceInteger` with non-Integer attribute");
510 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
518 MemRefLayoutAttrInterface layout,
532 MemRefType MemRefType::getChecked(
534 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
544 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
545 elementType, layout, memorySpace);
582 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
583 elementType, layout, memorySpace);
587 AffineMap map,
unsigned memorySpaceInd) {
608 unsigned memorySpaceInd) {
622 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
623 elementType, layout, memorySpace);
628 MemRefLayoutAttrInterface layout,
631 return emitError() <<
"invalid memref element type";
634 for (int64_t s : shape)
635 if (s < 0 && !ShapedType::isDynamic(s))
636 return emitError() <<
"invalid memref size";
638 assert(layout &&
"missing layout specification");
639 if (failed(layout.verifyLayout(shape,
emitError)))
643 return emitError() <<
"unsupported memory space Attribute";
648 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649 if (!isLastDimUnitStride())
652 auto memrefShape =
getShape().take_back(n);
653 if (ShapedType::isDynamicShape(memrefShape))
656 if (getLayout().isIdentity())
671 for (
auto dim : llvm::reverse(memrefShape.drop_front(1))) {
673 flattenedDims.push_back(dimProduct);
676 strides = strides.drop_back(1);
677 return llvm::equal(strides, llvm::reverse(flattenedDims));
680 MemRefType MemRefType::canonicalizeStridedLayout() {
681 AffineMap m = getLayout().getAffineMap();
688 if (m.getNumResults() > 1)
692 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
693 if (
auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
694 if (cst.getValue() == 0)
709 auto simplifiedLayoutExpr =
711 if (expr != simplifiedLayoutExpr)
714 simplifiedLayoutExpr)));
724 if (
auto dim = dyn_cast<AffineDimExpr>(e))
725 strides[dim.getPosition()] =
726 strides[dim.getPosition()] + multiplicativeFactor;
728 offset = offset + e * multiplicativeFactor;
739 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
751 auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
753 strides[dim.getPosition()] =
754 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
761 if (bin.getLHS().isSymbolicOrConstant())
762 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
764 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
770 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
772 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
773 return success(succeeded(res1) && succeeded(res2));
776 llvm_unreachable(
"unexpected binary operation");
793 SmallVectorImpl<AffineExpr> &strides,
795 AffineMap m = t.getLayout().getAffineMap();
797 if (m.getNumResults() != 1 && !m.isIdentity())
803 strides.assign(t.getRank(), zero);
806 if (m.isIdentity()) {
808 if (t.getRank() == 0)
814 assert(
false &&
"unexpected failure: extract strides in canonical layout");
827 unsigned numDims = m.getNumDims();
828 unsigned numSymbols = m.getNumSymbols();
830 for (
auto &stride : strides)
839 if (
auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
840 llvm::append_range(strides, strided.getStrides());
841 offset = strided.getOffset();
851 if (
auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
852 offset = cst.getValue();
854 offset = ShapedType::kDynamic;
855 for (
auto e : strideExprs) {
856 if (
auto c = llvm::dyn_cast<AffineConstantExpr>(e))
857 strides.push_back(c.getValue());
859 strides.push_back(ShapedType::kDynamic);
869 assert(succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
870 return {strides, offset};
873 bool MemRefType::isStrided() {
877 return succeeded(res);
880 bool MemRefType::isLastDimUnitStride() {
884 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
899 return emitError() <<
"invalid memref element type";
902 return emitError() <<
"unsupported memory space Attribute";
912 ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
919 for (
Type type : getTypes()) {
920 if (
auto nestedTuple = llvm::dyn_cast<TupleType>(type))
921 nestedTuple.getFlattenedTypes(types);
923 types.push_back(type);
928 size_t TupleType::size()
const {
return getImpl()->size(); }
941 assert(!exprs.empty() &&
"expected exprs");
943 assert(!maps.empty() &&
"Expected one non-empty map");
944 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
947 bool dynamicPoisonBit =
false;
948 int64_t runningSize = 1;
949 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
950 int64_t size = std::get<1>(en);
955 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
958 assert(runningSize > 0 &&
"integer overflow in size computation");
960 dynamicPoisonBit =
true;
969 exprs.reserve(sizes.size());
970 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setElementType(Type newElementType)
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)