34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
45#define DEBUG_TYPE "vector-narrow-type-emulation"
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.
getI1Type());
114 std::optional<Operation *> newMask =
117 [&](vector::CreateMaskOp createMaskOp)
118 -> std::optional<Operation *> {
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
131 rewriter, loc, s0, origIndex);
133 newMaskOperands.push_back(
135 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
138 .Case([&](vector::ConstantMaskOp constantMaskOp)
139 -> std::optional<Operation *> {
142 int64_t &maskIndex = maskDimSizes.back();
143 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
145 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
148 .Case([&](arith::ConstantOp constantOp)
149 -> std::optional<Operation *> {
151 if (maskShape.size() != 1)
168 cast<DenseIntElementsAttr>(constantOp.getValue());
170 paddedMaskValues.append(originalMask.template value_begin<bool>(),
171 originalMask.template value_end<bool>());
172 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
176 for (
size_t i = 0; i < paddedMaskValues.size();
177 i += numSrcElemsPerDest) {
178 bool combinedValue =
false;
179 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
180 combinedValue |= paddedMaskValues[i +
j];
182 compressedMaskValues.push_back(combinedValue);
184 return arith::ConstantOp::create(
192 while (!extractOps.empty()) {
194 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
195 extractOps.back().getMixedPosition());
196 extractOps.pop_back();
218 auto vectorType = cast<VectorType>(src.
getType());
219 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
220 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
221 "subvector out of bounds");
225 if (vectorType.getNumElements() == numElemsToExtract)
232 auto resultVectorType =
233 VectorType::get({numElemsToExtract}, vectorType.getElementType());
234 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
235 src, offsets, sizes, strides)
250 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
251 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
261 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
262 dest, offsets, strides);
288 auto srcVecTy = cast<VectorType>(src.
getType());
289 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
298 if (srcVecTy.getNumElements() == numElemsToExtract)
301 for (
int i = 0; i < numElemsToExtract; ++i) {
303 (i == 0) ? dyn_cast<Value>(offset)
304 : arith::AddIOp::create(
306 dyn_cast<Value>(offset),
308 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
309 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
332 auto srcVecTy = cast<VectorType>(src.
getType());
333 auto destVecTy = cast<VectorType>(dest.
getType());
334 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
335 "expected source and dest to be rank-1 vector types");
338 assert(numElemsToInsert > 0 &&
339 "the number of elements to insert must be greater than 0");
343 assert(numElemsToInsert <= destVecTy.getNumElements() &&
344 "subvector out of bounds");
347 for (
int64_t i = 0; i < numElemsToInsert; ++i) {
349 i == 0 ? destOffsetVal
350 : arith::AddIOp::create(
353 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
354 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
368 int64_t numContainerElemsToLoad,
370 Type containerElemTy) {
373 auto newLoad = vector::LoadOp::create(
374 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
376 return vector::BitCastOp::create(
378 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
386 VectorType downcastType,
387 VectorType upcastType,
Value mask,
390 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
391 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
392 "expected input and output number of bits to match");
393 if (trueValue.
getType() != downcastType) {
395 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
397 if (falseValue.
getType() != downcastType) {
399 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
402 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
404 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
423 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
427 auto atomicOp = memref::GenericAtomicRMWOp::create(
428 builder, loc, linearizedMemref,
ValueRange{storeIdx});
429 Value origValue = atomicOp.getCurrentValue();
436 auto oneElemVecType = VectorType::get({1}, origValue.getType());
437 Value origVecValue = vector::FromElementsOp::create(
438 builder, loc, oneElemVecType,
ValueRange{origValue});
443 oneElemVecType, mask, valueToStore, origVecValue);
444 auto scalarMaskedValue =
445 vector::ExtractOp::create(builder, loc, maskedValue, 0);
446 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
454 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
456 auto oneElemVecType =
457 VectorType::get({1}, linearizedMemref.getType().
getElementType());
459 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
461 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
466 oneElemVecType, mask, valueToStore, origVecValue);
467 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
485 assert(
vector.getType().getRank() == 1 &&
"expected 1-D vector");
486 auto vectorElementType =
vector.getType().getElementType();
490 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
491 "sliceNumElements * vector element size must be less than or equal to 8");
492 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
493 "vector element must be a valid sub-byte type");
494 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
495 auto emptyByteVector = arith::ConstantOp::create(
497 VectorType::get({emulatedPerContainerElem}, vectorElementType),
498 rewriter.getZeroAttr(
499 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
501 extractOffset, sliceNumElements);
551struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
554 ConvertVectorStore(MLIRContext *context,
bool disableAtomicRMW)
555 : OpConversionPattern<vector::StoreOp>(context),
556 disableAtomicRMW(disableAtomicRMW) {}
559 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter)
const override {
562 if (op.getValueToStore().getType().getRank() != 1)
563 return rewriter.notifyMatchFailure(op,
564 "only 1-D vectors are supported ATM");
566 auto loc = op.getLoc();
568 auto valueToStore = cast<VectorValue>(op.getValueToStore());
569 auto containerElemTy =
570 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
571 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
573 int containerBits = containerElemTy.getIntOrFloatBitWidth();
576 if (containerBits % emulatedBits != 0) {
577 return rewriter.notifyMatchFailure(
578 op,
"impossible to pack emulated elements into container elements "
579 "(bit-wise misalignment)");
581 int emulatedPerContainerElem = containerBits / emulatedBits;
596 auto origElements = valueToStore.getType().getNumElements();
598 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
603 auto trailingDim = op.getBase().getType().getShape().back();
604 bool trailingDimsMatch =
605 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
607 auto stridedMetadata =
608 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
612 OpFoldResult linearizedIndices;
613 memref::LinearizedMemRefInfo linearizedInfo;
614 std::tie(linearizedInfo, linearizedIndices) =
616 rewriter, loc, emulatedBits, containerBits,
617 stridedMetadata.getConstifiedMixedOffset(),
618 stridedMetadata.getConstifiedMixedSizes(),
619 stridedMetadata.getConstifiedMixedStrides(),
622 std::optional<int64_t> foldedNumFrontPadElems =
623 (isDivisibleInSize && trailingDimsMatch)
627 if (!foldedNumFrontPadElems) {
628 return rewriter.notifyMatchFailure(
629 op,
"subbyte store emulation: dynamic front padding size is "
630 "not yet implemented");
633 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
665 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
667 if (!emulationRequiresPartialStores) {
669 auto numElements = origElements / emulatedPerContainerElem;
670 auto bitCast = vector::BitCastOp::create(
671 rewriter, loc, VectorType::get(numElements, containerElemTy),
672 op.getValueToStore());
673 rewriter.replaceOpWithNewOp<vector::StoreOp>(
674 op, bitCast.getResult(), memrefBase,
710 Value currentDestIndex =
713 auto currentSourceIndex = 0;
716 auto subWidthStoreMaskType =
717 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
725 auto frontSubWidthStoreElem =
726 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
727 emulatedPerContainerElem;
728 if (frontSubWidthStoreElem > 0) {
729 SmallVector<bool> frontMaskValues(emulatedPerContainerElem,
false);
730 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
731 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
733 frontSubWidthStoreElem = origElements;
735 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
736 *foldedNumFrontPadElems,
true);
738 auto frontMask = arith::ConstantOp::create(
742 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
745 frontSubWidthStoreElem, *foldedNumFrontPadElems);
747 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
748 cast<VectorValue>(value), frontMask.getResult());
751 if (currentSourceIndex >= origElements) {
752 rewriter.eraseOp(op);
759 currentDestIndex = arith::AddIOp::create(
760 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
765 int64_t fullWidthStoreSize =
766 (origElements - currentSourceIndex) / emulatedPerContainerElem;
767 int64_t numNonFullWidthElements =
768 fullWidthStoreSize * emulatedPerContainerElem;
769 if (fullWidthStoreSize > 0) {
771 rewriter, loc, valueToStore, currentSourceIndex,
772 numNonFullWidthElements);
774 auto originType = cast<VectorType>(fullWidthStorePart.getType());
776 auto storeType = VectorType::get(
777 {originType.getNumElements() / emulatedPerContainerElem},
779 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
781 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
784 currentSourceIndex += numNonFullWidthElements;
785 currentDestIndex = arith::AddIOp::create(
786 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
793 auto remainingElements = origElements - currentSourceIndex;
794 if (remainingElements != 0) {
795 auto subWidthStorePart =
797 currentSourceIndex, remainingElements, 0);
800 auto maskValues = SmallVector<bool>(emulatedPerContainerElem,
false);
801 std::fill_n(maskValues.begin(), remainingElements, 1);
802 auto backMask = arith::ConstantOp::create(
806 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
807 cast<VectorValue>(subWidthStorePart), backMask.getResult());
810 rewriter.eraseOp(op);
815 const bool disableAtomicRMW;
829struct ConvertVectorMaskedStore final
830 : OpConversionPattern<vector::MaskedStoreOp> {
834 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
835 ConversionPatternRewriter &rewriter)
const override {
838 if (op.getValueToStore().getType().getRank() != 1)
839 return rewriter.notifyMatchFailure(
840 op,
"Memref in vector.maskedstore op must be flattened beforehand.");
842 auto loc = op.getLoc();
843 auto containerElemTy =
844 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
845 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
847 int containerBits = containerElemTy.getIntOrFloatBitWidth();
850 if (containerBits % emulatedBits != 0) {
851 return rewriter.notifyMatchFailure(
852 op,
"impossible to pack emulated elements into container elements "
853 "(bit-wise misalignment)");
856 int emulatedPerContainerElem = containerBits / emulatedBits;
857 int origElements = op.getValueToStore().getType().getNumElements();
858 if (origElements % emulatedPerContainerElem != 0)
861 auto stridedMetadata =
862 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
863 OpFoldResult linearizedIndicesOfr;
864 memref::LinearizedMemRefInfo linearizedInfo;
865 std::tie(linearizedInfo, linearizedIndicesOfr) =
867 rewriter, loc, emulatedBits, containerBits,
868 stridedMetadata.getConstifiedMixedOffset(),
869 stridedMetadata.getConstifiedMixedSizes(),
870 stridedMetadata.getConstifiedMixedStrides(),
872 Value linearizedIndices =
908 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
912 auto numElements = (origElements + emulatedPerContainerElem - 1) /
913 emulatedPerContainerElem;
914 auto newType = VectorType::get(numElements, containerElemTy);
915 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
916 rewriter.getZeroAttr(newType));
918 auto newLoad = vector::MaskedLoadOp::create(
919 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
920 newMask.value()->getResult(0), passThru);
922 auto newBitCastType =
923 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
925 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
926 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
927 op.getValueToStore(), valueToStore);
929 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
931 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
932 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
953struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
957 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
958 ConversionPatternRewriter &rewriter)
const override {
960 if (op.getVectorType().getRank() != 1)
961 return rewriter.notifyMatchFailure(
962 op,
"Memref in emulated vector ops must be flattened beforehand.");
964 auto loc = op.getLoc();
965 auto containerElemTy =
966 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
967 Type emulatedElemTy = op.getType().getElementType();
969 int containerBits = containerElemTy.getIntOrFloatBitWidth();
972 if (containerBits % emulatedBits != 0) {
973 return rewriter.notifyMatchFailure(
974 op,
"impossible to pack emulated elements into container elements "
975 "(bit-wise misalignment)");
977 int emulatedPerContainerElem = containerBits / emulatedBits;
1006 auto origElements = op.getVectorType().getNumElements();
1008 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1010 auto stridedMetadata =
1011 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1013 OpFoldResult linearizedIndices;
1014 memref::LinearizedMemRefInfo linearizedInfo;
1015 std::tie(linearizedInfo, linearizedIndices) =
1017 rewriter, loc, emulatedBits, containerBits,
1018 stridedMetadata.getConstifiedMixedOffset(),
1019 stridedMetadata.getConstifiedMixedSizes(),
1020 stridedMetadata.getConstifiedMixedStrides(),
1023 std::optional<int64_t> foldedIntraVectorOffset =
1024 isDivisibleInSize ? 0
1028 int64_t maxintraDataOffset =
1029 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1030 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1031 emulatedPerContainerElem);
1034 numElements, emulatedElemTy, containerElemTy);
1036 if (!foldedIntraVectorOffset) {
1037 auto resultVector = arith::ConstantOp::create(
1038 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1042 }
else if (!isDivisibleInSize) {
1044 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1046 rewriter.replaceOp(op,
result);
1061struct ConvertVectorMaskedLoad final
1062 : OpConversionPattern<vector::MaskedLoadOp> {
1066 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1067 ConversionPatternRewriter &rewriter)
const override {
1068 if (op.getVectorType().getRank() != 1)
1069 return rewriter.notifyMatchFailure(
1070 op,
"Memref in emulated vector ops must be flattened beforehand.");
1072 auto loc = op.getLoc();
1074 auto containerElemTy =
1075 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1076 Type emulatedElemTy = op.getType().getElementType();
1078 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1081 if (containerBits % emulatedBits != 0) {
1082 return rewriter.notifyMatchFailure(
1083 op,
"impossible to pack emulated elements into container elements "
1084 "(bit-wise misalignment)");
1086 int emulatedPerContainerElem = containerBits / emulatedBits;
1130 auto origType = op.getVectorType();
1131 auto origElements = origType.getNumElements();
1133 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1135 auto stridedMetadata =
1136 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1137 OpFoldResult linearizedIndices;
1138 memref::LinearizedMemRefInfo linearizedInfo;
1139 std::tie(linearizedInfo, linearizedIndices) =
1141 rewriter, loc, emulatedBits, containerBits,
1142 stridedMetadata.getConstifiedMixedOffset(),
1143 stridedMetadata.getConstifiedMixedSizes(),
1144 stridedMetadata.getConstifiedMixedStrides(),
1147 std::optional<int64_t> foldedIntraVectorOffset =
1148 isDivisibleInSize ? 0
1151 int64_t maxIntraDataOffset =
1152 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1153 FailureOr<Operation *> newMask =
1155 emulatedPerContainerElem, maxIntraDataOffset);
1159 Value passthru = op.getPassThru();
1161 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1162 emulatedPerContainerElem);
1163 auto loadType = VectorType::get(numElements, containerElemTy);
1164 auto newBitcastType =
1165 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1167 auto emptyVector = arith::ConstantOp::create(
1168 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1169 if (!foldedIntraVectorOffset) {
1173 }
else if (!isDivisibleInSize) {
1175 *foldedIntraVectorOffset);
1178 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1181 auto newLoad = vector::MaskedLoadOp::create(
1182 rewriter, loc, loadType, adaptor.getBase(),
1184 newMask.value()->getResult(0), newPassThru);
1189 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1191 Value mask = op.getMask();
1192 auto newSelectMaskType = VectorType::get(
1193 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1196 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1197 rewriter.getZeroAttr(newSelectMaskType));
1198 if (!foldedIntraVectorOffset) {
1202 }
else if (!isDivisibleInSize) {
1204 *foldedIntraVectorOffset);
1208 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1209 if (!foldedIntraVectorOffset) {
1211 rewriter, loc,
result, op.getPassThru(),
1213 }
else if (!isDivisibleInSize) {
1215 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1217 rewriter.replaceOp(op,
result);
1238static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1239 Type multiByteScalarTy) {
1240 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1242 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1245 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1246 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1247 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1249 int elemsPerMultiByte = multiByteBits / subByteBits;
1251 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1259struct ConvertVectorTransferRead final
1260 : OpConversionPattern<vector::TransferReadOp> {
1264 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter)
const override {
1269 if (op.getVectorType().getRank() != 1)
1270 return rewriter.notifyMatchFailure(
1271 op,
"Memref in emulated vector ops must be flattened beforehand.");
1273 auto loc = op.getLoc();
1274 auto containerElemTy =
1275 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1276 Type emulatedElemTy = op.getType().getElementType();
1278 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1281 if (containerBits % emulatedBits != 0) {
1282 return rewriter.notifyMatchFailure(
1283 op,
"impossible to pack emulated elements into container elements "
1284 "(bit-wise misalignment)");
1286 int emulatedPerContainerElem = containerBits / emulatedBits;
1288 auto origElements = op.getVectorType().getNumElements();
1291 bool isDivisibleInSize =
1292 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1296 Value padding = adaptor.getPadding();
1298 padding = arith::BitcastOp::create(
1300 IntegerType::get(rewriter.getContext(),
1305 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1307 auto stridedMetadata =
1308 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1310 OpFoldResult linearizedIndices;
1311 memref::LinearizedMemRefInfo linearizedInfo;
1312 std::tie(linearizedInfo, linearizedIndices) =
1314 rewriter, loc, emulatedBits, containerBits,
1315 stridedMetadata.getConstifiedMixedOffset(),
1316 stridedMetadata.getConstifiedMixedSizes(),
1317 stridedMetadata.getConstifiedMixedStrides(),
1320 std::optional<int64_t> foldedIntraVectorOffset =
1321 isDivisibleInSize ? 0
1324 int64_t maxIntraDataOffset =
1325 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1326 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1327 emulatedPerContainerElem);
1329 auto newRead = vector::TransferReadOp::create(
1330 rewriter, loc, VectorType::get(numElements, containerElemTy),
1335 auto bitCast = vector::BitCastOp::create(
1337 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1340 Value
result = bitCast->getResult(0);
1341 if (!foldedIntraVectorOffset) {
1342 auto zeros = arith::ConstantOp::create(
1343 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1347 }
else if (!isDivisibleInSize) {
1349 rewriter, loc,
result, *foldedIntraVectorOffset, origElements);
1351 rewriter.replaceOp(op,
result);
1366struct SourceElementRange {
1368 int64_t sourceElementIdx;
1370 int64_t sourceBitBegin;
1371 int64_t sourceBitEnd;
1374struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1380 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1382 for (int64_t i = 0; i < shuffleIdx; ++i)
1383 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1402struct BitCastBitsEnumerator {
1403 BitCastBitsEnumerator(VectorType sourceVectorType,
1404 VectorType targetVectorType);
1406 int64_t getMaxNumberOfEntries() {
1407 int64_t numVectors = 0;
1408 for (
const auto &l : sourceElementRanges)
1409 numVectors = std::max(numVectors, (int64_t)l.size());
1413 VectorType sourceVectorType;
1414 VectorType targetVectorType;
1415 SmallVector<SourceElementRangeList> sourceElementRanges;
1489struct BitCastRewriter {
1492 SmallVector<int64_t> shuffles;
1493 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1496 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1499 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1500 VectorType preconditionType, Operation *op);
1503 SmallVector<BitCastRewriter::Metadata>
1504 precomputeMetadata(IntegerType shuffledElementType);
1508 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1509 Value initialValue, Value runningResult,
1510 const BitCastRewriter::Metadata &metadata);
1515 BitCastBitsEnumerator enumerator;
1522 for (
const auto &l : vec) {
1523 for (
auto it : llvm::enumerate(l)) {
1524 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1525 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1526 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1533BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1534 VectorType targetVectorType)
1535 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1537 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1538 "requires -D non-scalable vector type");
1539 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1540 "requires -D non-scalable vector type");
1541 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1542 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1543 LDBG() <<
"sourceVectorType: " << sourceVectorType;
1545 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1546 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1547 LDBG() <<
"targetVectorType: " << targetVectorType;
1549 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1550 (
void)mostMinorSourceDim;
1551 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1552 "source and target bitwidths must match");
1556 for (
int64_t resultBit = 0; resultBit < bitwidth;) {
1557 int64_t resultElement = resultBit / targetBitWidth;
1558 int64_t resultBitInElement = resultBit % targetBitWidth;
1559 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1560 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1561 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1562 targetBitWidth - resultBitInElement);
1563 sourceElementRanges[resultElement].push_back(
1564 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1569BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1570 VectorType targetVectorType)
1571 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1572 LDBG() <<
"\n" << enumerator.sourceElementRanges;
1578 VectorType preconditionType,
1580 if (!preconditionType || preconditionType.isScalable())
1585 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1586 if (bitwidth % 8 != 0)
1592LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1593 VectorType preconditionType,
1595 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1598 if (!preconditionType || preconditionType.getRank() != 1)
1636 VectorType subByteVecTy,
1640 "container element type is not a scalar");
1647 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1651 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1654 if (subByteBits != 2 && subByteBits != 4)
1656 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1659 if (containerBits % subByteBits != 0)
1663 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1665 op,
"not possible to fit this sub-byte vector type into a vector of "
1666 "the given multi-byte type");
1671SmallVector<BitCastRewriter::Metadata>
1672BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1673 SmallVector<BitCastRewriter::Metadata>
result;
1674 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1675 shuffleIdx < e; ++shuffleIdx) {
1676 SmallVector<int64_t> shuffles;
1677 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1680 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1681 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1682 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1684 shuffles.push_back(sourceElement);
1686 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1687 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1689 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1690 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1692 IntegerAttr mask = IntegerAttr::get(
1693 shuffledElementType,
1694 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1696 masks.push_back(mask);
1698 int64_t shiftRight = bitLo;
1699 shiftRightAmounts.push_back(
1700 IntegerAttr::get(shuffledElementType, shiftRight));
1702 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1703 shiftLeftAmounts.push_back(
1704 IntegerAttr::get(shuffledElementType, shiftLeft));
1707 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1712Value BitCastRewriter::genericRewriteStep(
1713 PatternRewriter &rewriter, Location loc, Value initialValue,
1714 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1716 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1717 initialValue, metadata.shuffles);
1720 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1721 auto constOp = arith::ConstantOp::create(
1724 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1727 auto shiftRightConstantOp = arith::ConstantOp::create(
1730 Value shiftedRight =
1731 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1734 auto shiftLeftConstantOp = arith::ConstantOp::create(
1738 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1742 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1745 return runningResult;
1756 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1757 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1758 assert(8 % srcBitwidth == 0 &&
1759 "Unsupported sub-byte type (not a divisor of i8)");
1760 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1763 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1764 auto i8VecType = VectorType::get(vecShape, rewriter.
getI8Type());
1765 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1786 int bitIdx,
int numBits) {
1787 auto srcType = cast<VectorType>(src.
getType());
1789 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1790 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1791 "Invalid bitIdx range");
1792 if (bitsToShiftLeft != 0) {
1793 Value shiftLeftValues = arith::ConstantOp::create(
1795 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1798 int8_t bitsToShiftRight = 8 - numBits;
1799 Value shiftRightValues = arith::ConstantOp::create(
1801 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1828 int bitIdx,
int numBits) {
1829 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1830 "Invalid bitIdx range");
1831 auto srcType = cast<VectorType>(src.
getType());
1832 int8_t bitsToShiftRight = bitIdx;
1834 if (bitsToShiftRight != 0) {
1835 Value shiftRightValues = arith::ConstantOp::create(
1837 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1839 if (bitIdx + numBits == 8) {
1842 uint8_t lowBitsMask = (1 << numBits) - 1;
1843 Value lowBitsMaskValues = arith::ConstantOp::create(
1845 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1855 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1856 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1857 "Expected i4 type");
1864 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1865 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1868 return vector::InterleaveOp::create(rewriter, loc, low, high);
1875 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1876 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1877 "Expected i2 type");
1884 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1886 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1888 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1890 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1901 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1902 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1903 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1911 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1912 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1913 "Expected i8 type");
1916 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1919 constexpr int8_t i8LowBitMask = 0x0F;
1920 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1921 Value zeroOutMask = arith::ConstantOp::create(
1923 Value zeroOutLow = arith::AndIOp::create(
1924 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1927 constexpr int8_t bitsToShift = 4;
1928 auto shiftValues = arith::ConstantOp::create(
1930 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1934 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1937 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1938 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1945struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1948 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1949 PatternRewriter &rewriter)
const override {
1952 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1957 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1958 VectorType targetVectorType = bitCastOp.getResultVectorType();
1959 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1960 if (
failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1964 Value truncValue = truncOp.getIn();
1965 auto shuffledElementType =
1967 Value runningResult;
1968 for (
const BitCastRewriter ::Metadata &metadata :
1969 bcr.precomputeMetadata(shuffledElementType)) {
1970 runningResult = bcr.genericRewriteStep(
1971 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1975 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1976 shuffledElementType.getIntOrFloatBitWidth();
1978 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1979 rewriter.
replaceOp(bitCastOp, runningResult);
1982 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1985 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1986 rewriter.
replaceOp(bitCastOp, runningResult);
1989 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2006template <
typename ExtOpType>
2007struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2008 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2010 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2011 : OpRewritePattern<ExtOpType>(context, benefit) {}
2013 LogicalResult matchAndRewrite(ExtOpType extOp,
2014 PatternRewriter &rewriter)
const override {
2016 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2021 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2022 VectorType targetVectorType = bitCastOp.getResultVectorType();
2023 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2024 if (
failed(bcr.commonPrecondition(
2025 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2029 Value runningResult;
2030 Value sourceValue = bitCastOp.getSource();
2031 auto shuffledElementType =
2033 for (
const BitCastRewriter::Metadata &metadata :
2034 bcr.precomputeMetadata(shuffledElementType)) {
2035 runningResult = bcr.genericRewriteStep(
2036 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2041 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2042 shuffledElementType.getIntOrFloatBitWidth();
2045 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2048 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2089template <
typename ConversionOpType,
bool isSigned>
2090struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2091 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2093 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2094 PatternRewriter &rewriter)
const override {
2096 Value srcValue = conversionOp.getIn();
2097 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2098 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2106 rewriter, srcVecType,
2111 Location loc = conversionOp.getLoc();
2115 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2128 conversionOp, conversionOp.getType(), subByteExt);
2150struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2153 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2154 PatternRewriter &rewriter)
const override {
2156 Value srcValue = truncOp.getIn();
2157 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2158 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2159 if (!srcVecType || !dstVecType)
2166 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2172 rewriter, dstVecType,
2177 Location loc = truncOp.getLoc();
2178 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2180 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2186 rewriter.
replaceOp(truncOp, subByteTrunc);
2203struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2206 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2207 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2209 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2210 PatternRewriter &rewriter)
const override {
2212 constexpr unsigned minNativeBitwidth = 8;
2213 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2214 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2215 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2217 "not a sub-byte transpose");
2221 Location loc = transposeOp.getLoc();
2226 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2228 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2229 transposeOp.getVector());
2230 Value newTranspose = vector::TransposeOp::create(
2231 rewriter, loc, extOp, transposeOp.getPermutation());
2232 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2246void vector::populateVectorNarrowTypeEmulationPatterns(
2247 const arith::NarrowTypeEmulationConverter &typeConverter,
2248 RewritePatternSet &
patterns,
bool disableAtomicRMW) {
2251 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2252 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2253 typeConverter,
patterns.getContext());
2258 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW);
2261void vector::populateVectorNarrowTypeRewritePatterns(
2262 RewritePatternSet &
patterns, PatternBenefit benefit) {
2264 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2265 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2271 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2272 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2273 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2277 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2278 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
2283void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2284 RewritePatternSet &
patterns, PatternBenefit benefit) {
2288void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2289 arith::NarrowTypeEmulationConverter &typeConverter,
2292 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter,
patterns);
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.