20 #include "llvm/ADT/APSInt.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/Support/Endian.h" 31 #define GET_ATTRDEF_CLASSES 32 #include "mlir/IR/BuiltinAttributes.cpp.inc" 38 void BuiltinDialect::registerAttributes() {
41 DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
42 IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
43 SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
50 void ArrayAttr::walkImmediateSubElements(
57 SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
58 ArrayRef<std::pair<size_t, Attribute>> replacements)
const {
59 std::vector<Attribute> vector = getValue().vec();
60 for (
auto &it : replacements) {
61 vector[it.first] = it.second;
63 return get(getContext(), vector);
74 template <
bool inPlace>
78 switch (value.size()) {
87 storage.assign({value[0]});
90 bool isSorted = value[0] < value[1];
93 std::swap(storage[0], storage[1]);
94 }
else if (isSorted) {
95 storage.assign({value[0], value[1]});
97 storage.assign({value[1], value[0]});
103 storage.assign(value.begin(), value.end());
105 bool isSorted = llvm::is_sorted(value);
108 llvm::array_pod_sort(storage.begin(), storage.end());
116 static Optional<NamedAttribute>
118 const Optional<NamedAttribute> none{llvm::None};
119 if (value.size() < 2)
122 if (value.size() == 2)
123 return value[0].getName() == value[1].getName() ? value[0] : none;
125 const auto *it = std::adjacent_find(value.begin(), value.end(),
127 return l.
getName() == r.getName();
129 return it != value.end() ? *it : none;
136 "DictionaryAttr element names must be unique");
143 "DictionaryAttr element names must be unique");
147 Optional<NamedAttribute>
155 DictionaryAttr DictionaryAttr::get(
MLIRContext *context,
158 return DictionaryAttr::getEmpty(context);
162 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
165 "DictionaryAttr element names must be unique");
166 return Base::get(context, value);
170 DictionaryAttr DictionaryAttr::getWithSorted(
MLIRContext *context,
173 return DictionaryAttr::getEmpty(context);
175 assert(llvm::is_sorted(
177 "expected attribute values to be sorted");
179 "DictionaryAttr element names must be unique");
180 return Base::get(context, value);
184 Attribute DictionaryAttr::get(StringRef name)
const {
186 return it.second ? it.first->getValue() :
Attribute();
188 Attribute DictionaryAttr::get(StringAttr name)
const {
190 return it.second ? it.first->getValue() :
Attribute();
194 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name)
const {
196 return it.second ? *it.first : Optional<NamedAttribute>();
198 Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name)
const {
200 return it.second ? *it.first : Optional<NamedAttribute>();
211 DictionaryAttr::iterator DictionaryAttr::begin()
const {
212 return getValue().begin();
214 DictionaryAttr::iterator DictionaryAttr::end()
const {
215 return getValue().end();
217 size_t DictionaryAttr::size()
const {
return getValue().size(); }
219 DictionaryAttr DictionaryAttr::getEmptyUnchecked(
MLIRContext *context) {
223 void DictionaryAttr::walkImmediateSubElements(
227 walkAttrsFn(attr.getValue());
230 SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
231 ArrayRef<std::pair<size_t, Attribute>> replacements)
const {
232 std::vector<NamedAttribute> vec = getValue().vec();
233 for (
auto &it : replacements)
234 vec[it.first].setValue(it.second);
238 return getWithSorted(getContext(), vec);
245 StringAttr StringAttr::getEmptyStringAttrUnchecked(
MLIRContext *context) {
246 return Base::get(context,
"", NoneType::get(context));
250 StringAttr StringAttr::get(
MLIRContext *context,
const Twine &twine) {
252 if (twine.isTriviallyEmpty())
255 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
259 StringAttr StringAttr::get(
const Twine &twine,
Type type) {
261 return Base::get(type.
getContext(), twine.toStringRef(tempStr), type);
264 StringRef StringAttr::getValue()
const {
return getImpl()->value; }
266 Dialect *StringAttr::getReferencedDialect()
const {
267 return getImpl()->referencedDialect;
274 double FloatAttr::getValueAsDouble()
const {
275 return getValueAsDouble(getValue());
277 double FloatAttr::getValueAsDouble(APFloat value) {
278 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
279 bool losesInfo =
false;
280 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
283 return value.convertToDouble();
287 Type type, APFloat value) {
290 return emitError() <<
"expected floating point type";
295 <<
"FloatAttr type doesn't match the type implied by its value";
304 SymbolRefAttr SymbolRefAttr::get(
MLIRContext *ctx, StringRef value,
306 return get(StringAttr::get(ctx, value), nestedRefs);
310 return get(ctx,
value, {}).cast<FlatSymbolRefAttr>();
314 return get(
value, {}).cast<FlatSymbolRefAttr>();
320 assert(symName &&
"value does not have a valid symbol name");
321 return SymbolRefAttr::get(symName);
324 StringAttr SymbolRefAttr::getLeafReference()
const {
326 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
333 int64_t IntegerAttr::getInt()
const {
334 assert((getType().isIndex() || getType().isSignlessInteger()) &&
335 "must be signless integer");
336 return getValue().getSExtValue();
339 int64_t IntegerAttr::getSInt()
const {
340 assert(getType().isSignedInteger() &&
"must be signed integer");
341 return getValue().getSExtValue();
344 uint64_t IntegerAttr::getUInt()
const {
345 assert(getType().isUnsignedInteger() &&
"must be unsigned integer");
346 return getValue().getZExtValue();
351 APSInt IntegerAttr::getAPSInt()
const {
352 assert(!getType().isSignlessInteger() &&
353 "Signless integers don't carry a sign for APSInt");
354 return APSInt(getValue(), getType().isUnsignedInteger());
358 Type type, APInt value) {
359 if (IntegerType integerType = type.
dyn_cast<IntegerType>()) {
360 if (integerType.getWidth() != value.getBitWidth())
361 return emitError() <<
"integer type bit width (" << integerType.getWidth()
362 <<
") doesn't match value bit width (" 363 << value.getBitWidth() <<
")";
366 if (type.
isa<IndexType>())
368 return emitError() <<
"expected integer or index type";
371 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type,
bool value) {
372 auto attr = Base::get(type.
getContext(), type, APInt(1, value));
381 auto *storage =
reinterpret_cast<IntegerAttrStorage *
>(
impl);
382 return storage->value.getBoolValue();
386 IntegerAttr intAttr = attr.
dyn_cast<IntegerAttr>();
387 return intAttr && intAttr.getType().isSignlessInteger(1);
395 StringAttr dialect, StringRef attrData,
398 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
405 <<
"#" << dialect <<
"<\"" << attrData <<
"\"> : " << type
406 <<
" attribute created with unregistered dialect. If this is " 407 "intended, please call allowUnregisteredDialects() on the " 408 "MLIRContext, or use -allow-unregistered-dialect with " 409 "the MLIR opt tool used";
422 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
429 static void setBit(
char *rawData,
size_t bitPos,
bool value) {
431 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
433 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
437 static bool getBit(
const char *rawData,
size_t bitPos) {
438 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
445 assert(llvm::support::endian::system_endianness() ==
446 llvm::support::endianness::big);
447 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
452 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
453 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
454 numFilledWords, result);
458 size_t lastWordPos = numFilledWords;
460 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
461 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
462 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
466 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
467 valueLE.begin(), result + lastWordPos,
468 (numBytes - lastWordPos) * CHAR_BIT, 1);
475 assert(llvm::support::endian::system_endianness() ==
476 llvm::support::endianness::big);
477 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
484 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
486 inArray, numFilledWords,
487 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
492 size_t lastWordPos = numFilledWords;
494 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
495 inArray + lastWordPos, inArrayLE.begin(),
496 (numBytes - lastWordPos) * CHAR_BIT, 1);
500 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
502 const_cast<char *
>(
reinterpret_cast<const char *
>(result.getRawData())) +
504 APInt::APINT_BITS_PER_WORD, 1);
508 static void writeBits(
char *rawData,
size_t bitPos, APInt value) {
509 size_t bitWidth = value.getBitWidth();
513 return setBit(rawData, bitPos, value.isOneValue());
516 assert((bitPos % CHAR_BIT) == 0 &&
"expected bitPos to be 8-bit aligned");
517 if (llvm::support::endian::system_endianness() ==
518 llvm::support::endianness::big) {
525 rawData + (bitPos / CHAR_BIT));
527 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
528 llvm::divideCeil(bitWidth, CHAR_BIT),
529 rawData + (bitPos / CHAR_BIT));
535 static APInt
readBits(
const char *rawData,
size_t bitPos,
size_t bitWidth) {
538 return APInt(1,
getBit(rawData, bitPos) ? 1 : 0);
541 assert((bitPos % CHAR_BIT) == 0 &&
"expected bitPos to be 8-bit aligned");
542 APInt result(bitWidth, 0);
543 if (llvm::support::endian::system_endianness() ==
544 llvm::support::endianness::big) {
551 llvm::divideCeil(bitWidth, CHAR_BIT), result);
553 std::copy_n(rawData + (bitPos / CHAR_BIT),
554 llvm::divideCeil(bitWidth, CHAR_BIT),
556 reinterpret_cast<const char *>(result.getRawData())));
563 template <
typename Values>
565 return (values.size() == 1) ||
566 (type.getNumElements() ==
static_cast<int64_t
>(values.size()));
576 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
578 : llvm::indexed_accessor_iterator<AttributeElementIterator,
const void *,
584 Type eltTy = owner.getElementType();
585 if (
auto intEltTy = eltTy.dyn_cast<IntegerType>())
587 if (eltTy.isa<IndexType>())
589 if (
auto floatEltTy = eltTy.dyn_cast<
FloatType>()) {
592 return FloatAttr::get(eltTy, *floatIt);
594 if (
auto complexTy = eltTy.dyn_cast<ComplexType>()) {
595 auto complexEltTy = complexTy.getElementType();
597 if (complexEltTy.isa<IntegerType>()) {
598 auto value = *complexIntIt;
599 auto real = IntegerAttr::get(complexEltTy, value.real());
600 auto imag = IntegerAttr::get(complexEltTy, value.imag());
601 return ArrayAttr::get(complexTy.getContext(),
607 auto value = *complexFloatIt;
608 auto real = FloatAttr::get(complexEltTy, value.real());
609 auto imag = FloatAttr::get(complexEltTy, value.imag());
610 return ArrayAttr::get(complexTy.getContext(),
613 if (owner.isa<DenseStringElementsAttr>()) {
615 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
617 llvm_unreachable(
"unexpected element type");
623 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
629 return getBit(getData(), getDataIndex());
635 DenseElementsAttr::IntElementIterator::IntElementIterator(
650 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
653 std::complex<APInt>, std::complex<APInt>,
654 std::complex<APInt>>(
663 size_t offset = getDataIndex() * storageWidth * 2;
664 return {
readBits(getData(), offset, bitWidth),
665 readBits(getData(), offset + storageWidth, bitWidth)};
675 using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
678 DenseArrayBaseAttr::EltType
eltType,
683 return (
getType() == std::get<0>(tblgenKey)) &&
684 (
eltType == std::get<1>(tblgenKey)) &&
685 (
elements == std::get<2>(tblgenKey));
689 return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
690 std::get<2>(tblgenKey));
695 auto type = std::get<0>(tblgenKey);
696 auto eltType = std::get<1>(tblgenKey);
697 auto elements = std::get<2>(tblgenKey);
698 if (!elements.empty()) {
699 char *alloc =
static_cast<char *
>(
700 allocator.
allocate(elements.size(),
alignof(uint64_t)));
701 std::uninitialized_copy(elements.begin(), elements.end(), alloc);
713 return getImpl()->eltType;
717 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>)
const {
718 return cast<DenseI8ArrayAttr>().asArrayRef().begin();
721 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>)
const {
722 return cast<DenseI16ArrayAttr>().asArrayRef().begin();
725 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>)
const {
726 return cast<DenseI32ArrayAttr>().asArrayRef().begin();
729 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>)
const {
730 return cast<DenseI64ArrayAttr>().asArrayRef().begin();
732 const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>)
const {
733 return cast<DenseF32ArrayAttr>().asArrayRef().begin();
736 DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>)
const {
737 return cast<DenseF64ArrayAttr>().asArrayRef().begin();
744 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os)
const {
746 case DenseArrayBaseAttr::EltType::I8:
747 this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
749 case DenseArrayBaseAttr::EltType::I16:
750 this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
752 case DenseArrayBaseAttr::EltType::I32:
753 this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
755 case DenseArrayBaseAttr::EltType::I64:
756 this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
758 case DenseArrayBaseAttr::EltType::F32:
759 this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
761 case DenseArrayBaseAttr::EltType::F64:
762 this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
765 llvm_unreachable(
"<unknown DenseArrayBaseAttr>");
770 printWithoutBraces(os);
774 template <
typename T>
779 template <
typename T>
782 llvm::interleaveComma(values, os);
789 llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
792 template <
typename T>
795 printWithoutBraces(os);
801 template <
typename T>
809 if (
parser.parseFloat(doubleVal))
817 return parser.parseFloat(value);
821 template <
typename T>
827 if (parseDenseArrayAttrElt(parser, value))
829 data.push_back(value);
837 template <
typename T>
841 Attribute result = parseWithoutBraces(parser, odsType);
848 template <
typename T>
851 assert((raw.size() %
sizeof(T)) == 0);
852 return ArrayRef<T>(
reinterpret_cast<const T *
>(raw.data()),
853 raw.size() /
sizeof(T));
858 template <
typename T>
859 struct denseArrayAttrEltTypeBuilder;
861 struct denseArrayAttrEltTypeBuilder<int8_t> {
862 constexpr
static auto eltType = DenseArrayBaseAttr::EltType::I8;
863 static ShapedType getShapedType(
MLIRContext *context, int64_t shape) {
864 return VectorType::get(shape, IntegerType::get(context, 8));
868 struct denseArrayAttrEltTypeBuilder<int16_t> {
869 constexpr
static auto eltType = DenseArrayBaseAttr::EltType::I16;
870 static ShapedType getShapedType(
MLIRContext *context, int64_t shape) {
871 return VectorType::get(shape, IntegerType::get(context, 16));
875 struct denseArrayAttrEltTypeBuilder<int32_t> {
876 constexpr
static auto eltType = DenseArrayBaseAttr::EltType::I32;
877 static ShapedType getShapedType(
MLIRContext *context, int64_t shape) {
878 return VectorType::get(shape, IntegerType::get(context, 32));
882 struct denseArrayAttrEltTypeBuilder<int64_t> {
883 constexpr
static auto eltType = DenseArrayBaseAttr::EltType::I64;
884 static ShapedType getShapedType(
MLIRContext *context, int64_t shape) {
885 return VectorType::get(shape, IntegerType::get(context, 64));
889 struct denseArrayAttrEltTypeBuilder<float> {
890 constexpr
static auto eltType = DenseArrayBaseAttr::EltType::F32;
891 static ShapedType getShapedType(
MLIRContext *context, int64_t shape) {
892 return VectorType::get(shape, Float32Type::get(context));
896 struct denseArrayAttrEltTypeBuilder<double> {
897 constexpr
static auto eltType = DenseArrayBaseAttr::EltType::F64;
898 static ShapedType getShapedType(
MLIRContext *context, int64_t shape) {
899 return VectorType::get(shape, Float64Type::get(context));
905 template <
typename T>
909 denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
910 auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
911 auto rawArray =
ArrayRef<char>(
reinterpret_cast<const char *
>(content.data()),
912 content.size() *
sizeof(T));
913 return Base::get(context, shapedType, eltType, rawArray)
914 .template cast<DenseArrayAttr<T>>();
917 template <
typename T>
921 denseArrayAttrEltTypeBuilder<T>::eltType;
951 auto eltType = type.getElementType();
952 if (!type.getElementType().isIntOrIndexOrFloat()) {
954 stringValues.reserve(values.size());
956 assert(attr.
isa<StringAttr>() &&
957 "expected string value for non integer/index/float element");
958 stringValues.push_back(attr.
cast<StringAttr>().getValue());
960 return get(type, stringValues);
969 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
971 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
973 "expected attribute value to have element type");
975 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
976 else if (
eltType.isa<IntegerType, IndexType>())
977 intVal = values[i].cast<IntegerAttr>().getValue();
979 llvm_unreachable(
"unexpected element type");
981 assert(intVal.getBitWidth() == bitWidth &&
982 "expected value to have same bitwidth as element type");
983 writeBits(data.data(), i * storageBitWidth, intVal);
987 if (values.size() == 1 && values[0].getType().isInteger(1))
988 data[0] = data[0] ? -1 : 0;
990 return DenseIntOrFPElementsAttr::getRaw(type, data);
996 assert(type.getElementType().isInteger(1));
998 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
1000 if (!values.empty()) {
1001 bool isSplat =
true;
1002 bool firstValue = values[0];
1003 for (
int i = 0, e = values.size(); i != e; ++i) {
1004 isSplat &= values[i] == firstValue;
1005 setBit(buff.data(), i, values[i]);
1011 buff[0] = values[0] ? -1 : 0;
1015 return DenseIntOrFPElementsAttr::getRaw(type, buff);
1020 assert(!type.getElementType().isIntOrFloat());
1021 return DenseStringElementsAttr::get(type, values);
1029 assert(type.getElementType().isIntOrIndex());
1032 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1035 ArrayRef<std::complex<APInt>> values) {
1037 assert(complex.getElementType().isa<IntegerType>());
1040 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1042 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1050 assert(type.getElementType().isa<
FloatType>());
1053 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1057 ArrayRef<std::complex<APFloat>> values) {
1059 assert(complex.getElementType().isa<
FloatType>());
1064 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1072 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1078 bool &detectedSplat) {
1080 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1081 int64_t numElements = type.getNumElements();
1084 detectedSplat = numElements == 1;
1087 if (storageWidth == 1) {
1090 if (rawBuffer.size() == 1) {
1091 auto rawByte =
static_cast<uint8_t
>(rawBuffer[0]);
1092 if (rawByte == 0 || rawByte == 0xff) {
1093 detectedSplat =
true;
1099 return rawBufferWidth == llvm::alignTo<8>(numElements);
1104 if (rawBufferWidth == storageWidth) {
1105 detectedSplat =
true;
1110 return rawBufferWidth == storageWidth * numElements;
1120 static_cast<size_t>(dataEltSize * CHAR_BIT))
1129 auto intType = type.
dyn_cast<IntegerType>();
1134 if (intType.isSignless())
1136 return intType.isSigned() ? isSigned : !isSigned;
1142 int64_t dataEltSize,
1143 bool isInt,
bool isSigned) {
1144 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1149 int64_t dataEltSize,
1152 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1157 bool isSigned)
const {
1161 bool isSigned)
const {
1181 "expected complex integral type");
1182 return {
getType(), ComplexIntElementIterator(*
this, 0),
1186 -> ComplexIntElementIterator {
1188 "expected complex integral type");
1189 return ComplexIntElementIterator(*
this, 0);
1193 "expected complex integral type");
1203 return {
getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1204 FloatElementIterator(elementSemantics, raw_int_end())};
1218 assert(eltTy.
isa<
FloatType>() &&
"expected complex float type");
1219 const auto &semantics = eltTy.
cast<
FloatType>().getFloatSemantics();
1221 {semantics, {*
this, 0}},
1227 assert(eltTy.
isa<
FloatType>() &&
"expected complex float type");
1228 return {eltTy.
cast<
FloatType>().getFloatSemantics(), {*
this, 0}};
1233 assert(eltTy.
isa<
FloatType>() &&
"expected complex float type");
1251 ShapedType curType =
getType();
1252 if (curType == newType)
1255 assert(newType.getElementType() == curType.getElementType() &&
1256 "expected the same element type");
1257 assert(newType.getNumElements() == curType.getNumElements() &&
1258 "expected the same number of elements");
1259 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1263 assert(isSplat() &&
"expected a splat type");
1265 ShapedType curType =
getType();
1266 if (curType == newType)
1269 assert(newType.getElementType() == curType.getElementType() &&
1270 "expected the same element type");
1271 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1279 ShapedType curType =
getType();
1280 Type curElType = curType.getElementType();
1281 if (curElType == newElType)
1286 "expected element types with the same bitwidth");
1287 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1294 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1299 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1307 return getType().getElementType();
1311 return getType().getNumElements();
1319 template <
typename APRangeT>
1321 APRangeT &&values) {
1322 size_t numValues = llvm::size(values);
1323 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
1325 for (
auto it = values.begin(), e = values.end(); it != e;
1326 ++it, offset += storageWidth) {
1327 assert((*it).getBitWidth() <= storageWidth);
1332 if (numValues == 1 && (*values.begin()).
getBitWidth() == 1)
1333 data[0] = data[0] ? -1 : 0;
1340 size_t storageWidth,
1342 std::vector<char> data;
1343 auto unwrapFloat = [](
const APFloat &val) {
return val.bitcastToAPInt(); };
1345 return DenseIntOrFPElementsAttr::getRaw(type, data);
1352 size_t storageWidth,
1354 std::vector<char> data;
1356 return DenseIntOrFPElementsAttr::getRaw(type, data);
1361 assert((type.isa<RankedTensorType, VectorType>()) &&
1362 "type must be ranked tensor or vector");
1363 assert(type.hasStaticShape() &&
"type must have static shape");
1364 bool isSplat =
false;
1365 bool isValid = isValidRawBuffer(type, data, isSplat);
1368 return Base::get(type.getContext(), type, data, isSplat);
1376 int64_t dataEltSize,
1381 dataEltSize / 2, isInt, isSigned));
1383 int64_t numElements = data.size() / dataEltSize;
1385 assert(numElements == 1 || numElements == type.getNumElements());
1386 return getRaw(type, data);
1393 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
1394 int64_t dataEltSize,
bool isInt,
1399 int64_t numElements = data.size() / dataEltSize;
1400 assert(numElements == 1 || numElements == type.getNumElements());
1402 return getRaw(type, data);
1405 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1406 const char *inRawData,
char *outRawData,
size_t elementBitWidth,
1407 size_t numElements) {
1408 using llvm::support::ulittle16_t;
1409 using llvm::support::ulittle32_t;
1410 using llvm::support::ulittle64_t;
1412 assert(llvm::support::endian::system_endianness() ==
1413 llvm::support::endianness::big);
1417 switch (elementBitWidth) {
1419 const ulittle16_t *inRawDataPos =
1420 reinterpret_cast<const ulittle16_t *
>(inRawData);
1421 uint16_t *outDataPos =
reinterpret_cast<uint16_t *
>(outRawData);
1422 std::copy_n(inRawDataPos, numElements, outDataPos);
1426 const ulittle32_t *inRawDataPos =
1427 reinterpret_cast<const ulittle32_t *
>(inRawData);
1428 uint32_t *outDataPos =
reinterpret_cast<uint32_t *
>(outRawData);
1429 std::copy_n(inRawDataPos, numElements, outDataPos);
1433 const ulittle64_t *inRawDataPos =
1434 reinterpret_cast<const ulittle64_t *
>(inRawData);
1435 uint64_t *outDataPos =
reinterpret_cast<uint64_t *
>(outRawData);
1436 std::copy_n(inRawDataPos, numElements, outDataPos);
1440 size_t nBytes = elementBitWidth / CHAR_BIT;
1441 for (
size_t i = 0; i < nBytes; i++)
1442 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1448 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1451 size_t numElements = type.getNumElements();
1452 Type elementType = type.getElementType();
1453 if (ComplexType complexTy = elementType.
dyn_cast<ComplexType>()) {
1454 elementType = complexTy.getElementType();
1455 numElements = numElements * 2;
1458 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1459 inRawData.size() <= outRawData.size());
1460 if (elementBitWidth <= CHAR_BIT)
1461 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1463 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1464 elementBitWidth, numElements);
1471 template <
typename Fn,
typename Attr>
1473 Type newElementType,
1478 ShapedType newArrayType;
1479 if (inType.isa<RankedTensorType>())
1480 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1481 else if (inType.isa<UnrankedTensorType>())
1482 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1483 else if (
auto vType = inType.dyn_cast<VectorType>())
1484 newArrayType = VectorType::get(vType.getShape(), newElementType,
1485 vType.getNumScalableDims());
1487 assert(newArrayType &&
"Unhandled tensor type");
1489 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1490 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));
1493 auto processElt = [&](decltype(*attr.begin()) value,
size_t index) {
1494 auto newInt = mapping(value);
1495 assert(newInt.getBitWidth() == bitWidth);
1496 writeBits(data.data(), index * storageBitWidth, newInt);
1500 if (attr.isSplat()) {
1501 processElt(*attr.begin(), 0);
1502 return newArrayType;
1506 uint64_t elementIdx = 0;
1507 for (
auto value : attr)
1508 processElt(value, elementIdx++);
1509 return newArrayType;
1518 return getRaw(newArrayType, elementData);
1536 return getRaw(newArrayType, elementData);
1549 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1550 Dialect *dialect = getContext()->getLoadedDialect(getDialect());
1553 auto *
interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
1556 return failed(interface->decode(*
this, result));
1561 StringAttr dialect, StringRef value,
1564 return emitError() <<
"invalid dialect namespace '" << dialect <<
"'";
1573 APFloat SparseElementsAttr::getZeroAPFloat()
const {
1575 return APFloat(
eltType.getFloatSemantics());
1579 APInt SparseElementsAttr::getZeroAPInt()
const {
1581 return APInt::getZero(
eltType.getWidth());
1585 Attribute SparseElementsAttr::getZeroAttr()
const {
1590 return FloatAttr::get(
eltType, 0);
1593 if (getValues().isa<DenseStringElementsAttr>())
1594 return StringAttr::get(
"",
eltType);
1597 return IntegerAttr::get(
eltType, 0);
1602 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices()
const {
1603 std::vector<ptrdiff_t> flatSparseIndices;
1608 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1609 if (sparseIndices.isSplat()) {
1611 *sparseIndexValues.begin());
1612 flatSparseIndices.push_back(getFlattenedIndex(indices));
1613 return flatSparseIndices;
1617 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1618 size_t rank =
getType().getRank();
1619 for (
size_t i = 0, e = numSparseIndices; i != e; ++i)
1620 flatSparseIndices.push_back(getFlattenedIndex(
1621 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1622 return flatSparseIndices;
1629 ShapedType valuesType = values.
getType();
1630 if (valuesType.getRank() != 1)
1631 return emitError() <<
"expected 1-d tensor for sparse element values";
1634 ShapedType indicesType = sparseIndices.getType();
1635 auto emitShapeError = [&]() {
1636 return emitError() <<
"expected shape ([" << type.getShape()
1637 <<
"]); inferred shape of indices literal ([" 1638 << indicesType.getShape()
1639 <<
"]); inferred shape of values literal ([" 1640 << valuesType.getShape() <<
"])";
1643 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1644 if (indicesRank == 2) {
1645 if (indicesType.getDimSize(1) !=
static_cast<int64_t
>(rank))
1646 return emitShapeError();
1647 }
else if (indicesRank != 1 || rank != 1) {
1648 return emitShapeError();
1651 int64_t numSparseIndices = indicesType.getDimSize(0);
1652 if (numSparseIndices != valuesType.getDimSize(0))
1653 return emitShapeError();
1658 <<
"sparse index #" << indexNum
1659 <<
" is not contained within the value shape, with index=[" << index
1660 <<
"], and type=" << type;
1664 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1665 if (sparseIndices.isSplat()) {
1667 if (!ElementsAttr::isValidIndex(type, indices))
1668 return emitIndexError(0, indices);
1673 for (
size_t i = 0, e = numSparseIndices; i != e; ++i) {
1676 if (!ElementsAttr::isValidIndex(type, index))
1677 return emitIndexError(i, index);
1687 void TypeAttr::walkImmediateSubElements(
1690 walkTypesFn(getValue());
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
void printWithoutBraces(raw_ostream &os) const
Print the short form 42, 100, -1 without any braces or type prefix.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Operation is a basic unit of execution within MLIR.
static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value)
Parse a single element: generic template for int types, specialized for floating points below...
ComplexFloatElementIterator complex_float_value_end() const
This class represents a diagnostic that is inflight and set to be reported.
std::complex< APInt > operator*() const
Accesses the raw std::complex<APInt> value at this iterator position.
A symbol reference with a reference path containing a single element.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
ComplexIntElementIterator complex_value_begin() const
A utility iterator that allows walking over the internal raw APInt values.
Type getType() const
Get the type of this attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
AttrClass getAttrOfType(StringAttr name)
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
This is a utility allocator used to allocate memory for instances of derived types.
Attribute operator*() const
Accesses the Attribute value at this iterator position.
static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, Type newElementType, llvm::SmallVectorImpl< char > &data)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, char *result)
Copy actual numBytes data from value (APInt) to char array(result) for BE format. ...
A utility iterator that allows walking over the internal raw complex APInt values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape...
iterator_range_impl< ComplexFloatElementIterator > getComplexFloatValues() const
static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, bool isSigned)
Check the information for a C++ data type, check if this type is valid for the current attribute...
static bool classof(Attribute attr)
Methods for support type inquiry through isa, cast, and dyn_cast.
DenseArrayBaseAttrStorage(ShapedType type, DenseArrayBaseAttr::EltType eltType, ::llvm::ArrayRef< char > elements)
static constexpr const bool value
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
NamedAttribute represents a combination of a name and an Attribute value.
static bool hasSameElementsOrSplat(ShapedType type, const Values &values)
Returns true if 'values' corresponds to a splat, i.e.
static size_t getDenseElementStorageWidth(size_t origWidth)
Get the bitwidth of a dense element type within the buffer.
bool getValue() const
Return the boolean value of this attribute.
virtual ParseResult parseLSquare()=0
Parse a [ token.
T * allocate()
Allocate an instance of the provided type.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
static void setBit(char *rawData, size_t bitPos, bool value)
Set a bit to a specific value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
An attribute representing a reference to a dense vector or tensor object containing strings...
static void writeBits(char *rawData, size_t bitPos, APInt value)
Writes value to the bit position bitPos in array rawData.
Type getElementType() const
Return the element type of this DenseElementsAttr.
An attribute that represents a reference to a dense vector or tensor object.
bool operator==(const KeyTy &tblgenKey) const
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Iterator for walking over APFloat values.
static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth)
Reads the next bitWidth bits from the bit position bitPos in array rawData.
An attribute representing a reference to a dense vector or tensor object.
StringAttr getName() const
Return the name of the attribute.
Operation::operand_range getIndices(Operation *op)
static DenseArrayAttr get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
DenseElementsAttr bitcast(Type newElType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has bitcast eleme...
::llvm::hash_code hashKey(const KeyTy &tblgenKey)
An attribute representing a reference to a dense vector or tensor object.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
static DenseElementsAttr getRawIntOrFloat(ShapedType type, ArrayRef< char > data, int64_t dataEltSize, bool isInt, bool isSigned)
Overload of the raw 'get' method that asserts that the given type is of integer or floating-point typ...
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each int value to a new underlying APInt.
static bool classof(Attribute attr)
Method for support type inquiry through isa, cast and dyn_cast.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static bool getBit(const char *rawData, size_t bitPos)
Return the value of the specified bit.
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
size_t getDenseElementBitWidth(Type eltType)
Return the bit width which DenseElementsAttr should use for this type.
virtual ParseResult parseRSquare()=0
Parse a ] token.
This base class exposes generic asm parser hooks, usable across the various derived parsers...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static LLVM_ATTRIBUTE_UNUSED bool isComplexOfIntType(Type type)
Return if the given complex type has an integer element type.
bool operator*() const
Accesses the bool value at this iterator position.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static Optional< NamedAttribute > findDuplicateElement(ArrayRef< NamedAttribute > value)
Returns an entry with a duplicate name from the given sorted array of named attributes.
DenseArrayBaseAttr::EltType eltType
static void writeAPIntsToBuffer(size_t storageWidth, std::vector< char > &data, APRangeT &&values)
Utility method to write a range of APInt values to a buffer.
FloatElementIterator float_value_end() const
Type getType() const
Return the type of this attribute.
iterator_range_impl< FloatElementIterator > getFloatValues() const
Return the held element values as a range of APFloat.
std::tuple< ShapedType, DenseArrayBaseAttr::EltType, ::llvm::ArrayRef< char > > KeyTy
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static DenseArrayBaseAttrStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey)
void print(AsmPrinter &printer) const
Print the short form [42, 100, -1] without any type prefix.
static int64_t getNumElements(ShapedType type)
MLIRContext is the top-level object for a collection of MLIR operations.
Custom storage to ensure proper memory alignment for the allocation of DenseArray of any element type...
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
iterator_range_impl< ComplexIntElementIterator > getComplexIntValues() const
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Base storage class appearing in an attribute.
static bool classof(Attribute attr)
Support for isa<>/cast<>.
This class provides iterator utilities for an ElementsAttr range.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
This base class exposes generic asm printer hooks, usable across the various derived printers...
::llvm::ArrayRef< char > elements
ComplexFloatElementIterator complex_float_value_begin() const
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ParseResult parseDenseArrayAttrElt< float >(AsmParser &parser, float &value)
static Attribute parseWithoutBraces(AsmParser &parser, Type odsType)
Parse the short form 42, 100, -1 without any type prefix or braces.
static Attribute parse(AsmParser &parser, Type odsType)
Parse the short form [42, 100, -1] without any type prefix.
APInt operator*() const
Accesses the raw APInt value at this iterator position.
static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef< char > data, int64_t dataEltSize, bool isInt, bool isSigned)
Overload of the raw 'get' method that asserts that the given type is of complex type.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
MLIRContext * getContext() const
This class represents success/failure for parsing-like operations that find it important to chain tog...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Iterator for walking over complex APFloat values.
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
static bool dictionaryAttrSort(ArrayRef< NamedAttribute > value, SmallVectorImpl< NamedAttribute > &storage)
Helper function that does either an in place sort or sorts from source array into destination...
ArrayRef< StringRef > getRawStringData() const
Return the raw StringRef data held by this attribute.
FloatElementIterator float_value_begin() const
static bool classof(Attribute attr)
Method for supporting type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr)
Method for supporting type inquiry through isa, cast and dyn_cast.
ParseResult parseDenseArrayAttrElt< double >(AsmParser &parser, double &value)
bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APFloat &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
An attribute that represents a reference to a dense integer vector or tensor object.
ComplexIntElementIterator complex_value_end() const
static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, APInt &result)
Copy numBytes data from inArray(char array) to result(APINT) for BE format.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.