34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/DebugLog.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
79 int numSrcElemsPerDest,
80 int numFrontPadElems = 0) {
82 assert(numFrontPadElems < numSrcElemsPerDest &&
83 "numFrontPadElems must be less than numSrcElemsPerDest");
86 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
94 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
96 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
97 maskOp = extractOp.getVector().getDefiningOp();
98 extractOps.push_back(extractOp);
102 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
110 maskShape.back() = numDestElems;
112 std::optional<Operation *> newMask =
114 .Case<vector::CreateMaskOp>(
115 [&](
auto createMaskOp) -> std::optional<Operation *> {
125 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
128 rewriter, loc, s0, origIndex);
130 newMaskOperands.push_back(
132 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
135 .Case<vector::ConstantMaskOp>([&](
auto constantMaskOp)
136 -> std::optional<Operation *> {
139 int64_t &maskIndex = maskDimSizes.back();
142 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
145 .Case<arith::ConstantOp>([&](
auto constantOp)
146 -> std::optional<Operation *> {
148 if (maskShape.size() != 1)
165 cast<DenseIntElementsAttr>(constantOp.getValue());
167 paddedMaskValues.append(originalMask.template value_begin<bool>(),
168 originalMask.template value_end<bool>());
169 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
173 for (
size_t i = 0; i < paddedMaskValues.size();
174 i += numSrcElemsPerDest) {
175 bool combinedValue =
false;
176 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
177 combinedValue |= paddedMaskValues[i +
j];
179 compressedMaskValues.push_back(combinedValue);
181 return arith::ConstantOp::create(
189 while (!extractOps.empty()) {
191 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
192 extractOps.back().getMixedPosition());
193 extractOps.pop_back();
213 Value src, int64_t offset,
214 int64_t numElemsToExtract) {
215 auto vectorType = cast<VectorType>(src.
getType());
216 assert(vectorType.getRank() == 1 &&
"expected source to be rank-1-D vector ");
217 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
218 "subvector out of bounds");
222 if (vectorType.getNumElements() == numElemsToExtract)
229 auto resultVectorType =
231 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
232 src, offsets, sizes, strides)
247 [[maybe_unused]]
auto srcVecTy = cast<VectorType>(src.
getType());
248 [[maybe_unused]]
auto destVecTy = cast<VectorType>(dest.
getType());
249 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
250 "expected source and dest to be rank-1 vector types");
253 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
258 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
259 dest, offsets, strides);
284 int64_t numElemsToExtract) {
285 auto srcVecTy = cast<VectorType>(src.
getType());
286 assert(srcVecTy.getRank() == 1 &&
"expected source to be rank-1-D vector ");
290 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
291 "subvector out of bounds");
295 if (srcVecTy.getNumElements() == numElemsToExtract)
298 for (
int i = 0; i < numElemsToExtract; ++i) {
300 (i == 0) ? dyn_cast<Value>(offset)
301 : arith::AddIOp::create(
303 dyn_cast<Value>(offset),
305 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
306 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
328 int64_t numElemsToInsert) {
329 auto srcVecTy = cast<VectorType>(src.
getType());
330 auto destVecTy = cast<VectorType>(dest.
getType());
331 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
332 "expected source and dest to be rank-1 vector types");
335 assert(numElemsToInsert > 0 &&
336 "the number of elements to insert must be greater than 0");
340 assert(numElemsToInsert <= destVecTy.getNumElements() &&
341 "subvector out of bounds");
344 for (int64_t i = 0; i < numElemsToInsert; ++i) {
346 i == 0 ? destOffsetVal
347 : arith::AddIOp::create(
350 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
351 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
365 int64_t numContainerElemsToLoad,
367 Type containerElemTy) {
370 auto newLoad = vector::LoadOp::create(
371 rewriter, loc,
VectorType::get(numContainerElemsToLoad, containerElemTy),
373 return vector::BitCastOp::create(
383 VectorType downcastType,
384 VectorType upcastType,
Value mask,
387 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
388 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
389 "expected input and output number of bits to match");
390 if (trueValue.
getType() != downcastType) {
392 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
394 if (falseValue.
getType() != downcastType) {
396 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
399 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
401 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
420 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
424 auto atomicOp = memref::GenericAtomicRMWOp::create(
425 builder, loc, linearizedMemref,
ValueRange{storeIdx});
426 Value origValue = atomicOp.getCurrentValue();
434 Value origVecValue = vector::FromElementsOp::create(
435 builder, loc, oneElemVecType,
ValueRange{origValue});
440 oneElemVecType, mask, valueToStore, origVecValue);
441 auto scalarMaskedValue =
442 vector::ExtractOp::create(builder, loc, maskedValue, 0);
443 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
451 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
453 auto oneElemVecType =
456 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
458 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
463 oneElemVecType, mask, valueToStore, origVecValue);
464 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
479 int64_t extractOffset,
480 int64_t sliceNumElements,
481 int64_t insertOffset) {
482 assert(vector.getType().getRank() == 1 &&
"expected 1-D vector");
483 auto vectorElementType = vector.getType().getElementType();
487 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
488 "sliceNumElements * vector element size must be less than or equal to 8");
489 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
490 "vector element must be a valid sub-byte type");
491 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
492 auto emptyByteVector = arith::ConstantOp::create(
498 extractOffset, sliceNumElements);
551 ConvertVectorStore(
MLIRContext *context,
bool disableAtomicRMW)
553 disableAtomicRMW(disableAtomicRMW) {}
556 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
560 if (op.getValueToStore().getType().getRank() != 1)
562 "only 1-D vectors are supported ATM");
564 auto loc = op.getLoc();
566 auto valueToStore = cast<VectorValue>(op.getValueToStore());
567 auto containerElemTy =
568 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
569 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
571 int containerBits = containerElemTy.getIntOrFloatBitWidth();
574 if (containerBits % emulatedBits != 0) {
576 op,
"impossible to pack emulated elements into container elements "
577 "(bit-wise misalignment)");
579 int emulatedPerContainerElem = containerBits / emulatedBits;
594 auto origElements = valueToStore.getType().getNumElements();
596 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
601 auto trailingDim = op.getBase().getType().getShape().back();
602 bool trailingDimsMatch =
603 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
605 auto stridedMetadata =
606 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
612 std::tie(linearizedInfo, linearizedIndices) =
614 rewriter, loc, emulatedBits, containerBits,
615 stridedMetadata.getConstifiedMixedOffset(),
616 stridedMetadata.getConstifiedMixedSizes(),
617 stridedMetadata.getConstifiedMixedStrides(),
620 std::optional<int64_t> foldedNumFrontPadElems =
621 (isDivisibleInSize && trailingDimsMatch)
625 if (!foldedNumFrontPadElems) {
627 op,
"subbyte store emulation: dynamic front padding size is "
628 "not yet implemented");
631 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
663 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
665 if (!emulationRequiresPartialStores) {
667 auto numElements = origElements / emulatedPerContainerElem;
668 auto bitCast = vector::BitCastOp::create(
670 op.getValueToStore());
672 op, bitCast.getResult(), memrefBase,
708 Value currentDestIndex =
711 auto currentSourceIndex = 0;
714 auto subWidthStoreMaskType =
723 auto frontSubWidthStoreElem =
724 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
725 emulatedPerContainerElem;
726 if (frontSubWidthStoreElem > 0) {
728 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
729 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
731 frontSubWidthStoreElem = origElements;
733 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
734 *foldedNumFrontPadElems,
true);
736 auto frontMask = arith::ConstantOp::create(
740 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
743 frontSubWidthStoreElem, *foldedNumFrontPadElems);
745 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
746 cast<VectorValue>(value), frontMask.getResult());
749 if (currentSourceIndex >= origElements) {
757 currentDestIndex = arith::AddIOp::create(
763 int64_t fullWidthStoreSize =
764 (origElements - currentSourceIndex) / emulatedPerContainerElem;
765 int64_t numNonFullWidthElements =
766 fullWidthStoreSize * emulatedPerContainerElem;
767 if (fullWidthStoreSize > 0) {
769 rewriter, loc, valueToStore, currentSourceIndex,
770 numNonFullWidthElements);
772 auto originType = cast<VectorType>(fullWidthStorePart.getType());
775 {originType.getNumElements() / emulatedPerContainerElem},
777 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
779 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
782 currentSourceIndex += numNonFullWidthElements;
783 currentDestIndex = arith::AddIOp::create(
784 rewriter, loc, rewriter.
getIndexType(), currentDestIndex,
791 auto remainingElements = origElements - currentSourceIndex;
792 if (remainingElements != 0) {
793 auto subWidthStorePart =
795 currentSourceIndex, remainingElements, 0);
799 std::fill_n(maskValues.begin(), remainingElements, 1);
800 auto backMask = arith::ConstantOp::create(
804 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
805 cast<VectorValue>(subWidthStorePart), backMask.getResult());
813 const bool disableAtomicRMW;
821 struct ConvertVectorMaskedStore final
826 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
830 if (op.getValueToStore().getType().getRank() != 1)
832 "only 1-D vectors are supported ATM");
834 auto loc = op.getLoc();
835 auto containerElemTy =
836 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
837 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
839 int containerBits = containerElemTy.getIntOrFloatBitWidth();
842 if (containerBits % emulatedBits != 0) {
844 op,
"impossible to pack emulated elements into container elements "
845 "(bit-wise misalignment)");
848 int emulatedPerContainerElem = containerBits / emulatedBits;
849 int origElements = op.getValueToStore().getType().getNumElements();
850 if (origElements % emulatedPerContainerElem != 0)
853 auto stridedMetadata =
854 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
857 std::tie(linearizedInfo, linearizedIndicesOfr) =
859 rewriter, loc, emulatedBits, containerBits,
860 stridedMetadata.getConstifiedMixedOffset(),
861 stridedMetadata.getConstifiedMixedSizes(),
862 stridedMetadata.getConstifiedMixedStrides(),
864 Value linearizedIndices =
900 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
904 auto numElements = (origElements + emulatedPerContainerElem - 1) /
905 emulatedPerContainerElem;
907 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
910 auto newLoad = vector::MaskedLoadOp::create(
911 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
912 newMask.value()->getResult(0), passThru);
914 auto newBitCastType =
915 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
917 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
918 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
919 op.getValueToStore(), valueToStore);
921 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
924 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
939 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
943 if (op.getVectorType().getRank() != 1)
945 "only 1-D vectors are supported ATM");
947 auto loc = op.getLoc();
948 auto containerElemTy =
949 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
950 Type emulatedElemTy = op.getType().getElementType();
952 int containerBits = containerElemTy.getIntOrFloatBitWidth();
955 if (containerBits % emulatedBits != 0) {
957 op,
"impossible to pack emulated elements into container elements "
958 "(bit-wise misalignment)");
960 int emulatedPerContainerElem = containerBits / emulatedBits;
991 auto origElements = op.getVectorType().getNumElements();
993 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
995 auto stridedMetadata =
996 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1000 std::tie(linearizedInfo, linearizedIndices) =
1002 rewriter, loc, emulatedBits, containerBits,
1003 stridedMetadata.getConstifiedMixedOffset(),
1004 stridedMetadata.getConstifiedMixedSizes(),
1005 stridedMetadata.getConstifiedMixedStrides(),
1008 std::optional<int64_t> foldedIntraVectorOffset =
1009 isDivisibleInSize ? 0
1013 int64_t maxintraDataOffset =
1014 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1016 emulatedPerContainerElem);
1019 numElements, emulatedElemTy, containerElemTy);
1021 if (!foldedIntraVectorOffset) {
1022 auto resultVector = arith::ConstantOp::create(
1023 rewriter, loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1027 }
else if (!isDivisibleInSize) {
1029 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1041 struct ConvertVectorMaskedLoad final
1046 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1049 if (op.getVectorType().getRank() != 1)
1051 "only 1-D vectors are supported ATM");
1053 auto loc = op.getLoc();
1055 auto containerElemTy =
1056 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1057 Type emulatedElemTy = op.getType().getElementType();
1059 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1062 if (containerBits % emulatedBits != 0) {
1064 op,
"impossible to pack emulated elements into container elements "
1065 "(bit-wise misalignment)");
1067 int emulatedPerContainerElem = containerBits / emulatedBits;
1111 auto origType = op.getVectorType();
1112 auto origElements = origType.getNumElements();
1114 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1116 auto stridedMetadata =
1117 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1120 std::tie(linearizedInfo, linearizedIndices) =
1122 rewriter, loc, emulatedBits, containerBits,
1123 stridedMetadata.getConstifiedMixedOffset(),
1124 stridedMetadata.getConstifiedMixedSizes(),
1125 stridedMetadata.getConstifiedMixedStrides(),
1128 std::optional<int64_t> foldedIntraVectorOffset =
1129 isDivisibleInSize ? 0
1132 int64_t maxIntraDataOffset =
1133 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1134 FailureOr<Operation *> newMask =
1136 emulatedPerContainerElem, maxIntraDataOffset);
1140 Value passthru = op.getPassThru();
1143 emulatedPerContainerElem);
1145 auto newBitcastType =
1146 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1148 auto emptyVector = arith::ConstantOp::create(
1149 rewriter, loc, newBitcastType, rewriter.
getZeroAttr(newBitcastType));
1150 if (!foldedIntraVectorOffset) {
1154 }
else if (!isDivisibleInSize) {
1156 *foldedIntraVectorOffset);
1159 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1162 auto newLoad = vector::MaskedLoadOp::create(
1163 rewriter, loc, loadType, adaptor.getBase(),
1165 newMask.value()->getResult(0), newPassThru);
1170 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1172 Value mask = op.getMask();
1174 numElements * emulatedPerContainerElem, rewriter.
getI1Type());
1177 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1179 if (!foldedIntraVectorOffset) {
1183 }
else if (!isDivisibleInSize) {
1185 *foldedIntraVectorOffset);
1189 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1190 if (!foldedIntraVectorOffset) {
1192 rewriter, loc, result, op.getPassThru(),
1194 }
else if (!isDivisibleInSize) {
1196 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1219 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1220 Type multiByteScalarTy) {
1221 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) &&
"Not scalar!");
1223 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1226 assert(subByteBits < 8 &&
"Not a sub-byte scalar type!");
1227 assert(multiByteBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1228 assert(multiByteBits % subByteBits == 0 &&
"Unalagined element types!");
1230 int elemsPerMultiByte = multiByteBits / subByteBits;
1233 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1241 struct ConvertVectorTransferRead final
1246 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1250 if (op.getVectorType().getRank() != 1)
1252 "only 1-D vectors are supported ATM");
1254 auto loc = op.getLoc();
1255 auto containerElemTy =
1256 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1257 Type emulatedElemTy = op.getType().getElementType();
1259 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1262 if (containerBits % emulatedBits != 0) {
1264 op,
"impossible to pack emulated elements into container elements "
1265 "(bit-wise misalignment)");
1267 int emulatedPerContainerElem = containerBits / emulatedBits;
1269 auto origElements = op.getVectorType().getNumElements();
1272 bool isDivisibleInSize =
1273 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1277 Value padding = adaptor.getPadding();
1279 padding = arith::BitcastOp::create(
1286 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1288 auto stridedMetadata =
1289 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1293 std::tie(linearizedInfo, linearizedIndices) =
1295 rewriter, loc, emulatedBits, containerBits,
1296 stridedMetadata.getConstifiedMixedOffset(),
1297 stridedMetadata.getConstifiedMixedSizes(),
1298 stridedMetadata.getConstifiedMixedStrides(),
1301 std::optional<int64_t> foldedIntraVectorOffset =
1302 isDivisibleInSize ? 0
1305 int64_t maxIntraDataOffset =
1306 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1308 emulatedPerContainerElem);
1310 auto newRead = vector::TransferReadOp::create(
1316 auto bitCast = vector::BitCastOp::create(
1318 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1321 Value result = bitCast->getResult(0);
1322 if (!foldedIntraVectorOffset) {
1323 auto zeros = arith::ConstantOp::create(
1324 rewriter, loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1328 }
else if (!isDivisibleInSize) {
1330 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1347 struct SourceElementRange {
1349 int64_t sourceElementIdx;
1351 int64_t sourceBitBegin;
1352 int64_t sourceBitEnd;
1355 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1361 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1363 for (int64_t i = 0; i < shuffleIdx; ++i)
1364 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1383 struct BitCastBitsEnumerator {
1384 BitCastBitsEnumerator(VectorType sourceVectorType,
1385 VectorType targetVectorType);
1387 int64_t getMaxNumberOfEntries() {
1388 int64_t numVectors = 0;
1389 for (
const auto &l : sourceElementRanges)
1390 numVectors =
std::max(numVectors, (int64_t)l.size());
1394 VectorType sourceVectorType;
1395 VectorType targetVectorType;
1470 struct BitCastRewriter {
1477 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1481 VectorType preconditionType,
Operation *op);
1485 precomputeMetadata(IntegerType shuffledElementType);
1491 const BitCastRewriter::Metadata &metadata);
1496 BitCastBitsEnumerator enumerator;
1501 [[maybe_unused]]
static raw_ostream &
1503 for (
const auto &l : vec) {
1505 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1506 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1507 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1514 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1515 VectorType targetVectorType)
1516 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1518 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1519 "requires -D non-scalable vector type");
1520 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1521 "requires -D non-scalable vector type");
1522 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1523 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1524 LDBG() <<
"sourceVectorType: " << sourceVectorType;
1526 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1527 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1528 LDBG() <<
"targetVectorType: " << targetVectorType;
1530 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1531 (void)mostMinorSourceDim;
1532 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1533 "source and target bitwidths must match");
1537 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1538 int64_t resultElement = resultBit / targetBitWidth;
1539 int64_t resultBitInElement = resultBit % targetBitWidth;
1540 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1541 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1542 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
1543 targetBitWidth - resultBitInElement);
1544 sourceElementRanges[resultElement].push_back(
1545 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1550 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1551 VectorType targetVectorType)
1552 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1553 LDBG() <<
"\n" << enumerator.sourceElementRanges;
1559 VectorType preconditionType,
1561 if (!preconditionType || preconditionType.isScalable())
1566 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1567 if (bitwidth % 8 != 0)
1573 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1574 VectorType preconditionType,
1576 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1579 if (!preconditionType || preconditionType.getRank() != 1)
1617 VectorType subByteVecTy,
1621 "container element type is not a scalar");
1628 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1632 assert(containerBits % 8 == 0 &&
"Not a multi-byte scalar type!");
1635 if (subByteBits != 2 && subByteBits != 4)
1637 op,
"only 2-bit and 4-bit sub-byte type is supported at this moment");
1640 if (containerBits % subByteBits != 0)
1644 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1646 op,
"not possible to fit this sub-byte vector type into a vector of "
1647 "the given multi-byte type");
1653 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1655 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1656 shuffleIdx < e; ++shuffleIdx) {
1661 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1662 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1663 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1665 shuffles.push_back(sourceElement);
1667 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1668 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1670 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1671 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1674 shuffledElementType,
1675 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1677 masks.push_back(mask);
1679 int64_t shiftRight = bitLo;
1680 shiftRightAmounts.push_back(
1683 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1684 shiftLeftAmounts.push_back(
1688 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1693 Value BitCastRewriter::genericRewriteStep(
1695 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1697 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1698 initialValue, metadata.shuffles);
1701 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1702 auto constOp = arith::ConstantOp::create(
1705 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1708 auto shiftRightConstantOp = arith::ConstantOp::create(
1711 Value shiftedRight =
1712 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1715 auto shiftLeftConstantOp = arith::ConstantOp::create(
1719 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1723 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1726 return runningResult;
1737 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1738 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1739 assert(8 % srcBitwidth == 0 &&
1740 "Unsupported sub-byte type (not a divisor of i8)");
1741 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1744 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1746 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1767 int bitIdx,
int numBits) {
1768 auto srcType = cast<VectorType>(src.
getType());
1770 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1771 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1772 "Invalid bitIdx range");
1773 if (bitsToShiftLeft != 0) {
1774 Value shiftLeftValues = arith::ConstantOp::create(
1776 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1779 int8_t bitsToShiftRight = 8 - numBits;
1780 Value shiftRightValues = arith::ConstantOp::create(
1782 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1809 int bitIdx,
int numBits) {
1810 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1811 "Invalid bitIdx range");
1812 auto srcType = cast<VectorType>(src.
getType());
1813 int8_t bitsToShiftRight = bitIdx;
1815 if (bitsToShiftRight != 0) {
1816 Value shiftRightValues = arith::ConstantOp::create(
1818 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1820 if (bitIdx + numBits == 8) {
1823 uint8_t lowBitsMask = (1 << numBits) - 1;
1824 Value lowBitsMaskValues = arith::ConstantOp::create(
1826 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1836 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1837 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1838 "Expected i4 type");
1845 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1846 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1849 return vector::InterleaveOp::create(rewriter, loc, low, high);
1856 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1857 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1858 "Expected i2 type");
1865 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1867 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1869 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1871 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1882 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1883 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1884 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1892 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1893 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1894 "Expected i8 type");
1897 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1900 constexpr int8_t i8LowBitMask = 0x0F;
1901 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1902 Value zeroOutMask = arith::ConstantOp::create(
1904 Value zeroOutLow = arith::AndIOp::create(
1905 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1908 constexpr int8_t bitsToShift = 4;
1909 auto shiftValues = arith::ConstantOp::create(
1911 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1915 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1918 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1919 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1933 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1938 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1939 VectorType targetVectorType = bitCastOp.getResultVectorType();
1940 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1941 if (
failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1945 Value truncValue = truncOp.getIn();
1946 auto shuffledElementType =
1948 Value runningResult;
1949 for (
const BitCastRewriter ::Metadata &metadata :
1950 bcr.precomputeMetadata(shuffledElementType)) {
1951 runningResult = bcr.genericRewriteStep(
1952 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1956 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1957 shuffledElementType.getIntOrFloatBitWidth();
1959 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1960 rewriter.
replaceOp(bitCastOp, runningResult);
1963 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1966 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1967 rewriter.
replaceOp(bitCastOp, runningResult);
1970 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1987 template <
typename ExtOpType>
1997 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2002 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2003 VectorType targetVectorType = bitCastOp.getResultVectorType();
2004 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2005 if (
failed(bcr.commonPrecondition(
2006 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2010 Value runningResult;
2011 Value sourceValue = bitCastOp.getSource();
2012 auto shuffledElementType =
2014 for (
const BitCastRewriter::Metadata &metadata :
2015 bcr.precomputeMetadata(shuffledElementType)) {
2016 runningResult = bcr.genericRewriteStep(
2017 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2022 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2023 shuffledElementType.getIntOrFloatBitWidth();
2026 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2029 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2070 template <
typename ConversionOpType,
bool isSigned>
2077 Value srcValue = conversionOp.getIn();
2078 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2079 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2087 rewriter, srcVecType,
2092 Location loc = conversionOp.getLoc();
2096 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2109 conversionOp, conversionOp.getType(), subByteExt);
2137 Value srcValue = truncOp.getIn();
2138 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
2139 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2140 if (!srcVecType || !dstVecType)
2147 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2153 rewriter, dstVecType,
2159 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
2161 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2167 rewriter.
replaceOp(truncOp, subByteTrunc);
2193 constexpr
unsigned minNativeBitwidth = 8;
2194 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2195 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2196 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2198 "not a sub-byte transpose");
2202 Location loc = transposeOp.getLoc();
2207 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2209 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2210 transposeOp.getVector());
2211 Value newTranspose = vector::TransposeOp::create(
2212 rewriter, loc, extOp, transposeOp.getPermutation());
2213 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2227 void vector::populateVectorNarrowTypeEmulationPatterns(
2228 const arith::NarrowTypeEmulationConverter &typeConverter,
2233 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2234 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2235 typeConverter,
patterns.getContext());
2240 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW);
2243 void vector::populateVectorNarrowTypeRewritePatterns(
2246 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2247 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2253 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2254 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2255 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2259 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2260 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
2265 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
static Type getElementType(Type type)
Determine the element type of type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.