34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
45#define DEBUG_TYPE "vector-narrow-type-emulation"
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.
getI1Type());
114 std::optional<Operation *> newMask =
117 [&](vector::CreateMaskOp createMaskOp)
118 -> std::optional<Operation *> {
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
131 rewriter, loc, s0, origIndex);
133 newMaskOperands.push_back(
135 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
138 .Case([&](vector::ConstantMaskOp constantMaskOp)
139 -> std::optional<Operation *> {
142 int64_t &maskIndex = maskDimSizes.back();
143 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
145 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
148 .Case([&](arith::ConstantOp constantOp)
149 -> std::optional<Operation *> {
151 if (maskShape.size() != 1)
168 cast<DenseIntElementsAttr>(constantOp.getValue());
170 paddedMaskValues.append(originalMask.template value_begin<bool>(),
171 originalMask.template value_end<bool>());
172 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
176 for (
size_t i = 0; i < paddedMaskValues.size();
177 i += numSrcElemsPerDest) {
178 bool combinedValue =
false;
179 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
180 combinedValue |= paddedMaskValues[i +
j];
182 compressedMaskValues.push_back(combinedValue);
184 return arith::ConstantOp::create(
192 while (!extractOps.empty()) {
194 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
195 extractOps.back().getMixedPosition());
196 extractOps.pop_back();
218 auto vectorType = cast<VectorType>(src.
getType());
219 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
220 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
221 "subvector out of bounds");
225 if (vectorType.getNumElements() == numElemsToExtract)
232 auto resultVectorType =
233 VectorType::get({numElemsToExtract}, vectorType.getElementType());
234 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
235 src, offsets, sizes, strides)
250 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
251 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
261 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
262 dest, offsets, strides);
288 auto srcVecTy = cast<VectorType>(src.
getType());
289 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
298 if (srcVecTy.getNumElements() == numElemsToExtract)
301 for (
int i = 0; i < numElemsToExtract; ++i) {
303 (i == 0) ? dyn_cast<Value>(offset)
304 : arith::AddIOp::create(
306 dyn_cast<Value>(offset),
308 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
309 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
332 auto srcVecTy = cast<VectorType>(src.
getType());
333 auto destVecTy = cast<VectorType>(dest.
getType());
334 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
335 "expected source and dest to be rank-1 vector types");
338 assert(numElemsToInsert > 0 &&
339 "the number of elements to insert must be greater than 0");
343 assert(numElemsToInsert <= destVecTy.getNumElements() &&
344 "subvector out of bounds");
347 for (
int64_t i = 0; i < numElemsToInsert; ++i) {
349 i == 0 ? destOffsetVal
350 : arith::AddIOp::create(
353 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
354 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
368 int64_t numContainerElemsToLoad,
370 Type containerElemTy) {
373 auto newLoad = vector::LoadOp::create(
374 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
376 return vector::BitCastOp::create(
378 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
386 VectorType downcastType,
387 VectorType upcastType,
Value mask,
390 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
391 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
392 "expected input and output number of bits to match");
393 if (trueValue.
getType() != downcastType) {
395 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
397 if (falseValue.
getType() != downcastType) {
399 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
402 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
404 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
423 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
427 auto atomicOp = memref::GenericAtomicRMWOp::create(
428 builder, loc, linearizedMemref,
ValueRange{storeIdx});
429 Value origValue = atomicOp.getCurrentValue();
436 auto oneElemVecType = VectorType::get({1}, origValue.getType());
437 Value origVecValue = vector::FromElementsOp::create(
438 builder, loc, oneElemVecType,
ValueRange{origValue});
443 oneElemVecType, mask, valueToStore, origVecValue);
444 auto scalarMaskedValue =
445 vector::ExtractOp::create(builder, loc, maskedValue, 0);
446 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
454 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
456 auto oneElemVecType =
457 VectorType::get({1}, linearizedMemref.getType().
getElementType());
459 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
461 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
466 oneElemVecType, mask, valueToStore, origVecValue);
467 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
485 assert(
vector.getType().getRank() == 1 &&
"expected 1-D vector");
486 auto vectorElementType =
vector.getType().getElementType();
490 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
491 "sliceNumElements * vector element size must be less than or equal to 8");
492 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
493 "vector element must be a valid sub-byte type");
494 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
495 auto emptyByteVector = arith::ConstantOp::create(
497 VectorType::get({emulatedPerContainerElem}, vectorElementType),
498 rewriter.getZeroAttr(
499 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
501 extractOffset, sliceNumElements);
558struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
561 ConvertVectorStore(MLIRContext *context,
bool disableAtomicRMW,
563 : OpConversionPattern<vector::StoreOp>(context),
564 disableAtomicRMW(disableAtomicRMW), assumeAligned(assumeAligned) {}
567 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter)
const override {
570 if (op.getValueToStore().getType().getRank() != 1)
571 return rewriter.notifyMatchFailure(op,
572 "only 1-D vectors are supported ATM");
574 auto loc = op.getLoc();
576 auto valueToStore = cast<VectorValue>(op.getValueToStore());
577 auto containerElemTy =
578 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
579 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
581 int containerBits = containerElemTy.getIntOrFloatBitWidth();
584 if (containerBits % emulatedBits != 0) {
585 return rewriter.notifyMatchFailure(
586 op,
"impossible to pack emulated elements into container elements "
587 "(bit-wise misalignment)");
589 int emulatedPerContainerElem = containerBits / emulatedBits;
604 auto origElements = valueToStore.getType().getNumElements();
606 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
612 if (!isDivisibleInSize)
613 return rewriter.notifyMatchFailure(
614 op,
"the source vector does not fill whole container elements "
615 "(not divisible in size)");
617 auto stridedMetadata =
618 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
619 OpFoldResult linearizedIndices;
620 std::tie(std::ignore, linearizedIndices) =
622 rewriter, loc, emulatedBits, containerBits,
623 stridedMetadata.getConstifiedMixedOffset(),
624 stridedMetadata.getConstifiedMixedSizes(),
625 stridedMetadata.getConstifiedMixedStrides(),
627 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
628 int numElements = origElements / emulatedPerContainerElem;
629 auto bitCast = vector::BitCastOp::create(
630 rewriter, loc, VectorType::get(numElements, containerElemTy),
631 op.getValueToStore());
632 rewriter.replaceOpWithNewOp<vector::StoreOp>(
633 op, bitCast.getResult(), memrefBase,
642 auto trailingDim = op.getBase().getType().getShape().back();
643 bool trailingDimsMatch =
644 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
646 auto stridedMetadata =
647 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
651 OpFoldResult linearizedIndices;
652 memref::LinearizedMemRefInfo linearizedInfo;
653 std::tie(linearizedInfo, linearizedIndices) =
655 rewriter, loc, emulatedBits, containerBits,
656 stridedMetadata.getConstifiedMixedOffset(),
657 stridedMetadata.getConstifiedMixedSizes(),
658 stridedMetadata.getConstifiedMixedStrides(),
661 std::optional<int64_t> foldedNumFrontPadElems =
662 (isDivisibleInSize && trailingDimsMatch)
666 if (!foldedNumFrontPadElems) {
667 return rewriter.notifyMatchFailure(
668 op,
"subbyte store emulation: dynamic front padding size is "
669 "not yet implemented");
672 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
704 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
706 if (!emulationRequiresPartialStores) {
708 auto numElements = origElements / emulatedPerContainerElem;
709 auto bitCast = vector::BitCastOp::create(
710 rewriter, loc, VectorType::get(numElements, containerElemTy),
711 op.getValueToStore());
712 rewriter.replaceOpWithNewOp<vector::StoreOp>(
713 op, bitCast.getResult(), memrefBase,
749 Value currentDestIndex =
752 auto currentSourceIndex = 0;
755 auto subWidthStoreMaskType =
756 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
764 auto frontSubWidthStoreElem =
765 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
766 emulatedPerContainerElem;
767 if (frontSubWidthStoreElem > 0) {
768 SmallVector<bool> frontMaskValues(emulatedPerContainerElem,
false);
769 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
770 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
772 frontSubWidthStoreElem = origElements;
774 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
775 *foldedNumFrontPadElems,
true);
777 auto frontMask = arith::ConstantOp::create(
781 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
784 frontSubWidthStoreElem, *foldedNumFrontPadElems);
786 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
787 cast<VectorValue>(value), frontMask.getResult());
790 if (currentSourceIndex >= origElements) {
791 rewriter.eraseOp(op);
798 currentDestIndex = arith::AddIOp::create(
799 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
804 int64_t fullWidthStoreSize =
805 (origElements - currentSourceIndex) / emulatedPerContainerElem;
806 int64_t numNonFullWidthElements =
807 fullWidthStoreSize * emulatedPerContainerElem;
808 if (fullWidthStoreSize > 0) {
810 rewriter, loc, valueToStore, currentSourceIndex,
811 numNonFullWidthElements);
813 auto originType = cast<VectorType>(fullWidthStorePart.getType());
815 auto storeType = VectorType::get(
816 {originType.getNumElements() / emulatedPerContainerElem},
818 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
820 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
823 currentSourceIndex += numNonFullWidthElements;
824 currentDestIndex = arith::AddIOp::create(
825 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
832 auto remainingElements = origElements - currentSourceIndex;
833 if (remainingElements != 0) {
834 auto subWidthStorePart =
836 currentSourceIndex, remainingElements, 0);
839 auto maskValues = SmallVector<bool>(emulatedPerContainerElem,
false);
840 std::fill_n(maskValues.begin(), remainingElements, 1);
841 auto backMask = arith::ConstantOp::create(
845 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
846 cast<VectorValue>(subWidthStorePart), backMask.getResult());
849 rewriter.eraseOp(op);
854 const bool disableAtomicRMW;
855 const bool assumeAligned;
869struct ConvertVectorMaskedStore final
870 : OpConversionPattern<vector::MaskedStoreOp> {
874 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
875 ConversionPatternRewriter &rewriter)
const override {
878 if (op.getValueToStore().getType().getRank() != 1)
879 return rewriter.notifyMatchFailure(
880 op,
"Memref in vector.maskedstore op must be flattened beforehand.");
882 auto loc = op.getLoc();
883 auto containerElemTy =
884 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
885 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
887 int containerBits = containerElemTy.getIntOrFloatBitWidth();
890 if (containerBits % emulatedBits != 0) {
891 return rewriter.notifyMatchFailure(
892 op,
"impossible to pack emulated elements into container elements "
893 "(bit-wise misalignment)");
896 int emulatedPerContainerElem = containerBits / emulatedBits;
897 int origElements = op.getValueToStore().getType().getNumElements();
898 if (origElements % emulatedPerContainerElem != 0)
901 auto stridedMetadata =
902 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
903 OpFoldResult linearizedIndicesOfr;
904 memref::LinearizedMemRefInfo linearizedInfo;
905 std::tie(linearizedInfo, linearizedIndicesOfr) =
907 rewriter, loc, emulatedBits, containerBits,
908 stridedMetadata.getConstifiedMixedOffset(),
909 stridedMetadata.getConstifiedMixedSizes(),
910 stridedMetadata.getConstifiedMixedStrides(),
912 Value linearizedIndices =
948 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
952 auto numElements = (origElements + emulatedPerContainerElem - 1) /
953 emulatedPerContainerElem;
954 auto newType = VectorType::get(numElements, containerElemTy);
955 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
956 rewriter.getZeroAttr(newType));
958 auto newLoad = vector::MaskedLoadOp::create(
959 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
960 newMask.value()->getResult(0), passThru);
962 auto newBitCastType =
963 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
965 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
966 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
967 op.getValueToStore(), valueToStore);
969 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
971 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
972 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
993struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
997 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
998 ConversionPatternRewriter &rewriter)
const override {
1000 if (op.getVectorType().getRank() != 1)
1001 return rewriter.notifyMatchFailure(
1002 op,
"Memref in emulated vector ops must be flattened beforehand.");
1004 auto loc = op.getLoc();
1005 auto containerElemTy =
1006 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1007 Type emulatedElemTy = op.getType().getElementType();
1009 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1012 if (containerBits % emulatedBits != 0) {
1013 return rewriter.notifyMatchFailure(
1014 op,
"impossible to pack emulated elements into container elements "
1015 "(bit-wise misalignment)");
1017 int emulatedPerContainerElem = containerBits / emulatedBits;
1046 auto origElements = op.getVectorType().getNumElements();
1048 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1050 auto stridedMetadata =
1051 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1053 OpFoldResult linearizedIndices;
1054 memref::LinearizedMemRefInfo linearizedInfo;
1055 std::tie(linearizedInfo, linearizedIndices) =
1057 rewriter, loc, emulatedBits, containerBits,
1058 stridedMetadata.getConstifiedMixedOffset(),
1059 stridedMetadata.getConstifiedMixedSizes(),
1060 stridedMetadata.getConstifiedMixedStrides(),
1063 std::optional<int64_t> foldedIntraVectorOffset =
1064 isDivisibleInSize ? 0
1068 int64_t maxintraDataOffset =
1069 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1070 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1071 emulatedPerContainerElem);
1074 numElements, emulatedElemTy, containerElemTy);
1076 if (!foldedIntraVectorOffset) {
1077 auto resultVector = arith::ConstantOp::create(
1078 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1082 }
else if (!isDivisibleInSize) {
1084 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1086 rewriter.replaceOp(op,
result);
1101struct ConvertVectorMaskedLoad final
1102 : OpConversionPattern<vector::MaskedLoadOp> {
1106 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1107 ConversionPatternRewriter &rewriter)
const override {
1108 if (op.getVectorType().getRank() != 1)
1109 return rewriter.notifyMatchFailure(
1110 op,
"Memref in emulated vector ops must be flattened beforehand.");
1112 auto loc = op.getLoc();
1114 auto containerElemTy =
1115 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1116 Type emulatedElemTy = op.getType().getElementType();
1118 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1121 if (containerBits % emulatedBits != 0) {
1122 return rewriter.notifyMatchFailure(
1123 op,
"impossible to pack emulated elements into container elements "
1124 "(bit-wise misalignment)");
1126 int emulatedPerContainerElem = containerBits / emulatedBits;
1170 auto origType = op.getVectorType();
1171 auto origElements = origType.getNumElements();
1173 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1175 auto stridedMetadata =
1176 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1177 OpFoldResult linearizedIndices;
1178 memref::LinearizedMemRefInfo linearizedInfo;
1179 std::tie(linearizedInfo, linearizedIndices) =
1181 rewriter, loc, emulatedBits, containerBits,
1182 stridedMetadata.getConstifiedMixedOffset(),
1183 stridedMetadata.getConstifiedMixedSizes(),
1184 stridedMetadata.getConstifiedMixedStrides(),
1187 std::optional<int64_t> foldedIntraVectorOffset =
1188 isDivisibleInSize ? 0
1191 int64_t maxIntraDataOffset =
1192 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1193 FailureOr<Operation *> newMask =
1195 emulatedPerContainerElem, maxIntraDataOffset);
1199 Value passthru = op.getPassThru();
1201 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1202 emulatedPerContainerElem);
1203 auto loadType = VectorType::get(numElements, containerElemTy);
1204 auto newBitcastType =
1205 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1207 auto emptyVector = arith::ConstantOp::create(
1208 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1209 if (!foldedIntraVectorOffset) {
1213 }
else if (!isDivisibleInSize) {
1215 *foldedIntraVectorOffset);
1218 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1221 auto newLoad = vector::MaskedLoadOp::create(
1222 rewriter, loc, loadType, adaptor.getBase(),
1224 newMask.value()->getResult(0), newPassThru);
1229 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1231 Value mask = op.getMask();
1232 auto newSelectMaskType = VectorType::get(
1233 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1236 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1237 rewriter.getZeroAttr(newSelectMaskType));
1238 if (!foldedIntraVectorOffset) {
1242 }
else if (!isDivisibleInSize) {
1244 *foldedIntraVectorOffset);
1248 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1249 if (!foldedIntraVectorOffset) {
1251 rewriter, loc,
result, op.getPassThru(),
1253 }
else if (!isDivisibleInSize) {
1255 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1257 rewriter.replaceOp(op,
result);
1278static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1279 Type multiByteScalarTy) {
1280 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1282 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1285 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1286 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1287 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1289 int elemsPerMultiByte = multiByteBits / subByteBits;
1291 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1299struct ConvertVectorTransferRead final
1300 : OpConversionPattern<vector::TransferReadOp> {
1304 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1305 ConversionPatternRewriter &rewriter)
const override {
1309 if (op.getVectorType().getRank() != 1)
1310 return rewriter.notifyMatchFailure(
1311 op,
"Memref in emulated vector ops must be flattened beforehand.");
1313 auto loc = op.getLoc();
1314 auto containerElemTy =
1315 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1316 Type emulatedElemTy = op.getType().getElementType();
1318 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1321 if (containerBits % emulatedBits != 0) {
1322 return rewriter.notifyMatchFailure(
1323 op,
"impossible to pack emulated elements into container elements "
1324 "(bit-wise misalignment)");
1326 int emulatedPerContainerElem = containerBits / emulatedBits;
1328 auto origElements = op.getVectorType().getNumElements();
1331 bool isDivisibleInSize =
1332 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1336 Value padding = adaptor.getPadding();
1338 padding = arith::BitcastOp::create(
1340 IntegerType::get(rewriter.getContext(),
1345 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1347 auto stridedMetadata =
1348 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1350 OpFoldResult linearizedIndices;
1351 memref::LinearizedMemRefInfo linearizedInfo;
1352 std::tie(linearizedInfo, linearizedIndices) =
1354 rewriter, loc, emulatedBits, containerBits,
1355 stridedMetadata.getConstifiedMixedOffset(),
1356 stridedMetadata.getConstifiedMixedSizes(),
1357 stridedMetadata.getConstifiedMixedStrides(),
1360 std::optional<int64_t> foldedIntraVectorOffset =
1361 isDivisibleInSize ? 0
1364 int64_t maxIntraDataOffset =
1365 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1366 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1367 emulatedPerContainerElem);
1369 auto newRead = vector::TransferReadOp::create(
1370 rewriter, loc, VectorType::get(numElements, containerElemTy),
1375 auto bitCast = vector::BitCastOp::create(
1377 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1380 Value
result = bitCast->getResult(0);
1381 if (!foldedIntraVectorOffset) {
1382 auto zeros = arith::ConstantOp::create(
1383 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1387 }
else if (!isDivisibleInSize) {
1389 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1391 rewriter.replaceOp(op,
result);
1406struct SourceElementRange {
1408 int64_t sourceElementIdx;
1410 int64_t sourceBitBegin;
1411 int64_t sourceBitEnd;
1414struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1420 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1422 for (int64_t i = 0; i < shuffleIdx; ++i)
1423 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1442struct BitCastBitsEnumerator {
1443 BitCastBitsEnumerator(VectorType sourceVectorType,
1444 VectorType targetVectorType);
1446 int64_t getMaxNumberOfEntries() {
1447 int64_t numVectors = 0;
1448 for (
const auto &l : sourceElementRanges)
1449 numVectors = std::max(numVectors, (int64_t)l.size());
1453 VectorType sourceVectorType;
1454 VectorType targetVectorType;
1455 SmallVector<SourceElementRangeList> sourceElementRanges;
1529struct BitCastRewriter {
1532 SmallVector<int64_t> shuffles;
1533 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1536 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1539 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1540 VectorType preconditionType, Operation *op);
1543 SmallVector<BitCastRewriter::Metadata>
1544 precomputeMetadata(IntegerType shuffledElementType);
1548 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1549 Value initialValue, Value runningResult,
1550 const BitCastRewriter::Metadata &metadata);
1555 BitCastBitsEnumerator enumerator;
1562 for (
const auto &l : vec) {
1563 for (
auto it : llvm::enumerate(l)) {
1564 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1565 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1566 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1573BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1574 VectorType targetVectorType)
1575 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1577 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1578 "requires -D non-scalable vector type");
1579 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1580 "requires -D non-scalable vector type");
1581 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1582 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1583 LDBG() <<
"sourceVectorType: " << sourceVectorType;
1585 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1586 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1587 LDBG() <<
"targetVectorType: " << targetVectorType;
1589 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1590 (
void)mostMinorSourceDim;
1591 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1592 "source and target bitwidths must match");
1596 for (
int64_t resultBit = 0; resultBit < bitwidth;) {
1597 int64_t resultElement = resultBit / targetBitWidth;
1598 int64_t resultBitInElement = resultBit % targetBitWidth;
1599 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1600 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1601 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1602 targetBitWidth - resultBitInElement);
1603 sourceElementRanges[resultElement].push_back(
1604 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1609BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1610 VectorType targetVectorType)
1611 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1612 LDBG() <<
"\n" << enumerator.sourceElementRanges;
1618 VectorType preconditionType,
1620 if (!preconditionType || preconditionType.isScalable())
1625 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1626 if (bitwidth % 8 != 0)
1632LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1633 VectorType preconditionType,
1635 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1638 if (!preconditionType || preconditionType.getRank() != 1)
1676 VectorType subByteVecTy,
1680 "container element type is not a scalar");
1687 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1691 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1694 if (subByteBits != 2 && subByteBits != 4)
1696 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1699 if (containerBits % subByteBits != 0)
1703 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1705 op,
"not possible to fit this sub-byte vector type into a vector of "
1706 "the given multi-byte type");
1711SmallVector<BitCastRewriter::Metadata>
1712BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1713 SmallVector<BitCastRewriter::Metadata>
result;
1714 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1715 shuffleIdx < e; ++shuffleIdx) {
1716 SmallVector<int64_t> shuffles;
1717 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1720 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1721 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1722 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1724 shuffles.push_back(sourceElement);
1726 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1727 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1729 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1730 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1732 IntegerAttr mask = IntegerAttr::get(
1733 shuffledElementType,
1734 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1736 masks.push_back(mask);
1738 int64_t shiftRight = bitLo;
1739 shiftRightAmounts.push_back(
1740 IntegerAttr::get(shuffledElementType, shiftRight));
1742 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1743 shiftLeftAmounts.push_back(
1744 IntegerAttr::get(shuffledElementType, shiftLeft));
1747 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1752Value BitCastRewriter::genericRewriteStep(
1753 PatternRewriter &rewriter, Location loc, Value initialValue,
1754 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1756 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1757 initialValue, metadata.shuffles);
1760 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1761 auto constOp = arith::ConstantOp::create(
1764 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1767 auto shiftRightConstantOp = arith::ConstantOp::create(
1770 Value shiftedRight =
1771 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1774 auto shiftLeftConstantOp = arith::ConstantOp::create(
1778 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1782 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1785 return runningResult;
1796 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1797 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1798 assert(8 % srcBitwidth == 0 &&
1799 "Unsupported sub-byte type (not a divisor of i8)");
1800 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1803 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1804 auto i8VecType = VectorType::get(vecShape, rewriter.
getI8Type());
1805 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1826 int bitIdx,
int numBits) {
1827 auto srcType = cast<VectorType>(src.
getType());
1829 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1830 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1831 "Invalid bitIdx range");
1832 if (bitsToShiftLeft != 0) {
1833 Value shiftLeftValues = arith::ConstantOp::create(
1835 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1838 int8_t bitsToShiftRight = 8 - numBits;
1839 Value shiftRightValues = arith::ConstantOp::create(
1841 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1868 int bitIdx,
int numBits) {
1869 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1870 "Invalid bitIdx range");
1871 auto srcType = cast<VectorType>(src.
getType());
1872 int8_t bitsToShiftRight = bitIdx;
1874 if (bitsToShiftRight != 0) {
1875 Value shiftRightValues = arith::ConstantOp::create(
1877 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1879 if (bitIdx + numBits == 8) {
1882 uint8_t lowBitsMask = (1 << numBits) - 1;
1883 Value lowBitsMaskValues = arith::ConstantOp::create(
1885 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1895 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1896 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1897 "Expected i4 type");
1904 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1905 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1908 return vector::InterleaveOp::create(rewriter, loc, low, high);
1915 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1916 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1917 "Expected i2 type");
1924 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1926 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1928 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1930 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1941 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1942 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1943 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1951 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1952 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1953 "Expected i8 type");
1956 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1959 constexpr int8_t i8LowBitMask = 0x0F;
1960 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1961 Value zeroOutMask = arith::ConstantOp::create(
1963 Value zeroOutLow = arith::AndIOp::create(
1964 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1967 constexpr int8_t bitsToShift = 4;
1968 auto shiftValues = arith::ConstantOp::create(
1970 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1974 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1977 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1978 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1985struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1988 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1989 PatternRewriter &rewriter)
const override {
1992 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1997 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1998 VectorType targetVectorType = bitCastOp.getResultVectorType();
1999 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2000 if (
failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
2004 Value truncValue = truncOp.getIn();
2005 auto shuffledElementType =
2007 Value runningResult;
2008 for (
const BitCastRewriter ::Metadata &metadata :
2009 bcr.precomputeMetadata(shuffledElementType)) {
2010 runningResult = bcr.genericRewriteStep(
2011 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
2015 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
2016 shuffledElementType.getIntOrFloatBitWidth();
2018 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
2019 rewriter.
replaceOp(bitCastOp, runningResult);
2022 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2025 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
2026 rewriter.
replaceOp(bitCastOp, runningResult);
2029 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2046template <
typename ExtOpType>
2047struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2048 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2050 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2051 : OpRewritePattern<ExtOpType>(context, benefit) {}
2053 LogicalResult matchAndRewrite(ExtOpType extOp,
2054 PatternRewriter &rewriter)
const override {
2056 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2061 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2062 VectorType targetVectorType = bitCastOp.getResultVectorType();
2063 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2064 if (
failed(bcr.commonPrecondition(
2065 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2069 Value runningResult;
2070 Value sourceValue = bitCastOp.getSource();
2071 auto shuffledElementType =
2073 for (
const BitCastRewriter::Metadata &metadata :
2074 bcr.precomputeMetadata(shuffledElementType)) {
2075 runningResult = bcr.genericRewriteStep(
2076 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2081 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2082 shuffledElementType.getIntOrFloatBitWidth();
2085 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2088 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2129template <
typename ConversionOpType,
bool isSigned>
2130struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2131 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2133 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2134 PatternRewriter &rewriter)
const override {
2136 Value srcValue = conversionOp.getIn();
2137 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2138 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2146 rewriter, srcVecType,
2151 Location loc = conversionOp.getLoc();
2155 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2168 conversionOp, conversionOp.getType(), subByteExt);
2190struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2193 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2194 PatternRewriter &rewriter)
const override {
2196 Value srcValue = truncOp.getIn();
2197 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2198 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2199 if (!srcVecType || !dstVecType)
2206 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2212 rewriter, dstVecType,
2217 Location loc = truncOp.getLoc();
2218 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2220 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2226 rewriter.
replaceOp(truncOp, subByteTrunc);
2243struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2246 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2247 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2249 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2250 PatternRewriter &rewriter)
const override {
2252 constexpr unsigned minNativeBitwidth = 8;
2253 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2254 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2255 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2257 "not a sub-byte transpose");
2261 Location loc = transposeOp.getLoc();
2266 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2268 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2269 transposeOp.getVector());
2270 Value newTranspose = vector::TransposeOp::create(
2271 rewriter, loc, extOp, transposeOp.getPermutation());
2272 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2286void vector::populateVectorNarrowTypeEmulationPatterns(
2287 const arith::NarrowTypeEmulationConverter &typeConverter,
2288 RewritePatternSet &
patterns,
bool disableAtomicRMW,
bool assumeAligned) {
2291 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2292 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2293 typeConverter,
patterns.getContext());
2298 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW,
2302void vector::populateVectorNarrowTypeRewritePatterns(
2303 RewritePatternSet &
patterns, PatternBenefit benefit) {
2305 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2306 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2312 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2313 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2314 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2318 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2319 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
2324void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2325 RewritePatternSet &
patterns, PatternBenefit benefit) {
2329void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2330 arith::NarrowTypeEmulationConverter &typeConverter,
2333 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter,
patterns);
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.