34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
45#define DEBUG_TYPE "vector-narrow-type-emulation"
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.
getI1Type());
114 std::optional<Operation *> newMask =
116 .Case<vector::CreateMaskOp>(
117 [&](
auto createMaskOp) -> std::optional<Operation *> {
127 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
130 rewriter, loc, s0, origIndex);
132 newMaskOperands.push_back(
134 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
137 .Case<vector::ConstantMaskOp>([&](
auto constantMaskOp)
138 -> std::optional<Operation *> {
141 int64_t &maskIndex = maskDimSizes.back();
142 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
144 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
147 .Case<arith::ConstantOp>([&](
auto constantOp)
148 -> std::optional<Operation *> {
150 if (maskShape.size() != 1)
167 cast<DenseIntElementsAttr>(constantOp.getValue());
169 paddedMaskValues.append(originalMask.template value_begin<bool>(),
170 originalMask.template value_end<bool>());
171 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
175 for (
size_t i = 0; i < paddedMaskValues.size();
176 i += numSrcElemsPerDest) {
177 bool combinedValue =
false;
178 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
179 combinedValue |= paddedMaskValues[i +
j];
181 compressedMaskValues.push_back(combinedValue);
183 return arith::ConstantOp::create(
191 while (!extractOps.empty()) {
193 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
194 extractOps.back().getMixedPosition());
195 extractOps.pop_back();
217 auto vectorType = cast<VectorType>(src.
getType());
218 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
219 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220 "subvector out of bounds");
224 if (vectorType.getNumElements() == numElemsToExtract)
231 auto resultVectorType =
232 VectorType::get({numElemsToExtract}, vectorType.getElementType());
233 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
234 src, offsets, sizes, strides)
249 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
250 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
251 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
252 "expected source and dest to be rank-1 vector types");
255 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
260 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
261 dest, offsets, strides);
287 auto srcVecTy = cast<VectorType>(src.
getType());
288 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
292 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
293 "subvector out of bounds");
297 if (srcVecTy.getNumElements() == numElemsToExtract)
300 for (
int i = 0; i < numElemsToExtract; ++i) {
302 (i == 0) ? dyn_cast<Value>(offset)
303 : arith::AddIOp::create(
305 dyn_cast<Value>(offset),
307 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
308 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
331 auto srcVecTy = cast<VectorType>(src.
getType());
332 auto destVecTy = cast<VectorType>(dest.
getType());
333 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334 "expected source and dest to be rank-1 vector types");
337 assert(numElemsToInsert > 0 &&
338 "the number of elements to insert must be greater than 0");
342 assert(numElemsToInsert <= destVecTy.getNumElements() &&
343 "subvector out of bounds");
346 for (
int64_t i = 0; i < numElemsToInsert; ++i) {
348 i == 0 ? destOffsetVal
349 : arith::AddIOp::create(
352 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
353 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
367 int64_t numContainerElemsToLoad,
369 Type containerElemTy) {
372 auto newLoad = vector::LoadOp::create(
373 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
375 return vector::BitCastOp::create(
377 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
385 VectorType downcastType,
386 VectorType upcastType,
Value mask,
389 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391 "expected input and output number of bits to match");
392 if (trueValue.
getType() != downcastType) {
394 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
396 if (falseValue.
getType() != downcastType) {
398 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
401 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
403 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
422 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
426 auto atomicOp = memref::GenericAtomicRMWOp::create(
427 builder, loc, linearizedMemref,
ValueRange{storeIdx});
428 Value origValue = atomicOp.getCurrentValue();
435 auto oneElemVecType = VectorType::get({1}, origValue.getType());
436 Value origVecValue = vector::FromElementsOp::create(
437 builder, loc, oneElemVecType,
ValueRange{origValue});
442 oneElemVecType, mask, valueToStore, origVecValue);
443 auto scalarMaskedValue =
444 vector::ExtractOp::create(builder, loc, maskedValue, 0);
445 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
453 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
455 auto oneElemVecType =
456 VectorType::get({1}, linearizedMemref.getType().
getElementType());
458 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
460 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
465 oneElemVecType, mask, valueToStore, origVecValue);
466 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
484 assert(
vector.getType().getRank() == 1 &&
"expected 1-D vector");
485 auto vectorElementType =
vector.getType().getElementType();
489 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
490 "sliceNumElements * vector element size must be less than or equal to 8");
491 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
492 "vector element must be a valid sub-byte type");
493 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
494 auto emptyByteVector = arith::ConstantOp::create(
496 VectorType::get({emulatedPerContainerElem}, vectorElementType),
497 rewriter.getZeroAttr(
498 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
500 extractOffset, sliceNumElements);
550struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
553 ConvertVectorStore(MLIRContext *context,
bool disableAtomicRMW)
554 : OpConversionPattern<vector::StoreOp>(context),
555 disableAtomicRMW(disableAtomicRMW) {}
558 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter)
const override {
561 if (op.getValueToStore().getType().getRank() != 1)
562 return rewriter.notifyMatchFailure(op,
563 "only 1-D vectors are supported ATM");
565 auto loc = op.getLoc();
567 auto valueToStore = cast<VectorValue>(op.getValueToStore());
568 auto containerElemTy =
569 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
570 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
572 int containerBits = containerElemTy.getIntOrFloatBitWidth();
575 if (containerBits % emulatedBits != 0) {
576 return rewriter.notifyMatchFailure(
577 op,
"impossible to pack emulated elements into container elements "
578 "(bit-wise misalignment)");
580 int emulatedPerContainerElem = containerBits / emulatedBits;
595 auto origElements = valueToStore.getType().getNumElements();
597 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
602 auto trailingDim = op.getBase().getType().getShape().back();
603 bool trailingDimsMatch =
604 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
606 auto stridedMetadata =
607 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
611 OpFoldResult linearizedIndices;
612 memref::LinearizedMemRefInfo linearizedInfo;
613 std::tie(linearizedInfo, linearizedIndices) =
615 rewriter, loc, emulatedBits, containerBits,
616 stridedMetadata.getConstifiedMixedOffset(),
617 stridedMetadata.getConstifiedMixedSizes(),
618 stridedMetadata.getConstifiedMixedStrides(),
621 std::optional<int64_t> foldedNumFrontPadElems =
622 (isDivisibleInSize && trailingDimsMatch)
626 if (!foldedNumFrontPadElems) {
627 return rewriter.notifyMatchFailure(
628 op,
"subbyte store emulation: dynamic front padding size is "
629 "not yet implemented");
632 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
664 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
666 if (!emulationRequiresPartialStores) {
668 auto numElements = origElements / emulatedPerContainerElem;
669 auto bitCast = vector::BitCastOp::create(
670 rewriter, loc, VectorType::get(numElements, containerElemTy),
671 op.getValueToStore());
672 rewriter.replaceOpWithNewOp<vector::StoreOp>(
673 op, bitCast.getResult(), memrefBase,
709 Value currentDestIndex =
712 auto currentSourceIndex = 0;
715 auto subWidthStoreMaskType =
716 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
724 auto frontSubWidthStoreElem =
725 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
726 emulatedPerContainerElem;
727 if (frontSubWidthStoreElem > 0) {
728 SmallVector<bool> frontMaskValues(emulatedPerContainerElem,
false);
729 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
730 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
732 frontSubWidthStoreElem = origElements;
734 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
735 *foldedNumFrontPadElems,
true);
737 auto frontMask = arith::ConstantOp::create(
741 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
744 frontSubWidthStoreElem, *foldedNumFrontPadElems);
746 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
747 cast<VectorValue>(value), frontMask.getResult());
750 if (currentSourceIndex >= origElements) {
751 rewriter.eraseOp(op);
758 currentDestIndex = arith::AddIOp::create(
759 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
764 int64_t fullWidthStoreSize =
765 (origElements - currentSourceIndex) / emulatedPerContainerElem;
766 int64_t numNonFullWidthElements =
767 fullWidthStoreSize * emulatedPerContainerElem;
768 if (fullWidthStoreSize > 0) {
770 rewriter, loc, valueToStore, currentSourceIndex,
771 numNonFullWidthElements);
773 auto originType = cast<VectorType>(fullWidthStorePart.getType());
775 auto storeType = VectorType::get(
776 {originType.getNumElements() / emulatedPerContainerElem},
778 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
780 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
783 currentSourceIndex += numNonFullWidthElements;
784 currentDestIndex = arith::AddIOp::create(
785 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
792 auto remainingElements = origElements - currentSourceIndex;
793 if (remainingElements != 0) {
794 auto subWidthStorePart =
796 currentSourceIndex, remainingElements, 0);
799 auto maskValues = SmallVector<bool>(emulatedPerContainerElem,
false);
800 std::fill_n(maskValues.begin(), remainingElements, 1);
801 auto backMask = arith::ConstantOp::create(
805 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
806 cast<VectorValue>(subWidthStorePart), backMask.getResult());
809 rewriter.eraseOp(op);
814 const bool disableAtomicRMW;
828struct ConvertVectorMaskedStore final
829 : OpConversionPattern<vector::MaskedStoreOp> {
833 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
834 ConversionPatternRewriter &rewriter)
const override {
837 if (op.getValueToStore().getType().getRank() != 1)
838 return rewriter.notifyMatchFailure(
839 op,
"Memref in vector.maskedstore op must be flattened beforehand.");
841 auto loc = op.getLoc();
842 auto containerElemTy =
843 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
844 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
846 int containerBits = containerElemTy.getIntOrFloatBitWidth();
849 if (containerBits % emulatedBits != 0) {
850 return rewriter.notifyMatchFailure(
851 op,
"impossible to pack emulated elements into container elements "
852 "(bit-wise misalignment)");
855 int emulatedPerContainerElem = containerBits / emulatedBits;
856 int origElements = op.getValueToStore().getType().getNumElements();
857 if (origElements % emulatedPerContainerElem != 0)
860 auto stridedMetadata =
861 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
862 OpFoldResult linearizedIndicesOfr;
863 memref::LinearizedMemRefInfo linearizedInfo;
864 std::tie(linearizedInfo, linearizedIndicesOfr) =
866 rewriter, loc, emulatedBits, containerBits,
867 stridedMetadata.getConstifiedMixedOffset(),
868 stridedMetadata.getConstifiedMixedSizes(),
869 stridedMetadata.getConstifiedMixedStrides(),
871 Value linearizedIndices =
907 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
911 auto numElements = (origElements + emulatedPerContainerElem - 1) /
912 emulatedPerContainerElem;
913 auto newType = VectorType::get(numElements, containerElemTy);
914 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
915 rewriter.getZeroAttr(newType));
917 auto newLoad = vector::MaskedLoadOp::create(
918 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
919 newMask.value()->getResult(0), passThru);
921 auto newBitCastType =
922 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
924 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
925 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
926 op.getValueToStore(), valueToStore);
928 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
930 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
931 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
952struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
956 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
957 ConversionPatternRewriter &rewriter)
const override {
959 if (op.getVectorType().getRank() != 1)
960 return rewriter.notifyMatchFailure(
961 op,
"Memref in emulated vector ops must be flattened beforehand.");
963 auto loc = op.getLoc();
964 auto containerElemTy =
965 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
966 Type emulatedElemTy = op.getType().getElementType();
968 int containerBits = containerElemTy.getIntOrFloatBitWidth();
971 if (containerBits % emulatedBits != 0) {
972 return rewriter.notifyMatchFailure(
973 op,
"impossible to pack emulated elements into container elements "
974 "(bit-wise misalignment)");
976 int emulatedPerContainerElem = containerBits / emulatedBits;
1005 auto origElements = op.getVectorType().getNumElements();
1007 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1009 auto stridedMetadata =
1010 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1012 OpFoldResult linearizedIndices;
1013 memref::LinearizedMemRefInfo linearizedInfo;
1014 std::tie(linearizedInfo, linearizedIndices) =
1016 rewriter, loc, emulatedBits, containerBits,
1017 stridedMetadata.getConstifiedMixedOffset(),
1018 stridedMetadata.getConstifiedMixedSizes(),
1019 stridedMetadata.getConstifiedMixedStrides(),
1022 std::optional<int64_t> foldedIntraVectorOffset =
1023 isDivisibleInSize ? 0
1027 int64_t maxintraDataOffset =
1028 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1029 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1030 emulatedPerContainerElem);
1033 numElements, emulatedElemTy, containerElemTy);
1035 if (!foldedIntraVectorOffset) {
1036 auto resultVector = arith::ConstantOp::create(
1037 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1041 }
else if (!isDivisibleInSize) {
1043 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1045 rewriter.replaceOp(op,
result);
1060struct ConvertVectorMaskedLoad final
1061 : OpConversionPattern<vector::MaskedLoadOp> {
1065 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1066 ConversionPatternRewriter &rewriter)
const override {
1067 if (op.getVectorType().getRank() != 1)
1068 return rewriter.notifyMatchFailure(
1069 op,
"Memref in emulated vector ops must be flattened beforehand.");
1071 auto loc = op.getLoc();
1073 auto containerElemTy =
1074 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1075 Type emulatedElemTy = op.getType().getElementType();
1077 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1080 if (containerBits % emulatedBits != 0) {
1081 return rewriter.notifyMatchFailure(
1082 op,
"impossible to pack emulated elements into container elements "
1083 "(bit-wise misalignment)");
1085 int emulatedPerContainerElem = containerBits / emulatedBits;
1129 auto origType = op.getVectorType();
1130 auto origElements = origType.getNumElements();
1132 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1134 auto stridedMetadata =
1135 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1136 OpFoldResult linearizedIndices;
1137 memref::LinearizedMemRefInfo linearizedInfo;
1138 std::tie(linearizedInfo, linearizedIndices) =
1140 rewriter, loc, emulatedBits, containerBits,
1141 stridedMetadata.getConstifiedMixedOffset(),
1142 stridedMetadata.getConstifiedMixedSizes(),
1143 stridedMetadata.getConstifiedMixedStrides(),
1146 std::optional<int64_t> foldedIntraVectorOffset =
1147 isDivisibleInSize ? 0
1150 int64_t maxIntraDataOffset =
1151 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1152 FailureOr<Operation *> newMask =
1154 emulatedPerContainerElem, maxIntraDataOffset);
1158 Value passthru = op.getPassThru();
1160 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1161 emulatedPerContainerElem);
1162 auto loadType = VectorType::get(numElements, containerElemTy);
1163 auto newBitcastType =
1164 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1166 auto emptyVector = arith::ConstantOp::create(
1167 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1168 if (!foldedIntraVectorOffset) {
1172 }
else if (!isDivisibleInSize) {
1174 *foldedIntraVectorOffset);
1177 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1180 auto newLoad = vector::MaskedLoadOp::create(
1181 rewriter, loc, loadType, adaptor.getBase(),
1183 newMask.value()->getResult(0), newPassThru);
1188 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1190 Value mask = op.getMask();
1191 auto newSelectMaskType = VectorType::get(
1192 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1195 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1196 rewriter.getZeroAttr(newSelectMaskType));
1197 if (!foldedIntraVectorOffset) {
1201 }
else if (!isDivisibleInSize) {
1203 *foldedIntraVectorOffset);
1207 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1208 if (!foldedIntraVectorOffset) {
1210 rewriter, loc,
result, op.getPassThru(),
1212 }
else if (!isDivisibleInSize) {
1214 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1216 rewriter.replaceOp(op,
result);
1237static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1238 Type multiByteScalarTy) {
1239 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1241 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1244 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1245 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1246 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1248 int elemsPerMultiByte = multiByteBits / subByteBits;
1250 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1258struct ConvertVectorTransferRead final
1259 : OpConversionPattern<vector::TransferReadOp> {
1263 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1264 ConversionPatternRewriter &rewriter)
const override {
1268 if (op.getVectorType().getRank() != 1)
1269 return rewriter.notifyMatchFailure(
1270 op,
"Memref in emulated vector ops must be flattened beforehand.");
1272 auto loc = op.getLoc();
1273 auto containerElemTy =
1274 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1275 Type emulatedElemTy = op.getType().getElementType();
1277 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1280 if (containerBits % emulatedBits != 0) {
1281 return rewriter.notifyMatchFailure(
1282 op,
"impossible to pack emulated elements into container elements "
1283 "(bit-wise misalignment)");
1285 int emulatedPerContainerElem = containerBits / emulatedBits;
1287 auto origElements = op.getVectorType().getNumElements();
1290 bool isDivisibleInSize =
1291 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1295 Value padding = adaptor.getPadding();
1297 padding = arith::BitcastOp::create(
1299 IntegerType::get(rewriter.getContext(),
1304 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1306 auto stridedMetadata =
1307 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1309 OpFoldResult linearizedIndices;
1310 memref::LinearizedMemRefInfo linearizedInfo;
1311 std::tie(linearizedInfo, linearizedIndices) =
1313 rewriter, loc, emulatedBits, containerBits,
1314 stridedMetadata.getConstifiedMixedOffset(),
1315 stridedMetadata.getConstifiedMixedSizes(),
1316 stridedMetadata.getConstifiedMixedStrides(),
1319 std::optional<int64_t> foldedIntraVectorOffset =
1320 isDivisibleInSize ? 0
1323 int64_t maxIntraDataOffset =
1324 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1325 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1326 emulatedPerContainerElem);
1328 auto newRead = vector::TransferReadOp::create(
1329 rewriter, loc, VectorType::get(numElements, containerElemTy),
1334 auto bitCast = vector::BitCastOp::create(
1336 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1339 Value
result = bitCast->getResult(0);
1340 if (!foldedIntraVectorOffset) {
1341 auto zeros = arith::ConstantOp::create(
1342 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1346 }
else if (!isDivisibleInSize) {
1348 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1350 rewriter.replaceOp(op,
result);
1365struct SourceElementRange {
1367 int64_t sourceElementIdx;
1369 int64_t sourceBitBegin;
1370 int64_t sourceBitEnd;
1373struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1379 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1381 for (int64_t i = 0; i < shuffleIdx; ++i)
1382 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1401struct BitCastBitsEnumerator {
1402 BitCastBitsEnumerator(VectorType sourceVectorType,
1403 VectorType targetVectorType);
1405 int64_t getMaxNumberOfEntries() {
1406 int64_t numVectors = 0;
1407 for (
const auto &l : sourceElementRanges)
1408 numVectors = std::max(numVectors, (int64_t)l.size());
1412 VectorType sourceVectorType;
1413 VectorType targetVectorType;
1414 SmallVector<SourceElementRangeList> sourceElementRanges;
1488struct BitCastRewriter {
1491 SmallVector<int64_t> shuffles;
1492 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1495 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1498 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1499 VectorType preconditionType, Operation *op);
1502 SmallVector<BitCastRewriter::Metadata>
1503 precomputeMetadata(IntegerType shuffledElementType);
1507 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1508 Value initialValue, Value runningResult,
1509 const BitCastRewriter::Metadata &metadata);
1514 BitCastBitsEnumerator enumerator;
1521 for (
const auto &l : vec) {
1522 for (
auto it : llvm::enumerate(l)) {
1523 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1524 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1525 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1532BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1533 VectorType targetVectorType)
1534 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1536 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1537 "requires -D non-scalable vector type");
1538 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1539 "requires -D non-scalable vector type");
1540 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1541 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1542 LDBG() <<
"sourceVectorType: " << sourceVectorType;
1544 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1545 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1546 LDBG() <<
"targetVectorType: " << targetVectorType;
1548 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1549 (
void)mostMinorSourceDim;
1550 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1551 "source and target bitwidths must match");
1555 for (
int64_t resultBit = 0; resultBit < bitwidth;) {
1556 int64_t resultElement = resultBit / targetBitWidth;
1557 int64_t resultBitInElement = resultBit % targetBitWidth;
1558 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1559 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1560 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1561 targetBitWidth - resultBitInElement);
1562 sourceElementRanges[resultElement].push_back(
1563 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1568BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1569 VectorType targetVectorType)
1570 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1571 LDBG() <<
"\n" << enumerator.sourceElementRanges;
1577 VectorType preconditionType,
1579 if (!preconditionType || preconditionType.isScalable())
1584 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1585 if (bitwidth % 8 != 0)
1591LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1592 VectorType preconditionType,
1594 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1597 if (!preconditionType || preconditionType.getRank() != 1)
1635 VectorType subByteVecTy,
1639 "container element type is not a scalar");
1646 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1650 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1653 if (subByteBits != 2 && subByteBits != 4)
1655 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1658 if (containerBits % subByteBits != 0)
1662 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1664 op,
"not possible to fit this sub-byte vector type into a vector of "
1665 "the given multi-byte type");
1670SmallVector<BitCastRewriter::Metadata>
1671BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1672 SmallVector<BitCastRewriter::Metadata>
result;
1673 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1674 shuffleIdx < e; ++shuffleIdx) {
1675 SmallVector<int64_t> shuffles;
1676 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1679 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1680 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1681 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1683 shuffles.push_back(sourceElement);
1685 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1686 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1688 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1689 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1691 IntegerAttr mask = IntegerAttr::get(
1692 shuffledElementType,
1693 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1695 masks.push_back(mask);
1697 int64_t shiftRight = bitLo;
1698 shiftRightAmounts.push_back(
1699 IntegerAttr::get(shuffledElementType, shiftRight));
1701 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1702 shiftLeftAmounts.push_back(
1703 IntegerAttr::get(shuffledElementType, shiftLeft));
1706 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1711Value BitCastRewriter::genericRewriteStep(
1712 PatternRewriter &rewriter, Location loc, Value initialValue,
1713 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1715 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1716 initialValue, metadata.shuffles);
1719 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1720 auto constOp = arith::ConstantOp::create(
1723 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1726 auto shiftRightConstantOp = arith::ConstantOp::create(
1729 Value shiftedRight =
1730 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1733 auto shiftLeftConstantOp = arith::ConstantOp::create(
1737 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1741 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1744 return runningResult;
1755 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1756 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1757 assert(8 % srcBitwidth == 0 &&
1758 "Unsupported sub-byte type (not a divisor of i8)");
1759 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1762 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1763 auto i8VecType = VectorType::get(vecShape, rewriter.
getI8Type());
1764 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1785 int bitIdx,
int numBits) {
1786 auto srcType = cast<VectorType>(src.
getType());
1788 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1789 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1790 "Invalid bitIdx range");
1791 if (bitsToShiftLeft != 0) {
1792 Value shiftLeftValues = arith::ConstantOp::create(
1794 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1797 int8_t bitsToShiftRight = 8 - numBits;
1798 Value shiftRightValues = arith::ConstantOp::create(
1800 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1827 int bitIdx,
int numBits) {
1828 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1829 "Invalid bitIdx range");
1830 auto srcType = cast<VectorType>(src.
getType());
1831 int8_t bitsToShiftRight = bitIdx;
1833 if (bitsToShiftRight != 0) {
1834 Value shiftRightValues = arith::ConstantOp::create(
1836 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1838 if (bitIdx + numBits == 8) {
1841 uint8_t lowBitsMask = (1 << numBits) - 1;
1842 Value lowBitsMaskValues = arith::ConstantOp::create(
1844 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1854 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1855 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1856 "Expected i4 type");
1863 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1864 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1867 return vector::InterleaveOp::create(rewriter, loc, low, high);
1874 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1875 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1876 "Expected i2 type");
1883 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1885 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1887 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1889 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1900 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1901 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1902 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1910 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1911 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1912 "Expected i8 type");
1915 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1918 constexpr int8_t i8LowBitMask = 0x0F;
1919 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1920 Value zeroOutMask = arith::ConstantOp::create(
1922 Value zeroOutLow = arith::AndIOp::create(
1923 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1926 constexpr int8_t bitsToShift = 4;
1927 auto shiftValues = arith::ConstantOp::create(
1929 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1933 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1936 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1937 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1944struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1947 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1948 PatternRewriter &rewriter)
const override {
1951 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1956 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1957 VectorType targetVectorType = bitCastOp.getResultVectorType();
1958 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1959 if (
failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1963 Value truncValue = truncOp.getIn();
1964 auto shuffledElementType =
1966 Value runningResult;
1967 for (
const BitCastRewriter ::Metadata &metadata :
1968 bcr.precomputeMetadata(shuffledElementType)) {
1969 runningResult = bcr.genericRewriteStep(
1970 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1974 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1975 shuffledElementType.getIntOrFloatBitWidth();
1977 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1978 rewriter.
replaceOp(bitCastOp, runningResult);
1981 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1984 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1985 rewriter.
replaceOp(bitCastOp, runningResult);
1988 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2005template <
typename ExtOpType>
2006struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2007 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2009 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2010 : OpRewritePattern<ExtOpType>(context, benefit) {}
2012 LogicalResult matchAndRewrite(ExtOpType extOp,
2013 PatternRewriter &rewriter)
const override {
2015 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2020 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2021 VectorType targetVectorType = bitCastOp.getResultVectorType();
2022 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2023 if (
failed(bcr.commonPrecondition(
2024 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2028 Value runningResult;
2029 Value sourceValue = bitCastOp.getSource();
2030 auto shuffledElementType =
2032 for (
const BitCastRewriter::Metadata &metadata :
2033 bcr.precomputeMetadata(shuffledElementType)) {
2034 runningResult = bcr.genericRewriteStep(
2035 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2040 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2041 shuffledElementType.getIntOrFloatBitWidth();
2044 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2047 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2088template <
typename ConversionOpType,
bool isSigned>
2089struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2090 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2092 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2093 PatternRewriter &rewriter)
const override {
2095 Value srcValue = conversionOp.getIn();
2096 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2097 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2105 rewriter, srcVecType,
2110 Location loc = conversionOp.getLoc();
2114 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2127 conversionOp, conversionOp.getType(), subByteExt);
2149struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2152 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2153 PatternRewriter &rewriter)
const override {
2155 Value srcValue = truncOp.getIn();
2156 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2157 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2158 if (!srcVecType || !dstVecType)
2165 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2171 rewriter, dstVecType,
2176 Location loc = truncOp.getLoc();
2177 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2179 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2185 rewriter.
replaceOp(truncOp, subByteTrunc);
2202struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2205 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2206 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2208 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2209 PatternRewriter &rewriter)
const override {
2211 constexpr unsigned minNativeBitwidth = 8;
2212 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2213 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2214 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2216 "not a sub-byte transpose");
2220 Location loc = transposeOp.getLoc();
2225 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2227 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2228 transposeOp.getVector());
2229 Value newTranspose = vector::TransposeOp::create(
2230 rewriter, loc, extOp, transposeOp.getPermutation());
2231 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2245void vector::populateVectorNarrowTypeEmulationPatterns(
2246 const arith::NarrowTypeEmulationConverter &typeConverter,
2247 RewritePatternSet &
patterns,
bool disableAtomicRMW) {
2250 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2251 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2252 typeConverter,
patterns.getContext());
2257 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW);
2260void vector::populateVectorNarrowTypeRewritePatterns(
2261 RewritePatternSet &
patterns, PatternBenefit benefit) {
2263 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2264 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2270 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2271 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2272 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2276 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2277 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
2282void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2283 RewritePatternSet &
patterns, PatternBenefit benefit) {
2287void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2288 arith::NarrowTypeEmulationConverter &typeConverter,
2291 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter,
patterns);
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.