34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
82 int numSrcElemsPerDest,
83 int numFrontPadElems = 0) {
85 assert(numFrontPadElems < numSrcElemsPerDest &&
86 "numFrontPadElems must be less than numSrcElemsPerDest");
89 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
97 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
100 maskOp = extractOp.getVector().getDefiningOp();
101 extractOps.push_back(extractOp);
105 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
113 maskShape.back() = numDestElems;
115 std::optional<Operation *> newMask =
117 .Case<vector::CreateMaskOp>(
118 [&](
auto createMaskOp) -> std::optional<Operation *> {
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
131 rewriter, loc, s0, origIndex);
133 newMaskOperands.push_back(
135 return rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
138 .Case<vector::ConstantMaskOp>(
139 [&](
auto constantMaskOp) -> std::optional<Operation *> {
142 constantMaskOp.getMaskDimSizes());
143 int64_t &maskIndex = maskDimSizes.back();
146 return rewriter.
create<vector::ConstantMaskOp>(loc, newMaskType,
149 .Case<arith::ConstantOp>([&](
auto constantOp)
150 -> std::optional<Operation *> {
152 if (maskShape.size() != 1)
169 cast<DenseIntElementsAttr>(constantOp.getValue());
171 paddedMaskValues.append(originalMask.template value_begin<bool>(),
172 originalMask.template value_end<bool>());
173 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
177 for (
size_t i = 0; i < paddedMaskValues.size();
178 i += numSrcElemsPerDest) {
179 bool combinedValue =
false;
180 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
181 combinedValue |= paddedMaskValues[i +
j];
183 compressedMaskValues.push_back(combinedValue);
185 return rewriter.
create<arith::ConstantOp>(
192 while (!extractOps.empty()) {
193 newMask = rewriter.
create<vector::ExtractOp>(
194 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
195 extractOps.pop_back();
215 Value src, int64_t offset,
216 int64_t numElemsToExtract) {
217 auto vectorType = cast<VectorType>(src.
getType());
218 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
219 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220 "subvector out of bounds");
224 if (vectorType.getNumElements() == numElemsToExtract)
231 auto resultVectorType =
234 .
create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
235 offsets, sizes, strides)
250 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
251 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
261 return rewriter.
create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
262 dest, offsets, strides);
287 int64_t numElemsToExtract) {
288 auto srcVecTy = cast<VectorType>(src.
getType());
289 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
298 if (srcVecTy.getNumElements() == numElemsToExtract)
301 for (
int i = 0; i < numElemsToExtract; ++i) {
303 (i == 0) ? dyn_cast<Value>(offset)
304 : rewriter.
create<arith::AddIOp>(
306 rewriter.
create<arith::ConstantIndexOp>(loc, i));
307 auto extractOp = rewriter.
create<vector::ExtractOp>(loc, src, extractLoc);
308 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, i);
330 int64_t numElemsToInsert) {
331 auto srcVecTy = cast<VectorType>(src.
getType());
332 auto destVecTy = cast<VectorType>(dest.
getType());
333 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334 "expected source and dest to be rank-1 vector types");
337 assert(numElemsToInsert > 0 &&
338 "the number of elements to insert must be greater than 0");
342 assert(numElemsToInsert <= destVecTy.getNumElements() &&
343 "subvector out of bounds");
346 for (int64_t i = 0; i < numElemsToInsert; ++i) {
347 auto insertLoc = i == 0
349 : rewriter.
create<arith::AddIOp>(
351 rewriter.
create<arith::ConstantIndexOp>(loc, i));
352 auto extractOp = rewriter.
create<vector::ExtractOp>(loc, src, i);
353 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
367 int64_t numContainerElemsToLoad,
369 Type containerElemTy) {
372 auto newLoad = rewriter.
create<vector::LoadOp>(
375 return rewriter.
create<vector::BitCastOp>(
385 VectorType downcastType,
386 VectorType upcastType,
Value mask,
389 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391 "expected input and output number of bits to match");
392 if (trueValue.
getType() != downcastType) {
393 trueValue = builder.
create<vector::BitCastOp>(loc, downcastType, trueValue);
395 if (falseValue.
getType() != downcastType) {
397 builder.
create<vector::BitCastOp>(loc, downcastType, falseValue);
400 builder.
create<arith::SelectOp>(loc, mask, trueValue, falseValue);
402 return builder.
create<vector::BitCastOp>(loc, upcastType, selectedType);
421 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
425 auto atomicOp = builder.
create<memref::GenericAtomicRMWOp>(
427 Value origValue = atomicOp.getCurrentValue();
435 Value origVecValue = builder.
create<vector::FromElementsOp>(
441 oneElemVecType, mask, valueToStore, origVecValue);
442 auto scalarMaskedValue =
443 builder.
create<vector::ExtractOp>(loc, maskedValue, 0);
444 builder.
create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
452 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
454 auto oneElemVecType =
456 Value origVecValue = builder.
create<vector::LoadOp>(
457 loc, oneElemVecType, linearizedMemref,
ValueRange{linearizedIndex});
458 origVecValue = builder.
create<vector::BitCastOp>(loc, valueToStore.getType(),
463 oneElemVecType, mask, valueToStore, origVecValue);
464 builder.
create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
479 int64_t extractOffset,
480 int64_t sliceNumElements,
481 int64_t insertOffset) {
482 assert(vector.getType().getRank() == 1 &&
"expected 1-D vector");
483 auto vectorElementType = vector.getType().getElementType();
487 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
488 "sliceNumElements * vector element size must be less than or equal to 8");
489 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
490 "vector element must be a valid sub-byte type");
491 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
492 auto emptyByteVector = rewriter.
create<arith::ConstantOp>(
497 extractOffset, sliceNumElements);
550 ConvertVectorStore(
MLIRContext *context,
bool disableAtomicRMW)
552 disableAtomicRMW(disableAtomicRMW) {}
555 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
559 if (op.getValueToStore().getType().getRank() != 1)
561 "only 1-D vectors are supported ATM");
563 auto loc = op.getLoc();
565 auto valueToStore = cast<VectorValue>(op.getValueToStore());
566 auto containerElemTy =
567 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
568 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
570 int containerBits = containerElemTy.getIntOrFloatBitWidth();
573 if (containerBits % emulatedBits != 0) {
575 op,
"impossible to pack emulated elements into container elements "
576 "(bit-wise misalignment)");
578 int emulatedPerContainerElem = containerBits / emulatedBits;
593 auto origElements = valueToStore.getType().getNumElements();
595 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
597 auto stridedMetadata =
598 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
602 std::tie(linearizedInfo, linearizedIndices) =
604 rewriter, loc, emulatedBits, containerBits,
605 stridedMetadata.getConstifiedMixedOffset(),
606 stridedMetadata.getConstifiedMixedSizes(),
607 stridedMetadata.getConstifiedMixedStrides(),
610 std::optional<int64_t> foldedNumFrontPadElems =
611 isDivisibleInSize ? 0
614 if (!foldedNumFrontPadElems) {
616 op,
"subbyte store emulation: dynamic front padding size is "
617 "not yet implemented");
620 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
629 bool emulationRequiresPartialStores =
630 !isDivisibleInSize || *foldedNumFrontPadElems != 0;
631 if (!emulationRequiresPartialStores) {
633 auto numElements = origElements / emulatedPerContainerElem;
634 auto bitCast = rewriter.
create<vector::BitCastOp>(
636 op.getValueToStore());
638 op, bitCast.getResult(), memrefBase,
674 Value currentDestIndex =
677 auto currentSourceIndex = 0;
680 auto subWidthStoreMaskType =
689 auto frontSubWidthStoreElem =
690 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
691 emulatedPerContainerElem;
692 if (frontSubWidthStoreElem > 0) {
694 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
695 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
697 frontSubWidthStoreElem = origElements;
699 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
700 *foldedNumFrontPadElems,
true);
702 auto frontMask = rewriter.
create<arith::ConstantOp>(
705 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
708 frontSubWidthStoreElem, *foldedNumFrontPadElems);
710 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
711 cast<VectorValue>(value), frontMask.getResult());
714 if (currentSourceIndex >= origElements) {
722 currentDestIndex = rewriter.
create<arith::AddIOp>(
728 int64_t fullWidthStoreSize =
729 (origElements - currentSourceIndex) / emulatedPerContainerElem;
730 int64_t numNonFullWidthElements =
731 fullWidthStoreSize * emulatedPerContainerElem;
732 if (fullWidthStoreSize > 0) {
734 rewriter, loc, valueToStore, currentSourceIndex,
735 numNonFullWidthElements);
737 auto originType = cast<VectorType>(fullWidthStorePart.getType());
740 {originType.getNumElements() / emulatedPerContainerElem},
742 auto bitCast = rewriter.
create<vector::BitCastOp>(loc, storeType,
744 rewriter.
create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
747 currentSourceIndex += numNonFullWidthElements;
748 currentDestIndex = rewriter.
create<arith::AddIOp>(
750 rewriter.
create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
756 auto remainingElements = origElements - currentSourceIndex;
757 if (remainingElements != 0) {
758 auto subWidthStorePart =
760 currentSourceIndex, remainingElements, 0);
764 std::fill_n(maskValues.begin(), remainingElements, 1);
765 auto backMask = rewriter.
create<arith::ConstantOp>(
768 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
769 cast<VectorValue>(subWidthStorePart), backMask.getResult());
777 const bool disableAtomicRMW;
785 struct ConvertVectorMaskedStore final
790 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
794 if (op.getValueToStore().getType().getRank() != 1)
796 "only 1-D vectors are supported ATM");
798 auto loc = op.getLoc();
799 auto containerElemTy =
800 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
801 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
803 int containerBits = containerElemTy.getIntOrFloatBitWidth();
806 if (containerBits % emulatedBits != 0) {
808 op,
"impossible to pack emulated elements into container elements "
809 "(bit-wise misalignment)");
812 int emulatedPerContainerElem = containerBits / emulatedBits;
813 int origElements = op.getValueToStore().getType().getNumElements();
814 if (origElements % emulatedPerContainerElem != 0)
817 auto stridedMetadata =
818 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
821 std::tie(linearizedInfo, linearizedIndicesOfr) =
823 rewriter, loc, emulatedBits, containerBits,
824 stridedMetadata.getConstifiedMixedOffset(),
825 stridedMetadata.getConstifiedMixedSizes(),
826 stridedMetadata.getConstifiedMixedStrides(),
828 Value linearizedIndices =
864 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
868 auto numElements = (origElements + emulatedPerContainerElem - 1) /
869 emulatedPerContainerElem;
871 auto passThru = rewriter.
create<arith::ConstantOp>(
874 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
875 loc, newType, adaptor.getBase(), linearizedIndices,
876 newMask.value()->getResult(0), passThru);
878 auto newBitCastType =
879 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
881 rewriter.
create<vector::BitCastOp>(loc, newBitCastType, newLoad);
882 valueToStore = rewriter.
create<arith::SelectOp>(
883 loc, op.getMask(), op.getValueToStore(), valueToStore);
885 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
888 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
903 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
907 if (op.getVectorType().getRank() != 1)
909 "only 1-D vectors are supported ATM");
911 auto loc = op.getLoc();
912 auto containerElemTy =
913 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
914 Type emulatedElemTy = op.getType().getElementType();
916 int containerBits = containerElemTy.getIntOrFloatBitWidth();
919 if (containerBits % emulatedBits != 0) {
921 op,
"impossible to pack emulated elements into container elements "
922 "(bit-wise misalignment)");
924 int emulatedPerContainerElem = containerBits / emulatedBits;
955 auto origElements = op.getVectorType().getNumElements();
957 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
959 auto stridedMetadata =
960 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
964 std::tie(linearizedInfo, linearizedIndices) =
966 rewriter, loc, emulatedBits, containerBits,
967 stridedMetadata.getConstifiedMixedOffset(),
968 stridedMetadata.getConstifiedMixedSizes(),
969 stridedMetadata.getConstifiedMixedStrides(),
972 std::optional<int64_t> foldedIntraVectorOffset =
973 isDivisibleInSize ? 0
977 int64_t maxintraDataOffset =
978 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
980 emulatedPerContainerElem);
983 numElements, emulatedElemTy, containerElemTy);
985 if (!foldedIntraVectorOffset) {
986 auto resultVector = rewriter.
create<arith::ConstantOp>(
987 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
991 }
else if (!isDivisibleInSize) {
993 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1005 struct ConvertVectorMaskedLoad final
1010 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1013 if (op.getVectorType().getRank() != 1)
1015 "only 1-D vectors are supported ATM");
1017 auto loc = op.getLoc();
1019 auto containerElemTy =
1020 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1021 Type emulatedElemTy = op.getType().getElementType();
1023 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1026 if (containerBits % emulatedBits != 0) {
1028 op,
"impossible to pack emulated elements into container elements "
1029 "(bit-wise misalignment)");
1031 int emulatedPerContainerElem = containerBits / emulatedBits;
1075 auto origType = op.getVectorType();
1076 auto origElements = origType.getNumElements();
1078 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1080 auto stridedMetadata =
1081 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
1084 std::tie(linearizedInfo, linearizedIndices) =
1086 rewriter, loc, emulatedBits, containerBits,
1087 stridedMetadata.getConstifiedMixedOffset(),
1088 stridedMetadata.getConstifiedMixedSizes(),
1089 stridedMetadata.getConstifiedMixedStrides(),
1092 std::optional<int64_t> foldedIntraVectorOffset =
1093 isDivisibleInSize ? 0
1096 int64_t maxIntraDataOffset =
1097 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1098 FailureOr<Operation *> newMask =
1100 emulatedPerContainerElem, maxIntraDataOffset);
1101 if (failed(newMask))
1104 Value passthru = op.getPassThru();
1107 emulatedPerContainerElem);
1109 auto newBitcastType =
1110 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1112 auto emptyVector = rewriter.
create<arith::ConstantOp>(
1113 loc, newBitcastType, rewriter.
getZeroAttr(newBitcastType));
1114 if (!foldedIntraVectorOffset) {
1118 }
else if (!isDivisibleInSize) {
1120 *foldedIntraVectorOffset);
1123 rewriter.
create<vector::BitCastOp>(loc, loadType, passthru);
1126 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
1127 loc, loadType, adaptor.getBase(),
1129 newMask.value()->getResult(0), newPassThru);
1134 rewriter.
create<vector::BitCastOp>(loc, newBitcastType, newLoad);
1136 Value mask = op.getMask();
1138 numElements * emulatedPerContainerElem, rewriter.
getI1Type());
1140 auto emptyMask = rewriter.
create<arith::ConstantOp>(
1141 loc, newSelectMaskType, rewriter.
getZeroAttr(newSelectMaskType));
1142 if (!foldedIntraVectorOffset) {
1146 }
else if (!isDivisibleInSize) {
1148 *foldedIntraVectorOffset);
1152 rewriter.
create<arith::SelectOp>(loc, mask, bitCast, passthru);
1153 if (!foldedIntraVectorOffset) {
1155 rewriter, loc, result, op.getPassThru(),
1157 }
else if (!isDivisibleInSize) {
1159 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1182 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1183 Type multiByteScalarTy) {
1184 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1186 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1189 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1190 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1191 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1193 int elemsPerMultiByte = multiByteBits / subByteBits;
1196 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1204 struct ConvertVectorTransferRead final
1209 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1213 if (op.getVectorType().getRank() != 1)
1215 "only 1-D vectors are supported ATM");
1217 auto loc = op.getLoc();
1218 auto containerElemTy =
1219 cast<MemRefType>(adaptor.getSource().getType()).getElementType();
1220 Type emulatedElemTy = op.getType().getElementType();
1222 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1225 if (containerBits % emulatedBits != 0) {
1227 op,
"impossible to pack emulated elements into container elements "
1228 "(bit-wise misalignment)");
1230 int emulatedPerContainerElem = containerBits / emulatedBits;
1232 auto origElements = op.getVectorType().getNumElements();
1235 bool isDivisibleInSize =
1236 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1238 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, containerElemTy,
1239 adaptor.getPadding());
1241 auto stridedMetadata =
1242 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
1246 std::tie(linearizedInfo, linearizedIndices) =
1248 rewriter, loc, emulatedBits, containerBits,
1249 stridedMetadata.getConstifiedMixedOffset(),
1250 stridedMetadata.getConstifiedMixedSizes(),
1251 stridedMetadata.getConstifiedMixedStrides(),
1254 std::optional<int64_t> foldedIntraVectorOffset =
1255 isDivisibleInSize ? 0
1258 int64_t maxIntraDataOffset =
1259 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1261 emulatedPerContainerElem);
1263 auto newRead = rewriter.
create<vector::TransferReadOp>(
1264 loc,
VectorType::get(numElements, containerElemTy), adaptor.getSource(),
1268 auto bitCast = rewriter.
create<vector::BitCastOp>(
1270 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1273 Value result = bitCast->getResult(0);
1274 if (!foldedIntraVectorOffset) {
1275 auto zeros = rewriter.
create<arith::ConstantOp>(
1276 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1280 }
else if (!isDivisibleInSize) {
1282 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1299 struct SourceElementRange {
1301 int64_t sourceElementIdx;
1303 int64_t sourceBitBegin;
1304 int64_t sourceBitEnd;
1307 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1313 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1315 for (int64_t i = 0; i < shuffleIdx; ++i)
1316 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1335 struct BitCastBitsEnumerator {
1336 BitCastBitsEnumerator(VectorType sourceVectorType,
1337 VectorType targetVectorType);
1339 int64_t getMaxNumberOfEntries() {
1340 int64_t numVectors = 0;
1341 for (
const auto &l : sourceElementRanges)
1342 numVectors =
std::max(numVectors, (int64_t)l.size());
1346 VectorType sourceVectorType;
1347 VectorType targetVectorType;
1422 struct BitCastRewriter {
1429 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1433 VectorType preconditionType,
Operation *op);
1437 precomputeMetadata(IntegerType shuffledElementType);
1443 const BitCastRewriter::Metadata &metadata);
1448 BitCastBitsEnumerator enumerator;
1453 [[maybe_unused]]
static raw_ostream &
1455 for (
const auto &l : vec) {
1457 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1458 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1459 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1466 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1467 VectorType targetVectorType)
1468 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1470 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1471 "requires -D non-scalable vector type");
1472 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1473 "requires -D non-scalable vector type");
1474 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1475 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1476 LDBG(
"sourceVectorType: " << sourceVectorType);
1478 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1479 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1480 LDBG(
"targetVectorType: " << targetVectorType);
1482 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1483 (void)mostMinorSourceDim;
1484 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1485 "source and target bitwidths must match");
1489 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1490 int64_t resultElement = resultBit / targetBitWidth;
1491 int64_t resultBitInElement = resultBit % targetBitWidth;
1492 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1493 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1494 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
1495 targetBitWidth - resultBitInElement);
1496 sourceElementRanges[resultElement].push_back(
1497 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1502 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1503 VectorType targetVectorType)
1504 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1505 LDBG(
"\n" << enumerator.sourceElementRanges);
1511 VectorType preconditionType,
1513 if (!preconditionType || preconditionType.isScalable())
1518 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1519 if (bitwidth % 8 != 0)
1525 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1526 VectorType preconditionType,
1528 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1531 if (!preconditionType || preconditionType.getRank() != 1)
1569 VectorType subByteVecTy,
1573 "container element type is not a scalar");
1580 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1584 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1587 if (subByteBits != 2 && subByteBits != 4)
1589 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1592 if (containerBits % subByteBits != 0)
1596 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1598 op,
"not possible to fit this sub-byte vector type into a vector of "
1599 "the given multi-byte type");
1605 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1607 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1608 shuffleIdx < e; ++shuffleIdx) {
1613 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1614 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1615 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1617 shuffles.push_back(sourceElement);
1619 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1620 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1622 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1623 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1626 shuffledElementType,
1627 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1629 masks.push_back(mask);
1631 int64_t shiftRight = bitLo;
1632 shiftRightAmounts.push_back(
1635 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1636 shiftLeftAmounts.push_back(
1640 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1645 Value BitCastRewriter::genericRewriteStep(
1647 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1649 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
1650 loc, initialValue, initialValue, metadata.shuffles);
1653 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1654 auto constOp = rewriter.
create<arith::ConstantOp>(
1656 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
1659 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
1662 Value shiftedRight =
1663 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1666 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
1670 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1674 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1677 return runningResult;
1688 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1689 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1690 assert(8 % srcBitwidth == 0 &&
1691 "Unsupported sub-byte type (not a divisor of i8)");
1692 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1695 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1697 return rewriter.
create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1718 int bitIdx,
int numBits) {
1719 auto srcType = cast<VectorType>(src.
getType());
1721 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1722 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1723 "Invalid bitIdx range");
1724 if (bitsToShiftLeft != 0) {
1725 Value shiftLeftValues = rewriter.
create<arith::ConstantOp>(
1727 shl = rewriter.
create<arith::ShLIOp>(loc, src, shiftLeftValues);
1730 int8_t bitsToShiftRight = 8 - numBits;
1731 Value shiftRightValues = rewriter.
create<arith::ConstantOp>(
1733 Value shr = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1760 int bitIdx,
int numBits) {
1761 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1762 "Invalid bitIdx range");
1763 auto srcType = cast<VectorType>(src.
getType());
1764 int8_t bitsToShiftRight = bitIdx;
1766 if (bitsToShiftRight != 0) {
1767 Value shiftRightValues = rewriter.
create<arith::ConstantOp>(
1769 shr = rewriter.
create<arith::ShRUIOp>(loc, src, shiftRightValues);
1771 if (bitIdx + numBits == 8) {
1774 uint8_t lowBitsMask = (1 << numBits) - 1;
1775 Value lowBitsMaskValues = rewriter.
create<arith::ConstantOp>(
1777 return rewriter.
create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1787 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1788 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1789 "Expected i4 type");
1796 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1797 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1800 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
1807 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1808 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1809 "Expected i2 type");
1816 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1818 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1820 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1822 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1833 Value interleave02 = rewriter.
create<vector::InterleaveOp>(loc, vec0, vec2);
1834 Value interleave13 = rewriter.
create<vector::InterleaveOp>(loc, vec1, vec3);
1835 return rewriter.
create<vector::InterleaveOp>(loc, interleave02, interleave13);
1842 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1843 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1844 "Expected i8 type");
1847 auto deinterleaveOp = rewriter.
create<vector::DeinterleaveOp>(loc, srcValue);
1850 constexpr int8_t i8LowBitMask = 0x0F;
1851 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1852 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
1854 Value zeroOutLow = rewriter.
create<arith::AndIOp>(
1855 loc, deinterleaveOp.getRes1(), zeroOutMask);
1858 constexpr int8_t bitsToShift = 4;
1859 auto shiftValues = rewriter.
create<arith::ConstantOp>(
1861 Value shlHigh = rewriter.
create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1865 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1868 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1869 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1883 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1888 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1889 VectorType targetVectorType = bitCastOp.getResultVectorType();
1890 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1891 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1895 Value truncValue = truncOp.getIn();
1896 auto shuffledElementType =
1898 Value runningResult;
1899 for (
const BitCastRewriter ::Metadata &metadata :
1900 bcr.precomputeMetadata(shuffledElementType)) {
1901 runningResult = bcr.genericRewriteStep(
1902 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1906 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1907 shuffledElementType.getIntOrFloatBitWidth();
1909 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1910 rewriter.
replaceOp(bitCastOp, runningResult);
1913 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1916 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1917 rewriter.
replaceOp(bitCastOp, runningResult);
1920 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1937 template <
typename ExtOpType>
1947 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1952 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1953 VectorType targetVectorType = bitCastOp.getResultVectorType();
1954 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1955 if (failed(bcr.commonPrecondition(
1956 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1960 Value runningResult;
1961 Value sourceValue = bitCastOp.getSource();
1962 auto shuffledElementType =
1964 for (
const BitCastRewriter::Metadata &metadata :
1965 bcr.precomputeMetadata(shuffledElementType)) {
1966 runningResult = bcr.genericRewriteStep(
1967 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1972 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1973 shuffledElementType.getIntOrFloatBitWidth();
1976 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1979 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2020 template <
typename ConversionOpType,
bool isSigned>
2027 Value srcValue = conversionOp.getIn();
2028 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2029 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2037 rewriter, srcVecType,
2042 Location loc = conversionOp.getLoc();
2046 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2059 conversionOp, conversionOp.getType(), subByteExt);
2087 Value srcValue = truncOp.getIn();
2088 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2089 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2090 if (!srcVecType || !dstVecType)
2097 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2103 rewriter, dstVecType,
2109 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2111 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
2117 rewriter.
replaceOp(truncOp, subByteTrunc);
2143 constexpr
unsigned minNativeBitwidth = 8;
2144 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2145 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2146 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2148 "not a sub-byte transpose");
2152 Location loc = transposeOp.getLoc();
2157 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2159 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
2160 transposeOp.getVector());
2161 Value newTranspose = rewriter.
create<vector::TransposeOp>(
2162 loc, extOp, transposeOp.getPermutation());
2163 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2183 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2184 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2185 typeConverter,
patterns.getContext());
2190 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW);
2196 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2197 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2203 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2204 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2205 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2209 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2210 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.