19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
24#define GEN_PASS_DEF_XEGPUUNROLL
25#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29#define DEBUG_TYPE "xegpu-unroll"
35template <
typename SourceOp>
45 LDBG() <<
"Get unroll shape for: " << *op;
47 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
48 LDBG() <<
"--no filter constraint -> BAIL";
53 "expects the native shape for native shape call back function.");
54 auto nativeShape =
options.nativeShape(op);
60 bool returnSingleType =
false)
const {
61 return options.getUnrolledTypes(type, tileShape, returnSingleType);
68 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
69 auto shape = vecTy.getShape();
73 if (isa<xegpu::TensorDescType>(destTy)) {
78 auto castOp = UnrealizedConversionCastOp::create(
79 rewriter, loc, destTy, srcs,
81 return castOp.getResult(0);
84 llvm_unreachable(
"Unexpected destTy.");
93 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
98 if (isa<xegpu::TensorDescType>(src.
getType())) {
103 auto castOp = UnrealizedConversionCastOp::create(
104 rewriter, loc, destTypes, src,
106 return castOp.getResults();
109 llvm_unreachable(
"Unexpected src type.");
114 const char *
const packAttrName =
"__xegpu_blocking_pack__";
115 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
116 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
132 int64_t rank = tdescTy.getRank();
140 auto aV = llvm::cast<Value>(a);
142 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
147 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
149 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
155 for (
auto [idx, oldOff, offset] :
156 llvm::zip(validIdxes, oldOffsets, offsets))
157 mixedOffsets[idx] = addi(oldOff, offset);
159 auto newOp = createOp(mixedOffsets);
160 newOps.push_back(newOp);
165struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
166 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
167 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
170 xegpu::TensorDescType tdescTy = op.getType();
172 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
178 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
179 bool hasOffsets = op.getMixedOffsets().size() != 0;
181 auto newOp = xegpu::CreateNdDescOp::create(
182 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
183 op.getMixedStrides());
184 newOps.push_back(newOp);
187 return xegpu::CreateNdDescOp::create(
188 rewriter, loc, newTdescTy, op.getSource(), offsets,
189 op.getMixedSizes(), op.getMixedStrides());
192 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
193 *targetShape, createOp, loc, rewriter);
195 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
202struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
203 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
204 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
207 xegpu::TensorDescType tdescTy = op.getTensorDescType();
209 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
214 getUnrolledTypes(tdescTy, *targetShape);
216 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
219 for (
auto t : convertedTdesc) {
220 auto newOp = xegpu::UpdateNdOffsetOp::create(
221 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
222 newOps.push_back(newOp);
224 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
230struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
231 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
232 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
235 xegpu::TensorDescType tdescTy = op.getTensorDescType();
237 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
241 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
242 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
245 tdescTy, *targetShape, hasOffsets);
248 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
251 for (
auto t : convertedTdesc)
252 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
256 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
257 op.getL1HintAttr(), op.getL2HintAttr(),
263 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
264 createPrefetch, loc, rewriter);
272struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
273 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
274 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
278 VectorType valueTy = op.getType();
279 xegpu::TensorDescType tdescTy = op.getTensorDescType();
281 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
285 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
286 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
288 Type elemTy = tdescTy.getElementType();
289 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
292 tdescTy, *targetShape, hasOffsets);
295 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
299 for (
auto t : convertedTdescs) {
300 auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
302 newOps.push_back(newOp);
306 return xegpu::LoadNdOp::create(
307 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
308 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
309 op.getL2HintAttr(), op.getL3HintAttr());
311 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
312 *targetShape, createLoad, loc, rewriter);
315 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
322struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
323 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
324 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
327 VectorType valueTy = op.getValueType();
328 xegpu::TensorDescType tdescTy = op.getTensorDescType();
330 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
334 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
335 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
338 getUnrolledTypes(valueTy, *targetShape);
340 tdescTy, *targetShape, hasOffsets);
343 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
346 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
348 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
349 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
350 op.getL2HintAttr(), op.getL3HintAttr());
352 size_t valueIndex = 0;
354 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
355 convertedTdescs[0], offsets,
356 op.getL1HintAttr(), op.getL2HintAttr(),
362 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
363 createStore, loc, rewriter);
371struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
372 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
373 LogicalResult matchAndRewrite(xegpu::DpasOp op,
378 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
379 auto vecTy = dyn_cast<VectorType>(type);
380 return !vecTy || vecTy.getRank() != 2;
386 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
387 if (!targetShape || targetShape->size() != 3)
389 auto M = (*targetShape)[0];
390 auto K = (*targetShape)[1];
391 auto N = (*targetShape)[2];
393 int64_t aBlockSize[2] = {M, K};
394 int64_t bBlockSize[2] = {K, N};
395 int64_t cBlockSize[2] = {M, N};
399 VectorType type = val.getType();
400 std::optional<SmallVector<int64_t>> grids =
402 assert(grids &&
"Expecting grids to be computed.");
406 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
409 pack(val, convertedTypes, blockSize, loc, rewriter);
413 auto a = op.getLhs();
414 auto b = op.getRhs();
415 auto c = op.getAcc();
417 auto aShape = a.getType().getShape();
418 auto bShape =
b.getType().getShape();
421 aVals = packWrapper(a, aBlockSize);
422 bVals = packWrapper(
b, bBlockSize);
425 cVals = packWrapper(c, cBlockSize);
431 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
432 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
435 VectorType resultTy = op.getResult().getType();
436 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
438 int64_t mIters = aShape[0] / M;
439 int64_t kIters = aShape[1] / K;
440 int64_t nIters = bShape[1] / N;
443 for (
int64_t i = 0; i < mIters; ++i) {
447 tmpC = cVals[i * nIters +
j];
449 for (
int64_t k = 0; k < kIters; ++k) {
450 Value aVec = aVals[i * kIters + k];
451 Value bVec = bVals[k * nIters +
j];
454 operands.push_back(tmpC);
456 tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
459 newOps.push_back(tmpC);
462 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
468struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
469 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
470 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
473 xegpu::TensorDescType tdescTy = op.getType();
475 VectorType indiceVecTy = indiceVec.getType();
477 if (!tdescTy.isScattered())
480 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
485 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
487 if (originalChunkSize > 1)
488 targetIndiceShape.pop_back();
490 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
492 getUnrolledTypes(indiceVecTy, targetIndiceShape);
494 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
500 if (originalChunkSize > 1) {
501 int64_t blockedChunkSize = targetShape->back();
502 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
504 for (
auto [indice, indiceType] :
505 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
506 for (
int64_t i = 0; i < numNewChunks; ++i) {
509 i * blockedChunkSize);
511 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
513 arith::AddIOp::create(rewriter, loc, indice, incVec);
515 auto newOp = xegpu::CreateDescOp::create(
516 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
518 newOps.push_back(newOp);
522 for (
auto indice : convertedIndiceVec) {
523 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
524 op.getSource(), indice);
525 newOps.push_back(newOp);
529 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
536struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
537 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
538 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
542 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
543 xegpu::TensorDescType tdescTy = op.getTensorDescType();
546 if (!tdescTy || op.getOffsets())
549 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
554 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
556 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
558 Type elemTy = tdescTy.getElementType();
559 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
562 getUnrolledTypes(tdescTy, *targetShape);
564 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
569 if (originalChunkSize > 1) {
570 targetMaskShape.pop_back();
571 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
572 int64_t blockedChunkSize = targetShape->back();
573 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
576 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
578 convertedMasks.append(numNewChunks, mask);
580 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
582 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
583 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
588 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
589 auto newOp = xegpu::LoadGatherOp::create(
590 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
591 op.getL2HintAttr(), op.getL3HintAttr());
592 newOps.push_back(newOp);
595 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
605struct UnrollLoadGatherOpWithOffset
606 :
public UnrollPattern<xegpu::LoadGatherOp> {
607 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
608 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
611 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
612 Value offsets = op.getOffsets();
613 Value mask = op.getMask();
619 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
625 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
626 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
627 chunkSize = intAttr.getInt();
631 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
632 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
633 Type elemTy = valueTy.getElementType();
634 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
643 targetMaskShape.pop_back();
644 int64_t blockedChunkSize = targetShape->back();
645 int64_t numNewChunks = chunkSize / blockedChunkSize;
646 chunkSize = blockedChunkSize;
648 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
649 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
652 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
654 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
656 for (
auto maskVal : convertedMasksBase)
657 convertedMasks.append(numNewChunks, maskVal);
659 for (
auto [baseOffset, offsetType] :
660 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
661 for (
int64_t i = 0; i < numNewChunks; ++i) {
663 i * blockedChunkSize);
665 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
667 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
668 convertedOffsets.push_back(offsetVal);
672 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
674 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
676 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
678 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
681 auto layout = op.getLayoutAttr();
683 layout = layout.dropInstData();
686 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
687 auto newOp = xegpu::LoadGatherOp::create(
688 rewriter, loc, newValueTy, op.getSource(), o, m,
690 op.getL2HintAttr(), op.getL3HintAttr(), layout);
691 newOps.push_back(newOp);
694 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
704struct UnrollStoreScatterOpWithOffsets
705 :
public UnrollPattern<xegpu::StoreScatterOp> {
706 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
707 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
710 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
711 Value offsets = op.getOffsets();
712 Value mask = op.getMask();
718 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
723 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
724 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
725 chunkSize = intAttr.getInt();
729 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
730 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
738 targetMaskShape.pop_back();
739 int64_t blockedChunkSize = targetShape->back();
740 int64_t numNewChunks = chunkSize / blockedChunkSize;
741 chunkSize = blockedChunkSize;
743 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
744 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
747 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
749 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
751 for (
auto maskVal : convertedMasksBase)
752 convertedMasks.append(numNewChunks, maskVal);
754 for (
auto [baseOffset, offsetType] :
755 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
756 for (
int64_t i = 0; i < numNewChunks; ++i) {
758 i * blockedChunkSize);
760 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
762 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
763 convertedOffsets.push_back(offsetVal);
767 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
769 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
771 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
773 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
777 getUnrolledTypes(valueTy, *targetShape);
779 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
781 auto layout = op.getLayoutAttr();
783 layout = layout.dropInstData();
785 for (
auto [v, o, m] :
786 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
787 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
789 op.getL1HintAttr(), op.getL2HintAttr(),
790 op.getL3HintAttr(), layout);
798struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
799 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
800 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
803 xegpu::TensorDescType tdescTy = op.getTensorDescType();
806 if (!tdescTy || op.getOffsets())
809 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
814 getUnrolledTypes(tdescTy, *targetShape);
816 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
818 for (
auto t : convertedTdesc)
819 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t, op->getAttrs());
826struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
827 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
828 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
832 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
833 xegpu::TensorDescType tdescTy = op.getTensorDescType();
836 if (!tdescTy || op.getOffsets())
839 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
844 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
846 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
849 getUnrolledTypes(tdescTy, *targetShape);
851 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
856 if (originalChunkSize > 1) {
857 targetMaskShape.pop_back();
858 int64_t blockedChunkSize = targetShape->back();
859 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
860 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
863 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
865 convertedMasks.append(numNewChunks, mask);
867 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
868 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
873 getUnrolledTypes(valueTy, *targetShape);
875 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
877 for (
size_t i = 0; i < convertedValues.size(); ++i) {
878 Value v = convertedValues[i];
879 Value t = convertedTdescs[i];
880 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
881 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
882 op.getL2HintAttr(), op.getL3HintAttr());
890struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
891 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
892 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
895 xegpu::TensorDescType tdescTy = op.getTensorDescType();
897 if (!tdescTy.isScattered())
900 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
905 getUnrolledTypes(tdescTy, *targetShape);
907 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
910 VectorType offsetVecTy = offsetVec.getType();
914 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
915 if (originalChunkSize > 1) {
917 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
919 int64_t blockedChunkSize = targetShape->back();
920 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
922 for (
auto offset : pack(offsetVec, convertedOffsetTypes,
923 targetOffsetShape, loc, rewriter))
924 convertedOffsetVec.append(numNewChunks, offset);
927 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
929 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
932 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
934 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
935 newOps.push_back(newOp);
937 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
943struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
944 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
945 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
948 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
949 assert(valueTy &&
"the value type must be vector type!");
951 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
952 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
955 Type elemTy = valueTy.getElementType();
957 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
959 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
966 rewriter, loc, mixedOffsets,
968 offsetsList.push_back(adds);
972 layout = layout.dropInstData();
974 auto newOp = xegpu::LoadMatrixOp::create(
975 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
976 newOps.push_back(newOp);
978 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
984struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
985 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
986 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
988 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
993 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
994 assert(valueTy &&
"the value type must be vector type!");
996 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
999 getUnrolledTypes(valueTy, *targetShape);
1001 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
1008 rewriter, loc, mixedOffsets,
1010 offsetsList.push_back(adds);
1013 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
1014 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
1015 layout.dropInstData());
1027 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
1028 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
1029 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
1030 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
1031 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.