19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
24 #define GEN_PASS_DEF_XEGPUUNROLL
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "xegpu-unroll"
35 template <
typename SourceOp>
45 LDBG() <<
"Get unroll shape for: " << *op;
48 LDBG() <<
"--no filter constraint -> BAIL";
53 "expects the native shape for native shape call back function.");
54 auto nativeShape =
options.nativeShape(op);
60 return options.getUnrolledTypes(type, tileShape);
67 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
68 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
69 "Expecting blockSize size to match the rank of destTy.");
70 auto shape = vecTy.getShape();
74 if (isa<xegpu::TensorDescType>(destTy)) {
79 auto castOp = UnrealizedConversionCastOp::create(
80 rewriter, loc, destTy, srcs,
82 return castOp.getResult(0);
85 llvm_unreachable(
"Unexpected destTy.");
94 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
95 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
96 "Expecting blockSize size to match the rank of src.");
101 if (isa<xegpu::TensorDescType>(src.
getType())) {
106 auto castOp = UnrealizedConversionCastOp::create(
107 rewriter, loc, destTypes, src,
109 return castOp.getResults();
112 llvm_unreachable(
"Unexpected src type.");
117 const char *
const packAttrName =
"__xegpu_blocking_pack__";
118 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
119 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
124 struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
125 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
126 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
129 xegpu::TensorDescType tdescTy = op.getType();
130 int64_t rank = tdescTy.getRank();
133 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
137 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
144 auto aV = llvm::cast<Value>(a);
146 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
155 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
157 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
163 for (
auto [idx, oldOff, offset] :
164 llvm::zip(validIdxes, oldOffsets, offsets))
165 mixedOffsets[idx] = addi(oldOff, offset);
167 auto newOp = xegpu::CreateNdDescOp::create(
168 rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
169 op.getMixedSizes(), op.getMixedStrides());
170 newOps.push_back(newOp);
172 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
179 struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
180 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
181 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
184 xegpu::TensorDescType tdescTy = op.getTensorDescType();
186 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
191 getUnrolledTypes(tdescTy, *targetShape);
193 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
196 for (
auto t : convertedTdesc) {
197 auto newOp = xegpu::UpdateNdOffsetOp::create(
198 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
199 newOps.push_back(newOp);
201 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
207 struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
208 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
209 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
212 xegpu::TensorDescType tdescTy = op.getTensorDescType();
214 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
218 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
219 if ((offsetSize != 0) || op.getConstOffsetsAttr())
223 getUnrolledTypes(tdescTy, *targetShape);
225 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
227 for (
auto t : convertedTdesc)
228 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
236 struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
237 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
242 VectorType valueTy = op.getType();
243 xegpu::TensorDescType tdescTy = op.getTensorDescType();
245 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
249 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
250 if ((offsetSize != 0) || op.getConstOffsetsAttr())
253 Type elemTy = tdescTy.getElementType();
254 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
257 getUnrolledTypes(tdescTy, *targetShape);
259 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
262 for (
auto t : convertedTdescs) {
264 xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
265 newOps.push_back(newOp);
268 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
275 struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
276 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
277 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
280 VectorType valueTy = op.getValueType();
281 xegpu::TensorDescType tdescTy = op.getTensorDescType();
283 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
287 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
288 if ((offsetSize != 0) || op.getConstOffsetsAttr())
292 getUnrolledTypes(valueTy, *targetShape);
294 getUnrolledTypes(tdescTy, *targetShape);
297 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
299 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
301 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
302 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
303 op.getL2HintAttr(), op.getL3HintAttr());
310 struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
311 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
312 LogicalResult matchAndRewrite(xegpu::DpasOp op,
317 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
318 auto vecTy = dyn_cast<VectorType>(type);
319 return !vecTy || vecTy.getRank() != 2;
325 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
326 if (!targetShape || targetShape->size() != 3)
328 auto M = (*targetShape)[0];
329 auto K = (*targetShape)[1];
330 auto N = (*targetShape)[2];
332 int64_t aBlockSize[2] = {M, K};
333 int64_t bBlockSize[2] = {K, N};
334 int64_t cBlockSize[2] = {M, N};
338 VectorType type = val.getType();
339 std::optional<SmallVector<int64_t>> grids =
341 assert(grids &&
"Expecting grids to be computed.");
345 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
348 pack(val, convertedTypes, blockSize, loc, rewriter);
352 auto a = op.getLhs();
353 auto b = op.getRhs();
354 auto c = op.getAcc();
356 auto aShape = a.getType().getShape();
357 auto bShape = b.getType().getShape();
360 aVals = packWrapper(a, aBlockSize);
361 bVals = packWrapper(b, bBlockSize);
364 cVals = packWrapper(c, cBlockSize);
370 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
371 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
374 VectorType resultTy = op.getResult().getType();
377 int64_t mIters = aShape[0] / M;
378 int64_t kIters = aShape[1] / K;
379 int64_t nIters = bShape[1] / N;
382 for (int64_t i = 0; i < mIters; ++i) {
383 for (int64_t
j = 0;
j < nIters; ++
j) {
386 tmpC = cVals[i * nIters +
j];
388 for (int64_t k = 0; k < kIters; ++k) {
389 Value aVec = aVals[i * kIters + k];
390 Value bVec = bVals[k * nIters +
j];
393 operands.push_back(tmpC);
395 tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
398 newOps.push_back(tmpC);
401 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
407 struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
408 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
409 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
412 xegpu::TensorDescType tdescTy = op.getType();
414 VectorType indiceVecTy = indiceVec.getType();
416 if (!tdescTy.isScattered())
419 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
424 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
426 if (originalChunkSize > 1)
427 targetIndiceShape.pop_back();
429 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
431 getUnrolledTypes(indiceVecTy, targetIndiceShape);
433 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
439 if (originalChunkSize > 1) {
440 int64_t blockedChunkSize = targetShape->back();
441 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
443 for (
auto [indice, indiceType] :
444 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
445 for (int64_t i = 0; i < numNewChunks; ++i) {
448 i * blockedChunkSize);
450 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
452 arith::AddIOp::create(rewriter, loc, indice, incVec);
454 auto newOp = xegpu::CreateDescOp::create(
455 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
457 newOps.push_back(newOp);
461 for (
auto indice : convertedIndiceVec) {
462 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
463 op.getSource(), indice);
464 newOps.push_back(newOp);
468 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
475 struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
476 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
477 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
481 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
482 xegpu::TensorDescType tdescTy = op.getTensorDescType();
485 if (!tdescTy || op.getOffsets())
488 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
493 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
495 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
497 Type elemTy = tdescTy.getElementType();
498 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
501 getUnrolledTypes(tdescTy, *targetShape);
503 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
508 if (originalChunkSize > 1) {
509 targetMaskShape.pop_back();
510 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
511 int64_t blockedChunkSize = targetShape->back();
512 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
515 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
517 convertedMasks.append(numNewChunks, mask);
519 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
521 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
522 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
527 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
528 auto newOp = xegpu::LoadGatherOp::create(
529 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
530 op.getL2HintAttr(), op.getL3HintAttr());
531 newOps.push_back(newOp);
534 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
544 struct UnrollLoadGatherOpWithOffset
545 :
public UnrollPattern<xegpu::LoadGatherOp> {
546 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
547 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
550 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
551 Value offsets = op.getOffsets();
552 Value mask = op.getMask();
558 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
563 int64_t chunkSize = 1;
564 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
565 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
566 chunkSize = intAttr.getInt();
570 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
571 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
572 Type elemTy = valueTy.getElementType();
573 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
582 targetMaskShape.pop_back();
583 int64_t blockedChunkSize = targetShape->back();
584 int64_t numNewChunks = chunkSize / blockedChunkSize;
585 chunkSize = blockedChunkSize;
587 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
588 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
591 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
593 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
595 for (
auto maskVal : convertedMasksBase)
596 convertedMasks.append(numNewChunks, maskVal);
598 for (
auto [baseOffset, offsetType] :
599 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
600 for (int64_t i = 0; i < numNewChunks; ++i) {
602 i * blockedChunkSize);
604 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
606 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
607 convertedOffsets.push_back(offsetVal);
611 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
613 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
615 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
617 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
621 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
622 auto newOp = xegpu::LoadGatherOp::create(
623 rewriter, loc, newValueTy, op.getSource(), o, m,
625 op.getL2HintAttr(), op.getL3HintAttr());
626 newOps.push_back(newOp);
629 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
639 struct UnrollStoreScatterOpWithOffsets
640 :
public UnrollPattern<xegpu::StoreScatterOp> {
641 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
642 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
645 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
646 Value offsets = op.getOffsets();
647 Value mask = op.getMask();
653 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
657 int64_t chunkSize = 1;
658 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
659 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
660 chunkSize = intAttr.getInt();
664 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
665 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
673 targetMaskShape.pop_back();
674 int64_t blockedChunkSize = targetShape->back();
675 int64_t numNewChunks = chunkSize / blockedChunkSize;
676 chunkSize = blockedChunkSize;
678 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
679 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
682 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
684 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
686 for (
auto maskVal : convertedMasksBase)
687 convertedMasks.append(numNewChunks, maskVal);
689 for (
auto [baseOffset, offsetType] :
690 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
691 for (int64_t i = 0; i < numNewChunks; ++i) {
693 i * blockedChunkSize);
695 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
697 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
698 convertedOffsets.push_back(offsetVal);
702 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
704 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
706 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
708 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
712 getUnrolledTypes(valueTy, *targetShape);
714 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
716 for (
auto [v, o, m] :
717 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
718 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
720 op.getL1HintAttr(), op.getL2HintAttr(),
729 struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
730 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
731 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
734 xegpu::TensorDescType tdescTy = op.getTensorDescType();
737 if (!tdescTy || op.getOffsets())
740 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
745 getUnrolledTypes(tdescTy, *targetShape);
747 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
749 for (
auto t : convertedTdesc)
750 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t, op->getAttrs());
757 struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
758 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
759 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
763 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
764 xegpu::TensorDescType tdescTy = op.getTensorDescType();
767 if (!tdescTy || op.getOffsets())
770 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
775 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
777 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
780 getUnrolledTypes(tdescTy, *targetShape);
782 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
787 if (originalChunkSize > 1) {
788 targetMaskShape.pop_back();
789 int64_t blockedChunkSize = targetShape->back();
790 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
791 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
794 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
796 convertedMasks.append(numNewChunks, mask);
798 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
799 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
804 getUnrolledTypes(valueTy, *targetShape);
806 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
808 for (
size_t i = 0; i < convertedValues.size(); ++i) {
809 Value v = convertedValues[i];
810 Value t = convertedTdescs[i];
811 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
812 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
813 op.getL2HintAttr(), op.getL3HintAttr());
821 struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
822 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
823 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
826 xegpu::TensorDescType tdescTy = op.getTensorDescType();
828 if (!tdescTy.isScattered())
831 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
836 getUnrolledTypes(tdescTy, *targetShape);
838 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
841 VectorType offsetVecTy = offsetVec.getType();
845 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
846 if (originalChunkSize > 1) {
848 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
850 int64_t blockedChunkSize = targetShape->back();
851 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
853 for (
auto offset :
pack(offsetVec, convertedOffsetTypes,
854 targetOffsetShape, loc, rewriter))
855 convertedOffsetVec.append(numNewChunks, offset);
858 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
860 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
863 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
865 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
866 newOps.push_back(newOp);
868 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
874 struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
875 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
876 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
879 VectorType valueTy = op.getType();
880 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
881 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
884 Type elemTy = valueTy.getElementType();
886 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
888 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
895 rewriter, loc, mixedOffsets,
897 offsetsList.push_back(adds);
901 layout = layout.dropInstData();
903 auto newOp = xegpu::LoadMatrixOp::create(
904 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
905 newOps.push_back(newOp);
907 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
913 struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
914 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
915 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
917 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
922 VectorType valueTy = op.getData().getType();
924 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
927 getUnrolledTypes(valueTy, *targetShape);
929 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
936 rewriter, loc, mixedOffsets,
938 offsetsList.push_back(adds);
941 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
942 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
943 layout.dropInstData());
955 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
956 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
957 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
958 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
959 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.