20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/DebugLog.h"
25#define GEN_PASS_DEF_XEGPUUNROLL
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30#define DEBUG_TYPE "xegpu-unroll"
36template <
typename SourceOp>
46 LDBG() <<
"Get unroll shape for: " << *op;
48 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
49 LDBG() <<
"--no filter constraint -> BAIL";
54 "expects the native shape for native shape call back function.");
55 auto nativeShape =
options.nativeShape(op);
61 bool returnSingleType =
false)
const {
62 return options.getUnrolledTypes(type, tileShape, returnSingleType);
69 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
70 auto shape = vecTy.getShape();
74 if (isa<xegpu::TensorDescType>(destTy)) {
79 auto castOp = UnrealizedConversionCastOp::create(
80 rewriter, loc, destTy, srcs,
82 return castOp.getResult(0);
85 llvm_unreachable(
"Unexpected destTy.");
94 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
99 if (isa<xegpu::TensorDescType>(src.
getType())) {
104 auto castOp = UnrealizedConversionCastOp::create(
105 rewriter, loc, destTypes, src,
107 return castOp.getResults();
110 llvm_unreachable(
"Unexpected src type.");
115 const char *
const packAttrName =
"__xegpu_blocking_pack__";
116 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
117 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
133 int64_t rank = tdescTy.getRank();
141 auto aV = llvm::cast<Value>(a);
143 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
148 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
150 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
156 for (
auto [idx, oldOff, offset] :
157 llvm::zip(validIdxes, oldOffsets, offsets))
158 mixedOffsets[idx] = addi(oldOff, offset);
160 auto newOp = createOp(mixedOffsets);
161 newOps.push_back(newOp);
166struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
167 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
168 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
171 xegpu::TensorDescType tdescTy = op.getType();
173 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
179 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
180 bool hasOffsets = op.getMixedOffsets().size() != 0;
182 auto newOp = xegpu::CreateNdDescOp::create(
183 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
184 op.getMixedStrides());
185 newOps.push_back(newOp);
188 return xegpu::CreateNdDescOp::create(
189 rewriter, loc, newTdescTy, op.getSource(), offsets,
190 op.getMixedSizes(), op.getMixedStrides());
193 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
194 *targetShape, createOp, loc, rewriter);
196 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
203struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
204 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
205 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
208 xegpu::TensorDescType tdescTy = op.getTensorDescType();
210 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
215 getUnrolledTypes(tdescTy, *targetShape);
217 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
220 for (
auto t : convertedTdesc) {
221 auto newOp = xegpu::UpdateNdOffsetOp::create(
222 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
223 newOps.push_back(newOp);
225 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
231struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
232 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
233 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
236 xegpu::TensorDescType tdescTy = op.getTensorDescType();
238 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
242 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
244 layout = layout.dropInstData();
245 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
246 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
249 tdescTy, *targetShape, hasOffsets);
252 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
255 for (
auto t : convertedTdesc)
256 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
260 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
261 op.getL1HintAttr(), op.getL2HintAttr(),
262 op.getL3HintAttr(), layout);
267 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
268 createPrefetch, loc, rewriter);
276struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
277 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
278 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
282 VectorType valueTy = op.getType();
283 xegpu::TensorDescType tdescTy = op.getTensorDescType();
285 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
289 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
291 layout = layout.dropInstData();
292 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
293 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
295 Type elemTy = tdescTy.getElementType();
296 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
299 tdescTy, *targetShape, hasOffsets);
302 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
306 for (
auto t : convertedTdescs) {
308 xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
310 newOps.push_back(newOp);
314 return xegpu::LoadNdOp::create(
315 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
316 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
317 op.getL2HintAttr(), op.getL3HintAttr(), layout);
319 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
320 *targetShape, createLoad, loc, rewriter);
323 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
330struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
331 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
332 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
335 VectorType valueTy = op.getValueType();
336 xegpu::TensorDescType tdescTy = op.getTensorDescType();
338 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
342 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
344 layout = layout.dropInstData();
345 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
346 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
349 getUnrolledTypes(valueTy, *targetShape);
351 tdescTy, *targetShape, hasOffsets);
354 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
357 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
359 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
360 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
361 op.getL2HintAttr(), op.getL3HintAttr());
363 size_t valueIndex = 0;
365 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
366 convertedTdescs[0], offsets,
367 op.getL1HintAttr(), op.getL2HintAttr(),
368 op.getL3HintAttr(), layout);
373 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
374 createStore, loc, rewriter);
382struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
383 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
384 LogicalResult matchAndRewrite(xegpu::DpasOp op,
389 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
390 auto vecTy = dyn_cast<VectorType>(type);
391 return !vecTy || vecTy.getRank() != 2;
397 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
398 if (!targetShape || targetShape->size() != 3)
400 auto M = (*targetShape)[0];
401 auto K = (*targetShape)[1];
402 auto N = (*targetShape)[2];
404 int64_t aBlockSize[2] = {M, K};
405 int64_t bBlockSize[2] = {K, N};
406 int64_t cBlockSize[2] = {M, N};
410 VectorType type = val.getType();
411 std::optional<SmallVector<int64_t>> grids =
413 assert(grids &&
"Expecting grids to be computed.");
417 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
420 pack(val, convertedTypes, blockSize, loc, rewriter);
424 auto a = op.getLhs();
425 auto b = op.getRhs();
426 auto c = op.getAcc();
428 auto aShape = a.getType().getShape();
429 auto bShape =
b.getType().getShape();
432 aVals = packWrapper(a, aBlockSize);
433 bVals = packWrapper(
b, bBlockSize);
436 cVals = packWrapper(c, cBlockSize);
442 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
443 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
446 VectorType resultTy = op.getResult().getType();
447 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
449 int64_t mIters = aShape[0] / M;
450 int64_t kIters = aShape[1] / K;
451 int64_t nIters = bShape[1] / N;
454 for (
int64_t i = 0; i < mIters; ++i) {
458 tmpC = cVals[i * nIters +
j];
460 for (
int64_t k = 0; k < kIters; ++k) {
461 Value aVec = aVals[i * kIters + k];
462 Value bVec = bVals[k * nIters +
j];
465 operands.push_back(tmpC);
468 xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
471 newOps.push_back(tmpC);
474 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
480struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
481 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
482 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
485 xegpu::TensorDescType tdescTy = op.getType();
487 VectorType indiceVecTy = indiceVec.getType();
489 if (!tdescTy.isScattered())
492 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
497 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
499 if (originalChunkSize > 1)
500 targetIndiceShape.pop_back();
502 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
504 getUnrolledTypes(indiceVecTy, targetIndiceShape);
506 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
512 if (originalChunkSize > 1) {
513 int64_t blockedChunkSize = targetShape->back();
514 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
516 for (
auto [indice, indiceType] :
517 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
518 for (
int64_t i = 0; i < numNewChunks; ++i) {
521 i * blockedChunkSize);
523 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
525 arith::AddIOp::create(rewriter, loc, indice, incVec);
527 auto newOp = xegpu::CreateDescOp::create(
528 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
530 newOps.push_back(newOp);
534 for (
auto indice : convertedIndiceVec) {
535 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
536 op.getSource(), indice);
537 newOps.push_back(newOp);
541 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
548struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
549 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
550 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
554 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
555 xegpu::TensorDescType tdescTy = op.getTensorDescType();
558 if (!tdescTy || op.getOffsets())
561 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
566 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
568 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
570 Type elemTy = tdescTy.getElementType();
571 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
574 getUnrolledTypes(tdescTy, *targetShape);
576 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
581 if (originalChunkSize > 1) {
582 targetMaskShape.pop_back();
583 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
584 int64_t blockedChunkSize = targetShape->back();
585 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
588 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
590 convertedMasks.append(numNewChunks, mask);
592 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
594 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
595 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
600 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
601 auto newOp = xegpu::LoadGatherOp::create(
602 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
603 op.getL2HintAttr(), op.getL3HintAttr());
604 newOps.push_back(newOp);
607 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
617struct UnrollLoadGatherOpWithOffset
618 :
public UnrollPattern<xegpu::LoadGatherOp> {
619 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
620 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
623 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
624 Value offsets = op.getOffsets();
625 Value mask = op.getMask();
631 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
637 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
638 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
639 chunkSize = intAttr.getInt();
643 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
644 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
645 Type elemTy = valueTy.getElementType();
646 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
655 targetMaskShape.pop_back();
656 int64_t blockedChunkSize = targetShape->back();
657 int64_t numNewChunks = chunkSize / blockedChunkSize;
658 chunkSize = blockedChunkSize;
660 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
661 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
664 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
666 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
668 for (
auto maskVal : convertedMasksBase)
669 convertedMasks.append(numNewChunks, maskVal);
671 for (
auto [baseOffset, offsetType] :
672 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
673 for (
int64_t i = 0; i < numNewChunks; ++i) {
675 i * blockedChunkSize);
677 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
679 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
680 convertedOffsets.push_back(offsetVal);
684 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
686 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
688 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
690 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
693 auto layout = op.getLayoutAttr();
695 layout = layout.dropInstData();
698 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
699 auto newOp = xegpu::LoadGatherOp::create(
700 rewriter, loc, newValueTy, op.getSource(), o, m,
702 op.getL2HintAttr(), op.getL3HintAttr(), layout);
703 newOps.push_back(newOp);
706 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
716struct UnrollStoreScatterOpWithOffsets
717 :
public UnrollPattern<xegpu::StoreScatterOp> {
718 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
719 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
722 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
723 Value offsets = op.getOffsets();
724 Value mask = op.getMask();
730 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
735 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
736 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
737 chunkSize = intAttr.getInt();
741 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
742 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
750 targetMaskShape.pop_back();
751 int64_t blockedChunkSize = targetShape->back();
752 int64_t numNewChunks = chunkSize / blockedChunkSize;
753 chunkSize = blockedChunkSize;
755 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
756 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
759 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
761 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
763 for (
auto maskVal : convertedMasksBase)
764 convertedMasks.append(numNewChunks, maskVal);
766 for (
auto [baseOffset, offsetType] :
767 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
768 for (
int64_t i = 0; i < numNewChunks; ++i) {
770 i * blockedChunkSize);
772 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
774 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
775 convertedOffsets.push_back(offsetVal);
779 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
781 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
783 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
785 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
789 getUnrolledTypes(valueTy, *targetShape);
791 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
793 auto layout = op.getLayoutAttr();
795 layout = layout.dropInstData();
797 for (
auto [v, o, m] :
798 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
799 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
801 op.getL1HintAttr(), op.getL2HintAttr(),
802 op.getL3HintAttr(), layout);
810struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
811 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
812 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
815 xegpu::TensorDescType tdescTy = op.getTensorDescType();
818 if (!tdescTy || op.getOffsets())
821 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
826 getUnrolledTypes(tdescTy, *targetShape);
828 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
830 for (
auto t : convertedTdesc)
831 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t,
839struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
840 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
841 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
845 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
846 xegpu::TensorDescType tdescTy = op.getTensorDescType();
849 if (!tdescTy || op.getOffsets())
852 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
857 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
859 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
862 getUnrolledTypes(tdescTy, *targetShape);
864 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
869 if (originalChunkSize > 1) {
870 targetMaskShape.pop_back();
871 int64_t blockedChunkSize = targetShape->back();
872 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
873 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
876 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
878 convertedMasks.append(numNewChunks, mask);
880 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
881 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
886 getUnrolledTypes(valueTy, *targetShape);
888 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
890 for (
size_t i = 0; i < convertedValues.size(); ++i) {
891 Value v = convertedValues[i];
892 Value t = convertedTdescs[i];
893 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
894 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
895 op.getL2HintAttr(), op.getL3HintAttr());
903struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
904 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
905 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
908 xegpu::TensorDescType tdescTy = op.getTensorDescType();
910 if (!tdescTy.isScattered())
913 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
918 getUnrolledTypes(tdescTy, *targetShape);
920 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
923 VectorType offsetVecTy = offsetVec.getType();
927 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
928 if (originalChunkSize > 1) {
930 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
932 int64_t blockedChunkSize = targetShape->back();
933 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
935 for (
auto offset : pack(offsetVec, convertedOffsetTypes,
936 targetOffsetShape, loc, rewriter))
937 convertedOffsetVec.append(numNewChunks, offset);
940 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
942 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
945 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
947 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
948 newOps.push_back(newOp);
950 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
956struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
957 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
958 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
961 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
962 assert(valueTy &&
"the value type must be vector type!");
964 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
965 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
968 Type elemTy = valueTy.getElementType();
970 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
972 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
979 rewriter, loc, mixedOffsets,
981 offsetsList.push_back(adds);
985 layout = layout.dropInstData();
987 auto newOp = xegpu::LoadMatrixOp::create(
988 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
989 newOps.push_back(newOp);
991 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
997struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
998 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
999 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
1001 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
1006 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
1007 assert(valueTy &&
"the value type must be vector type!");
1009 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
1012 getUnrolledTypes(valueTy, *targetShape);
1014 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
1021 rewriter, loc, mixedOffsets,
1023 offsetsList.push_back(adds);
1026 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
1027 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
1028 layout.dropInstData());
1040 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
1041 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
1042 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
1043 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
1044 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.