10 #include "TypeDetail.h" 20 #include "llvm/ADT/APFloat.h" 21 #include "llvm/ADT/BitVector.h" 22 #include "llvm/ADT/Sequence.h" 23 #include "llvm/ADT/Twine.h" 24 #include "llvm/ADT/TypeSwitch.h" 33 #define GET_TYPEDEF_CLASSES 34 #include "mlir/IR/BuiltinTypes.cpp.inc" 40 void BuiltinDialect::registerTypes() {
42 #define GET_TYPEDEF_LIST 43 #include "mlir/IR/BuiltinTypes.cpp.inc" 55 return emitError() <<
"invalid element type for complex";
66 SignednessSemantics signedness) {
67 if (width > IntegerType::kMaxWidth) {
68 return emitError() <<
"integer bitwidth is limited to " 69 << IntegerType::kMaxWidth <<
" bits";
74 unsigned IntegerType::getWidth()
const {
return getImpl()->width; }
76 IntegerType::SignednessSemantics IntegerType::getSignedness()
const {
77 return getImpl()->signedness;
80 IntegerType IntegerType::scaleElementBitwidth(
unsigned scale) {
83 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
91 if (isa<Float16Type, BFloat16Type>())
93 if (isa<Float32Type>())
95 if (isa<Float64Type>())
97 if (isa<Float80Type>())
99 if (isa<Float128Type>())
101 llvm_unreachable(
"unexpected float type");
106 if (isa<BFloat16Type>())
107 return APFloat::BFloat();
108 if (isa<Float16Type>())
109 return APFloat::IEEEhalf();
110 if (isa<Float32Type>())
111 return APFloat::IEEEsingle();
112 if (isa<Float64Type>())
113 return APFloat::IEEEdouble();
114 if (isa<Float80Type>())
115 return APFloat::x87DoubleExtended();
116 if (isa<Float128Type>())
117 return APFloat::IEEEquad();
118 llvm_unreachable(
"non-floating point type used");
125 if (isF16() || isBF16()) {
138 return APFloat::semanticsPrecision(getFloatSemantics());
145 unsigned FunctionType::getNumInputs()
const {
return getImpl()->numInputs; }
148 return getImpl()->getInputs();
151 unsigned FunctionType::getNumResults()
const {
return getImpl()->numResults; }
154 return getImpl()->getResults();
158 return get(getContext(), inputs, results);
163 FunctionType FunctionType::getWithArgsAndResults(
168 getInputs(), argIndices, argTypes, argStorage);
170 getResults(), resultIndices, resultTypes, resultStorage);
171 return clone(newArgTypes, newResultTypes);
176 FunctionType::getWithoutArgsAndResults(
const BitVector &argIndices,
177 const BitVector &resultIndices) {
180 getInputs(), argIndices, argStorage);
182 getResults(), resultIndices, resultStorage);
183 return clone(newArgTypes, newResultTypes);
186 void FunctionType::walkImmediateSubElements(
189 for (
Type type : llvm::concat<const Type>(getInputs(), getResults()))
195 unsigned numInputs = getNumInputs();
196 return get(getContext(), replTypes.take_front(numInputs),
197 replTypes.drop_front(numInputs));
206 StringAttr dialect, StringRef typeData) {
208 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
215 <<
"`!" << dialect <<
"<\"" << typeData <<
"\">" 216 <<
"` type created with unregistered dialect. If this is " 217 "intended, please call allowUnregisteredDialects() on the " 218 "MLIRContext, or use -allow-unregistered-dialect with " 219 "the MLIR opt tool used";
231 unsigned numScalableDims) {
232 if (!isValidElementType(elementType))
234 <<
"vector elements must be int/index/float type but got " 237 if (any_of(shape, [](int64_t i) {
return i <= 0; }))
239 <<
"vector types must have positive constant sizes but got " 245 VectorType VectorType::scaleElementBitwidth(
unsigned scale) {
249 if (
auto scaledEt = et.scaleElementBitwidth(scale))
250 return VectorType::get(
getShape(), scaledEt, getNumScalableDims());
252 if (
auto scaledEt = et.scaleElementBitwidth(scale))
253 return VectorType::get(
getShape(), scaledEt, getNumScalableDims());
257 void VectorType::walkImmediateSubElements(
265 return get(
getShape(), replTypes.front(), getNumScalableDims());
269 Type elementType)
const {
270 return VectorType::get(shape.value_or(
getShape()), elementType,
271 getNumScalableDims());
280 .Case<RankedTensorType, UnrankedTensorType>(
281 [](
auto type) {
return type.getElementType(); });
287 return cast<RankedTensorType>().
getShape();
291 Type elementType)
const {
292 if (
auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
294 return RankedTensorType::get(*shape, elementType);
295 return UnrankedTensorType::get(elementType);
298 auto rankedTy = cast<RankedTensorType>();
300 return RankedTensorType::get(rankedTy.getShape(), elementType,
301 rankedTy.getEncoding());
302 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
303 rankedTy.getEncoding());
311 return emitError() <<
"invalid tensor element type: " << elementType;
320 return type.
isa<ComplexType,
FloatType, IntegerType, OpaqueType, VectorType,
322 !llvm::isa<BuiltinDialect>(type.
getDialect());
333 for (int64_t s : shape)
335 return emitError() <<
"invalid tensor dimension size";
342 void RankedTensorType::walkImmediateSubElements(
347 walkAttrsFn(encoding);
350 Type RankedTensorType::replaceImmediateSubElements(
352 return get(
getShape(), replTypes.front(),
353 replAttrs.empty() ?
Attribute() : replAttrs.back());
366 void UnrankedTensorType::walkImmediateSubElements(
372 Type UnrankedTensorType::replaceImmediateSubElements(
374 return get(replTypes.front());
384 [](
auto type) {
return type.getElementType(); });
390 return cast<MemRefType>().
getShape();
394 Type elementType)
const {
395 if (
auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
397 return UnrankedMemRefType::get(elementType, getMemorySpace());
411 if (
auto rankedMemRefTy = dyn_cast<MemRefType>())
412 return rankedMemRefTy.getMemorySpace();
413 return cast<UnrankedMemRefType>().getMemorySpace();
417 if (
auto rankedMemRefTy = dyn_cast<MemRefType>())
418 return rankedMemRefTy.getMemorySpaceAsInt();
437 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
438 llvm::SmallDenseSet<unsigned> unusedDims;
439 unsigned reducedIdx = 0;
440 for (
unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
442 if (reducedIdx < reducedRank &&
443 originalShape[originalIdx] == reducedShape[reducedIdx]) {
448 unusedDims.insert(originalIdx);
451 if (originalShape[originalIdx] != 1)
455 if (reducedIdx != reducedRank)
462 ShapedType candidateReducedType) {
463 if (originalType == candidateReducedType)
466 ShapedType originalShapedType = originalType.cast<ShapedType>();
467 ShapedType candidateReducedShapedType =
468 candidateReducedType.cast<ShapedType>();
473 candidateReducedShapedType.getShape();
474 unsigned originalRank = originalShape.size(),
475 candidateReducedRank = candidateReducedShape.size();
476 if (candidateReducedRank > originalRank)
479 auto optionalUnusedDimsMask =
483 if (!optionalUnusedDimsMask)
486 if (originalShapedType.getElementType() !=
487 candidateReducedShapedType.getElementType())
499 if (memorySpace.
isa<IntegerAttr, StringAttr, DictionaryAttr>())
503 if (!isa<BuiltinDialect>(memorySpace.
getDialect()))
511 if (memorySpace == 0)
514 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
519 if (intMemorySpace && intMemorySpace.getValue() == 0)
529 assert(memorySpace.
isa<IntegerAttr>() &&
530 "Using `getMemorySpaceInteger` with non-Integer attribute");
532 return static_cast<unsigned>(memorySpace.
cast<IntegerAttr>().getInt());
547 MemRefLayoutAttrInterface layout,
557 return Base::get(elementType.
getContext(), shape, elementType, layout,
561 MemRefType MemRefType::getChecked(
563 Type elementType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
573 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
574 elementType, layout, memorySpace);
586 Attribute layout = AffineMapAttr::get(map);
591 return Base::get(elementType.
getContext(), shape, elementType, layout,
606 Attribute layout = AffineMapAttr::get(map);
611 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
612 elementType, layout, memorySpace);
616 AffineMap map,
unsigned memorySpaceInd) {
624 Attribute layout = AffineMapAttr::get(map);
630 return Base::get(elementType.
getContext(), shape, elementType, layout,
637 unsigned memorySpaceInd) {
645 Attribute layout = AffineMapAttr::get(map);
651 return Base::getChecked(emitErrorFn, elementType.
getContext(), shape,
652 elementType, layout, memorySpace);
657 MemRefLayoutAttrInterface layout,
660 return emitError() <<
"invalid memref element type";
663 for (int64_t s : shape)
665 return emitError() <<
"invalid memref size";
667 assert(layout &&
"missing layout specification");
672 return emitError() <<
"unsupported memory space Attribute";
677 void MemRefType::walkImmediateSubElements(
681 if (!getLayout().isIdentity())
682 walkAttrsFn(getLayout());
683 walkAttrsFn(getMemorySpace());
688 bool hasLayout = replAttrs.size() > 1;
689 return get(
getShape(), replTypes[0],
690 hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
691 : MemRefLayoutAttrInterface(),
692 hasLayout ? replAttrs[1] : replAttrs[0]);
707 return emitError() <<
"invalid memref element type";
710 return emitError() <<
"unsupported memory space Attribute";
722 strides[dim.getPosition()] =
723 strides[dim.getPosition()] + multiplicativeFactor;
725 offset = offset + e * multiplicativeFactor;
750 strides[dim.getPosition()] =
751 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
758 if (bin.getLHS().isSymbolicOrConstant())
759 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
761 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
767 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
769 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
773 llvm_unreachable(
"unexpected binary operation");
779 AffineMap m = t.getLayout().getAffineMap();
787 strides.assign(t.getRank(), zero);
792 if (t.getRank() == 0)
798 assert(
false &&
"unexpected failure: extract strides in canonical layout");
814 for (
auto &stride : strides)
842 offset = cst.getValue();
844 offset = ShapedType::kDynamicStrideOrOffset;
845 for (
auto e : strideExprs) {
847 strides.push_back(c.getValue());
849 strides.push_back(ShapedType::kDynamicStrideOrOffset);
854 void UnrankedMemRefType::walkImmediateSubElements(
858 walkAttrsFn(getMemorySpace());
861 Type UnrankedMemRefType::replaceImmediateSubElements(
863 return get(replTypes.front(), replAttrs.front());
871 ArrayRef<Type> TupleType::getTypes()
const {
return getImpl()->getTypes(); }
878 for (
Type type : getTypes()) {
879 if (
auto nestedTuple = type.dyn_cast<TupleType>())
880 nestedTuple.getFlattenedTypes(types);
882 types.push_back(type);
887 size_t TupleType::size()
const {
return getImpl()->size(); }
889 void TupleType::walkImmediateSubElements(
892 for (
Type type : getTypes())
898 return get(getContext(), replTypes);
909 unsigned nSymbols = 0;
913 if (offset != MemRefType::getDynamicStrideOrOffset()) {
924 auto dim = en.index();
925 auto stride = en.value();
926 assert(stride != 0 &&
"Invalid stride specification");
930 if (stride != MemRefType::getDynamicStrideOrOffset())
935 expr = expr + d * mult;
947 AffineMap m = t.getLayout().getAffineMap();
960 if (cst.getValue() == 0)
968 if (t.getShape().empty())
976 auto simplifiedLayoutExpr =
978 if (expr != simplifiedLayoutExpr)
991 assert(!exprs.empty() &&
"expected exprs");
993 assert(!maps.empty() &&
"Expected one non-empty map");
994 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
997 bool dynamicPoisonBit =
false;
998 int64_t runningSize = 1;
999 for (
auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
1000 int64_t size = std::get<1>(en);
1005 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
1007 runningSize *= size;
1008 assert(runningSize > 0 &&
"integer overflow in size computation");
1010 dynamicPoisonBit =
true;
1019 auto val = ShapedType::kDynamicStrideOrOffset;
1028 exprs.reserve(sizes.size());
1029 for (
auto dim : llvm::seq<unsigned>(0, sizes.size()))
1058 assert(
false &&
"expected strided memref");
1069 AffineExpr contiguousRowMajor = canonical + offset;
1072 return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1079 if (!memrefType.hasStaticShape())
1083 memrefType.getContext(), memrefType.getShape(),
1084 memrefType.getElementType(), offset);
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Affine binary operation expression.
Include the generated interface declarations.
Dialect & getDialect() const
Get the dialect this type is registered to.
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
RHS of mod is always a constant or a symbolic expression with a positive value.
U dyn_cast_or_null() const
MemRefType eraseStridedLayout(MemRefType t)
Return a version of t with a layout that has all dynamic offset and strides.
unsigned getNumSymbols() const
unsigned getNumDims() const
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
This class represents a diagnostic that is inflight and set to be reported.
AffineMap getStridedLinearLayoutMap(MemRefType t)
Return the layout map in strided linear layout AffineMap form.
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static AffineExpr getOffsetExpr(MemRefType memrefType)
Return the AffineExpr representation of the offset, assuming memRefType is a strided memref...
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
static FloatType getF32(MLIRContext *ctx)
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
TensorType cloneWith(Optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
An integer constant appearing in affine expression.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getResult(unsigned idx) const
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
unsigned getWidth()
Return the bitwidth of this float type.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Attributes are known-constant values of operations.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
BaseMemRefType cloneWith(Optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Base type for affine expression.
MLIRContext * getContext() const
RHS of mul is always a constant or a symbolic expression.
This class provides an abstraction over the various different ranges of value types.
unsigned getNumResults() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
RHS of floordiv is always a constant or a symbolic expression.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
RHS of ceildiv is always a constant or a symbolic expression.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, ArrayRef< int64_t > shape, Type elementType, AffineExpr offset)
Helper to construct a contiguous MemRefType of shape, elementType and offset AffineExpr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Type getElementType() const
Returns the element type of this memref type.
Builder & setElementType(Type newElementType)
This is a builder type that keeps local references to arguments.
A dimensional identifier appearing in an affine expression.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
MLIRContext is the top-level object for a collection of MLIR operations.
Type getElementType() const
Returns the element type of this tensor type.
FloatType scaleElementBitwidth(unsigned scale)
Get or create a new FloatType with bitwidth scaled by scale.
AffineMap makeStridedLinearLayoutMap(ArrayRef< int64_t > strides, int64_t offset, MLIRContext *context)
Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() represents a dynamic value)...
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Builder & setMemorySpace(Attribute newMemorySpace)
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
Builder & setShape(ArrayRef< int64_t > newShape)
static FloatType getF64(MLIRContext *ctx)
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType)
Helper determining if a memref is static-shape and contiguous-row-major layout, while still allowing ...
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.