34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
82 int numSrcElemsPerDest,
83 int numFrontPadElems = 0) {
85 assert(numFrontPadElems < numSrcElemsPerDest &&
86 "numFrontPadElems must be less than numSrcElemsPerDest");
89 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
97 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
100 maskOp = extractOp.getVector().getDefiningOp();
101 extractOps.push_back(extractOp);
105 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
113 maskShape.back() = numDestElems;
115 std::optional<Operation *> newMask =
117 .Case<vector::CreateMaskOp>(
118 [&](
auto createMaskOp) -> std::optional<Operation *> {
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
131 rewriter, loc, s0, origIndex);
133 newMaskOperands.push_back(
135 return rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
138 .Case<vector::ConstantMaskOp>(
139 [&](
auto constantMaskOp) -> std::optional<Operation *> {
142 constantMaskOp.getMaskDimSizes());
143 int64_t &maskIndex = maskDimSizes.back();
146 return rewriter.
create<vector::ConstantMaskOp>(loc, newMaskType,
149 .Case<arith::ConstantOp>([&](
auto constantOp)
150 -> std::optional<Operation *> {
152 if (maskShape.size() != 1)
169 cast<DenseIntElementsAttr>(constantOp.getValue());
171 paddedMaskValues.append(originalMask.template value_begin<bool>(),
172 originalMask.template value_end<bool>());
173 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
177 for (
size_t i = 0; i < paddedMaskValues.size();
178 i += numSrcElemsPerDest) {
179 bool combinedValue =
false;
180 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
181 combinedValue |= paddedMaskValues[i +
j];
183 compressedMaskValues.push_back(combinedValue);
185 return rewriter.
create<arith::ConstantOp>(
192 while (!extractOps.empty()) {
193 newMask = rewriter.
create<vector::ExtractOp>(
194 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
195 extractOps.pop_back();
215 Value src, int64_t offset,
216 int64_t numElemsToExtract) {
217 auto vectorType = cast<VectorType>(src.
getType());
218 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
219 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220 "subvector out of bounds");
224 if (vectorType.getNumElements() == numElemsToExtract)
231 auto resultVectorType =
234 .
create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
235 offsets, sizes, strides)
250 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
251 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
261 return rewriter.
create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
262 dest, offsets, strides);
287 int64_t numElemsToExtract) {
288 auto srcVecTy = cast<VectorType>(src.
getType());
289 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
298 if (srcVecTy.getNumElements() == numElemsToExtract)
301 for (
int i = 0; i < numElemsToExtract; ++i) {
303 (i == 0) ? dyn_cast<Value>(offset)
304 : rewriter.
create<arith::AddIOp>(
306 rewriter.
create<arith::ConstantIndexOp>(loc, i));
307 auto extractOp = rewriter.
create<vector::ExtractOp>(loc, src, extractLoc);
308 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, i);
330 int64_t numElemsToInsert) {
331 auto srcVecTy = cast<VectorType>(src.
getType());
332 auto destVecTy = cast<VectorType>(dest.
getType());
333 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334 "expected source and dest to be rank-1 vector types");
337 assert(numElemsToInsert > 0 &&
338 "the number of elements to insert must be greater than 0");
342 assert(numElemsToInsert <= destVecTy.getNumElements() &&
343 "subvector out of bounds");
346 for (int64_t i = 0; i < numElemsToInsert; ++i) {
347 auto insertLoc = i == 0
349 : rewriter.
create<arith::AddIOp>(
351 rewriter.
create<arith::ConstantIndexOp>(loc, i));
352 auto extractOp = rewriter.
create<vector::ExtractOp>(loc, src, i);
353 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
367 int64_t numContainerElemsToLoad,
369 Type containerElemTy) {
372 auto newLoad = rewriter.
create<vector::LoadOp>(
375 return rewriter.
create<vector::BitCastOp>(
385 VectorType downcastType,
386 VectorType upcastType,
Value mask,
389 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391 "expected input and output number of bits to match");
392 if (trueValue.
getType() != downcastType) {
393 trueValue = builder.
create<vector::BitCastOp>(loc, downcastType, trueValue);
395 if (falseValue.
getType() != downcastType) {
397 builder.
create<vector::BitCastOp>(loc, downcastType, falseValue);
400 builder.
create<arith::SelectOp>(loc, mask, trueValue, falseValue);
402 return builder.
create<vector::BitCastOp>(loc, upcastType, selectedType);
421 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
425 auto atomicOp = builder.
create<memref::GenericAtomicRMWOp>(
427 Value origValue = atomicOp.getCurrentValue();
435 Value origVecValue = builder.
create<vector::FromElementsOp>(
441 oneElemVecType, mask, valueToStore, origVecValue);
442 auto scalarMaskedValue =
443 builder.
create<vector::ExtractOp>(loc, maskedValue, 0);
444 builder.
create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
452 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
454 auto oneElemVecType =
456 Value origVecValue = builder.
create<vector::LoadOp>(
457 loc, oneElemVecType, linearizedMemref,
ValueRange{linearizedIndex});
458 origVecValue = builder.
create<vector::BitCastOp>(loc, valueToStore.getType(),
463 oneElemVecType, mask, valueToStore, origVecValue);
464 builder.
create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
479 int64_t extractOffset,
480 int64_t sliceNumElements,
481 int64_t insertOffset) {
482 assert(vector.getType().getRank() == 1 &&
"expected 1-D vector");
483 auto vectorElementType = vector.getType().getElementType();
487 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
488 "sliceNumElements * vector element size must be less than or equal to 8");
489 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
490 "vector element must be a valid sub-byte type");
491 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
492 auto emptyByteVector = rewriter.
create<arith::ConstantOp>(
497 extractOffset, sliceNumElements);
550 ConvertVectorStore(
MLIRContext *context,
bool disableAtomicRMW)
552 disableAtomicRMW(disableAtomicRMW) {}
555 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
559 if (op.getValueToStore().getType().getRank() != 1)
561 "only 1-D vectors are supported ATM");
563 auto loc = op.getLoc();
565 auto valueToStore = cast<VectorValue>(op.getValueToStore());
566 auto containerElemTy =
567 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
568 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
570 int containerBits = containerElemTy.getIntOrFloatBitWidth();
573 if (containerBits % emulatedBits != 0) {
575 op,
"impossible to pack emulated elements into container elements "
576 "(bit-wise misalignment)");
578 int emulatedPerContainerElem = containerBits / emulatedBits;
593 auto origElements = valueToStore.getType().getNumElements();
595 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
600 auto trailingDim = op.getBase().getType().getShape().back();
601 bool trailingDimsMatch =
602 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
604 auto stridedMetadata =
605 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
611 std::tie(linearizedInfo, linearizedIndices) =
613 rewriter, loc, emulatedBits, containerBits,
614 stridedMetadata.getConstifiedMixedOffset(),
615 stridedMetadata.getConstifiedMixedSizes(),
616 stridedMetadata.getConstifiedMixedStrides(),
619 std::optional<int64_t> foldedNumFrontPadElems =
620 (isDivisibleInSize && trailingDimsMatch)
624 if (!foldedNumFrontPadElems) {
626 op,
"subbyte store emulation: dynamic front padding size is "
627 "not yet implemented");
630 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
662 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
664 if (!emulationRequiresPartialStores) {
666 auto numElements = origElements / emulatedPerContainerElem;
667 auto bitCast = rewriter.
create<vector::BitCastOp>(
669 op.getValueToStore());
671 op, bitCast.getResult(), memrefBase,
707 Value currentDestIndex =
710 auto currentSourceIndex = 0;
713 auto subWidthStoreMaskType =
722 auto frontSubWidthStoreElem =
723 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
724 emulatedPerContainerElem;
725 if (frontSubWidthStoreElem > 0) {
727 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
728 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
730 frontSubWidthStoreElem = origElements;
732 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
733 *foldedNumFrontPadElems,
true);
735 auto frontMask = rewriter.
create<arith::ConstantOp>(
738 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
741 frontSubWidthStoreElem, *foldedNumFrontPadElems);
743 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
744 cast<VectorValue>(value), frontMask.getResult());
747 if (currentSourceIndex >= origElements) {
755 currentDestIndex = rewriter.
create<arith::AddIOp>(
761 int64_t fullWidthStoreSize =
762 (origElements - currentSourceIndex) / emulatedPerContainerElem;
763 int64_t numNonFullWidthElements =
764 fullWidthStoreSize * emulatedPerContainerElem;
765 if (fullWidthStoreSize > 0) {
767 rewriter, loc, valueToStore, currentSourceIndex,
768 numNonFullWidthElements);
770 auto originType = cast<VectorType>(fullWidthStorePart.getType());
773 {originType.getNumElements() / emulatedPerContainerElem},
775 auto bitCast = rewriter.
create<vector::BitCastOp>(loc, storeType,
777 rewriter.
create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
780 currentSourceIndex += numNonFullWidthElements;
781 currentDestIndex = rewriter.
create<arith::AddIOp>(
783 rewriter.
create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
789 auto remainingElements = origElements - currentSourceIndex;
790 if (remainingElements != 0) {
791 auto subWidthStorePart =
793 currentSourceIndex, remainingElements, 0);
797 std::fill_n(maskValues.begin(), remainingElements, 1);
798 auto backMask = rewriter.
create<arith::ConstantOp>(
801 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
802 cast<VectorValue>(subWidthStorePart), backMask.getResult());
810 const bool disableAtomicRMW;
818 struct ConvertVectorMaskedStore final
823 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
827 if (op.getValueToStore().getType().getRank() != 1)
829 "only 1-D vectors are supported ATM");
831 auto loc = op.getLoc();
832 auto containerElemTy =
833 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
834 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
836 int containerBits = containerElemTy.getIntOrFloatBitWidth();
839 if (containerBits % emulatedBits != 0) {
841 op,
"impossible to pack emulated elements into container elements "
842 "(bit-wise misalignment)");
845 int emulatedPerContainerElem = containerBits / emulatedBits;
846 int origElements = op.getValueToStore().getType().getNumElements();
847 if (origElements % emulatedPerContainerElem != 0)
850 auto stridedMetadata =
851 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
854 std::tie(linearizedInfo, linearizedIndicesOfr) =
856 rewriter, loc, emulatedBits, containerBits,
857 stridedMetadata.getConstifiedMixedOffset(),
858 stridedMetadata.getConstifiedMixedSizes(),
859 stridedMetadata.getConstifiedMixedStrides(),
861 Value linearizedIndices =
897 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
901 auto numElements = (origElements + emulatedPerContainerElem - 1) /
902 emulatedPerContainerElem;
904 auto passThru = rewriter.
create<arith::ConstantOp>(
907 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
908 loc, newType, adaptor.getBase(), linearizedIndices,
909 newMask.value()->getResult(0), passThru);
911 auto newBitCastType =
912 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
914 rewriter.
create<vector::BitCastOp>(loc, newBitCastType, newLoad);
915 valueToStore = rewriter.
create<arith::SelectOp>(
916 loc, op.getMask(), op.getValueToStore(), valueToStore);
918 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
921 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
936 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
940 if (op.getVectorType().getRank() != 1)
942 "only 1-D vectors are supported ATM");
944 auto loc = op.getLoc();
945 auto containerElemTy =
946 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
947 Type emulatedElemTy = op.getType().getElementType();
949 int containerBits = containerElemTy.getIntOrFloatBitWidth();
952 if (containerBits % emulatedBits != 0) {
954 op,
"impossible to pack emulated elements into container elements "
955 "(bit-wise misalignment)");
957 int emulatedPerContainerElem = containerBits / emulatedBits;
988 auto origElements = op.getVectorType().getNumElements();
990 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
992 auto stridedMetadata =
993 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
997 std::tie(linearizedInfo, linearizedIndices) =
999 rewriter, loc, emulatedBits, containerBits,
1000 stridedMetadata.getConstifiedMixedOffset(),
1001 stridedMetadata.getConstifiedMixedSizes(),
1002 stridedMetadata.getConstifiedMixedStrides(),
1005 std::optional<int64_t> foldedIntraVectorOffset =
1006 isDivisibleInSize ? 0
1010 int64_t maxintraDataOffset =
1011 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1013 emulatedPerContainerElem);
1016 numElements, emulatedElemTy, containerElemTy);
1018 if (!foldedIntraVectorOffset) {
1019 auto resultVector = rewriter.
create<arith::ConstantOp>(
1020 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1024 }
else if (!isDivisibleInSize) {
1026 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1038 struct ConvertVectorMaskedLoad final
1043 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1046 if (op.getVectorType().getRank() != 1)
1048 "only 1-D vectors are supported ATM");
1050 auto loc = op.getLoc();
1052 auto containerElemTy =
1053 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1054 Type emulatedElemTy = op.getType().getElementType();
1056 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1059 if (containerBits % emulatedBits != 0) {
1061 op,
"impossible to pack emulated elements into container elements "
1062 "(bit-wise misalignment)");
1064 int emulatedPerContainerElem = containerBits / emulatedBits;
1108 auto origType = op.getVectorType();
1109 auto origElements = origType.getNumElements();
1111 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1113 auto stridedMetadata =
1114 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
1117 std::tie(linearizedInfo, linearizedIndices) =
1119 rewriter, loc, emulatedBits, containerBits,
1120 stridedMetadata.getConstifiedMixedOffset(),
1121 stridedMetadata.getConstifiedMixedSizes(),
1122 stridedMetadata.getConstifiedMixedStrides(),
1125 std::optional<int64_t> foldedIntraVectorOffset =
1126 isDivisibleInSize ? 0
1129 int64_t maxIntraDataOffset =
1130 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1131 FailureOr<Operation *> newMask =
1133 emulatedPerContainerElem, maxIntraDataOffset);
1134 if (failed(newMask))
1137 Value passthru = op.getPassThru();
1140 emulatedPerContainerElem);
1142 auto newBitcastType =
1143 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1145 auto emptyVector = rewriter.
create<arith::ConstantOp>(
1146 loc, newBitcastType, rewriter.
getZeroAttr(newBitcastType));
1147 if (!foldedIntraVectorOffset) {
1151 }
else if (!isDivisibleInSize) {
1153 *foldedIntraVectorOffset);
1156 rewriter.
create<vector::BitCastOp>(loc, loadType, passthru);
1159 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
1160 loc, loadType, adaptor.getBase(),
1162 newMask.value()->getResult(0), newPassThru);
1167 rewriter.
create<vector::BitCastOp>(loc, newBitcastType, newLoad);
1169 Value mask = op.getMask();
1171 numElements * emulatedPerContainerElem, rewriter.
getI1Type());
1173 auto emptyMask = rewriter.
create<arith::ConstantOp>(
1174 loc, newSelectMaskType, rewriter.
getZeroAttr(newSelectMaskType));
1175 if (!foldedIntraVectorOffset) {
1179 }
else if (!isDivisibleInSize) {
1181 *foldedIntraVectorOffset);
1185 rewriter.
create<arith::SelectOp>(loc, mask, bitCast, passthru);
1186 if (!foldedIntraVectorOffset) {
1188 rewriter, loc, result, op.getPassThru(),
1190 }
else if (!isDivisibleInSize) {
1192 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1215 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1216 Type multiByteScalarTy) {
1217 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1219 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1222 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1223 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1224 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1226 int elemsPerMultiByte = multiByteBits / subByteBits;
1229 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1237 struct ConvertVectorTransferRead final
1242 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1246 if (op.getVectorType().getRank() != 1)
1248 "only 1-D vectors are supported ATM");
1250 auto loc = op.getLoc();
1251 auto containerElemTy =
1252 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1253 Type emulatedElemTy = op.getType().getElementType();
1255 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1258 if (containerBits % emulatedBits != 0) {
1260 op,
"impossible to pack emulated elements into container elements "
1261 "(bit-wise misalignment)");
1263 int emulatedPerContainerElem = containerBits / emulatedBits;
1265 auto origElements = op.getVectorType().getNumElements();
1268 bool isDivisibleInSize =
1269 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1271 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, containerElemTy,
1272 adaptor.getPadding());
1274 auto stridedMetadata =
1275 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
1279 std::tie(linearizedInfo, linearizedIndices) =
1281 rewriter, loc, emulatedBits, containerBits,
1282 stridedMetadata.getConstifiedMixedOffset(),
1283 stridedMetadata.getConstifiedMixedSizes(),
1284 stridedMetadata.getConstifiedMixedStrides(),
1287 std::optional<int64_t> foldedIntraVectorOffset =
1288 isDivisibleInSize ? 0
1291 int64_t maxIntraDataOffset =
1292 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1294 emulatedPerContainerElem);
1296 auto newRead = rewriter.
create<vector::TransferReadOp>(
1297 loc,
VectorType::get(numElements, containerElemTy), adaptor.getBase(),
1301 auto bitCast = rewriter.
create<vector::BitCastOp>(
1303 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1306 Value result = bitCast->getResult(0);
1307 if (!foldedIntraVectorOffset) {
1308 auto zeros = rewriter.
create<arith::ConstantOp>(
1309 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1313 }
else if (!isDivisibleInSize) {
1315 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1332 struct SourceElementRange {
1334 int64_t sourceElementIdx;
1336 int64_t sourceBitBegin;
1337 int64_t sourceBitEnd;
1340 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1346 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1348 for (int64_t i = 0; i < shuffleIdx; ++i)
1349 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1368 struct BitCastBitsEnumerator {
1369 BitCastBitsEnumerator(VectorType sourceVectorType,
1370 VectorType targetVectorType);
1372 int64_t getMaxNumberOfEntries() {
1373 int64_t numVectors = 0;
1374 for (
const auto &l : sourceElementRanges)
1375 numVectors =
std::max(numVectors, (int64_t)l.size());
1379 VectorType sourceVectorType;
1380 VectorType targetVectorType;
1455 struct BitCastRewriter {
1462 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1466 VectorType preconditionType,
Operation *op);
1470 precomputeMetadata(IntegerType shuffledElementType);
1476 const BitCastRewriter::Metadata &metadata);
1481 BitCastBitsEnumerator enumerator;
1486 [[maybe_unused]]
static raw_ostream &
1488 for (
const auto &l : vec) {
1490 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1491 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1492 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1499 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1500 VectorType targetVectorType)
1501 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1503 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1504 "requires -D non-scalable vector type");
1505 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1506 "requires -D non-scalable vector type");
1507 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1508 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1509 LDBG(
"sourceVectorType: " << sourceVectorType);
1511 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1512 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1513 LDBG(
"targetVectorType: " << targetVectorType);
1515 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1516 (void)mostMinorSourceDim;
1517 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1518 "source and target bitwidths must match");
1522 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1523 int64_t resultElement = resultBit / targetBitWidth;
1524 int64_t resultBitInElement = resultBit % targetBitWidth;
1525 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1526 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1527 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
1528 targetBitWidth - resultBitInElement);
1529 sourceElementRanges[resultElement].push_back(
1530 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1535 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1536 VectorType targetVectorType)
1537 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1538 LDBG(
"\n" << enumerator.sourceElementRanges);
1544 VectorType preconditionType,
1546 if (!preconditionType || preconditionType.isScalable())
1551 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1552 if (bitwidth % 8 != 0)
1558 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1559 VectorType preconditionType,
1561 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1564 if (!preconditionType || preconditionType.getRank() != 1)
1602 VectorType subByteVecTy,
1606 "container element type is not a scalar");
1613 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1617 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1620 if (subByteBits != 2 && subByteBits != 4)
1622 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1625 if (containerBits % subByteBits != 0)
1629 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1631 op,
"not possible to fit this sub-byte vector type into a vector of "
1632 "the given multi-byte type");
1638 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1640 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1641 shuffleIdx < e; ++shuffleIdx) {
1646 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1647 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1648 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1650 shuffles.push_back(sourceElement);
1652 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1653 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1655 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1656 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1659 shuffledElementType,
1660 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1662 masks.push_back(mask);
1664 int64_t shiftRight = bitLo;
1665 shiftRightAmounts.push_back(
1668 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1669 shiftLeftAmounts.push_back(
1673 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1678 Value BitCastRewriter::genericRewriteStep(
1680 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1682 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
1683 loc, initialValue, initialValue, metadata.shuffles);
1686 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1687 auto constOp = rewriter.
create<arith::ConstantOp>(
1689 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
1692 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
1695 Value shiftedRight =
1696 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1699 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
1703 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1707 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1710 return runningResult;
1721 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1722 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1723 assert(8 % srcBitwidth == 0 &&
1724 "Unsupported sub-byte type (not a divisor of i8)");
1725 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1728 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1730 return rewriter.
create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1751 int bitIdx,
int numBits) {
1752 auto srcType = cast<VectorType>(src.
getType());
1754 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1755 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1756 "Invalid bitIdx range");
1757 if (bitsToShiftLeft != 0) {
1758 Value shiftLeftValues = rewriter.
create<arith::ConstantOp>(
1760 shl = rewriter.
create<arith::ShLIOp>(loc, src, shiftLeftValues);
1763 int8_t bitsToShiftRight = 8 - numBits;
1764 Value shiftRightValues = rewriter.
create<arith::ConstantOp>(
1766 Value shr = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1793 int bitIdx,
int numBits) {
1794 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1795 "Invalid bitIdx range");
1796 auto srcType = cast<VectorType>(src.
getType());
1797 int8_t bitsToShiftRight = bitIdx;
1799 if (bitsToShiftRight != 0) {
1800 Value shiftRightValues = rewriter.
create<arith::ConstantOp>(
1802 shr = rewriter.
create<arith::ShRUIOp>(loc, src, shiftRightValues);
1804 if (bitIdx + numBits == 8) {
1807 uint8_t lowBitsMask = (1 << numBits) - 1;
1808 Value lowBitsMaskValues = rewriter.
create<arith::ConstantOp>(
1810 return rewriter.
create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1820 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1821 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1822 "Expected i4 type");
1829 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1830 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1833 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
1840 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1841 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1842 "Expected i2 type");
1849 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1851 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1853 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1855 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1866 Value interleave02 = rewriter.
create<vector::InterleaveOp>(loc, vec0, vec2);
1867 Value interleave13 = rewriter.
create<vector::InterleaveOp>(loc, vec1, vec3);
1868 return rewriter.
create<vector::InterleaveOp>(loc, interleave02, interleave13);
1875 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1876 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1877 "Expected i8 type");
1880 auto deinterleaveOp = rewriter.
create<vector::DeinterleaveOp>(loc, srcValue);
1883 constexpr int8_t i8LowBitMask = 0x0F;
1884 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1885 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
1887 Value zeroOutLow = rewriter.
create<arith::AndIOp>(
1888 loc, deinterleaveOp.getRes1(), zeroOutMask);
1891 constexpr int8_t bitsToShift = 4;
1892 auto shiftValues = rewriter.
create<arith::ConstantOp>(
1894 Value shlHigh = rewriter.
create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1898 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1901 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1902 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1916 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1921 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1922 VectorType targetVectorType = bitCastOp.getResultVectorType();
1923 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1924 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1928 Value truncValue = truncOp.getIn();
1929 auto shuffledElementType =
1931 Value runningResult;
1932 for (
const BitCastRewriter ::Metadata &metadata :
1933 bcr.precomputeMetadata(shuffledElementType)) {
1934 runningResult = bcr.genericRewriteStep(
1935 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1939 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1940 shuffledElementType.getIntOrFloatBitWidth();
1942 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1943 rewriter.
replaceOp(bitCastOp, runningResult);
1946 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1949 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1950 rewriter.
replaceOp(bitCastOp, runningResult);
1953 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1970 template <
typename ExtOpType>
1980 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1985 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1986 VectorType targetVectorType = bitCastOp.getResultVectorType();
1987 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1988 if (failed(bcr.commonPrecondition(
1989 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1993 Value runningResult;
1994 Value sourceValue = bitCastOp.getSource();
1995 auto shuffledElementType =
1997 for (
const BitCastRewriter::Metadata &metadata :
1998 bcr.precomputeMetadata(shuffledElementType)) {
1999 runningResult = bcr.genericRewriteStep(
2000 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2005 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2006 shuffledElementType.getIntOrFloatBitWidth();
2009 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2012 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2053 template <
typename ConversionOpType,
bool isSigned>
2060 Value srcValue = conversionOp.getIn();
2061 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2062 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2070 rewriter, srcVecType,
2075 Location loc = conversionOp.getLoc();
2079 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2092 conversionOp, conversionOp.getType(), subByteExt);
2120 Value srcValue = truncOp.getIn();
2121 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2122 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2123 if (!srcVecType || !dstVecType)
2130 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2136 rewriter, dstVecType,
2142 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2144 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
2150 rewriter.
replaceOp(truncOp, subByteTrunc);
2176 constexpr
unsigned minNativeBitwidth = 8;
2177 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2178 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2179 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2181 "not a sub-byte transpose");
2185 Location loc = transposeOp.getLoc();
2190 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2192 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
2193 transposeOp.getVector());
2194 Value newTranspose = rewriter.
create<vector::TransposeOp>(
2195 loc, extOp, transposeOp.getPermutation());
2196 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2210 void vector::populateVectorNarrowTypeEmulationPatterns(
2211 const arith::NarrowTypeEmulationConverter &typeConverter,
2216 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2217 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2218 typeConverter,
patterns.getContext());
2223 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW);
2226 void vector::populateVectorNarrowTypeRewritePatterns(
2229 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2230 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2236 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2237 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2238 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2242 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2243 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
2248 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.