34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
75 int numSrcElemsPerDest,
76 int numFrontPadElems = 0) {
78 assert(numFrontPadElems < numSrcElemsPerDest &&
79 "numFrontPadElems must be less than numSrcElemsPerDest");
82 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
90 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
92 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
93 maskOp = extractOp.getVector().getDefiningOp();
94 extractOps.push_back(extractOp);
98 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
106 maskShape.back() = numDestElems;
108 std::optional<Operation *> newMask =
110 .Case<vector::CreateMaskOp>(
111 [&](
auto createMaskOp) -> std::optional<Operation *> {
121 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
124 rewriter, loc, s0, origIndex);
126 newMaskOperands.push_back(
128 return rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
131 .Case<vector::ConstantMaskOp>(
132 [&](
auto constantMaskOp) -> std::optional<Operation *> {
135 constantMaskOp.getMaskDimSizes());
136 int64_t &maskIndex = maskDimSizes.back();
139 return rewriter.
create<vector::ConstantMaskOp>(loc, newMaskType,
142 .Case<arith::ConstantOp>([&](
auto constantOp)
143 -> std::optional<Operation *> {
145 if (maskShape.size() != 1)
162 cast<DenseIntElementsAttr>(constantOp.getValue());
164 paddedMaskValues.append(originalMask.template value_begin<bool>(),
165 originalMask.template value_end<bool>());
166 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
170 for (
size_t i = 0; i < paddedMaskValues.size();
171 i += numSrcElemsPerDest) {
172 bool combinedValue =
false;
173 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
174 combinedValue |= paddedMaskValues[i +
j];
176 compressedMaskValues.push_back(combinedValue);
178 return rewriter.
create<arith::ConstantOp>(
185 while (!extractOps.empty()) {
186 newMask = rewriter.
create<vector::ExtractOp>(
187 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
188 extractOps.pop_back();
197 VectorType extractType,
Value source,
199 int64_t subvecSize) {
200 auto vectorType = cast<VectorType>(source.
getType());
201 assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
202 "expected 1-D source and destination types");
204 assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
205 "subvector out of bounds");
208 if (vectorType.getNumElements() == subvecSize)
215 .
create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
225 [[maybe_unused]]
auto srcType = cast<VectorType>(src.
getType());
226 [[maybe_unused]]
auto destType = cast<VectorType>(dest.
getType());
227 assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
228 "expected source and dest to be vector type");
231 return rewriter.
create<vector::InsertStridedSliceOp>(loc, dest.
getType(), src,
232 dest, offsets, strides);
242 int64_t numElementsToExtract) {
243 for (
int i = 0; i < numElementsToExtract; ++i) {
245 (i == 0) ? offset.dyn_cast<
Value>()
246 : rewriter.
create<arith::AddIOp>(
248 rewriter.
create<arith::ConstantIndexOp>(loc, i));
250 rewriter.
create<vector::ExtractOp>(loc, source, extractLoc);
251 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, i);
261 assert(length > 0 &&
"length must be greater than 0");
262 Value destOffsetVal =
264 for (
size_t i = 0; i < length; ++i) {
265 auto insertLoc = i == 0
267 : rewriter.
create<arith::AddIOp>(
269 rewriter.
create<arith::ConstantIndexOp>(loc, i));
270 auto extractOp = rewriter.
create<vector::ExtractOp>(loc, source, i);
271 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
283 int64_t numEmultedElementsToLoad,
Type origElemType,
284 Type emulatedElemType) {
287 auto newLoad = rewriter.
create<vector::LoadOp>(
288 loc,
VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
290 return rewriter.
create<vector::BitCastOp>(
305 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
309 if (op.getValueToStore().getType().getRank() != 1)
311 "only 1-D vectors are supported ATM");
313 auto loc = op.getLoc();
314 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
315 Type oldElementType = op.getValueToStore().getType().getElementType();
316 Type newElementType = convertedType.getElementType();
320 if (dstBits % srcBits != 0) {
322 op,
"only dstBits % srcBits == 0 supported");
324 int scale = dstBits / srcBits;
339 auto origElements = op.getValueToStore().getType().getNumElements();
340 if (origElements % scale != 0)
343 auto stridedMetadata =
344 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
347 std::tie(std::ignore, linearizedIndices) =
349 rewriter, loc, srcBits, dstBits,
350 stridedMetadata.getConstifiedMixedOffset(),
351 stridedMetadata.getConstifiedMixedSizes(),
352 stridedMetadata.getConstifiedMixedStrides(),
355 auto numElements = origElements / scale;
356 auto bitCast = rewriter.
create<vector::BitCastOp>(
358 op.getValueToStore());
361 op, bitCast.getResult(), adaptor.getBase(),
371 struct ConvertVectorMaskedStore final
376 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
380 if (op.getValueToStore().getType().getRank() != 1)
382 "only 1-D vectors are supported ATM");
384 auto loc = op.getLoc();
385 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
386 Type oldElementType = op.getValueToStore().getType().getElementType();
387 Type newElementType = convertedType.getElementType();
391 if (dstBits % srcBits != 0) {
393 op,
"only dstBits % srcBits == 0 supported");
396 int scale = dstBits / srcBits;
397 int origElements = op.getValueToStore().getType().getNumElements();
398 if (origElements % scale != 0)
401 auto stridedMetadata =
402 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
405 std::tie(linearizedInfo, linearizedIndicesOfr) =
407 rewriter, loc, srcBits, dstBits,
408 stridedMetadata.getConstifiedMixedOffset(),
409 stridedMetadata.getConstifiedMixedSizes(),
410 stridedMetadata.getConstifiedMixedStrides(),
412 Value linearizedIndices =
447 FailureOr<Operation *> newMask =
452 auto numElements = (origElements + scale - 1) / scale;
454 auto passThru = rewriter.
create<arith::ConstantOp>(
457 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
458 loc, newType, adaptor.getBase(), linearizedIndices,
459 newMask.value()->getResult(0), passThru);
461 auto newBitCastType =
VectorType::get(numElements * scale, oldElementType);
463 rewriter.
create<vector::BitCastOp>(loc, newBitCastType, newLoad);
464 valueToStore = rewriter.
create<arith::SelectOp>(
465 loc, op.getMask(), op.getValueToStore(), valueToStore);
467 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
470 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
484 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
488 if (op.getVectorType().getRank() != 1)
490 "only 1-D vectors are supported ATM");
492 auto loc = op.getLoc();
493 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
494 Type oldElementType = op.getType().getElementType();
495 Type newElementType = convertedType.getElementType();
499 if (dstBits % srcBits != 0) {
501 op,
"only dstBits % srcBits == 0 supported");
503 int scale = dstBits / srcBits;
534 auto origElements = op.getVectorType().getNumElements();
535 bool isUnalignedEmulation = origElements % scale != 0;
537 auto stridedMetadata =
538 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
542 std::tie(linearizedInfo, linearizedIndices) =
544 rewriter, loc, srcBits, dstBits,
545 stridedMetadata.getConstifiedMixedOffset(),
546 stridedMetadata.getConstifiedMixedSizes(),
547 stridedMetadata.getConstifiedMixedStrides(),
550 std::optional<int64_t> foldedIntraVectorOffset =
556 int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
561 numElements, oldElementType, newElementType);
563 if (!foldedIntraVectorOffset) {
564 auto resultVector = rewriter.
create<arith::ConstantOp>(
565 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
569 }
else if (isUnalignedEmulation) {
572 *foldedIntraVectorOffset, origElements);
583 struct ConvertVectorMaskedLoad final
588 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
591 if (op.getVectorType().getRank() != 1)
593 "only 1-D vectors are supported ATM");
595 auto loc = op.getLoc();
596 auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
597 Type oldElementType = op.getType().getElementType();
598 Type newElementType = convertedType.getElementType();
602 if (dstBits % srcBits != 0) {
604 op,
"only dstBits % srcBits == 0 supported");
606 int scale = dstBits / srcBits;
650 auto origType = op.getVectorType();
651 auto origElements = origType.getNumElements();
652 bool isUnalignedEmulation = origElements % scale != 0;
654 auto stridedMetadata =
655 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
658 std::tie(linearizedInfo, linearizedIndices) =
660 rewriter, loc, srcBits, dstBits,
661 stridedMetadata.getConstifiedMixedOffset(),
662 stridedMetadata.getConstifiedMixedSizes(),
663 stridedMetadata.getConstifiedMixedStrides(),
666 std::optional<int64_t> foldedIntraVectorOffset =
671 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
673 rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
677 Value passthru = op.getPassThru();
682 auto newBitcastType =
VectorType::get(numElements * scale, oldElementType);
684 auto emptyVector = rewriter.
create<arith::ConstantOp>(
685 loc, newBitcastType, rewriter.
getZeroAttr(newBitcastType));
686 if (!foldedIntraVectorOffset) {
690 }
else if (isUnalignedEmulation) {
692 *foldedIntraVectorOffset);
695 rewriter.
create<vector::BitCastOp>(loc, loadType, passthru);
698 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
699 loc, loadType, adaptor.getBase(),
701 newMask.value()->getResult(0), newPassThru);
706 rewriter.
create<vector::BitCastOp>(loc, newBitcastType, newLoad);
708 Value mask = op.getMask();
709 auto newSelectMaskType =
712 auto emptyMask = rewriter.
create<arith::ConstantOp>(
713 loc, newSelectMaskType, rewriter.
getZeroAttr(newSelectMaskType));
714 if (!foldedIntraVectorOffset) {
718 }
else if (isUnalignedEmulation) {
720 *foldedIntraVectorOffset);
724 rewriter.
create<arith::SelectOp>(loc, mask, bitCast, passthru);
725 if (!foldedIntraVectorOffset) {
729 }
else if (isUnalignedEmulation) {
732 *foldedIntraVectorOffset, origElements);
744 struct ConvertVectorTransferRead final
749 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
753 if (op.getVectorType().getRank() != 1)
755 "only 1-D vectors are supported ATM");
757 auto loc = op.getLoc();
758 auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
759 Type oldElementType = op.getType().getElementType();
760 Type newElementType = convertedType.getElementType();
764 if (dstBits % srcBits != 0) {
766 op,
"only dstBits % srcBits == 0 supported");
768 int scale = dstBits / srcBits;
770 auto origElements = op.getVectorType().getNumElements();
772 bool isUnalignedEmulation = origElements % scale != 0;
774 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, newElementType,
775 adaptor.getPadding());
777 auto stridedMetadata =
778 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
782 std::tie(linearizedInfo, linearizedIndices) =
784 rewriter, loc, srcBits, dstBits,
785 stridedMetadata.getConstifiedMixedOffset(),
786 stridedMetadata.getConstifiedMixedSizes(),
787 stridedMetadata.getConstifiedMixedStrides(),
790 std::optional<int64_t> foldedIntraVectorOffset =
795 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
799 auto newRead = rewriter.
create<vector::TransferReadOp>(
800 loc,
VectorType::get(numElements, newElementType), adaptor.getSource(),
804 auto bitCast = rewriter.
create<vector::BitCastOp>(
807 Value result = bitCast->getResult(0);
808 if (!foldedIntraVectorOffset) {
809 auto zeros = rewriter.
create<arith::ConstantOp>(
810 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
814 }
else if (isUnalignedEmulation) {
817 *foldedIntraVectorOffset, origElements);
834 struct SourceElementRange {
836 int64_t sourceElementIdx;
838 int64_t sourceBitBegin;
839 int64_t sourceBitEnd;
842 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
848 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
850 for (int64_t i = 0; i < shuffleIdx; ++i)
851 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
870 struct BitCastBitsEnumerator {
871 BitCastBitsEnumerator(VectorType sourceVectorType,
872 VectorType targetVectorType);
874 int64_t getMaxNumberOfEntries() {
875 int64_t numVectors = 0;
876 for (
const auto &l : sourceElementRanges)
877 numVectors =
std::max(numVectors, (int64_t)l.size());
881 VectorType sourceVectorType;
882 VectorType targetVectorType;
957 struct BitCastRewriter {
964 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
968 VectorType preconditionType,
Operation *op);
972 precomputeMetadata(IntegerType shuffledElementType);
978 const BitCastRewriter::Metadata &metadata);
983 BitCastBitsEnumerator enumerator;
988 [[maybe_unused]]
static raw_ostream &
990 for (
const auto &l : vec) {
992 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
993 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
994 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1001 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1002 VectorType targetVectorType)
1003 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1005 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1006 "requires -D non-scalable vector type");
1007 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1008 "requires -D non-scalable vector type");
1009 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1010 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1011 LDBG(
"sourceVectorType: " << sourceVectorType);
1013 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1014 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1015 LDBG(
"targetVectorType: " << targetVectorType);
1017 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1018 (void)mostMinorSourceDim;
1019 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1020 "source and target bitwidths must match");
1024 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1025 int64_t resultElement = resultBit / targetBitWidth;
1026 int64_t resultBitInElement = resultBit % targetBitWidth;
1027 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1028 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1029 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
1030 targetBitWidth - resultBitInElement);
1031 sourceElementRanges[resultElement].push_back(
1032 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1037 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1038 VectorType targetVectorType)
1039 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1040 LDBG(
"\n" << enumerator.sourceElementRanges);
1046 VectorType preconditionType,
1048 if (!preconditionType || preconditionType.isScalable())
1053 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1054 if (bitwidth % 8 != 0)
1060 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1061 VectorType preconditionType,
1063 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1066 if (!preconditionType || preconditionType.getRank() != 1)
1082 if (!srcType || !dstType)
1084 unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
1085 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1088 if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1089 (dstElemBitwidth % srcElemBitwidth) != 0)
1092 if ((srcType.getShape().back() % 2) != 0)
1094 op,
"Not an even number of i4 elements in trailing dim");
1100 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1102 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1103 shuffleIdx < e; ++shuffleIdx) {
1108 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1109 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1110 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1112 shuffles.push_back(sourceElement);
1114 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1115 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1117 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1118 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1121 shuffledElementType,
1122 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1124 masks.push_back(mask);
1126 int64_t shiftRight = bitLo;
1127 shiftRightAmounts.push_back(
1130 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1131 shiftLeftAmounts.push_back(
1135 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1140 Value BitCastRewriter::genericRewriteStep(
1142 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1144 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
1145 loc, initialValue, initialValue, metadata.shuffles);
1148 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1149 auto constOp = rewriter.
create<arith::ConstantOp>(
1151 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
1154 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
1157 Value shiftedRight =
1158 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1161 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
1165 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1169 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1172 return runningResult;
1180 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1181 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1182 "Expected i4 type");
1186 constexpr int64_t i4Toi8BitwidthFactor = 2;
1187 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1189 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
1193 constexpr int8_t bitsToShift = 4;
1194 auto shiftValues = rewriter.
create<arith::ConstantOp>(
1196 Value shl = rewriter.
create<arith::ShLIOp>(loc, i8Vector, shiftValues);
1197 Value low = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftValues);
1198 Value high = rewriter.
create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
1201 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
1209 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1210 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1211 "Expected i4 type");
1215 constexpr int64_t i4Toi8BitwidthFactor = 2;
1216 i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1218 Value i8Vector = rewriter.
create<vector::BitCastOp>(loc, i8VecType, srcValue);
1222 constexpr uint8_t lowBitsMask = 15;
1223 auto lowBitsMaskValues = rewriter.
create<arith::ConstantOp>(
1225 Value low = rewriter.
create<arith::AndIOp>(loc, i8VecType, i8Vector,
1227 constexpr int8_t highBitsToShift = 4;
1228 auto highShiftValues = rewriter.
create<arith::ConstantOp>(
1230 Value high = rewriter.
create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1233 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
1241 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1242 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1243 "Expected i8 type");
1246 auto deinterleaveOp = rewriter.
create<vector::DeinterleaveOp>(loc, srcValue);
1249 constexpr int8_t i8LowBitMask = 0x0F;
1250 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1251 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
1253 Value zeroOutLow = rewriter.
create<arith::AndIOp>(
1254 loc, deinterleaveOp.getRes1(), zeroOutMask);
1257 constexpr int8_t bitsToShift = 4;
1258 auto shiftValues = rewriter.
create<arith::ConstantOp>(
1260 Value shlHigh = rewriter.
create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1264 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1267 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1268 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1282 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1287 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1288 VectorType targetVectorType = bitCastOp.getResultVectorType();
1289 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1290 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1294 Value truncValue = truncOp.getIn();
1295 auto shuffledElementType =
1297 Value runningResult;
1298 for (
const BitCastRewriter ::Metadata &metadata :
1299 bcr.precomputeMetadata(shuffledElementType)) {
1300 runningResult = bcr.genericRewriteStep(
1301 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1305 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1306 shuffledElementType.getIntOrFloatBitWidth();
1308 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1309 rewriter.
replaceOp(bitCastOp, runningResult);
1312 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1315 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1316 rewriter.
replaceOp(bitCastOp, runningResult);
1319 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1336 template <
typename ExtOpType>
1346 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1351 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1352 VectorType targetVectorType = bitCastOp.getResultVectorType();
1353 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1354 if (failed(bcr.commonPrecondition(
1355 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1359 Value runningResult;
1360 Value sourceValue = bitCastOp.getSource();
1361 auto shuffledElementType =
1363 for (
const BitCastRewriter::Metadata &metadata :
1364 bcr.precomputeMetadata(shuffledElementType)) {
1365 runningResult = bcr.genericRewriteStep(
1366 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1371 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1372 shuffledElementType.getIntOrFloatBitWidth();
1375 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1378 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1418 template <
typename ConversionOpType,
bool isSigned>
1425 Value srcValue = conversionOp.getIn();
1426 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1427 auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1450 conversionOp, conversionOp.getType(), subByteExt);
1477 Value srcValue = truncOp.getIn();
1478 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1479 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1480 if (!srcVecType || !dstVecType)
1494 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
1496 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
1502 rewriter.
replaceOp(truncOp, subByteTrunc);
1527 constexpr
unsigned minNativeBitwidth = 8;
1528 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1529 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1530 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1532 "not a sub-byte transpose");
1536 Location loc = transposeOp.getLoc();
1541 auto srcNativeVecType = srcSubByteVecType.cloneWith(
1543 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
1544 transposeOp.getVector());
1545 Value newTranspose = rewriter.
create<vector::TransposeOp>(
1546 loc, extOp, transposeOp.getPermutation());
1547 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1565 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1566 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1567 typeConverter,
patterns.getContext());
1572 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1573 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
1578 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
1579 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
1580 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
1583 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
1584 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult offset, int64_t numElementsToExtract)
Extracts a 1-D subvector from a 1-D source vector, with index at offset and size numElementsToExtract...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector by overwriting the elements starting at offset.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, VectorType extractType, Value source, int64_t frontOffset, int64_t subvecSize)
Extracts 1-D subvector from a 1-D vector.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static TypedValue< VectorType > emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numEmultedElementsToLoad, Type origElemType, Type emulatedElemType)
Returns the op sequence for an emulated sub-byte data type vector load.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult destOffsetVar, size_t length)
Inserts a 1-D subvector into a 1-D dest vector at index destOffsetVar.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.