10 #include "TypeDetail.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
39 void BuiltinDialect::registerTypes() {
41 #define GET_TYPEDEF_LIST
42 #include "mlir/IR/BuiltinTypes.cpp.inc"
54 return emitError() <<
"invalid element type for complex";
65 SignednessSemantics signedness) {
66 if (width > IntegerType::kMaxWidth) {
67 return emitError() <<
"integer bitwidth is limited to "
68 << IntegerType::kMaxWidth <<
" bits";
73 unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
75 IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
76 return getImpl()->signedness;
79 IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
90 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
91 Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*
this))
93 if (llvm::isa<Float16Type, BFloat16Type>(*
this))
95 if (llvm::isa<Float32Type, FloatTF32Type>(*
this))
97 if (llvm::isa<Float64Type>(*
this))
99 if (llvm::isa<Float80Type>(*
this))
101 if (llvm::isa<Float128Type>(*
this))
103 llvm_unreachable(
"unexpected float type");
108 if (llvm::isa<Float8E5M2Type>(*
this))
109 return APFloat::Float8E5M2();
110 if (llvm::isa<Float8E4M3FNType>(*
this))
111 return APFloat::Float8E4M3FN();
112 if (llvm::isa<Float8E5M2FNUZType>(*
this))
113 return APFloat::Float8E5M2FNUZ();
114 if (llvm::isa<Float8E4M3FNUZType>(*
this))
115 return APFloat::Float8E4M3FNUZ();
116 if (llvm::isa<Float8E4M3B11FNUZType>(*
this))
117 return APFloat::Float8E4M3B11FNUZ();
118 if (llvm::isa<BFloat16Type>(*
this))
119 return APFloat::BFloat();
120 if (llvm::isa<Float16Type>(*
this))
121 return APFloat::IEEEhalf();
122 if (llvm::isa<FloatTF32Type>(*
this))
123 return APFloat::FloatTF32();
124 if (llvm::isa<Float32Type>(*
this))
125 return APFloat::IEEEsingle();
126 if (llvm::isa<Float64Type>(*
this))
127 return APFloat::IEEEdouble();
128 if (llvm::isa<Float80Type>(*
this))
129 return APFloat::x87DoubleExtended();
130 if (llvm::isa<Float128Type>(*
this))
131 return APFloat::IEEEquad();
132 llvm_unreachable(
"non-floating point type used");
139 if (isF16() || isBF16()) {
152 return APFloat::semanticsPrecision(getFloatSemantics());
159 unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
162 return getImpl()->getInputs();
165 unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
168 return getImpl()->getResults();
177 FunctionType FunctionType::getWithArgsAndResults(
184 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
185 return clone(newArgTypes, newResultTypes);
190 FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
191 const BitVector &resultIndices) {
196 return clone(newArgTypes, newResultTypes);
205 StringAttr dialect, StringRef typeData) {
207 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
214 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
215 <<
"` type created with unregistered dialect. If this is "
216 "intended, please call allowUnregisteredDialects() on the "
217 "MLIRContext, or use -allow-unregistered-dialect with "
218 "the MLIR opt tool used";
231 if (!isValidElementType(elementType))
233 <<
"vector elements must be int/index/float type but got "
236 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
238 <<
"vector types must have positive constant sizes but got "
241 if (scalableDims.size() != shape.size())
242 return emitError() <<
"number of dims must match, got "
243 << scalableDims.size() <<
" and " << shape.size();
248 VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
252 if (
auto scaledEt = et.scaleElementBitwidth(scale))
255 if (
auto scaledEt = et.scaleElementBitwidth(scale))
261 Type elementType)
const {
272 .Case<RankedTensorType, UnrankedTensorType>(
273 [](
auto type) {
return type.getElementType(); });
279 return llvm::cast<RankedTensorType>(*this).getShape();
283 Type elementType)
const {
284 if (llvm::dyn_cast<UnrankedTensorType>(*
this)) {
290 auto rankedTy = llvm::cast<RankedTensorType>(*
this);
293 rankedTy.getEncoding());
295 rankedTy.getEncoding());
299 Type elementType)
const {
300 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
304 return ::llvm::cast<RankedTensorType>(cloneWith(shape,
getElementType()));
312 return emitError() <<
"invalid tensor element type: " << elementType;
321 return llvm::isa<ComplexType,
FloatType, IntegerType, OpaqueType, VectorType,
323 !llvm::isa<BuiltinDialect>(type.
getDialect());
334 for (int64_t s : shape)
335 if (s < 0 && !ShapedType::isDynamic(s))
336 return emitError() <<
"invalid tensor dimension size";
337 if (
auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
360 [](
auto type) {
return type.getElementType(); });
366 return llvm::cast<MemRefType>(*this).getShape();
370 Type elementType)
const {
371 if (llvm::dyn_cast<UnrankedMemRefType>(*
this)) {
387 Type elementType)
const {
388 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
392 return ::llvm::cast<MemRefType>(cloneWith(shape,
getElementType()));
396 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
397 return rankedMemRefTy.getMemorySpace();
398 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
402 if (
auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*
this))
403 return rankedMemRefTy.getMemorySpaceAsInt();
404 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
419 std::optional<llvm::SmallDenseSet<unsigned>>
422 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
423 llvm::SmallDenseSet<unsigned> unusedDims;
424 unsigned reducedIdx = 0;
425 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
427 if (reducedIdx < reducedRank &&
428 originalShape[originalIdx] == reducedShape[reducedIdx]) {
433 unusedDims.insert(originalIdx);
436 if (originalShape[originalIdx] != 1)
440 if (reducedIdx != reducedRank)
447 ShapedType candidateReducedType) {
448 if (originalType == candidateReducedType)
451 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
452 ShapedType candidateReducedShapedType =
453 llvm::cast<ShapedType>(candidateReducedType);
458 candidateReducedShapedType.getShape();
459 unsigned originalRank = originalShape.size(),
460 candidateReducedRank = candidateReducedShape.size();
461 if (candidateReducedRank > originalRank)
464 auto optionalUnusedDimsMask =
468 if (!optionalUnusedDimsMask)
471 if (originalShapedType.getElementType() !=
472 candidateReducedShapedType.getElementType())
484 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
488 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
496 if (memorySpace == 0)
503 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
504 if (intMemorySpace && intMemorySpace.getValue() == 0)
514 assert(llvm::isa<IntegerAttr>(memorySpace) &&
515 "Using `getMemorySpaceInteger` with non-Integer attribute");
517 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
525 MemRefLayoutAttrInterface layout,
539 MemRefType MemRefType::getChecked(
541 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
551 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
552 elementType, layout, memorySpace);
589 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
590 elementType, layout, memorySpace);
594 AffineMap map,
unsigned memorySpaceInd) {
615 unsigned memorySpaceInd) {
629 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
630 elementType, layout, memorySpace);
635 MemRefLayoutAttrInterface layout,
638 return emitError() <<
"invalid memref element type";
641 for (int64_t s : shape)
642 if (s < 0 && !ShapedType::isDynamic(s))
643 return emitError() <<
"invalid memref size";
645 assert(layout &&
"missing layout specification");
650 return emitError() <<
"unsupported memory space Attribute";
667 return emitError() <<
"invalid memref element type";
670 return emitError() <<
"unsupported memory space Attribute";
682 strides[dim.getPosition()] =
683 strides[dim.getPosition()] + multiplicativeFactor;
685 offset = offset + e * multiplicativeFactor;
710 strides[dim.getPosition()] =
711 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
718 if (bin.getLHS().isSymbolicOrConstant())
719 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
721 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
727 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
729 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
733 llvm_unreachable(
"unexpected binary operation");
750 SmallVectorImpl<AffineExpr> &strides,
752 AffineMap m = t.getLayout().getAffineMap();
760 strides.assign(t.getRank(), zero);
765 if (t.getRank() == 0)
771 assert(
false &&
"unexpected failure: extract strides in canonical layout");
787 for (
auto &stride : strides)
811 if (
auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
812 llvm::append_range(strides, strided.getStrides());
813 offset = strided.getOffset();
824 offset = cst.getValue();
826 offset = ShapedType::kDynamic;
827 for (
auto e : strideExprs) {
829 strides.push_back(c.getValue());
831 strides.push_back(ShapedType::kDynamic);
836 std::pair<SmallVector<int64_t>, int64_t>
842 assert(
succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
843 return {strides, offset};
851 ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
858 for (
Type type : getTypes()) {
859 if (
auto nestedTuple = llvm::dyn_cast<TupleType>(type))
860 nestedTuple.getFlattenedTypes(types);
862 types.push_back(type);
867 size_t TupleType::size()
const {
return getImpl()->size(); }
879 AffineMap m = t.getLayout().getAffineMap();
892 if (cst.getValue() == 0)
900 if (t.getShape().empty())
908 auto simplifiedLayoutExpr =
910 if (expr != simplifiedLayoutExpr)
923 assert(!exprs.empty() &&
"expected exprs");
925 assert(!maps.empty() &&
"Expected one non-empty map");
926 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
929 bool dynamicPoisonBit =
false;
930 int64_t runningSize = 1;
931 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
932 int64_t size = std::get<1>(en);
937 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
940 assert(runningSize > 0 &&
"integer overflow in size computation");
942 dynamicPoisonBit =
true;
951 exprs.reserve(sizes.size());
952 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
968 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
static FloatType getF64(MLIRContext *ctx)
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setElementType(Type newElementType)
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Detect if any of the given parameter types has a sub-element handler.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This class represents an efficient way to signal success or failure.