19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
24#define GEN_PASS_DEF_XEGPUUNROLL
25#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29#define DEBUG_TYPE "xegpu-unroll"
35template <
typename SourceOp>
45 LDBG() <<
"Get unroll shape for: " << *op;
47 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
48 LDBG() <<
"--no filter constraint -> BAIL";
53 "expects the native shape for native shape call back function.");
54 auto nativeShape =
options.nativeShape(op);
60 bool returnSingleType =
false)
const {
61 return options.getUnrolledTypes(type, tileShape, returnSingleType);
68 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
69 auto shape = vecTy.getShape();
73 if (isa<xegpu::TensorDescType>(destTy)) {
78 auto castOp = UnrealizedConversionCastOp::create(
79 rewriter, loc, destTy, srcs,
81 return castOp.getResult(0);
84 llvm_unreachable(
"Unexpected destTy.");
93 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
98 if (isa<xegpu::TensorDescType>(src.
getType())) {
103 auto castOp = UnrealizedConversionCastOp::create(
104 rewriter, loc, destTypes, src,
106 return castOp.getResults();
109 llvm_unreachable(
"Unexpected src type.");
114 const char *
const packAttrName =
"__xegpu_blocking_pack__";
115 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
116 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
132 int64_t rank = tdescTy.getRank();
140 auto aV = llvm::cast<Value>(a);
142 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
147 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
149 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
155 for (
auto [idx, oldOff, offset] :
156 llvm::zip(validIdxes, oldOffsets, offsets))
157 mixedOffsets[idx] = addi(oldOff, offset);
159 auto newOp = createOp(mixedOffsets);
160 newOps.push_back(newOp);
165struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
166 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
167 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
170 xegpu::TensorDescType tdescTy = op.getType();
172 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
178 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
179 bool hasOffsets = op.getMixedOffsets().size() != 0;
181 auto newOp = xegpu::CreateNdDescOp::create(
182 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
183 op.getMixedStrides());
184 newOps.push_back(newOp);
187 return xegpu::CreateNdDescOp::create(
188 rewriter, loc, newTdescTy, op.getSource(), offsets,
189 op.getMixedSizes(), op.getMixedStrides());
192 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
193 *targetShape, createOp, loc, rewriter);
195 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
202struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
203 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
204 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
207 xegpu::TensorDescType tdescTy = op.getTensorDescType();
209 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
214 getUnrolledTypes(tdescTy, *targetShape);
216 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
219 for (
auto t : convertedTdesc) {
220 auto newOp = xegpu::UpdateNdOffsetOp::create(
221 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
222 newOps.push_back(newOp);
224 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
230struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
231 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
232 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
235 xegpu::TensorDescType tdescTy = op.getTensorDescType();
237 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
241 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
243 layout = layout.dropInstData();
244 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
245 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
248 tdescTy, *targetShape, hasOffsets);
251 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
254 for (
auto t : convertedTdesc)
255 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
259 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
260 op.getL1HintAttr(), op.getL2HintAttr(),
261 op.getL3HintAttr(), layout);
266 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
267 createPrefetch, loc, rewriter);
275struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
276 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
277 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
281 VectorType valueTy = op.getType();
282 xegpu::TensorDescType tdescTy = op.getTensorDescType();
284 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
288 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
290 layout = layout.dropInstData();
291 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
292 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
294 Type elemTy = tdescTy.getElementType();
295 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
298 tdescTy, *targetShape, hasOffsets);
301 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
305 for (
auto t : convertedTdescs) {
307 xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
309 newOps.push_back(newOp);
313 return xegpu::LoadNdOp::create(
314 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
315 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
316 op.getL2HintAttr(), op.getL3HintAttr(), layout);
318 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
319 *targetShape, createLoad, loc, rewriter);
322 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
329struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
330 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
331 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
334 VectorType valueTy = op.getValueType();
335 xegpu::TensorDescType tdescTy = op.getTensorDescType();
337 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
341 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
343 layout = layout.dropInstData();
344 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
345 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
348 getUnrolledTypes(valueTy, *targetShape);
350 tdescTy, *targetShape, hasOffsets);
353 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
356 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
358 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
359 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
360 op.getL2HintAttr(), op.getL3HintAttr());
362 size_t valueIndex = 0;
364 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
365 convertedTdescs[0], offsets,
366 op.getL1HintAttr(), op.getL2HintAttr(),
367 op.getL3HintAttr(), layout);
372 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
373 createStore, loc, rewriter);
381struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
382 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
383 LogicalResult matchAndRewrite(xegpu::DpasOp op,
388 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
389 auto vecTy = dyn_cast<VectorType>(type);
390 return !vecTy || vecTy.getRank() != 2;
396 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
397 if (!targetShape || targetShape->size() != 3)
399 auto M = (*targetShape)[0];
400 auto K = (*targetShape)[1];
401 auto N = (*targetShape)[2];
403 int64_t aBlockSize[2] = {M, K};
404 int64_t bBlockSize[2] = {K, N};
405 int64_t cBlockSize[2] = {M, N};
409 VectorType type = val.getType();
410 std::optional<SmallVector<int64_t>> grids =
412 assert(grids &&
"Expecting grids to be computed.");
416 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
419 pack(val, convertedTypes, blockSize, loc, rewriter);
423 auto a = op.getLhs();
424 auto b = op.getRhs();
425 auto c = op.getAcc();
427 auto aShape = a.getType().getShape();
428 auto bShape =
b.getType().getShape();
431 aVals = packWrapper(a, aBlockSize);
432 bVals = packWrapper(
b, bBlockSize);
435 cVals = packWrapper(c, cBlockSize);
441 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
442 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
445 VectorType resultTy = op.getResult().getType();
446 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
448 int64_t mIters = aShape[0] / M;
449 int64_t kIters = aShape[1] / K;
450 int64_t nIters = bShape[1] / N;
453 for (
int64_t i = 0; i < mIters; ++i) {
457 tmpC = cVals[i * nIters +
j];
459 for (
int64_t k = 0; k < kIters; ++k) {
460 Value aVec = aVals[i * kIters + k];
461 Value bVec = bVals[k * nIters +
j];
464 operands.push_back(tmpC);
467 xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
470 newOps.push_back(tmpC);
473 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
479struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
480 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
481 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
484 xegpu::TensorDescType tdescTy = op.getType();
486 VectorType indiceVecTy = indiceVec.getType();
488 if (!tdescTy.isScattered())
491 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
496 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
498 if (originalChunkSize > 1)
499 targetIndiceShape.pop_back();
501 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
503 getUnrolledTypes(indiceVecTy, targetIndiceShape);
505 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
511 if (originalChunkSize > 1) {
512 int64_t blockedChunkSize = targetShape->back();
513 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
515 for (
auto [indice, indiceType] :
516 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
517 for (
int64_t i = 0; i < numNewChunks; ++i) {
520 i * blockedChunkSize);
522 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
524 arith::AddIOp::create(rewriter, loc, indice, incVec);
526 auto newOp = xegpu::CreateDescOp::create(
527 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
529 newOps.push_back(newOp);
533 for (
auto indice : convertedIndiceVec) {
534 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
535 op.getSource(), indice);
536 newOps.push_back(newOp);
540 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
547struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
548 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
549 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
553 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
554 xegpu::TensorDescType tdescTy = op.getTensorDescType();
557 if (!tdescTy || op.getOffsets())
560 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
565 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
567 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
569 Type elemTy = tdescTy.getElementType();
570 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
573 getUnrolledTypes(tdescTy, *targetShape);
575 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
580 if (originalChunkSize > 1) {
581 targetMaskShape.pop_back();
582 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
583 int64_t blockedChunkSize = targetShape->back();
584 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
587 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
589 convertedMasks.append(numNewChunks, mask);
591 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
593 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
594 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
599 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
600 auto newOp = xegpu::LoadGatherOp::create(
601 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
602 op.getL2HintAttr(), op.getL3HintAttr());
603 newOps.push_back(newOp);
606 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
616struct UnrollLoadGatherOpWithOffset
617 :
public UnrollPattern<xegpu::LoadGatherOp> {
618 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
619 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
622 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
623 Value offsets = op.getOffsets();
624 Value mask = op.getMask();
630 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
636 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
637 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
638 chunkSize = intAttr.getInt();
642 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
643 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
644 Type elemTy = valueTy.getElementType();
645 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
654 targetMaskShape.pop_back();
655 int64_t blockedChunkSize = targetShape->back();
656 int64_t numNewChunks = chunkSize / blockedChunkSize;
657 chunkSize = blockedChunkSize;
659 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
660 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
663 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
665 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
667 for (
auto maskVal : convertedMasksBase)
668 convertedMasks.append(numNewChunks, maskVal);
670 for (
auto [baseOffset, offsetType] :
671 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
672 for (
int64_t i = 0; i < numNewChunks; ++i) {
674 i * blockedChunkSize);
676 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
678 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
679 convertedOffsets.push_back(offsetVal);
683 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
685 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
687 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
689 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
692 auto layout = op.getLayoutAttr();
694 layout = layout.dropInstData();
697 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
698 auto newOp = xegpu::LoadGatherOp::create(
699 rewriter, loc, newValueTy, op.getSource(), o, m,
701 op.getL2HintAttr(), op.getL3HintAttr(), layout);
702 newOps.push_back(newOp);
705 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
715struct UnrollStoreScatterOpWithOffsets
716 :
public UnrollPattern<xegpu::StoreScatterOp> {
717 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
718 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
721 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
722 Value offsets = op.getOffsets();
723 Value mask = op.getMask();
729 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
734 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
735 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
736 chunkSize = intAttr.getInt();
740 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
741 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
749 targetMaskShape.pop_back();
750 int64_t blockedChunkSize = targetShape->back();
751 int64_t numNewChunks = chunkSize / blockedChunkSize;
752 chunkSize = blockedChunkSize;
754 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
755 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
758 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
760 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
762 for (
auto maskVal : convertedMasksBase)
763 convertedMasks.append(numNewChunks, maskVal);
765 for (
auto [baseOffset, offsetType] :
766 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
767 for (
int64_t i = 0; i < numNewChunks; ++i) {
769 i * blockedChunkSize);
771 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
773 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
774 convertedOffsets.push_back(offsetVal);
778 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
780 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
782 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
784 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
788 getUnrolledTypes(valueTy, *targetShape);
790 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
792 auto layout = op.getLayoutAttr();
794 layout = layout.dropInstData();
796 for (
auto [v, o, m] :
797 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
798 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
800 op.getL1HintAttr(), op.getL2HintAttr(),
801 op.getL3HintAttr(), layout);
809struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
810 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
811 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
814 xegpu::TensorDescType tdescTy = op.getTensorDescType();
817 if (!tdescTy || op.getOffsets())
820 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
825 getUnrolledTypes(tdescTy, *targetShape);
827 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
829 for (
auto t : convertedTdesc)
830 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t,
838struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
839 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
840 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
844 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
845 xegpu::TensorDescType tdescTy = op.getTensorDescType();
848 if (!tdescTy || op.getOffsets())
851 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
856 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
858 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
861 getUnrolledTypes(tdescTy, *targetShape);
863 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
868 if (originalChunkSize > 1) {
869 targetMaskShape.pop_back();
870 int64_t blockedChunkSize = targetShape->back();
871 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
872 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
875 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
877 convertedMasks.append(numNewChunks, mask);
879 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
880 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
885 getUnrolledTypes(valueTy, *targetShape);
887 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
889 for (
size_t i = 0; i < convertedValues.size(); ++i) {
890 Value v = convertedValues[i];
891 Value t = convertedTdescs[i];
892 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
893 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
894 op.getL2HintAttr(), op.getL3HintAttr());
902struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
903 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
904 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
907 xegpu::TensorDescType tdescTy = op.getTensorDescType();
909 if (!tdescTy.isScattered())
912 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
917 getUnrolledTypes(tdescTy, *targetShape);
919 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
922 VectorType offsetVecTy = offsetVec.getType();
926 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
927 if (originalChunkSize > 1) {
929 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
931 int64_t blockedChunkSize = targetShape->back();
932 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
934 for (
auto offset : pack(offsetVec, convertedOffsetTypes,
935 targetOffsetShape, loc, rewriter))
936 convertedOffsetVec.append(numNewChunks, offset);
939 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
941 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
944 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
946 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
947 newOps.push_back(newOp);
949 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
955struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
956 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
957 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
960 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
961 assert(valueTy &&
"the value type must be vector type!");
963 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
964 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
967 Type elemTy = valueTy.getElementType();
969 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
971 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
978 rewriter, loc, mixedOffsets,
980 offsetsList.push_back(adds);
984 layout = layout.dropInstData();
986 auto newOp = xegpu::LoadMatrixOp::create(
987 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
988 newOps.push_back(newOp);
990 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
996struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
997 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
998 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
1000 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
1005 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
1006 assert(valueTy &&
"the value type must be vector type!");
1008 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
1011 getUnrolledTypes(valueTy, *targetShape);
1013 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
1020 rewriter, loc, mixedOffsets,
1022 offsetsList.push_back(adds);
1025 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
1026 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
1027 layout.dropInstData());
1039 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
1040 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
1041 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
1042 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
1043 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.