34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/DebugLog.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
45 #define DEBUG_TYPE "vector-narrow-type-emulation"
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
112 maskShape.back() = numDestElems;
114 std::optional<Operation *> newMask =
116 .Case<vector::CreateMaskOp>(
117 [&](
auto createMaskOp) -> std::optional<Operation *> {
127 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
130 rewriter, loc, s0, origIndex);
132 newMaskOperands.push_back(
134 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
137 .Case<vector::ConstantMaskOp>([&](
auto constantMaskOp)
138 -> std::optional<Operation *> {
141 int64_t &maskIndex = maskDimSizes.back();
144 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
147 .Case<arith::ConstantOp>([&](
auto constantOp)
148 -> std::optional<Operation *> {
150 if (maskShape.size() != 1)
167 cast<DenseIntElementsAttr>(constantOp.getValue());
169 paddedMaskValues.append(originalMask.template value_begin<bool>(),
170 originalMask.template value_end<bool>());
171 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
175 for (
size_t i = 0; i < paddedMaskValues.size();
176 i += numSrcElemsPerDest) {
177 bool combinedValue =
false;
178 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
179 combinedValue |= paddedMaskValues[i +
j];
181 compressedMaskValues.push_back(combinedValue);
183 return arith::ConstantOp::create(
191 while (!extractOps.empty()) {
193 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
194 extractOps.back().getMixedPosition());
195 extractOps.pop_back();
215 Value src, int64_t offset,
216 int64_t numElemsToExtract) {
217 auto vectorType = cast<VectorType>(src.
getType());
218 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
219 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220 "subvector out of bounds");
224 if (vectorType.getNumElements() == numElemsToExtract)
231 auto resultVectorType =
233 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
234 src, offsets, sizes, strides)
249 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
250 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
251 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
252 "expected source and dest to be rank-1 vector types");
255 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
260 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
261 dest, offsets, strides);
286 int64_t numElemsToExtract) {
287 auto srcVecTy = cast<VectorType>(src.
getType());
288 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
292 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
293 "subvector out of bounds");
297 if (srcVecTy.getNumElements() == numElemsToExtract)
300 for (
int i = 0; i < numElemsToExtract; ++i) {
302 (i == 0) ? dyn_cast<Value>(offset)
303 : arith::AddIOp::create(
305 dyn_cast<Value>(offset),
307 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
308 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
330 int64_t numElemsToInsert) {
331 auto srcVecTy = cast<VectorType>(src.
getType());
332 auto destVecTy = cast<VectorType>(dest.
getType());
333 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334 "expected source and dest to be rank-1 vector types");
337 assert(numElemsToInsert > 0 &&
338 "the number of elements to insert must be greater than 0");
342 assert(numElemsToInsert <= destVecTy.getNumElements() &&
343 "subvector out of bounds");
346 for (int64_t i = 0; i < numElemsToInsert; ++i) {
348 i == 0 ? destOffsetVal
349 : arith::AddIOp::create(
352 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
353 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
367 int64_t numContainerElemsToLoad,
369 Type containerElemTy) {
372 auto newLoad = vector::LoadOp::create(
373 rewriter, loc,
VectorType::get(numContainerElemsToLoad, containerElemTy),
375 return vector::BitCastOp::create(
385 VectorType downcastType,
386 VectorType upcastType,
Value mask,
389 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391 "expected input and output number of bits to match");
392 if (trueValue.
getType() != downcastType) {
394 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
396 if (falseValue.
getType() != downcastType) {
398 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
401 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
403 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
422 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
426 auto atomicOp = memref::GenericAtomicRMWOp::create(
427 builder, loc, linearizedMemref,
ValueRange{storeIdx});
428 Value origValue = atomicOp.getCurrentValue();
436 Value origVecValue = vector::FromElementsOp::create(
437 builder, loc, oneElemVecType,
ValueRange{origValue});
442 oneElemVecType, mask, valueToStore, origVecValue);
443 auto scalarMaskedValue =
444 vector::ExtractOp::create(builder, loc, maskedValue, 0);
445 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
453 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
455 auto oneElemVecType =
458 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
460 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
465 oneElemVecType, mask, valueToStore, origVecValue);
466 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
481 int64_t extractOffset,
482 int64_t sliceNumElements,
483 int64_t insertOffset) {
484 assert(vector.getType().getRank() == 1 &&
"expected 1-D vector");
485 auto vectorElementType = vector.getType().getElementType();
489 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
490 "sliceNumElements * vector element size must be less than or equal to 8");
491 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
492 "vector element must be a valid sub-byte type");
493 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
494 auto emptyByteVector = arith::ConstantOp::create(
500 extractOffset, sliceNumElements);
553 ConvertVectorStore(
MLIRContext *context,
bool disableAtomicRMW)
555 disableAtomicRMW(disableAtomicRMW) {}
558 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
561 if (op.getValueToStore().getType().getRank() != 1)
563 "only 1-D vectors are supported ATM");
565 auto loc = op.getLoc();
567 auto valueToStore = cast<VectorValue>(op.getValueToStore());
568 auto containerElemTy =
569 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
570 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
572 int containerBits = containerElemTy.getIntOrFloatBitWidth();
575 if (containerBits % emulatedBits != 0) {
577 op,
"impossible to pack emulated elements into container elements "
578 "(bit-wise misalignment)");
580 int emulatedPerContainerElem = containerBits / emulatedBits;
595 auto origElements = valueToStore.getType().getNumElements();
597 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
602 auto trailingDim = op.getBase().getType().getShape().back();
603 bool trailingDimsMatch =
604 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
606 auto stridedMetadata =
607 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
613 std::tie(linearizedInfo, linearizedIndices) =
615 rewriter, loc, emulatedBits, containerBits,
616 stridedMetadata.getConstifiedMixedOffset(),
617 stridedMetadata.getConstifiedMixedSizes(),
618 stridedMetadata.getConstifiedMixedStrides(),
621 std::optional<int64_t> foldedNumFrontPadElems =
622 (isDivisibleInSize && trailingDimsMatch)
626 if (!foldedNumFrontPadElems) {
628 op,
"subbyte store emulation: dynamic front padding size is "
629 "not yet implemented");
632 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
664 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
666 if (!emulationRequiresPartialStores) {
668 auto numElements = origElements / emulatedPerContainerElem;
669 auto bitCast = vector::BitCastOp::create(
671 op.getValueToStore());
673 op, bitCast.getResult(), memrefBase,
709 Value currentDestIndex =
712 auto currentSourceIndex = 0;
715 auto subWidthStoreMaskType =
724 auto frontSubWidthStoreElem =
725 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
726 emulatedPerContainerElem;
727 if (frontSubWidthStoreElem > 0) {
729 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
730 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
732 frontSubWidthStoreElem = origElements;
734 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
735 *foldedNumFrontPadElems,
true);
737 auto frontMask = arith::ConstantOp::create(
741 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
744 frontSubWidthStoreElem, *foldedNumFrontPadElems);
746 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
747 cast<VectorValue>(value), frontMask.getResult());
750 if (currentSourceIndex >= origElements) {
758 currentDestIndex = arith::AddIOp::create(
764 int64_t fullWidthStoreSize =
765 (origElements - currentSourceIndex) / emulatedPerContainerElem;
766 int64_t numNonFullWidthElements =
767 fullWidthStoreSize * emulatedPerContainerElem;
768 if (fullWidthStoreSize > 0) {
770 rewriter, loc, valueToStore, currentSourceIndex,
771 numNonFullWidthElements);
773 auto originType = cast<VectorType>(fullWidthStorePart.getType());
776 {originType.getNumElements() / emulatedPerContainerElem},
778 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
780 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
783 currentSourceIndex += numNonFullWidthElements;
784 currentDestIndex = arith::AddIOp::create(
785 rewriter, loc, rewriter.
getIndexType(), currentDestIndex,
792 auto remainingElements = origElements - currentSourceIndex;
793 if (remainingElements != 0) {
794 auto subWidthStorePart =
796 currentSourceIndex, remainingElements, 0);
800 std::fill_n(maskValues.begin(), remainingElements, 1);
801 auto backMask = arith::ConstantOp::create(
805 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
806 cast<VectorValue>(subWidthStorePart), backMask.getResult());
814 const bool disableAtomicRMW;
828 struct ConvertVectorMaskedStore final
833 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
837 if (op.getValueToStore().getType().getRank() != 1)
839 op,
"Memref in vector.maskedstore op must be flattened beforehand.");
841 auto loc = op.getLoc();
842 auto containerElemTy =
843 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
844 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
846 int containerBits = containerElemTy.getIntOrFloatBitWidth();
849 if (containerBits % emulatedBits != 0) {
851 op,
"impossible to pack emulated elements into container elements "
852 "(bit-wise misalignment)");
855 int emulatedPerContainerElem = containerBits / emulatedBits;
856 int origElements = op.getValueToStore().getType().getNumElements();
857 if (origElements % emulatedPerContainerElem != 0)
860 auto stridedMetadata =
861 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
864 std::tie(linearizedInfo, linearizedIndicesOfr) =
866 rewriter, loc, emulatedBits, containerBits,
867 stridedMetadata.getConstifiedMixedOffset(),
868 stridedMetadata.getConstifiedMixedSizes(),
869 stridedMetadata.getConstifiedMixedStrides(),
871 Value linearizedIndices =
907 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
911 auto numElements = (origElements + emulatedPerContainerElem - 1) /
912 emulatedPerContainerElem;
914 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
917 auto newLoad = vector::MaskedLoadOp::create(
918 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
919 newMask.value()->getResult(0), passThru);
921 auto newBitCastType =
922 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
924 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
925 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
926 op.getValueToStore(), valueToStore);
928 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
931 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
956 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
959 if (op.getVectorType().getRank() != 1)
961 op,
"Memref in emulated vector ops must be flattened beforehand.");
963 auto loc = op.getLoc();
964 auto containerElemTy =
965 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
966 Type emulatedElemTy = op.getType().getElementType();
968 int containerBits = containerElemTy.getIntOrFloatBitWidth();
971 if (containerBits % emulatedBits != 0) {
973 op,
"impossible to pack emulated elements into container elements "
974 "(bit-wise misalignment)");
976 int emulatedPerContainerElem = containerBits / emulatedBits;
1005 auto origElements = op.getVectorType().getNumElements();
1007 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1009 auto stridedMetadata =
1010 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1014 std::tie(linearizedInfo, linearizedIndices) =
1016 rewriter, loc, emulatedBits, containerBits,
1017 stridedMetadata.getConstifiedMixedOffset(),
1018 stridedMetadata.getConstifiedMixedSizes(),
1019 stridedMetadata.getConstifiedMixedStrides(),
1022 std::optional<int64_t> foldedIntraVectorOffset =
1023 isDivisibleInSize ? 0
1027 int64_t maxintraDataOffset =
1028 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1030 emulatedPerContainerElem);
1033 numElements, emulatedElemTy, containerElemTy);
1035 if (!foldedIntraVectorOffset) {
1036 auto resultVector = arith::ConstantOp::create(
1037 rewriter, loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1041 }
else if (!isDivisibleInSize) {
1043 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1060 struct ConvertVectorMaskedLoad final
1065 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1067 if (op.getVectorType().getRank() != 1)
1069 op,
"Memref in emulated vector ops must be flattened beforehand.");
1071 auto loc = op.getLoc();
1073 auto containerElemTy =
1074 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1075 Type emulatedElemTy = op.getType().getElementType();
1077 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1080 if (containerBits % emulatedBits != 0) {
1082 op,
"impossible to pack emulated elements into container elements "
1083 "(bit-wise misalignment)");
1085 int emulatedPerContainerElem = containerBits / emulatedBits;
1129 auto origType = op.getVectorType();
1130 auto origElements = origType.getNumElements();
1132 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1134 auto stridedMetadata =
1135 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1138 std::tie(linearizedInfo, linearizedIndices) =
1140 rewriter, loc, emulatedBits, containerBits,
1141 stridedMetadata.getConstifiedMixedOffset(),
1142 stridedMetadata.getConstifiedMixedSizes(),
1143 stridedMetadata.getConstifiedMixedStrides(),
1146 std::optional<int64_t> foldedIntraVectorOffset =
1147 isDivisibleInSize ? 0
1150 int64_t maxIntraDataOffset =
1151 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1152 FailureOr<Operation *> newMask =
1154 emulatedPerContainerElem, maxIntraDataOffset);
1158 Value passthru = op.getPassThru();
1161 emulatedPerContainerElem);
1163 auto newBitcastType =
1164 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1166 auto emptyVector = arith::ConstantOp::create(
1167 rewriter, loc, newBitcastType, rewriter.
getZeroAttr(newBitcastType));
1168 if (!foldedIntraVectorOffset) {
1172 }
else if (!isDivisibleInSize) {
1174 *foldedIntraVectorOffset);
1177 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1180 auto newLoad = vector::MaskedLoadOp::create(
1181 rewriter, loc, loadType, adaptor.getBase(),
1183 newMask.value()->getResult(0), newPassThru);
1188 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1190 Value mask = op.getMask();
1192 numElements * emulatedPerContainerElem, rewriter.
getI1Type());
1195 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1197 if (!foldedIntraVectorOffset) {
1201 }
else if (!isDivisibleInSize) {
1203 *foldedIntraVectorOffset);
1207 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1208 if (!foldedIntraVectorOffset) {
1210 rewriter, loc, result, op.getPassThru(),
1212 }
else if (!isDivisibleInSize) {
1214 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1237 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1238 Type multiByteScalarTy) {
1239 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1241 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1244 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1245 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1246 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1248 int elemsPerMultiByte = multiByteBits / subByteBits;
1250 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1258 struct ConvertVectorTransferRead final
1263 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1268 if (op.getVectorType().getRank() != 1)
1270 op,
"Memref in emulated vector ops must be flattened beforehand.");
1272 auto loc = op.getLoc();
1273 auto containerElemTy =
1274 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1275 Type emulatedElemTy = op.getType().getElementType();
1277 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1280 if (containerBits % emulatedBits != 0) {
1282 op,
"impossible to pack emulated elements into container elements "
1283 "(bit-wise misalignment)");
1285 int emulatedPerContainerElem = containerBits / emulatedBits;
1287 auto origElements = op.getVectorType().getNumElements();
1290 bool isDivisibleInSize =
1291 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1295 Value padding = adaptor.getPadding();
1297 padding = arith::BitcastOp::create(
1304 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1306 auto stridedMetadata =
1307 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1311 std::tie(linearizedInfo, linearizedIndices) =
1313 rewriter, loc, emulatedBits, containerBits,
1314 stridedMetadata.getConstifiedMixedOffset(),
1315 stridedMetadata.getConstifiedMixedSizes(),
1316 stridedMetadata.getConstifiedMixedStrides(),
1319 std::optional<int64_t> foldedIntraVectorOffset =
1320 isDivisibleInSize ? 0
1323 int64_t maxIntraDataOffset =
1324 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1326 emulatedPerContainerElem);
1328 auto newRead = vector::TransferReadOp::create(
1334 auto bitCast = vector::BitCastOp::create(
1336 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1339 Value result = bitCast->getResult(0);
1340 if (!foldedIntraVectorOffset) {
1341 auto zeros = arith::ConstantOp::create(
1342 rewriter, loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1346 }
else if (!isDivisibleInSize) {
1348 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1365 struct SourceElementRange {
1367 int64_t sourceElementIdx;
1369 int64_t sourceBitBegin;
1370 int64_t sourceBitEnd;
1373 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1379 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1381 for (int64_t i = 0; i < shuffleIdx; ++i)
1382 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1401 struct BitCastBitsEnumerator {
1402 BitCastBitsEnumerator(VectorType sourceVectorType,
1403 VectorType targetVectorType);
1405 int64_t getMaxNumberOfEntries() {
1406 int64_t numVectors = 0;
1407 for (
const auto &l : sourceElementRanges)
1408 numVectors =
std::max(numVectors, (int64_t)l.size());
1412 VectorType sourceVectorType;
1413 VectorType targetVectorType;
1488 struct BitCastRewriter {
1495 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1499 VectorType preconditionType,
Operation *op);
1503 precomputeMetadata(IntegerType shuffledElementType);
1509 const BitCastRewriter::Metadata &metadata);
1514 BitCastBitsEnumerator enumerator;
1519 [[maybe_unused]]
static raw_ostream &
1521 for (
const auto &l : vec) {
1523 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1524 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1525 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1532 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1533 VectorType targetVectorType)
1534 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1536 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1537 "requires -D non-scalable vector type");
1538 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1539 "requires -D non-scalable vector type");
1540 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1541 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1542 LDBG() <<
"sourceVectorType: " << sourceVectorType;
1544 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1545 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1546 LDBG() <<
"targetVectorType: " << targetVectorType;
1548 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1549 (void)mostMinorSourceDim;
1550 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1551 "source and target bitwidths must match");
1555 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1556 int64_t resultElement = resultBit / targetBitWidth;
1557 int64_t resultBitInElement = resultBit % targetBitWidth;
1558 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1559 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1560 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
1561 targetBitWidth - resultBitInElement);
1562 sourceElementRanges[resultElement].push_back(
1563 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1568 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1569 VectorType targetVectorType)
1570 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1571 LDBG() <<
"\n" << enumerator.sourceElementRanges;
1577 VectorType preconditionType,
1579 if (!preconditionType || preconditionType.isScalable())
1584 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1585 if (bitwidth % 8 != 0)
1591 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1592 VectorType preconditionType,
1594 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1597 if (!preconditionType || preconditionType.getRank() != 1)
1635 VectorType subByteVecTy,
1639 "container element type is not a scalar");
1646 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1650 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1653 if (subByteBits != 2 && subByteBits != 4)
1655 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1658 if (containerBits % subByteBits != 0)
1662 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1664 op,
"not possible to fit this sub-byte vector type into a vector of "
1665 "the given multi-byte type");
1671 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1673 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1674 shuffleIdx < e; ++shuffleIdx) {
1679 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1680 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1681 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1683 shuffles.push_back(sourceElement);
1685 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1686 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1688 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1689 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1692 shuffledElementType,
1693 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1695 masks.push_back(mask);
1697 int64_t shiftRight = bitLo;
1698 shiftRightAmounts.push_back(
1701 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1702 shiftLeftAmounts.push_back(
1706 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1711 Value BitCastRewriter::genericRewriteStep(
1713 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1715 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1716 initialValue, metadata.shuffles);
1719 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1720 auto constOp = arith::ConstantOp::create(
1723 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1726 auto shiftRightConstantOp = arith::ConstantOp::create(
1729 Value shiftedRight =
1730 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1733 auto shiftLeftConstantOp = arith::ConstantOp::create(
1737 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1741 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1744 return runningResult;
1755 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1756 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1757 assert(8 % srcBitwidth == 0 &&
1758 "Unsupported sub-byte type (not a divisor of i8)");
1759 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1762 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1764 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1785 int bitIdx,
int numBits) {
1786 auto srcType = cast<VectorType>(src.
getType());
1788 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1789 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1790 "Invalid bitIdx range");
1791 if (bitsToShiftLeft != 0) {
1792 Value shiftLeftValues = arith::ConstantOp::create(
1794 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1797 int8_t bitsToShiftRight = 8 - numBits;
1798 Value shiftRightValues = arith::ConstantOp::create(
1800 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1827 int bitIdx,
int numBits) {
1828 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1829 "Invalid bitIdx range");
1830 auto srcType = cast<VectorType>(src.
getType());
1831 int8_t bitsToShiftRight = bitIdx;
1833 if (bitsToShiftRight != 0) {
1834 Value shiftRightValues = arith::ConstantOp::create(
1836 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1838 if (bitIdx + numBits == 8) {
1841 uint8_t lowBitsMask = (1 << numBits) - 1;
1842 Value lowBitsMaskValues = arith::ConstantOp::create(
1844 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1854 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1855 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1856 "Expected i4 type");
1863 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1864 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1867 return vector::InterleaveOp::create(rewriter, loc, low, high);
1874 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1875 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1876 "Expected i2 type");
1883 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1885 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1887 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1889 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1900 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1901 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1902 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1910 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1911 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1912 "Expected i8 type");
1915 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1918 constexpr int8_t i8LowBitMask = 0x0F;
1919 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1920 Value zeroOutMask = arith::ConstantOp::create(
1922 Value zeroOutLow = arith::AndIOp::create(
1923 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1926 constexpr int8_t bitsToShift = 4;
1927 auto shiftValues = arith::ConstantOp::create(
1929 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1933 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1936 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1937 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1951 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1956 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1957 VectorType targetVectorType = bitCastOp.getResultVectorType();
1958 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1959 if (
failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1963 Value truncValue = truncOp.getIn();
1964 auto shuffledElementType =
1966 Value runningResult;
1967 for (
const BitCastRewriter ::Metadata &metadata :
1968 bcr.precomputeMetadata(shuffledElementType)) {
1969 runningResult = bcr.genericRewriteStep(
1970 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1974 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1975 shuffledElementType.getIntOrFloatBitWidth();
1977 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1978 rewriter.
replaceOp(bitCastOp, runningResult);
1981 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1984 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1985 rewriter.
replaceOp(bitCastOp, runningResult);
1988 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2005 template <
typename ExtOpType>
2015 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2020 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2021 VectorType targetVectorType = bitCastOp.getResultVectorType();
2022 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2023 if (
failed(bcr.commonPrecondition(
2024 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2028 Value runningResult;
2029 Value sourceValue = bitCastOp.getSource();
2030 auto shuffledElementType =
2032 for (
const BitCastRewriter::Metadata &metadata :
2033 bcr.precomputeMetadata(shuffledElementType)) {
2034 runningResult = bcr.genericRewriteStep(
2035 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2040 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2041 shuffledElementType.getIntOrFloatBitWidth();
2044 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2047 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2088 template <
typename ConversionOpType,
bool isSigned>
2095 Value srcValue = conversionOp.getIn();
2096 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2097 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2105 rewriter, srcVecType,
2110 Location loc = conversionOp.getLoc();
2114 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2127 conversionOp, conversionOp.getType(), subByteExt);
2155 Value srcValue = truncOp.getIn();
2156 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2157 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2158 if (!srcVecType || !dstVecType)
2165 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2171 rewriter, dstVecType,
2177 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2179 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2185 rewriter.
replaceOp(truncOp, subByteTrunc);
2211 constexpr
unsigned minNativeBitwidth = 8;
2212 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2213 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2214 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2216 "not a sub-byte transpose");
2220 Location loc = transposeOp.getLoc();
2225 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2227 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2228 transposeOp.getVector());
2229 Value newTranspose = vector::TransposeOp::create(
2230 rewriter, loc, extOp, transposeOp.getPermutation());
2231 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2245 void vector::populateVectorNarrowTypeEmulationPatterns(
2246 const arith::NarrowTypeEmulationConverter &typeConverter,
2250 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2251 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2252 typeConverter,
patterns.getContext());
2257 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW);
2260 void vector::populateVectorNarrowTypeRewritePatterns(
2263 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2264 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2270 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2271 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2272 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2276 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2277 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
2282 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2287 void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2288 arith::NarrowTypeEmulationConverter &typeConverter,
2291 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter,
patterns);
static Type getElementType(Type type)
Determine the element type of type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.