34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
75 int numSrcElemsPerDest,
76 int numFrontPadElems = 0) {
78 assert(numFrontPadElems < numSrcElemsPerDest &&
79 "numFrontPadElems must be less than numSrcElemsPerDest");
82 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
90 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
92 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
93 maskOp = extractOp.getVector().getDefiningOp();
94 extractOps.push_back(extractOp);
98 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
106 maskShape.back() = numDestElems;
108 std::optional<Operation *> newMask =
110 .Case<vector::CreateMaskOp>(
111 [&](
auto createMaskOp) -> std::optional<Operation *> {
121 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
124 rewriter, loc, s0, origIndex);
126 newMaskOperands.push_back(
128 return rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
131 .Case<vector::ConstantMaskOp>([&](
auto constantMaskOp)
132 -> std::optional<Operation *> {
134 size_t numMaskOperands = maskDimSizes.size();
135 int64_t origIndex = maskDimSizes[numMaskOperands - 1];
136 int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
142 if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
146 newMaskDimSizes.push_back(maskIndex);
148 if (numFrontPadElems == 0)
149 return rewriter.
create<vector::ConstantMaskOp>(loc, newMaskType,
153 for (int64_t i = 0; i < numDestElems; ++i)
154 newMaskValues.push_back(i >= startIndex && i < maskIndex);
156 return rewriter.
create<arith::ConstantOp>(loc, newMaskType,
159 .Case<arith::ConstantOp>([&](
auto constantOp)
160 -> std::optional<Operation *> {
162 if (maskShape.size() != 1)
179 cast<DenseIntElementsAttr>(constantOp.getValue());
181 paddedMaskValues.append(originalMask.template value_begin<bool>(),
182 originalMask.template value_end<bool>());
183 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
187 for (
size_t i = 0; i < paddedMaskValues.size();
188 i += numSrcElemsPerDest) {
189 bool combinedValue =
false;
190 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
191 combinedValue |= paddedMaskValues[i +
j];
193 compressedMaskValues.push_back(combinedValue);
195 return rewriter.
create<arith::ConstantOp>(
202 while (!extractOps.empty()) {
203 newMask = rewriter.
create<vector::ExtractOp>(
204 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
205 extractOps.pop_back();
214 VectorType extractType,
Value source,
216 int64_t subvecSize) {
217 auto vectorType = cast<VectorType>(source.
getType());
218 assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
219 "expected 1-D source and destination types");
221 assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
222 "subvector out of bounds");
225 if (vectorType.getNumElements() == subvecSize)
232 .
create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
242 [[maybe_unused]]
auto srcType = cast<VectorType>(src.
getType());
243 [[maybe_unused]]
auto destType = cast<VectorType>(dest.
getType());
244 assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
245 "expected source and dest to be vector type");
248 return rewriter.
create<vector::InsertStridedSliceOp>(loc, dest.
getType(), src,
249 dest, offsets, strides);
259 int64_t numElementsToExtract) {
260 for (
int i = 0; i < numElementsToExtract; ++i) {
262 (i == 0) ? offset.dyn_cast<
Value>()
263 : rewriter.
create<arith::AddIOp>(
265 rewriter.
create<arith::ConstantIndexOp>(loc, i));
267 rewriter.
create<vector::ExtractOp>(loc, source, extractLoc);
268 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, i);
278 assert(length > 0 &&
"length must be greater than 0");
279 Value destOffsetVal =
281 for (
size_t i = 0; i < length; ++i) {
282 auto insertLoc = i == 0
284 : rewriter.
create<arith::AddIOp>(
286 rewriter.
create<arith::ConstantIndexOp>(loc, i));
287 auto extractOp = rewriter.
create<vector::ExtractOp>(loc, source, i);
288 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
300 int64_t numEmultedElementsToLoad,
Type origElemType,
301 Type emulatedElemType) {
304 auto newLoad = rewriter.
create<vector::LoadOp>(
305 loc,
VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
307 return rewriter.
create<vector::BitCastOp>(
322 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
326 if (op.getValueToStore().getType().getRank() != 1)
328 "only 1-D vectors are supported ATM");
330 auto loc = op.getLoc();
331 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
332 Type oldElementType = op.getValueToStore().getType().getElementType();
333 Type newElementType = convertedType.getElementType();
337 if (dstBits % srcBits != 0) {
339 op,
"only dstBits % srcBits == 0 supported");
341 int scale = dstBits / srcBits;
356 auto origElements = op.getValueToStore().getType().getNumElements();
357 if (origElements % scale != 0)
360 auto stridedMetadata =
361 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
364 std::tie(std::ignore, linearizedIndices) =
366 rewriter, loc, srcBits, dstBits,
367 stridedMetadata.getConstifiedMixedOffset(),
368 stridedMetadata.getConstifiedMixedSizes(),
369 stridedMetadata.getConstifiedMixedStrides(),
372 auto numElements = origElements / scale;
373 auto bitCast = rewriter.
create<vector::BitCastOp>(
375 op.getValueToStore());
378 op, bitCast.getResult(), adaptor.getBase(),
388 struct ConvertVectorMaskedStore final
393 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
397 if (op.getValueToStore().getType().getRank() != 1)
399 "only 1-D vectors are supported ATM");
401 auto loc = op.getLoc();
402 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
403 Type oldElementType = op.getValueToStore().getType().getElementType();
404 Type newElementType = convertedType.getElementType();
408 if (dstBits % srcBits != 0) {
410 op,
"only dstBits % srcBits == 0 supported");
413 int scale = dstBits / srcBits;
414 int origElements = op.getValueToStore().getType().getNumElements();
415 if (origElements % scale != 0)
418 auto stridedMetadata =
419 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
422 std::tie(linearizedInfo, linearizedIndicesOfr) =
424 rewriter, loc, srcBits, dstBits,
425 stridedMetadata.getConstifiedMixedOffset(),
426 stridedMetadata.getConstifiedMixedSizes(),
427 stridedMetadata.getConstifiedMixedStrides(),
429 Value linearizedIndices =
464 FailureOr<Operation *> newMask =
469 auto numElements = (origElements + scale - 1) / scale;
471 auto passThru = rewriter.
create<arith::ConstantOp>(
474 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
475 loc, newType, adaptor.getBase(), linearizedIndices,
476 newMask.value()->getResult(0), passThru);
478 auto newBitCastType =
VectorType::get(numElements * scale, oldElementType);
480 rewriter.
create<vector::BitCastOp>(loc, newBitCastType, newLoad);
481 valueToStore = rewriter.
create<arith::SelectOp>(
482 loc, op.getMask(), op.getValueToStore(), valueToStore);
484 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
487 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
501 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
505 if (op.getVectorType().getRank() != 1)
507 "only 1-D vectors are supported ATM");
509 auto loc = op.getLoc();
510 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
511 Type oldElementType = op.getType().getElementType();
512 Type newElementType = convertedType.getElementType();
516 if (dstBits % srcBits != 0) {
518 op,
"only dstBits % srcBits == 0 supported");
520 int scale = dstBits / srcBits;
551 auto origElements = op.getVectorType().getNumElements();
552 bool isUnalignedEmulation = origElements % scale != 0;
554 auto stridedMetadata =
555 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
559 std::tie(linearizedInfo, linearizedIndices) =
561 rewriter, loc, srcBits, dstBits,
562 stridedMetadata.getConstifiedMixedOffset(),
563 stridedMetadata.getConstifiedMixedSizes(),
564 stridedMetadata.getConstifiedMixedStrides(),
567 std::optional<int64_t> foldedIntraVectorOffset =
573 int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
578 numElements, oldElementType, newElementType);
580 if (!foldedIntraVectorOffset) {
581 auto resultVector = rewriter.
create<arith::ConstantOp>(
582 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
586 }
else if (isUnalignedEmulation) {
589 *foldedIntraVectorOffset, origElements);
600 struct ConvertVectorMaskedLoad final
605 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
609 if (op.getVectorType().getRank() != 1)
611 "only 1-D vectors are supported ATM");
613 auto loc = op.getLoc();
614 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
615 Type oldElementType = op.getType().getElementType();
616 Type newElementType = convertedType.getElementType();
620 if (dstBits % srcBits != 0) {
622 op,
"only dstBits % srcBits == 0 supported");
624 int scale = dstBits / srcBits;
668 auto origType = op.getVectorType();
669 auto origElements = origType.getNumElements();
670 bool isUnalignedEmulation = origElements % scale != 0;
672 auto stridedMetadata =
673 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
676 std::tie(linearizedInfo, linearizedIndices) =
678 rewriter, loc, srcBits, dstBits,
679 stridedMetadata.getConstifiedMixedOffset(),
680 stridedMetadata.getConstifiedMixedSizes(),
681 stridedMetadata.getConstifiedMixedStrides(),
684 std::optional<int64_t> foldedIntraVectorOffset =
689 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
691 rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
695 Value passthru = op.getPassThru();
700 auto newBitcastType =
VectorType::get(numElements * scale, oldElementType);
702 auto emptyVector = rewriter.
create<arith::ConstantOp>(
703 loc, newBitcastType, rewriter.
getZeroAttr(newBitcastType));
704 if (!foldedIntraVectorOffset) {
708 }
else if (isUnalignedEmulation) {
710 *foldedIntraVectorOffset);
713 rewriter.
create<vector::BitCastOp>(loc, loadType, passthru);
716 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
717 loc, loadType, adaptor.getBase(),
719 newMask.value()->getResult(0), newPassThru);
724 rewriter.
create<vector::BitCastOp>(loc, newBitcastType, newLoad);
726 Value mask = op.getMask();
727 auto newSelectMaskType =
730 auto emptyMask = rewriter.
create<arith::ConstantOp>(
731 loc, newSelectMaskType, rewriter.
getZeroAttr(newSelectMaskType));
732 if (!foldedIntraVectorOffset) {
736 }
else if (isUnalignedEmulation) {
738 *foldedIntraVectorOffset);
742 rewriter.
create<arith::SelectOp>(loc, mask, bitCast, passthru);
743 if (!foldedIntraVectorOffset) {
747 }
else if (isUnalignedEmulation) {
750 *foldedIntraVectorOffset, origElements);
762 struct ConvertVectorTransferRead final
767 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
771 if (op.getVectorType().getRank() != 1)
773 "only 1-D vectors are supported ATM");
775 auto loc = op.getLoc();
776 auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
777 Type oldElementType = op.getType().getElementType();
778 Type newElementType = convertedType.getElementType();
782 if (dstBits % srcBits != 0) {
784 op,
"only dstBits % srcBits == 0 supported");
786 int scale = dstBits / srcBits;
788 auto origElements = op.getVectorType().getNumElements();
790 bool isUnalignedEmulation = origElements % scale != 0;
792 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, newElementType,
793 adaptor.getPadding());
795 auto stridedMetadata =
796 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
800 std::tie(linearizedInfo, linearizedIndices) =
802 rewriter, loc, srcBits, dstBits,
803 stridedMetadata.getConstifiedMixedOffset(),
804 stridedMetadata.getConstifiedMixedSizes(),
805 stridedMetadata.getConstifiedMixedStrides(),
808 std::optional<int64_t> foldedIntraVectorOffset =
813 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
817 auto newRead = rewriter.
create<vector::TransferReadOp>(
818 loc,
VectorType::get(numElements, newElementType), adaptor.getSource(),
822 auto bitCast = rewriter.
create<vector::BitCastOp>(
825 Value result = bitCast->getResult(0);
826 if (!foldedIntraVectorOffset) {
827 auto zeros = rewriter.
create<arith::ConstantOp>(
828 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
832 }
else if (isUnalignedEmulation) {
835 *foldedIntraVectorOffset, origElements);
852 struct SourceElementRange {
854 int64_t sourceElementIdx;
856 int64_t sourceBitBegin;
857 int64_t sourceBitEnd;
860 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
866 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
868 for (int64_t i = 0; i < shuffleIdx; ++i)
869 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
888 struct BitCastBitsEnumerator {
889 BitCastBitsEnumerator(VectorType sourceVectorType,
890 VectorType targetVectorType);
892 int64_t getMaxNumberOfEntries() {
893 int64_t numVectors = 0;
894 for (
const auto &l : sourceElementRanges)
895 numVectors =
std::max(numVectors, (int64_t)l.size());
899 VectorType sourceVectorType;
900 VectorType targetVectorType;
975 struct BitCastRewriter {
982 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
986 VectorType preconditionType,
Operation *op);
990 precomputeMetadata(IntegerType shuffledElementType);
996 const BitCastRewriter::Metadata &metadata);
1001 BitCastBitsEnumerator enumerator;
1006 [[maybe_unused]]
static raw_ostream &
1008 for (
const auto &l : vec) {
1010 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1011 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1012 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1019 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1020 VectorType targetVectorType)
1021 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1023 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1024 "requires -D non-scalable vector type");
1025 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1026 "requires -D non-scalable vector type");
1027 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1028 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1029 LDBG(
"sourceVectorType: " << sourceVectorType);
1031 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1032 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1033 LDBG(
"targetVectorType: " << targetVectorType);
1035 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1036 (void)mostMinorSourceDim;
1037 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1038 "source and target bitwidths must match");
1042 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1043 int64_t resultElement = resultBit / targetBitWidth;
1044 int64_t resultBitInElement = resultBit % targetBitWidth;
1045 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1046 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1047 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
1048 targetBitWidth - resultBitInElement);
1049 sourceElementRanges[resultElement].push_back(
1050 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1055 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1056 VectorType targetVectorType)
1057 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1058 LDBG(
"\n" << enumerator.sourceElementRanges);
1064 VectorType preconditionType,
1066 if (!preconditionType || preconditionType.isScalable())
1071 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1072 if (bitwidth % 8 != 0)
1078 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1079 VectorType preconditionType,
1081 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1084 if (!preconditionType || preconditionType.getRank() != 1)
1100 if (!srcType || !dstType)
1102 unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
1103 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1106 if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1107 (dstElemBitwidth % srcElemBitwidth) != 0)
1110 if ((srcType.getShape().back() % 2) != 0)
1112 op,
"Not an even number of i4 elements in trailing dim");
1118 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1120 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1121 shuffleIdx < e; ++shuffleIdx) {
1126 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1127 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1128 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1130 shuffles.push_back(sourceElement);
1132 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1133 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1135 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1136 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1139 shuffledElementType,
1140 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1142 masks.push_back(mask);
1144 int64_t shiftRight = bitLo;
1145 shiftRightAmounts.push_back(
1148 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1149 shiftLeftAmounts.push_back(
1153 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1158 Value BitCastRewriter::genericRewriteStep(
1160 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1162 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
1163 loc, initialValue, initialValue, metadata.shuffles);
1166 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1167 auto constOp = rewriter.
create<arith::ConstantOp>(
1169 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
1172 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
1175 Value shiftedRight =
1176 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1179 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
1183 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1187 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1190 return runningResult;
1198 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1199 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1200 "Expected i4 type");
1204 constexpr int64_t i4Toi8BitwidthFactor = 2;
1205 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1207 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
1211 constexpr int8_t bitsToShift = 4;
1212 auto shiftValues = rewriter.
create<arith::ConstantOp>(
1214 Value shl = rewriter.
create<arith::ShLIOp>(loc, i8Vector, shiftValues);
1215 Value low = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftValues);
1216 Value high = rewriter.
create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
1219 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
1227 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1228 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1229 "Expected i4 type");
1233 constexpr int64_t i4Toi8BitwidthFactor = 2;
1234 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1236 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
1240 constexpr uint8_t lowBitsMask = 15;
1241 auto lowBitsMaskValues = rewriter.
create<arith::ConstantOp>(
1243 Value low = rewriter.
create<arith::AndIOp>(loc, i8VecType, i8Vector,
1245 constexpr int8_t highBitsToShift = 4;
1246 auto highShiftValues = rewriter.
create<arith::ConstantOp>(
1248 Value high = rewriter.
create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1251 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
1259 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1260 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1261 "Expected i8 type");
1264 auto deinterleaveOp = rewriter.
create<vector::DeinterleaveOp>(loc, srcValue);
1267 constexpr int8_t i8LowBitMask = 0x0F;
1268 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1269 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
1271 Value zeroOutLow = rewriter.
create<arith::AndIOp>(
1272 loc, deinterleaveOp.getRes1(), zeroOutMask);
1275 constexpr int8_t bitsToShift = 4;
1276 auto shiftValues = rewriter.
create<arith::ConstantOp>(
1278 Value shlHigh = rewriter.
create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1282 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1285 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1286 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1300 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1305 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1306 VectorType targetVectorType = bitCastOp.getResultVectorType();
1307 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1308 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1312 Value truncValue = truncOp.getIn();
1313 auto shuffledElementType =
1315 Value runningResult;
1316 for (
const BitCastRewriter ::Metadata &metadata :
1317 bcr.precomputeMetadata(shuffledElementType)) {
1318 runningResult = bcr.genericRewriteStep(
1319 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1323 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1324 shuffledElementType.getIntOrFloatBitWidth();
1326 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1327 rewriter.
replaceOp(bitCastOp, runningResult);
1330 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1333 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1334 rewriter.
replaceOp(bitCastOp, runningResult);
1337 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1354 template <
typename ExtOpType>
1364 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1369 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1370 VectorType targetVectorType = bitCastOp.getResultVectorType();
1371 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1372 if (failed(bcr.commonPrecondition(
1373 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1377 Value runningResult;
1378 Value sourceValue = bitCastOp.getSource();
1379 auto shuffledElementType =
1381 for (
const BitCastRewriter::Metadata &metadata :
1382 bcr.precomputeMetadata(shuffledElementType)) {
1383 runningResult = bcr.genericRewriteStep(
1384 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1389 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1390 shuffledElementType.getIntOrFloatBitWidth();
1393 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1396 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1436 template <
typename ConversionOpType,
bool isSigned>
1443 Value srcValue = conversionOp.getIn();
1444 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1445 auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1468 conversionOp, conversionOp.getType(), subByteExt);
1495 Value srcValue = truncOp.getIn();
1496 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1497 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1498 if (!srcVecType || !dstVecType)
1512 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
1514 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
1520 rewriter.
replaceOp(truncOp, subByteTrunc);
1545 constexpr
unsigned minNativeBitwidth = 8;
1546 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1547 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1548 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1550 "not a sub-byte transpose");
1554 Location loc = transposeOp.getLoc();
1559 auto srcNativeVecType = srcSubByteVecType.cloneWith(
1561 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
1562 transposeOp.getVector());
1563 Value newTranspose = rewriter.
create<vector::TransposeOp>(
1564 loc, extOp, transposeOp.getPermutation());
1565 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1583 patterns.
add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1584 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1590 patterns.
add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1591 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.
getContext(),
1596 patterns.
add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
1597 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
1598 RewriteAlignedSubByteIntTrunc>(patterns.
getContext(),
1601 .
add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
1602 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
1608 patterns.
add<RewriteVectorTranspose>(patterns.
getContext(), benefit);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult offset, int64_t numElementsToExtract)
Extracts a 1-D subvector from a 1-D source vector, with index at offset and size numElementsToExtract...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector by overwriting the elements starting at offset.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, VectorType extractType, Value source, int64_t frontOffset, int64_t subvecSize)
Extracts 1-D subvector from a 1-D vector.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static TypedValue< VectorType > emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numEmultedElementsToLoad, Type origElemType, Type emulatedElemType)
Returns the op sequence for an emulated sub-byte data type vector load.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult destOffsetVar, size_t length)
Inserts a 1-D subvector into a 1-D dest vector at index destOffsetVar.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.