34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
82 int numSrcElemsPerDest,
83 int numFrontPadElems = 0) {
85 assert(numFrontPadElems < numSrcElemsPerDest &&
86 "numFrontPadElems must be less than numSrcElemsPerDest");
89 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
97 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99 if (
auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
100 maskOp = extractOp.getVector().getDefiningOp();
101 extractOps.push_back(extractOp);
105 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
113 maskShape.back() = numDestElems;
115 std::optional<Operation *> newMask =
117 .Case<vector::CreateMaskOp>(
118 [&](
auto createMaskOp) -> std::optional<Operation *> {
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
131 rewriter, loc, s0, origIndex);
133 newMaskOperands.push_back(
135 return rewriter.
create<vector::CreateMaskOp>(loc, newMaskType,
138 .Case<vector::ConstantMaskOp>(
139 [&](
auto constantMaskOp) -> std::optional<Operation *> {
142 constantMaskOp.getMaskDimSizes());
143 int64_t &maskIndex = maskDimSizes.back();
146 return rewriter.
create<vector::ConstantMaskOp>(loc, newMaskType,
149 .Case<arith::ConstantOp>([&](
auto constantOp)
150 -> std::optional<Operation *> {
152 if (maskShape.size() != 1)
169 cast<DenseIntElementsAttr>(constantOp.getValue());
171 paddedMaskValues.append(originalMask.template value_begin<bool>(),
172 originalMask.template value_end<bool>());
173 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest,
false);
177 for (
size_t i = 0; i < paddedMaskValues.size();
178 i += numSrcElemsPerDest) {
179 bool combinedValue =
false;
180 for (
int j = 0;
j < numSrcElemsPerDest; ++
j) {
181 combinedValue |= paddedMaskValues[i +
j];
183 compressedMaskValues.push_back(combinedValue);
185 return rewriter.
create<arith::ConstantOp>(
192 while (!extractOps.empty()) {
193 newMask = rewriter.
create<vector::ExtractOp>(
194 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
195 extractOps.pop_back();
204 Value source, int64_t frontOffset,
205 int64_t subvecSize) {
206 auto vectorType = cast<VectorType>(source.
getType());
207 assert(vectorType.getRank() == 1 &&
"expected 1-D source types");
208 assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
209 "subvector out of bounds");
212 if (vectorType.getNumElements() == subvecSize)
219 auto resultVectorType =
222 .
create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
223 offsets, sizes, strides)
232 [[maybe_unused]]
auto srcType = cast<VectorType>(src.
getType());
233 [[maybe_unused]]
auto destType = cast<VectorType>(dest.
getType());
234 assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
235 "expected source and dest to be vector type");
238 return rewriter.
create<vector::InsertStridedSliceOp>(loc, dest.
getType(), src,
239 dest, offsets, strides);
249 int64_t numElementsToExtract) {
250 assert(isa<VectorValue>(source) &&
"expected `source` to be a vector type");
251 for (
int i = 0; i < numElementsToExtract; ++i) {
253 (i == 0) ? offset.dyn_cast<
Value>()
254 : rewriter.
create<arith::AddIOp>(
256 rewriter.
create<arith::ConstantIndexOp>(loc, i));
258 rewriter.
create<vector::ExtractOp>(loc, source, extractLoc);
259 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, i);
269 assert(isa<VectorValue>(source) &&
"expected `source` to be a vector type");
270 assert(length > 0 &&
"length must be greater than 0");
271 Value destOffsetVal =
273 for (
size_t i = 0; i < length; ++i) {
274 auto insertLoc = i == 0
276 : rewriter.
create<arith::AddIOp>(
278 rewriter.
create<arith::ConstantIndexOp>(loc, i));
279 auto extractOp = rewriter.
create<vector::ExtractOp>(loc, source, i);
280 dest = rewriter.
create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
294 int64_t numContainerElemsToLoad,
296 Type containerElemTy) {
299 auto newLoad = rewriter.
create<vector::LoadOp>(
302 return rewriter.
create<vector::BitCastOp>(
312 VectorType downcastType,
313 VectorType upcastType,
Value mask,
316 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
317 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
318 "expected input and output number of bits to match");
319 if (trueValue.
getType() != downcastType) {
320 trueValue = builder.
create<vector::BitCastOp>(loc, downcastType, trueValue);
322 if (falseValue.
getType() != downcastType) {
324 builder.
create<vector::BitCastOp>(loc, downcastType, falseValue);
327 builder.
create<arith::SelectOp>(loc, mask, trueValue, falseValue);
329 return builder.
create<vector::BitCastOp>(loc, upcastType, selectedType);
348 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
352 auto atomicOp = builder.
create<memref::GenericAtomicRMWOp>(
354 Value origValue = atomicOp.getCurrentValue();
362 Value origVecValue = builder.
create<vector::FromElementsOp>(
368 oneElemVecType, mask, valueToStore, origVecValue);
369 auto scalarMaskedValue =
370 builder.
create<vector::ExtractOp>(loc, maskedValue, 0);
371 builder.
create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
379 assert(valueToStore.getType().getRank() == 1 &&
"expected 1-D vector");
381 auto oneElemVecType =
383 Value origVecValue = builder.
create<vector::LoadOp>(
384 loc, oneElemVecType, linearizedMemref,
ValueRange{linearizedIndex});
385 origVecValue = builder.
create<vector::BitCastOp>(loc, valueToStore.getType(),
390 oneElemVecType, mask, valueToStore, origVecValue);
391 builder.
create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
406 int64_t extractOffset,
407 int64_t sliceNumElements,
408 int64_t insertOffset) {
409 assert(vector.getType().getRank() == 1 &&
"expected 1-D vector");
410 auto vectorElementType = vector.getType().getElementType();
414 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
415 "sliceNumElements * vector element size must be less than or equal to 8");
416 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
417 "vector element must be a valid sub-byte type");
418 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
419 auto emptyByteVector = rewriter.
create<arith::ConstantOp>(
424 extractOffset, sliceNumElements);
439 ConvertVectorStore(
MLIRContext *context,
bool disableAtomicRMW)
441 disableAtomicRMW(disableAtomicRMW) {}
444 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
448 if (op.getValueToStore().getType().getRank() != 1)
450 "only 1-D vectors are supported ATM");
452 auto loc = op.getLoc();
454 auto valueToStore = cast<VectorValue>(op.getValueToStore());
455 auto containerElemTy =
456 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
457 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
459 int containerBits = containerElemTy.getIntOrFloatBitWidth();
462 if (containerBits % emulatedBits != 0) {
464 op,
"impossible to pack emulated elements into container elements "
465 "(bit-wise misalignment)");
467 int numSrcElemsPerDest = containerBits / emulatedBits;
482 auto origElements = valueToStore.getType().getNumElements();
483 bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
485 auto stridedMetadata =
486 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
490 std::tie(linearizedInfo, linearizedIndices) =
492 rewriter, loc, emulatedBits, containerBits,
493 stridedMetadata.getConstifiedMixedOffset(),
494 stridedMetadata.getConstifiedMixedSizes(),
495 stridedMetadata.getConstifiedMixedStrides(),
498 std::optional<int64_t> foldedNumFrontPadElems =
503 if (!foldedNumFrontPadElems) {
505 op,
"subbyte store emulation: dynamic front padding size is "
506 "not yet implemented");
509 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
518 bool emulationRequiresPartialStores =
519 !isAlignedEmulation || *foldedNumFrontPadElems != 0;
520 if (!emulationRequiresPartialStores) {
522 auto numElements = origElements / numSrcElemsPerDest;
523 auto bitCast = rewriter.
create<vector::BitCastOp>(
525 op.getValueToStore());
527 op, bitCast.getResult(), memrefBase,
563 Value currentDestIndex =
566 auto currentSourceIndex = 0;
569 auto subWidthStoreMaskType =
578 auto frontSubWidthStoreElem =
579 (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
580 if (frontSubWidthStoreElem > 0) {
582 if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
583 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
585 frontSubWidthStoreElem = origElements;
587 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
588 *foldedNumFrontPadElems,
true);
590 auto frontMask = rewriter.
create<arith::ConstantOp>(
593 currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
596 frontSubWidthStoreElem, *foldedNumFrontPadElems);
598 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
599 cast<VectorValue>(value), frontMask.getResult());
602 if (currentSourceIndex >= origElements) {
610 currentDestIndex = rewriter.
create<arith::AddIOp>(
616 int64_t fullWidthStoreSize =
617 (origElements - currentSourceIndex) / numSrcElemsPerDest;
618 int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
619 if (fullWidthStoreSize > 0) {
621 rewriter, loc, valueToStore, currentSourceIndex,
622 numNonFullWidthElements);
624 auto originType = cast<VectorType>(fullWidthStorePart.getType());
627 {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
628 auto bitCast = rewriter.
create<vector::BitCastOp>(loc, storeType,
630 rewriter.
create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
633 currentSourceIndex += numNonFullWidthElements;
634 currentDestIndex = rewriter.
create<arith::AddIOp>(
636 rewriter.
create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
642 auto remainingElements = origElements - currentSourceIndex;
643 if (remainingElements != 0) {
644 auto subWidthStorePart =
646 currentSourceIndex, remainingElements, 0);
650 std::fill_n(maskValues.begin(), remainingElements, 1);
651 auto backMask = rewriter.
create<arith::ConstantOp>(
654 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
655 cast<VectorValue>(subWidthStorePart), backMask.getResult());
663 const bool disableAtomicRMW;
671 struct ConvertVectorMaskedStore final
676 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
680 if (op.getValueToStore().getType().getRank() != 1)
682 "only 1-D vectors are supported ATM");
684 auto loc = op.getLoc();
685 auto containerElemTy =
686 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
687 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
689 int containerBits = containerElemTy.getIntOrFloatBitWidth();
692 if (containerBits % emulatedBits != 0) {
694 op,
"impossible to pack emulated elements into container elements "
695 "(bit-wise misalignment)");
698 int emulatedPerContainerElem = containerBits / emulatedBits;
699 int origElements = op.getValueToStore().getType().getNumElements();
700 if (origElements % emulatedPerContainerElem != 0)
703 auto stridedMetadata =
704 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
707 std::tie(linearizedInfo, linearizedIndicesOfr) =
709 rewriter, loc, emulatedBits, containerBits,
710 stridedMetadata.getConstifiedMixedOffset(),
711 stridedMetadata.getConstifiedMixedSizes(),
712 stridedMetadata.getConstifiedMixedStrides(),
714 Value linearizedIndices =
750 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
754 auto numElements = (origElements + emulatedPerContainerElem - 1) /
755 emulatedPerContainerElem;
757 auto passThru = rewriter.
create<arith::ConstantOp>(
760 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
761 loc, newType, adaptor.getBase(), linearizedIndices,
762 newMask.value()->getResult(0), passThru);
764 auto newBitCastType =
765 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
767 rewriter.
create<vector::BitCastOp>(loc, newBitCastType, newLoad);
768 valueToStore = rewriter.
create<arith::SelectOp>(
769 loc, op.getMask(), op.getValueToStore(), valueToStore);
771 rewriter.
create<vector::BitCastOp>(loc, newType, valueToStore);
774 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
789 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
793 if (op.getVectorType().getRank() != 1)
795 "only 1-D vectors are supported ATM");
797 auto loc = op.getLoc();
798 auto containerElemTy =
799 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
800 Type emulatedElemTy = op.getType().getElementType();
802 int containerBits = containerElemTy.getIntOrFloatBitWidth();
805 if (containerBits % emulatedBits != 0) {
807 op,
"impossible to pack emulated elements into container elements "
808 "(bit-wise misalignment)");
810 int emulatedPerContainerElem = containerBits / emulatedBits;
841 auto origElements = op.getVectorType().getNumElements();
843 bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
845 auto stridedMetadata =
846 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
850 std::tie(linearizedInfo, linearizedIndices) =
852 rewriter, loc, emulatedBits, containerBits,
853 stridedMetadata.getConstifiedMixedOffset(),
854 stridedMetadata.getConstifiedMixedSizes(),
855 stridedMetadata.getConstifiedMixedStrides(),
858 std::optional<int64_t> foldedIntraVectorOffset =
863 int64_t maxintraDataOffset =
864 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
866 emulatedPerContainerElem);
869 numElements, emulatedElemTy, containerElemTy);
871 if (!foldedIntraVectorOffset) {
872 auto resultVector = rewriter.
create<arith::ConstantOp>(
873 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
877 }
else if (!isFullyAligned) {
879 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
891 struct ConvertVectorMaskedLoad final
896 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
899 if (op.getVectorType().getRank() != 1)
901 "only 1-D vectors are supported ATM");
903 auto loc = op.getLoc();
905 auto containerElemTy =
906 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
907 Type emulatedElemTy = op.getType().getElementType();
909 int containerBits = containerElemTy.getIntOrFloatBitWidth();
912 if (containerBits % emulatedBits != 0) {
914 op,
"impossible to pack emulated elements into container elements "
915 "(bit-wise misalignment)");
917 int emulatedPerContainerElem = containerBits / emulatedBits;
961 auto origType = op.getVectorType();
962 auto origElements = origType.getNumElements();
963 bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
965 auto stridedMetadata =
966 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
969 std::tie(linearizedInfo, linearizedIndices) =
971 rewriter, loc, emulatedBits, containerBits,
972 stridedMetadata.getConstifiedMixedOffset(),
973 stridedMetadata.getConstifiedMixedSizes(),
974 stridedMetadata.getConstifiedMixedStrides(),
977 std::optional<int64_t> foldedIntraVectorOffset =
982 int64_t maxIntraDataOffset =
983 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
984 FailureOr<Operation *> newMask =
986 emulatedPerContainerElem, maxIntraDataOffset);
990 Value passthru = op.getPassThru();
993 emulatedPerContainerElem);
995 auto newBitcastType =
996 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
998 auto emptyVector = rewriter.
create<arith::ConstantOp>(
999 loc, newBitcastType, rewriter.
getZeroAttr(newBitcastType));
1000 if (!foldedIntraVectorOffset) {
1004 }
else if (!isAlignedEmulation) {
1006 *foldedIntraVectorOffset);
1009 rewriter.
create<vector::BitCastOp>(loc, loadType, passthru);
1012 auto newLoad = rewriter.
create<vector::MaskedLoadOp>(
1013 loc, loadType, adaptor.getBase(),
1015 newMask.value()->getResult(0), newPassThru);
1020 rewriter.
create<vector::BitCastOp>(loc, newBitcastType, newLoad);
1022 Value mask = op.getMask();
1024 numElements * emulatedPerContainerElem, rewriter.
getI1Type());
1026 auto emptyMask = rewriter.
create<arith::ConstantOp>(
1027 loc, newSelectMaskType, rewriter.
getZeroAttr(newSelectMaskType));
1028 if (!foldedIntraVectorOffset) {
1032 }
else if (!isAlignedEmulation) {
1034 *foldedIntraVectorOffset);
1038 rewriter.
create<arith::SelectOp>(loc, mask, bitCast, passthru);
1039 if (!foldedIntraVectorOffset) {
1041 rewriter, loc, result, op.getPassThru(),
1043 }
else if (!isAlignedEmulation) {
1045 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1058 struct ConvertVectorTransferRead final
1063 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1067 if (op.getVectorType().getRank() != 1)
1069 "only 1-D vectors are supported ATM");
1071 auto loc = op.getLoc();
1072 auto containerElemTy =
1073 cast<MemRefType>(adaptor.getSource().getType()).getElementType();
1074 Type emulatedElemTy = op.getType().getElementType();
1076 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1079 if (containerBits % emulatedBits != 0) {
1081 op,
"impossible to pack emulated elements into container elements "
1082 "(bit-wise misalignment)");
1084 int emulatedPerContainerElem = containerBits / emulatedBits;
1086 auto origElements = op.getVectorType().getNumElements();
1089 bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
1091 auto newPadding = rewriter.
create<arith::ExtUIOp>(loc, containerElemTy,
1092 adaptor.getPadding());
1094 auto stridedMetadata =
1095 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
1099 std::tie(linearizedInfo, linearizedIndices) =
1101 rewriter, loc, emulatedBits, containerBits,
1102 stridedMetadata.getConstifiedMixedOffset(),
1103 stridedMetadata.getConstifiedMixedSizes(),
1104 stridedMetadata.getConstifiedMixedStrides(),
1107 std::optional<int64_t> foldedIntraVectorOffset =
1111 int64_t maxIntraDataOffset =
1112 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1114 emulatedPerContainerElem);
1116 auto newRead = rewriter.
create<vector::TransferReadOp>(
1117 loc,
VectorType::get(numElements, containerElemTy), adaptor.getSource(),
1121 auto bitCast = rewriter.
create<vector::BitCastOp>(
1123 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1126 Value result = bitCast->getResult(0);
1127 if (!foldedIntraVectorOffset) {
1128 auto zeros = rewriter.
create<arith::ConstantOp>(
1129 loc, op.getType(), rewriter.
getZeroAttr(op.getType()));
1133 }
else if (!isFullyAligned) {
1135 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1152 struct SourceElementRange {
1154 int64_t sourceElementIdx;
1156 int64_t sourceBitBegin;
1157 int64_t sourceBitEnd;
1160 struct SourceElementRangeList :
public SmallVector<SourceElementRange> {
1166 int64_t computeLeftShiftAmount(int64_t shuffleIdx)
const {
1168 for (int64_t i = 0; i < shuffleIdx; ++i)
1169 res += (*
this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1188 struct BitCastBitsEnumerator {
1189 BitCastBitsEnumerator(VectorType sourceVectorType,
1190 VectorType targetVectorType);
1192 int64_t getMaxNumberOfEntries() {
1193 int64_t numVectors = 0;
1194 for (
const auto &l : sourceElementRanges)
1195 numVectors =
std::max(numVectors, (int64_t)l.size());
1199 VectorType sourceVectorType;
1200 VectorType targetVectorType;
1275 struct BitCastRewriter {
1282 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1286 VectorType preconditionType,
Operation *op);
1290 precomputeMetadata(IntegerType shuffledElementType);
1296 const BitCastRewriter::Metadata &metadata);
1301 BitCastBitsEnumerator enumerator;
1306 [[maybe_unused]]
static raw_ostream &
1308 for (
const auto &l : vec) {
1310 os <<
"{ " << it.value().sourceElementIdx <<
": b@["
1311 << it.value().sourceBitBegin <<
".." << it.value().sourceBitEnd
1312 <<
") lshl: " << l.computeLeftShiftAmount(it.index()) <<
" } ";
1319 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1320 VectorType targetVectorType)
1321 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1323 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1324 "requires -D non-scalable vector type");
1325 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1326 "requires -D non-scalable vector type");
1327 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1328 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1329 LDBG(
"sourceVectorType: " << sourceVectorType);
1331 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1332 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1333 LDBG(
"targetVectorType: " << targetVectorType);
1335 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1336 (void)mostMinorSourceDim;
1337 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1338 "source and target bitwidths must match");
1342 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1343 int64_t resultElement = resultBit / targetBitWidth;
1344 int64_t resultBitInElement = resultBit % targetBitWidth;
1345 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1346 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1347 int64_t step =
std::min(sourceBitWidth - sourceBitInElement,
1348 targetBitWidth - resultBitInElement);
1349 sourceElementRanges[resultElement].push_back(
1350 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1355 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1356 VectorType targetVectorType)
1357 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1358 LDBG(
"\n" << enumerator.sourceElementRanges);
1364 VectorType preconditionType,
1366 if (!preconditionType || preconditionType.isScalable())
1371 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1372 if (bitwidth % 8 != 0)
1378 LogicalResult BitCastRewriter::commonPrecondition(
PatternRewriter &rewriter,
1379 VectorType preconditionType,
1381 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1384 if (!preconditionType || preconditionType.getRank() != 1)
1403 VectorType subByteVecType,
1406 if (!subByteVecType || !dstType)
1408 unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1409 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1411 if (dstElemBitwidth < 8)
1413 op,
"the bitwidth of dstType must be greater than or equal to 8");
1414 if (dstElemBitwidth % srcElemBitwidth != 0)
1416 if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1418 op,
"only src bitwidth of 2 or 4 is supported at this moment");
1420 const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1421 if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1423 op,
"the trailing dimension of the input vector of sub-bytes must be a "
1424 "multiple of 8 / <sub-byte-width>");
1430 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1432 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1433 shuffleIdx < e; ++shuffleIdx) {
1438 for (
auto &srcEltRangeList : enumerator.sourceElementRanges) {
1439 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1440 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1442 shuffles.push_back(sourceElement);
1444 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1445 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1447 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1448 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1451 shuffledElementType,
1452 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1454 masks.push_back(mask);
1456 int64_t shiftRight = bitLo;
1457 shiftRightAmounts.push_back(
1460 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1461 shiftLeftAmounts.push_back(
1465 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1470 Value BitCastRewriter::genericRewriteStep(
1472 Value runningResult,
const BitCastRewriter::Metadata &metadata) {
1474 auto shuffleOp = rewriter.
create<vector::ShuffleOp>(
1475 loc, initialValue, initialValue, metadata.shuffles);
1478 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1479 auto constOp = rewriter.
create<arith::ConstantOp>(
1481 Value andValue = rewriter.
create<arith::AndIOp>(loc, shuffleOp, constOp);
1484 auto shiftRightConstantOp = rewriter.
create<arith::ConstantOp>(
1487 Value shiftedRight =
1488 rewriter.
create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1491 auto shiftLeftConstantOp = rewriter.
create<arith::ConstantOp>(
1495 rewriter.
create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1499 ? rewriter.
create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1502 return runningResult;
1513 auto srcVecType = cast<VectorType>(subByteVec.
getType());
1514 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1515 assert(8 % srcBitwidth == 0 &&
1516 "Unsupported sub-byte type (not a divisor of i8)");
1517 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1520 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1522 return rewriter.
create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1543 int bitIdx,
int numBits) {
1544 auto srcType = cast<VectorType>(src.
getType());
1546 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1547 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1548 "Invalid bitIdx range");
1549 if (bitsToShiftLeft != 0) {
1550 Value shiftLeftValues = rewriter.
create<arith::ConstantOp>(
1552 shl = rewriter.
create<arith::ShLIOp>(loc, src, shiftLeftValues);
1555 int8_t bitsToShiftRight = 8 - numBits;
1556 Value shiftRightValues = rewriter.
create<arith::ConstantOp>(
1558 Value shr = rewriter.
create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1585 int bitIdx,
int numBits) {
1586 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1587 "Invalid bitIdx range");
1588 auto srcType = cast<VectorType>(src.
getType());
1589 int8_t bitsToShiftRight = bitIdx;
1591 if (bitsToShiftRight != 0) {
1592 Value shiftRightValues = rewriter.
create<arith::ConstantOp>(
1594 shr = rewriter.
create<arith::ShRUIOp>(loc, src, shiftRightValues);
1596 if (bitIdx + numBits == 8) {
1599 uint8_t lowBitsMask = (1 << numBits) - 1;
1600 Value lowBitsMaskValues = rewriter.
create<arith::ConstantOp>(
1602 return rewriter.
create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1612 [[maybe_unused]]
auto srcVecType = cast<VectorType>(srcValue.
getType());
1613 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1614 "Expected i4 type");
1621 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1622 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1625 return rewriter.
create<vector::InterleaveOp>(loc, low, high);
1632 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1633 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1634 "Expected i2 type");
1641 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1643 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1645 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1647 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1658 Value interleave02 = rewriter.
create<vector::InterleaveOp>(loc, vec0, vec2);
1659 Value interleave13 = rewriter.
create<vector::InterleaveOp>(loc, vec1, vec3);
1660 return rewriter.
create<vector::InterleaveOp>(loc, interleave02, interleave13);
1667 VectorType srcVecType = cast<VectorType>(srcValue.
getType());
1668 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1669 "Expected i8 type");
1672 auto deinterleaveOp = rewriter.
create<vector::DeinterleaveOp>(loc, srcValue);
1675 constexpr int8_t i8LowBitMask = 0x0F;
1676 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1677 Value zeroOutMask = rewriter.
create<arith::ConstantOp>(
1679 Value zeroOutLow = rewriter.
create<arith::AndIOp>(
1680 loc, deinterleaveOp.getRes1(), zeroOutMask);
1683 constexpr int8_t bitsToShift = 4;
1684 auto shiftValues = rewriter.
create<arith::ConstantOp>(
1686 Value shlHigh = rewriter.
create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1690 auto mergedHiLowOp = rewriter.
create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1693 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI4Type());
1694 return rewriter.
create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1708 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1713 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1714 VectorType targetVectorType = bitCastOp.getResultVectorType();
1715 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1716 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1720 Value truncValue = truncOp.getIn();
1721 auto shuffledElementType =
1723 Value runningResult;
1724 for (
const BitCastRewriter ::Metadata &metadata :
1725 bcr.precomputeMetadata(shuffledElementType)) {
1726 runningResult = bcr.genericRewriteStep(
1727 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1731 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1732 shuffledElementType.getIntOrFloatBitWidth();
1734 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1735 rewriter.
replaceOp(bitCastOp, runningResult);
1738 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1741 if (runningResult.
getType() == bitCastOp.getResultVectorType()) {
1742 rewriter.
replaceOp(bitCastOp, runningResult);
1745 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1762 template <
typename ExtOpType>
1772 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1777 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1778 VectorType targetVectorType = bitCastOp.getResultVectorType();
1779 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1780 if (failed(bcr.commonPrecondition(
1781 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1785 Value runningResult;
1786 Value sourceValue = bitCastOp.getSource();
1787 auto shuffledElementType =
1789 for (
const BitCastRewriter::Metadata &metadata :
1790 bcr.precomputeMetadata(shuffledElementType)) {
1791 runningResult = bcr.genericRewriteStep(
1792 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1797 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1798 shuffledElementType.getIntOrFloatBitWidth();
1801 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1804 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1845 template <
typename ConversionOpType,
bool isSigned>
1852 Value srcValue = conversionOp.getIn();
1853 VectorType srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1854 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1866 Location loc = conversionOp.getLoc();
1870 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1883 conversionOp, conversionOp.getType(), subByteExt);
1911 Value srcValue = truncOp.getIn();
1912 auto srcVecType = dyn_cast<VectorType>(srcValue.
getType());
1913 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1914 if (!srcVecType || !dstVecType)
1921 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
1932 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.
getI8Type());
1934 rewriter.
create<arith::TruncIOp>(loc, i8VecType, srcValue);
1940 rewriter.
replaceOp(truncOp, subByteTrunc);
1966 constexpr
unsigned minNativeBitwidth = 8;
1967 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1968 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1969 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1971 "not a sub-byte transpose");
1975 Location loc = transposeOp.getLoc();
1980 auto srcNativeVecType = srcSubByteVecType.cloneWith(
1982 Value extOp = rewriter.
create<arith::ExtSIOp>(loc, srcNativeVecType,
1983 transposeOp.getVector());
1984 Value newTranspose = rewriter.
create<vector::TransposeOp>(
1985 loc, extOp, transposeOp.getPermutation());
1986 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2006 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2007 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2008 typeConverter,
patterns.getContext());
2013 patterns.insert<ConvertVectorStore>(
patterns.getContext(), disableAtomicRMW);
2019 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2020 RewriteExtOfBitCast<arith::ExtSIOp>>(
patterns.getContext(),
2026 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp,
true>,
2027 RewriteAlignedSubByteIntExt<arith::SIToFPOp,
true>,
2028 RewriteAlignedSubByteIntTrunc>(
patterns.getContext(),
2032 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp,
false>,
2033 RewriteAlignedSubByteIntExt<arith::UIToFPOp,
false>>(
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value source, int64_t frontOffset, int64_t subvecSize)
Extracts 1-D subvector from a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector by overwriting the elements starting at offset.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecType, VectorType dstType, Operation *op)
Verify that subByteVecType and dstType are aligned.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value source, Value dest, OpFoldResult offset, int64_t numElementsToExtract)
Extracts a 1-D subvector from a 1-D source vector, with index at offset and size numElementsToExtract...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value source, Value dest, OpFoldResult destOffsetVar, size_t length)
Inserts a 1-D subvector into a 1-D dest vector at index destOffsetVar.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.