34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
45#define DEBUG_TYPE "vector-narrow-type-emulation"
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.
getI1Type());
114 std::optional<Operation *> newMask =
117 [&](vector::CreateMaskOp createMaskOp)
118 -> std::optional<Operation *> {
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
131 rewriter, loc, s0, origIndex);
133 newMaskOperands.push_back(
135 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
138 .Case([&](vector::ConstantMaskOp constantMaskOp)
139 -> std::optional<Operation *> {
142 int64_t &maskIndex = maskDimSizes.back();
143 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
145 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
148 .Case([&](arith::ConstantOp constantOp)
149 -> std::optional<Operation *> {
151 if (maskShape.size() != 1)
168 cast<DenseIntElementsAttr>(constantOp.getValue());
170 paddedMaskValues.append(originalMask.template value_begin<bool>(),
171 originalMask.template value_end<bool>());
172 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
176 for (
size_t i = 0; i < paddedMaskValues.size();
177 i += numSrcElemsPerDest) {
178 bool combinedValue =
false;
179 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
180 combinedValue |= paddedMaskValues[i +
j];
182 compressedMaskValues.push_back(combinedValue);
184 return arith::ConstantOp::create(
192 while (!extractOps.empty()) {
194 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
195 extractOps.back().getMixedPosition());
196 extractOps.pop_back();
218 auto vectorType = cast<VectorType>(src.
getType());
219 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
220 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
221 "subvector out of bounds");
225 if (vectorType.getNumElements() == numElemsToExtract)
232 auto resultVectorType =
233 VectorType::get({numElemsToExtract}, vectorType.getElementType());
234 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
235 src, offsets, sizes, strides)
250 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
251 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
261 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
262 dest, offsets, strides);
288 auto srcVecTy = cast<VectorType>(src.
getType());
289 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
298 if (srcVecTy.getNumElements() == numElemsToExtract)
301 for (
int i = 0; i < numElemsToExtract; ++i) {
303 (i == 0) ? dyn_cast<Value>(offset)
304 : arith::AddIOp::create(
306 dyn_cast<Value>(offset),
308 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
309 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
332 auto srcVecTy = cast<VectorType>(src.
getType());
333 auto destVecTy = cast<VectorType>(dest.
getType());
334 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
335 "expected source and dest to be rank-1 vector types");
338 assert(numElemsToInsert > 0 &&
339 "the number of elements to insert must be greater than 0");
343 assert(numElemsToInsert <= destVecTy.getNumElements() &&
344 "subvector out of bounds");
347 for (
int64_t i = 0; i < numElemsToInsert; ++i) {
349 i == 0 ? destOffsetVal
350 : arith::AddIOp::create(
353 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
354 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
368 int64_t numContainerElemsToLoad,
370 Type containerElemTy) {
373 auto newLoad = vector::LoadOp::create(
374 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
376 return vector::BitCastOp::create(
378 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
386 VectorType downcastType,
387 VectorType upcastType,
Value mask,
390 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
391 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
392 "expected input and output number of bits to match");
393 if (trueValue.
getType() != downcastType) {
395 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
397 if (falseValue.
getType() != downcastType) {
399 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
402 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
404 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
423 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
427 auto atomicOp = memref::GenericAtomicRMWOp::create(
428 builder, loc, linearizedMemref,
ValueRange{storeIdx});
429 Value origValue = atomicOp.getCurrentValue();
436 auto oneElemVecType = VectorType::get({1}, origValue.getType());
437 Value origVecValue = vector::FromElementsOp::create(
438 builder, loc, oneElemVecType,
ValueRange{origValue});
443 oneElemVecType, mask, valueToStore, origVecValue);
444 auto scalarMaskedValue =
445 vector::ExtractOp::create(builder, loc, maskedValue, 0);
446 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
454 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
456 auto oneElemVecType =
457 VectorType::get({1}, linearizedMemref.getType().
getElementType());
459 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
461 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
466 oneElemVecType, mask, valueToStore, origVecValue);
467 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
485 assert(
vector.getType().getRank() == 1 &&
"expected 1-D vector");
486 auto vectorElementType =
vector.getType().getElementType();
490 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
491 "sliceNumElements * vector element size must be less than or equal to 8");
492 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
493 "vector element must be a valid sub-byte type");
494 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
495 auto emptyByteVector = arith::ConstantOp::create(
497 VectorType::get({emulatedPerContainerElem}, vectorElementType),
498 rewriter.getZeroAttr(
499 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
501 extractOffset, sliceNumElements);
558struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
561 ConvertVectorStore(MLIRContext *context,
bool disableAtomicRMW,
563 : OpConversionPattern<vector::StoreOp>(context),
564 disableAtomicRMW(disableAtomicRMW), assumeAligned(assumeAligned) {}
567 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter)
const override {
570 if (op.getValueToStore().getType().getRank() != 1)
571 return rewriter.notifyMatchFailure(op,
572 "only 1-D vectors are supported ATM");
574 auto loc = op.getLoc();
576 auto valueToStore = cast<VectorValue>(op.getValueToStore());
577 auto containerElemTy =
578 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
579 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
581 int containerBits = containerElemTy.getIntOrFloatBitWidth();
584 if (containerBits % emulatedBits != 0) {
585 return rewriter.notifyMatchFailure(
586 op,
"impossible to pack emulated elements into container elements "
587 "(bit-wise misalignment)");
589 int emulatedPerContainerElem = containerBits / emulatedBits;
604 auto origElements = valueToStore.getType().getNumElements();
606 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
612 if (!isDivisibleInSize)
613 return rewriter.notifyMatchFailure(
614 op,
"the source vector does not fill whole container elements "
615 "(not divisible in size)");
617 auto stridedMetadata =
618 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
619 OpFoldResult linearizedIndices;
620 std::tie(std::ignore, linearizedIndices) =
622 rewriter, loc, emulatedBits, containerBits,
623 stridedMetadata.getConstifiedMixedOffset(),
624 stridedMetadata.getConstifiedMixedSizes(),
625 stridedMetadata.getConstifiedMixedStrides(),
627 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
628 int numElements = origElements / emulatedPerContainerElem;
629 auto bitCast = vector::BitCastOp::create(
630 rewriter, loc, VectorType::get(numElements, containerElemTy),
631 op.getValueToStore());
632 rewriter.replaceOpWithNewOp<vector::StoreOp>(
633 op, bitCast.getResult(), memrefBase,
638 auto stridedMetadata =
639 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
641 OpFoldResult linearizedIndices;
642 memref::LinearizedMemRefInfo linearizedInfo;
643 std::tie(linearizedInfo, linearizedIndices) =
645 rewriter, loc, emulatedBits, containerBits,
646 stridedMetadata.getConstifiedMixedOffset(),
647 stridedMetadata.getConstifiedMixedSizes(),
648 stridedMetadata.getConstifiedMixedStrides(),
655 std::optional<int64_t> foldedNumFrontPadElems =
658 if (!foldedNumFrontPadElems) {
659 return rewriter.notifyMatchFailure(
660 op,
"subbyte store emulation: dynamic front padding size is "
661 "not yet implemented");
664 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
696 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
698 if (!emulationRequiresPartialStores) {
700 auto numElements = origElements / emulatedPerContainerElem;
701 auto bitCast = vector::BitCastOp::create(
702 rewriter, loc, VectorType::get(numElements, containerElemTy),
703 op.getValueToStore());
704 rewriter.replaceOpWithNewOp<vector::StoreOp>(
705 op, bitCast.getResult(), memrefBase,
741 Value currentDestIndex =
744 auto currentSourceIndex = 0;
747 auto subWidthStoreMaskType =
748 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
756 auto frontSubWidthStoreElem =
757 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
758 emulatedPerContainerElem;
759 if (frontSubWidthStoreElem > 0) {
760 SmallVector<bool> frontMaskValues(emulatedPerContainerElem,
false);
761 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
762 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
764 frontSubWidthStoreElem = origElements;
766 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
767 *foldedNumFrontPadElems,
true);
769 auto frontMask = arith::ConstantOp::create(
773 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
776 frontSubWidthStoreElem, *foldedNumFrontPadElems);
778 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
779 cast<VectorValue>(value), frontMask.getResult());
782 if (currentSourceIndex >= origElements) {
783 rewriter.eraseOp(op);
790 currentDestIndex = arith::AddIOp::create(
791 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
796 int64_t fullWidthStoreSize =
797 (origElements - currentSourceIndex) / emulatedPerContainerElem;
798 int64_t numNonFullWidthElements =
799 fullWidthStoreSize * emulatedPerContainerElem;
800 if (fullWidthStoreSize > 0) {
802 rewriter, loc, valueToStore, currentSourceIndex,
803 numNonFullWidthElements);
805 auto originType = cast<VectorType>(fullWidthStorePart.getType());
807 auto storeType = VectorType::get(
808 {originType.getNumElements() / emulatedPerContainerElem},
810 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
812 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
815 currentSourceIndex += numNonFullWidthElements;
816 currentDestIndex = arith::AddIOp::create(
817 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
824 auto remainingElements = origElements - currentSourceIndex;
825 if (remainingElements != 0) {
826 auto subWidthStorePart =
828 currentSourceIndex, remainingElements, 0);
831 auto maskValues = SmallVector<bool>(emulatedPerContainerElem,
false);
832 std::fill_n(maskValues.begin(), remainingElements, 1);
833 auto backMask = arith::ConstantOp::create(
837 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
838 cast<VectorValue>(subWidthStorePart), backMask.getResult());
841 rewriter.eraseOp(op);
846 const bool disableAtomicRMW;
847 const bool assumeAligned;
861struct ConvertVectorMaskedStore final
862 : OpConversionPattern<vector::MaskedStoreOp> {
866 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
867 ConversionPatternRewriter &rewriter)
const override {
870 if (op.getValueToStore().getType().getRank() != 1)
871 return rewriter.notifyMatchFailure(
872 op,
"Memref in vector.maskedstore op must be flattened beforehand.");
874 auto loc = op.getLoc();
875 auto containerElemTy =
876 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
877 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
879 int containerBits = containerElemTy.getIntOrFloatBitWidth();
882 if (containerBits % emulatedBits != 0) {
883 return rewriter.notifyMatchFailure(
884 op,
"impossible to pack emulated elements into container elements "
885 "(bit-wise misalignment)");
888 int emulatedPerContainerElem = containerBits / emulatedBits;
889 int origElements = op.getValueToStore().getType().getNumElements();
890 if (origElements % emulatedPerContainerElem != 0)
893 auto stridedMetadata =
894 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
895 OpFoldResult linearizedIndicesOfr;
896 memref::LinearizedMemRefInfo linearizedInfo;
897 std::tie(linearizedInfo, linearizedIndicesOfr) =
899 rewriter, loc, emulatedBits, containerBits,
900 stridedMetadata.getConstifiedMixedOffset(),
901 stridedMetadata.getConstifiedMixedSizes(),
902 stridedMetadata.getConstifiedMixedStrides(),
904 Value linearizedIndices =
940 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
944 auto numElements = (origElements + emulatedPerContainerElem - 1) /
945 emulatedPerContainerElem;
946 auto newType = VectorType::get(numElements, containerElemTy);
947 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
948 rewriter.getZeroAttr(newType));
950 auto newLoad = vector::MaskedLoadOp::create(
951 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
952 newMask.value()->getResult(0), passThru);
954 auto newBitCastType =
955 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
957 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
958 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
959 op.getValueToStore(), valueToStore);
961 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
963 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
964 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
985struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
989 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
990 ConversionPatternRewriter &rewriter)
const override {
992 if (op.getVectorType().getRank() != 1)
993 return rewriter.notifyMatchFailure(
994 op,
"Memref in emulated vector ops must be flattened beforehand.");
996 auto loc = op.getLoc();
997 auto containerElemTy =
998 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
999 Type emulatedElemTy = op.getType().getElementType();
1001 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1004 if (containerBits % emulatedBits != 0) {
1005 return rewriter.notifyMatchFailure(
1006 op,
"impossible to pack emulated elements into container elements "
1007 "(bit-wise misalignment)");
1009 int emulatedPerContainerElem = containerBits / emulatedBits;
1038 auto origElements = op.getVectorType().getNumElements();
1040 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1042 auto stridedMetadata =
1043 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1045 OpFoldResult linearizedIndices;
1046 memref::LinearizedMemRefInfo linearizedInfo;
1047 std::tie(linearizedInfo, linearizedIndices) =
1049 rewriter, loc, emulatedBits, containerBits,
1050 stridedMetadata.getConstifiedMixedOffset(),
1051 stridedMetadata.getConstifiedMixedSizes(),
1052 stridedMetadata.getConstifiedMixedStrides(),
1055 std::optional<int64_t> foldedIntraVectorOffset =
1056 isDivisibleInSize ? 0
1060 int64_t maxintraDataOffset =
1061 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1062 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1063 emulatedPerContainerElem);
1066 numElements, emulatedElemTy, containerElemTy);
1068 if (!foldedIntraVectorOffset) {
1069 auto resultVector = arith::ConstantOp::create(
1070 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1074 }
else if (!isDivisibleInSize) {
1076 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1078 rewriter.replaceOp(op,
result);
1093struct ConvertVectorMaskedLoad final
1094 : OpConversionPattern<vector::MaskedLoadOp> {
1098 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter)
const override {
1100 if (op.getVectorType().getRank() != 1)
1101 return rewriter.notifyMatchFailure(
1102 op,
"Memref in emulated vector ops must be flattened beforehand.");
1104 auto loc = op.getLoc();
1106 auto containerElemTy =
1107 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1108 Type emulatedElemTy = op.getType().getElementType();
1110 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1113 if (containerBits % emulatedBits != 0) {
1114 return rewriter.notifyMatchFailure(
1115 op,
"impossible to pack emulated elements into container elements "
1116 "(bit-wise misalignment)");
1118 int emulatedPerContainerElem = containerBits / emulatedBits;
1162 auto origType = op.getVectorType();
1163 auto origElements = origType.getNumElements();
1165 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1167 auto stridedMetadata =
1168 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1169 OpFoldResult linearizedIndices;
1170 memref::LinearizedMemRefInfo linearizedInfo;
1171 std::tie(linearizedInfo, linearizedIndices) =
1173 rewriter, loc, emulatedBits, containerBits,
1174 stridedMetadata.getConstifiedMixedOffset(),
1175 stridedMetadata.getConstifiedMixedSizes(),
1176 stridedMetadata.getConstifiedMixedStrides(),
1179 std::optional<int64_t> foldedIntraVectorOffset =
1180 isDivisibleInSize ? 0
1183 int64_t maxIntraDataOffset =
1184 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1185 FailureOr<Operation *> newMask =
1187 emulatedPerContainerElem, maxIntraDataOffset);
1191 Value passthru = op.getPassThru();
1193 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1194 emulatedPerContainerElem);
1195 auto loadType = VectorType::get(numElements, containerElemTy);
1196 auto newBitcastType =
1197 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1199 auto emptyVector = arith::ConstantOp::create(
1200 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1201 if (!foldedIntraVectorOffset) {
1205 }
else if (!isDivisibleInSize) {
1207 *foldedIntraVectorOffset);
1210 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1213 auto newLoad = vector::MaskedLoadOp::create(
1214 rewriter, loc, loadType, adaptor.getBase(),
1216 newMask.value()->getResult(0), newPassThru);
1221 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1223 Value mask = op.getMask();
1224 auto newSelectMaskType = VectorType::get(
1225 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1228 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1229 rewriter.getZeroAttr(newSelectMaskType));
1230 if (!foldedIntraVectorOffset) {
1234 }
else if (!isDivisibleInSize) {
1236 *foldedIntraVectorOffset);
1240 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1241 if (!foldedIntraVectorOffset) {
1243 rewriter, loc,
result, op.getPassThru(),
1245 }
else if (!isDivisibleInSize) {
1247 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1249 rewriter.replaceOp(op,
result);
1270static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1271 Type multiByteScalarTy) {
1272 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1274 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1277 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1278 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1279 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1281 int elemsPerMultiByte = multiByteBits / subByteBits;
1283 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1291struct ConvertVectorTransferRead final
1292 : OpConversionPattern<vector::TransferReadOp> {
1296 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1297 ConversionPatternRewriter &rewriter)
const override {
1301 if (op.getVectorType().getRank() != 1)
1302 return rewriter.notifyMatchFailure(
1303 op,
"Memref in emulated vector ops must be flattened beforehand.");
1305 auto loc = op.getLoc();
1306 auto containerElemTy =
1307 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1308 Type emulatedElemTy = op.getType().getElementType();
1310 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1313 if (containerBits % emulatedBits != 0) {
1314 return rewriter.notifyMatchFailure(
1315 op,
"impossible to pack emulated elements into container elements "
1316 "(bit-wise misalignment)");
1318 int emulatedPerContainerElem = containerBits / emulatedBits;
1320 auto origElements = op.getVectorType().getNumElements();
1323 bool isDivisibleInSize =
1324 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1328 Value padding = adaptor.getPadding();
1330 padding = arith::BitcastOp::create(
1332 IntegerType::get(rewriter.getContext(),
1337 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1339 auto stridedMetadata =
1340 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1342 OpFoldResult linearizedIndices;
1343 memref::LinearizedMemRefInfo linearizedInfo;
1344 std::tie(linearizedInfo, linearizedIndices) =
1346 rewriter, loc, emulatedBits, containerBits,
1347 stridedMetadata.getConstifiedMixedOffset(),
1348 stridedMetadata.getConstifiedMixedSizes(),
1349 stridedMetadata.getConstifiedMixedStrides(),
1352 std::optional<int64_t> foldedIntraVectorOffset =
1353 isDivisibleInSize ? 0
1356 int64_t maxIntraDataOffset =
1357 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1358 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1359 emulatedPerContainerElem);
1361 auto newRead = vector::TransferReadOp::create(
1362 rewriter, loc, VectorType::get(numElements, containerElemTy),
1367 auto bitCast = vector::BitCastOp::create(
1369 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1372 Value
result = bitCast->getResult(0);
1373 if (!foldedIntraVectorOffset) {
1374 auto zeros = arith::ConstantOp::create(
1375 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1379 }
else if (!isDivisibleInSize) {
1381 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1383 rewriter.replaceOp(op,
result);
1398struct SourceElementRange {
1400 int64_t sourceElementIdx;
1402 int64_t sourceBitBegin;
1403 int64_t sourceBitEnd;
1406struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1412 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1414 for (int64_t i = 0; i < shuffleIdx; ++i)
1415 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1434struct BitCastBitsEnumerator {
1435 BitCastBitsEnumerator(VectorType sourceVectorType,
1436 VectorType targetVectorType);
1438 int64_t getMaxNumberOfEntries() {
1439 int64_t numVectors = 0;
1440 for (
const auto &l : sourceElementRanges)
1441 numVectors = std::max(numVectors, (int64_t)l.size());
1445 VectorType sourceVectorType;
1446 VectorType targetVectorType;
1447 SmallVector<SourceElementRangeList> sourceElementRanges;
1521struct BitCastRewriter {
1524 SmallVector<int64_t> shuffles;
1525 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1528 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1531 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1532 VectorType preconditionType, Operation *op);
1535 SmallVector<BitCastRewriter::Metadata>
1536 precomputeMetadata(IntegerType shuffledElementType);
1540 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1541 Value initialValue, Value runningResult,
1542 const BitCastRewriter::Metadata &metadata);
1547 BitCastBitsEnumerator enumerator;
1554 for (
const auto &l : vec) {
1555 for (
auto it : llvm::enumerate(l)) {
1556 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1557 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1558 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1565BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1566 VectorType targetVectorType)
1567 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1569 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1570 "requires -D non-scalable vector type");
1571 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1572 "requires -D non-scalable vector type");
1573 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1574 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1575 LDBG() <<
"sourceVectorType: " << sourceVectorType;
1577 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1578 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1579 LDBG() <<
"targetVectorType: " << targetVectorType;
1581 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1582 (
void)mostMinorSourceDim;
1583 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1584 "source and target bitwidths must match");
1588 for (
int64_t resultBit = 0; resultBit < bitwidth;) {
1589 int64_t resultElement = resultBit / targetBitWidth;
1590 int64_t resultBitInElement = resultBit % targetBitWidth;
1591 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1592 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1593 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1594 targetBitWidth - resultBitInElement);
1595 sourceElementRanges[resultElement].push_back(
1596 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1601BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1602 VectorType targetVectorType)
1603 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1604 LDBG() <<
"\n" << enumerator.sourceElementRanges;
1610 VectorType preconditionType,
1612 if (!preconditionType || preconditionType.isScalable())
1617 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1618 if (bitwidth % 8 != 0)
1624LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1625 VectorType preconditionType,
1627 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1630 if (!preconditionType || preconditionType.getRank() != 1)
1668 VectorType subByteVecTy,
1672 "container element type is not a scalar");
1679 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1683 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1686 if (subByteBits != 2 && subByteBits != 4)
1688 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1691 if (containerBits % subByteBits != 0)
1695 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1697 op,
"not possible to fit this sub-byte vector type into a vector of "
1698 "the given multi-byte type");
1703SmallVector<BitCastRewriter::Metadata>
1704BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1705 SmallVector<BitCastRewriter::Metadata>
result;
1706 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1707 shuffleIdx < e; ++shuffleIdx) {
1708 SmallVector<int64_t> shuffles;
1709 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1712 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1713 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1714 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1716 shuffles.push_back(sourceElement);
1718 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1719 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1721 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1722 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1724 IntegerAttr mask = IntegerAttr::get(
1725 shuffledElementType,
1726 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1728 masks.push_back(mask);
1730 int64_t shiftRight = bitLo;
1731 shiftRightAmounts.push_back(
1732 IntegerAttr::get(shuffledElementType, shiftRight));
1734 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1735 shiftLeftAmounts.push_back(
1736 IntegerAttr::get(shuffledElementType, shiftLeft));
1739 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1744Value BitCastRewriter::genericRewriteStep(
1745 PatternRewriter &rewriter, Location loc, Value initialValue,
1746 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1748 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1749 initialValue, metadata.shuffles);
1752 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1753 auto constOp = arith::ConstantOp::create(
1756 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1759 auto shiftRightConstantOp = arith::ConstantOp::create(
1762 Value shiftedRight =
1763 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1766 auto shiftLeftConstantOp = arith::ConstantOp::create(
1770 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1774 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1777 return runningResult;
1788 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1789 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1790 assert(8 % srcBitwidth == 0 &&
1791 "Unsupported sub-byte type (not a divisor of i8)");
1792 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1795 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1796 auto i8VecType = VectorType::get(vecShape, rewriter.
getI8Type());
1797 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1818 int bitIdx,
int numBits) {
1819 auto srcType = cast<VectorType>(src.
getType());
1821 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1822 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1823 "Invalid bitIdx range");
1824 if (bitsToShiftLeft != 0) {
1825 Value shiftLeftValues = arith::ConstantOp::create(
1827 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1830 int8_t bitsToShiftRight = 8 - numBits;
1831 Value shiftRightValues = arith::ConstantOp::create(
1833 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1860 int bitIdx,
int numBits) {
1861 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1862 "Invalid bitIdx range");
1863 auto srcType = cast<VectorType>(src.
getType());
1864 int8_t bitsToShiftRight = bitIdx;
1866 if (bitsToShiftRight != 0) {
1867 Value shiftRightValues = arith::ConstantOp::create(
1869 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1871 if (bitIdx + numBits == 8) {
1874 uint8_t lowBitsMask = (1 << numBits) - 1;
1875 Value lowBitsMaskValues = arith::ConstantOp::create(
1877 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1887 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1888 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1889 "Expected i4 type");
1896 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1897 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1900 return vector::InterleaveOp::create(rewriter, loc, low, high);
1907 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1908 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1909 "Expected i2 type");
1916 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1918 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1920 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1922 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1933 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1934 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1935 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1943 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1944 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1945 "Expected i8 type");
1948 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1951 constexpr int8_t i8LowBitMask = 0x0F;
1952 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1953 Value zeroOutMask = arith::ConstantOp::create(
1955 Value zeroOutLow = arith::AndIOp::create(
1956 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1959 constexpr int8_t bitsToShift = 4;
1960 auto shiftValues = arith::ConstantOp::create(
1962 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1966 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1969 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1970 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1977struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1980 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1981 PatternRewriter &rewriter)
const override {
1984 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1989 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1990 VectorType targetVectorType = bitCastOp.getResultVectorType();
1991 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1992 if (
failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1996 Value truncValue = truncOp.getIn();
1997 auto shuffledElementType =
1999 Value runningResult;
2000 for (
const BitCastRewriter ::Metadata &metadata :
2001 bcr.precomputeMetadata(shuffledElementType)) {
2002 runningResult = bcr.genericRewriteStep(
2003 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
2007 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
2008 shuffledElementType.getIntOrFloatBitWidth();
2010 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
2011 rewriter.
replaceOp(bitCastOp, runningResult);
2014 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2017 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
2018 rewriter.
replaceOp(bitCastOp, runningResult);
2021 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2038template <
typename ExtOpType>
2039struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2040 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2042 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2043 : OpRewritePattern<ExtOpType>(context, benefit) {}
2045 LogicalResult matchAndRewrite(ExtOpType extOp,
2046 PatternRewriter &rewriter)
const override {
2048 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2053 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2054 VectorType targetVectorType = bitCastOp.getResultVectorType();
2055 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2056 if (
failed(bcr.commonPrecondition(
2057 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2061 Value runningResult;
2062 Value sourceValue = bitCastOp.getSource();
2063 auto shuffledElementType =
2065 for (
const BitCastRewriter::Metadata &metadata :
2066 bcr.precomputeMetadata(shuffledElementType)) {
2067 runningResult = bcr.genericRewriteStep(
2068 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2073 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2074 shuffledElementType.getIntOrFloatBitWidth();
2077 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2080 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2121template <
typename ConversionOpType,
bool isSigned>
2122struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2123 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2125 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2126 PatternRewriter &rewriter)
const override {
2128 Value srcValue = conversionOp.getIn();
2129 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2130 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2138 rewriter, srcVecType,
2143 Location loc = conversionOp.getLoc();
2147 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2161 if (subByteExt.
getType() == conversionOp.getType())
2162 rewriter.
replaceOp(conversionOp, subByteExt);
2165 conversionOp, conversionOp.getType(), subByteExt);
2187struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2190 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2191 PatternRewriter &rewriter)
const override {
2193 Value srcValue = truncOp.getIn();
2194 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2195 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2196 if (!srcVecType || !dstVecType)
2203 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2209 rewriter, dstVecType,
2214 Location loc = truncOp.getLoc();
2215 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2217 srcVecType == i8VecType
2219 : arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2225 rewriter.
replaceOp(truncOp, subByteTrunc);
2242struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2245 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2246 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2248 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2249 PatternRewriter &rewriter)
const override {
2251 constexpr unsigned minNativeBitwidth = 8;
2252 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2253 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2254 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2256 "not a sub-byte transpose");
2260 Location loc = transposeOp.getLoc();
2265 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2267 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2268 transposeOp.getVector());
2269 Value newTranspose = vector::TransposeOp::create(
2270 rewriter, loc, extOp, transposeOp.getPermutation());
2271 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2285void vector::populateVectorNarrowTypeEmulationPatterns(
2286 const arith::NarrowTypeEmulationConverter &typeConverter,
2287 RewritePatternSet &patterns,
bool disableAtomicRMW,
bool assumeAligned) {
2290 patterns.
add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2291 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2297 patterns.
insert<ConvertVectorStore>(patterns.
getContext(), disableAtomicRMW,
2301void vector::populateVectorNarrowTypeRewritePatterns(
2302 RewritePatternSet &patterns, PatternBenefit benefit) {
2304 patterns.
add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2305 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.
getContext(),
2311 patterns.
add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2312 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2313 RewriteAlignedSubByteIntTrunc>(patterns.
getContext(),
2317 .
add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2318 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
2323void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2324 RewritePatternSet &patterns, PatternBenefit benefit) {
2325 patterns.
add<RewriteVectorTranspose>(patterns.
getContext(), benefit);
2328void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2329 arith::NarrowTypeEmulationConverter &typeConverter,
2330 RewritePatternSet &patterns) {
2332 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Patterns for flattening all supported multi-dimensional memref operations into one-dimensional memref...
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.