10 #include "TypeDetail.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
40 void BuiltinDialect::registerTypes() {
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
55 return emitError() <<
"invalid element type for complex";
66 SignednessSemantics signedness) {
67 if (width > IntegerType::kMaxWidth) {
68 return emitError() <<
"integer bitwidth is limited to "
69 << IntegerType::kMaxWidth <<
" bits";
74 unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
76 IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
77 return getImpl()->signedness;
80 IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
83 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
91 if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
92 Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
94 if (isa<Float16Type, BFloat16Type>())
96 if (isa<Float32Type>())
98 if (isa<Float64Type>())
100 if (isa<Float80Type>())
102 if (isa<Float128Type>())
104 llvm_unreachable(
"unexpected float type");
109 if (isa<Float8E5M2Type>())
110 return APFloat::Float8E5M2();
111 if (isa<Float8E4M3FNType>())
112 return APFloat::Float8E4M3FN();
113 if (isa<Float8E5M2FNUZType>())
114 return APFloat::Float8E5M2FNUZ();
115 if (isa<Float8E4M3FNUZType>())
116 return APFloat::Float8E4M3FNUZ();
117 if (isa<Float8E4M3B11FNUZType>())
118 return APFloat::Float8E4M3B11FNUZ();
119 if (isa<BFloat16Type>())
120 return APFloat::BFloat();
121 if (isa<Float16Type>())
122 return APFloat::IEEEhalf();
123 if (isa<Float32Type>())
124 return APFloat::IEEEsingle();
125 if (isa<Float64Type>())
126 return APFloat::IEEEdouble();
127 if (isa<Float80Type>())
128 return APFloat::x87DoubleExtended();
129 if (isa<Float128Type>())
130 return APFloat::IEEEquad();
131 llvm_unreachable(
"non-floating point type used");
138 if (isF16() || isBF16()) {
151 return APFloat::semanticsPrecision(getFloatSemantics());
158 unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
161 return getImpl()->getInputs();
164 unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
167 return getImpl()->getResults();
171 return get(getContext(), inputs, results);
176 FunctionType FunctionType::getWithArgsAndResults(
181 getInputs(), argIndices, argTypes, argStorage);
183 getResults(), resultIndices, resultTypes, resultStorage);
184 return clone(newArgTypes, newResultTypes);
189 FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
190 const BitVector &resultIndices) {
193 getInputs(), argIndices, argStorage);
195 getResults(), resultIndices, resultStorage);
196 return clone(newArgTypes, newResultTypes);
205 StringAttr dialect, StringRef typeData) {
207 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
214 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">"
215 <<
"` type created with unregistered dialect. If this is "
216 "intended, please call allowUnregisteredDialects() on the "
217 "MLIRContext, or use -allow-unregistered-dialect with "
218 "the MLIR opt tool used";
230 unsigned numScalableDims) {
231 if (!isValidElementType(elementType))
233 <<
"vector elements must be int/index/float type but got "
236 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
238 <<
"vector types must have positive constant sizes but got "
244 VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
248 if (
auto scaledEt = et.scaleElementBitwidth(scale))
249 return VectorType::get(
getShape(), scaledEt, getNumScalableDims());
251 if (
auto scaledEt = et.scaleElementBitwidth(scale))
252 return VectorType::get(
getShape(), scaledEt, getNumScalableDims());
257 Type elementType)
const {
258 return VectorType::get(shape.value_or(
getShape()), elementType,
259 getNumScalableDims());
268 .Case<RankedTensorType, UnrankedTensorType>(
269 [](
auto type) {
return type.getElementType(); });
275 return cast<RankedTensorType>().getShape();
279 Type elementType)
const {
280 if (
auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
282 return RankedTensorType::get(*shape, elementType);
283 return UnrankedTensorType::get(elementType);
286 auto rankedTy = cast<RankedTensorType>();
288 return RankedTensorType::get(rankedTy.getShape(), elementType,
289 rankedTy.getEncoding());
290 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
291 rankedTy.getEncoding());
299 return emitError() <<
"invalid tensor element type: " << elementType;
308 return type.
isa<ComplexType,
FloatType, IntegerType, OpaqueType, VectorType,
310 !llvm::isa<BuiltinDialect>(type.
getDialect());
321 for (int64_t s : shape)
322 if (s < 0 && !ShapedType::isDynamic(s))
323 return emitError() <<
"invalid tensor dimension size";
347 [](
auto type) {
return type.getElementType(); });
353 return cast<MemRefType>().getShape();
357 Type elementType)
const {
358 if (
auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
360 return UnrankedMemRefType::get(elementType, getMemorySpace());
374 if (
auto rankedMemRefTy = dyn_cast<MemRefType>())
375 return rankedMemRefTy.getMemorySpace();
376 return cast<UnrankedMemRefType>().getMemorySpace();
380 if (
auto rankedMemRefTy = dyn_cast<MemRefType>())
381 return rankedMemRefTy.getMemorySpaceAsInt();
382 return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
397 std::optional<llvm::SmallDenseSet<unsigned>>
400 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
401 llvm::SmallDenseSet<unsigned> unusedDims;
402 unsigned reducedIdx = 0;
403 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
405 if (reducedIdx < reducedRank &&
406 originalShape[originalIdx] == reducedShape[reducedIdx]) {
411 unusedDims.insert(originalIdx);
414 if (originalShape[originalIdx] != 1)
418 if (reducedIdx != reducedRank)
425 ShapedType candidateReducedType) {
426 if (originalType == candidateReducedType)
429 ShapedType originalShapedType = originalType.cast<ShapedType>();
430 ShapedType candidateReducedShapedType =
431 candidateReducedType.cast<ShapedType>();
436 candidateReducedShapedType.getShape();
437 unsigned originalRank = originalShape.size(),
438 candidateReducedRank = candidateReducedShape.size();
439 if (candidateReducedRank > originalRank)
442 auto optionalUnusedDimsMask =
446 if (!optionalUnusedDimsMask)
449 if (originalShapedType.getElementType() !=
450 candidateReducedShapedType.getElementType())
462 if (memorySpace.
isa<IntegerAttr, StringAttr, DictionaryAttr>())
466 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
474 if (memorySpace == 0)
477 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
482 if (intMemorySpace && intMemorySpace.getValue() == 0)
492 assert(memorySpace.
isa<IntegerAttr>() &&
493 "Using `getMemorySpaceInteger` with non-Integer attribute");
495 return static_cast<unsigned>(memorySpace.
cast<IntegerAttr>().getInt());
503 MemRefLayoutAttrInterface layout,
513 return Base::get(elementType.
getContext(), shape, elementType, layout,
517 MemRefType MemRefType::getChecked(
519 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
529 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
530 elementType, layout, memorySpace);
542 Attribute layout = AffineMapAttr::get(map);
547 return Base::get(elementType.
getContext(), shape, elementType, layout,
562 Attribute layout = AffineMapAttr::get(map);
567 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
568 elementType, layout, memorySpace);
572 AffineMap map,
unsigned memorySpaceInd) {
580 Attribute layout = AffineMapAttr::get(map);
586 return Base::get(elementType.
getContext(), shape, elementType, layout,
593 unsigned memorySpaceInd) {
601 Attribute layout = AffineMapAttr::get(map);
607 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
608 elementType, layout, memorySpace);
613 MemRefLayoutAttrInterface layout,
616 return emitError() <<
"invalid memref element type";
619 for (int64_t s : shape)
620 if (s < 0 && !ShapedType::isDynamic(s))
621 return emitError() <<
"invalid memref size";
623 assert(layout &&
"missing layout specification");
628 return emitError() <<
"unsupported memory space Attribute";
645 return emitError() <<
"invalid memref element type";
648 return emitError() <<
"unsupported memory space Attribute";
660 strides[dim.getPosition()] =
661 strides[dim.getPosition()] + multiplicativeFactor;
663 offset = offset + e * multiplicativeFactor;
688 strides[dim.getPosition()] =
689 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
696 if (bin.getLHS().isSymbolicOrConstant())
697 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
699 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
705 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
707 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
711 llvm_unreachable(
"unexpected binary operation");
728 SmallVectorImpl<AffineExpr> &strides,
730 AffineMap m = t.getLayout().getAffineMap();
738 strides.assign(t.getRank(), zero);
743 if (t.getRank() == 0)
749 assert(
false &&
"unexpected failure: extract strides in canonical layout");
765 for (
auto &stride : strides)
789 if (
auto strided = t.getLayout().dyn_cast<StridedLayoutAttr>()) {
790 llvm::append_range(strides, strided.getStrides());
791 offset = strided.getOffset();
802 offset = cst.getValue();
804 offset = ShapedType::kDynamic;
805 for (
auto e : strideExprs) {
807 strides.push_back(c.getValue());
809 strides.push_back(ShapedType::kDynamic);
814 std::pair<SmallVector<int64_t>, int64_t>
820 assert(
succeeded(status) &&
"Invalid use of check-free getStridesAndOffset");
821 return {strides, offset};
829 ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
836 for (
Type type : getTypes()) {
837 if (
auto nestedTuple = type.dyn_cast<TupleType>())
838 nestedTuple.getFlattenedTypes(types);
840 types.push_back(type);
845 size_t TupleType::size()
const {
return getImpl()->size(); }
857 AffineMap m = t.getLayout().getAffineMap();
870 if (cst.getValue() == 0)
878 if (t.getShape().empty())
886 auto simplifiedLayoutExpr =
888 if (expr != simplifiedLayoutExpr)
901 assert(!exprs.empty() &&
"expected exprs");
903 assert(!maps.empty() &&
"Expected one non-empty map");
904 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
907 bool dynamicPoisonBit =
false;
908 int64_t runningSize = 1;
909 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
910 int64_t size = std::get<1>(en);
915 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
918 assert(runningSize > 0 &&
"integer overflow in size computation");
920 dynamicPoisonBit =
true;
929 exprs.reserve(sizes.size());
930 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
U dyn_cast_or_null() const
Dialect & getDialect() const
Get the dialect this attribute is registered to.
bool isa() const
Casting utility functions.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
static FloatType getF64(MLIRContext *ctx)
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
static FloatType getF32(MLIRContext *ctx)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setElementType(Type newElementType)
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Detect if any of the given parameter types has a sub-element handler.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This class represents an efficient way to signal success or failure.