19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
24 #define GEN_PASS_DEF_XEGPUUNROLL
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "xegpu-unroll"
35 template <
typename SourceOp>
45 LDBG() <<
"Get unroll shape for: " << *op;
48 LDBG() <<
"--no filter constraint -> BAIL";
53 "expects the native shape for native shape call back function.");
54 auto nativeShape =
options.nativeShape(op);
60 bool returnSingleType =
false)
const {
61 return options.getUnrolledTypes(type, tileShape, returnSingleType);
68 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
69 auto shape = vecTy.getShape();
73 if (isa<xegpu::TensorDescType>(destTy)) {
78 auto castOp = UnrealizedConversionCastOp::create(
79 rewriter, loc, destTy, srcs,
81 return castOp.getResult(0);
84 llvm_unreachable(
"Unexpected destTy.");
93 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
98 if (isa<xegpu::TensorDescType>(src.
getType())) {
103 auto castOp = UnrealizedConversionCastOp::create(
104 rewriter, loc, destTypes, src,
106 return castOp.getResults();
109 llvm_unreachable(
"Unexpected src type.");
114 const char *
const packAttrName =
"__xegpu_blocking_pack__";
115 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
116 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
132 int64_t rank = tdescTy.getRank();
140 auto aV = llvm::cast<Value>(a);
142 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
147 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
149 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
155 for (
auto [idx, oldOff, offset] :
156 llvm::zip(validIdxes, oldOffsets, offsets))
157 mixedOffsets[idx] = addi(oldOff, offset);
159 auto newOp = createOp(mixedOffsets);
160 newOps.push_back(newOp);
165 struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
166 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
167 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
170 xegpu::TensorDescType tdescTy = op.getType();
172 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
178 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
179 bool hasOffsets = op.getMixedOffsets().size() != 0;
181 auto newOp = xegpu::CreateNdDescOp::create(
182 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
183 op.getMixedStrides());
184 newOps.push_back(newOp);
187 return xegpu::CreateNdDescOp::create(
188 rewriter, loc, newTdescTy, op.getSource(), offsets,
189 op.getMixedSizes(), op.getMixedStrides());
192 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
193 *targetShape, createOp, loc, rewriter);
195 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
202 struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
203 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
204 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
207 xegpu::TensorDescType tdescTy = op.getTensorDescType();
209 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
214 getUnrolledTypes(tdescTy, *targetShape);
216 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
219 for (
auto t : convertedTdesc) {
220 auto newOp = xegpu::UpdateNdOffsetOp::create(
221 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
222 newOps.push_back(newOp);
224 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
230 struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
231 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
232 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
235 xegpu::TensorDescType tdescTy = op.getTensorDescType();
237 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
241 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
242 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
245 tdescTy, *targetShape, hasOffsets);
248 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
251 for (
auto t : convertedTdesc)
252 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
256 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
257 op.getL1HintAttr(), op.getL2HintAttr(),
263 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
264 createPrefetch, loc, rewriter);
272 struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
273 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
274 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
278 VectorType valueTy = op.getType();
279 xegpu::TensorDescType tdescTy = op.getTensorDescType();
281 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
285 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
286 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
288 Type elemTy = tdescTy.getElementType();
289 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
292 tdescTy, *targetShape, hasOffsets);
295 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
299 for (
auto t : convertedTdescs) {
300 auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
302 newOps.push_back(newOp);
306 return xegpu::LoadNdOp::create(
307 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
308 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
309 op.getL2HintAttr(), op.getL3HintAttr());
311 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
312 *targetShape, createLoad, loc, rewriter);
315 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
322 struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
323 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
324 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
327 VectorType valueTy = op.getValueType();
328 xegpu::TensorDescType tdescTy = op.getTensorDescType();
330 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
334 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
335 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
338 getUnrolledTypes(valueTy, *targetShape);
340 tdescTy, *targetShape, hasOffsets);
343 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
346 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
348 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
349 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
350 op.getL2HintAttr(), op.getL3HintAttr());
352 size_t valueIndex = 0;
354 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
355 convertedTdescs[0], offsets,
356 op.getL1HintAttr(), op.getL2HintAttr(),
362 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
363 createStore, loc, rewriter);
371 struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
372 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
373 LogicalResult matchAndRewrite(xegpu::DpasOp op,
378 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
379 auto vecTy = dyn_cast<VectorType>(type);
380 return !vecTy || vecTy.getRank() != 2;
386 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
387 if (!targetShape || targetShape->size() != 3)
389 auto M = (*targetShape)[0];
390 auto K = (*targetShape)[1];
391 auto N = (*targetShape)[2];
393 int64_t aBlockSize[2] = {M, K};
394 int64_t bBlockSize[2] = {K, N};
395 int64_t cBlockSize[2] = {M, N};
399 VectorType type = val.getType();
400 std::optional<SmallVector<int64_t>> grids =
402 assert(grids &&
"Expecting grids to be computed.");
406 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
409 pack(val, convertedTypes, blockSize, loc, rewriter);
413 auto a = op.getLhs();
414 auto b = op.getRhs();
415 auto c = op.getAcc();
417 auto aShape = a.getType().getShape();
418 auto bShape = b.getType().getShape();
421 aVals = packWrapper(a, aBlockSize);
422 bVals = packWrapper(b, bBlockSize);
425 cVals = packWrapper(c, cBlockSize);
431 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
432 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
435 VectorType resultTy = op.getResult().getType();
438 int64_t mIters = aShape[0] / M;
439 int64_t kIters = aShape[1] / K;
440 int64_t nIters = bShape[1] / N;
443 for (int64_t i = 0; i < mIters; ++i) {
444 for (int64_t
j = 0;
j < nIters; ++
j) {
447 tmpC = cVals[i * nIters +
j];
449 for (int64_t k = 0; k < kIters; ++k) {
450 Value aVec = aVals[i * kIters + k];
451 Value bVec = bVals[k * nIters +
j];
454 operands.push_back(tmpC);
456 tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
459 newOps.push_back(tmpC);
462 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
468 struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
469 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
470 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
473 xegpu::TensorDescType tdescTy = op.getType();
475 VectorType indiceVecTy = indiceVec.getType();
477 if (!tdescTy.isScattered())
480 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
485 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
487 if (originalChunkSize > 1)
488 targetIndiceShape.pop_back();
490 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
492 getUnrolledTypes(indiceVecTy, targetIndiceShape);
494 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
500 if (originalChunkSize > 1) {
501 int64_t blockedChunkSize = targetShape->back();
502 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
504 for (
auto [indice, indiceType] :
505 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
506 for (int64_t i = 0; i < numNewChunks; ++i) {
509 i * blockedChunkSize);
511 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
513 arith::AddIOp::create(rewriter, loc, indice, incVec);
515 auto newOp = xegpu::CreateDescOp::create(
516 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
518 newOps.push_back(newOp);
522 for (
auto indice : convertedIndiceVec) {
523 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
524 op.getSource(), indice);
525 newOps.push_back(newOp);
529 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
536 struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
537 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
538 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
542 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
543 xegpu::TensorDescType tdescTy = op.getTensorDescType();
546 if (!tdescTy || op.getOffsets())
549 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
554 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
556 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
558 Type elemTy = tdescTy.getElementType();
559 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
562 getUnrolledTypes(tdescTy, *targetShape);
564 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
569 if (originalChunkSize > 1) {
570 targetMaskShape.pop_back();
571 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
572 int64_t blockedChunkSize = targetShape->back();
573 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
576 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
578 convertedMasks.append(numNewChunks, mask);
580 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
582 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
583 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
588 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
589 auto newOp = xegpu::LoadGatherOp::create(
590 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
591 op.getL2HintAttr(), op.getL3HintAttr());
592 newOps.push_back(newOp);
595 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
605 struct UnrollLoadGatherOpWithOffset
606 :
public UnrollPattern<xegpu::LoadGatherOp> {
607 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
608 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
611 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
612 Value offsets = op.getOffsets();
613 Value mask = op.getMask();
619 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
624 int64_t chunkSize = 1;
625 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
626 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
627 chunkSize = intAttr.getInt();
631 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
632 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
633 Type elemTy = valueTy.getElementType();
643 targetMaskShape.pop_back();
644 int64_t blockedChunkSize = targetShape->back();
645 int64_t numNewChunks = chunkSize / blockedChunkSize;
646 chunkSize = blockedChunkSize;
648 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
649 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
652 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
654 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
656 for (
auto maskVal : convertedMasksBase)
657 convertedMasks.append(numNewChunks, maskVal);
659 for (
auto [baseOffset, offsetType] :
660 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
661 for (int64_t i = 0; i < numNewChunks; ++i) {
663 i * blockedChunkSize);
665 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
667 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
668 convertedOffsets.push_back(offsetVal);
672 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
674 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
676 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
678 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
682 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
683 auto newOp = xegpu::LoadGatherOp::create(
684 rewriter, loc, newValueTy, op.getSource(), o, m,
686 op.getL2HintAttr(), op.getL3HintAttr());
687 newOps.push_back(newOp);
690 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
700 struct UnrollStoreScatterOpWithOffsets
701 :
public UnrollPattern<xegpu::StoreScatterOp> {
702 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
703 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
706 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
707 Value offsets = op.getOffsets();
708 Value mask = op.getMask();
714 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
718 int64_t chunkSize = 1;
719 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
720 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
721 chunkSize = intAttr.getInt();
725 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
726 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
734 targetMaskShape.pop_back();
735 int64_t blockedChunkSize = targetShape->back();
736 int64_t numNewChunks = chunkSize / blockedChunkSize;
737 chunkSize = blockedChunkSize;
739 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
740 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
743 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
745 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
747 for (
auto maskVal : convertedMasksBase)
748 convertedMasks.append(numNewChunks, maskVal);
750 for (
auto [baseOffset, offsetType] :
751 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
752 for (int64_t i = 0; i < numNewChunks; ++i) {
754 i * blockedChunkSize);
756 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
758 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
759 convertedOffsets.push_back(offsetVal);
763 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
765 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
767 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
769 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
773 getUnrolledTypes(valueTy, *targetShape);
775 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
777 for (
auto [v, o, m] :
778 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
779 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
781 op.getL1HintAttr(), op.getL2HintAttr(),
790 struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
791 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
792 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
795 xegpu::TensorDescType tdescTy = op.getTensorDescType();
798 if (!tdescTy || op.getOffsets())
801 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
806 getUnrolledTypes(tdescTy, *targetShape);
808 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
810 for (
auto t : convertedTdesc)
811 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t, op->getAttrs());
818 struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
819 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
820 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
824 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
825 xegpu::TensorDescType tdescTy = op.getTensorDescType();
828 if (!tdescTy || op.getOffsets())
831 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
836 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
838 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
841 getUnrolledTypes(tdescTy, *targetShape);
843 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
848 if (originalChunkSize > 1) {
849 targetMaskShape.pop_back();
850 int64_t blockedChunkSize = targetShape->back();
851 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
852 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
855 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
857 convertedMasks.append(numNewChunks, mask);
859 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
860 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
865 getUnrolledTypes(valueTy, *targetShape);
867 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
869 for (
size_t i = 0; i < convertedValues.size(); ++i) {
870 Value v = convertedValues[i];
871 Value t = convertedTdescs[i];
872 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
873 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
874 op.getL2HintAttr(), op.getL3HintAttr());
882 struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
883 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
884 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
887 xegpu::TensorDescType tdescTy = op.getTensorDescType();
889 if (!tdescTy.isScattered())
892 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
897 getUnrolledTypes(tdescTy, *targetShape);
899 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
902 VectorType offsetVecTy = offsetVec.getType();
906 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
907 if (originalChunkSize > 1) {
909 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
911 int64_t blockedChunkSize = targetShape->back();
912 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
914 for (
auto offset :
pack(offsetVec, convertedOffsetTypes,
915 targetOffsetShape, loc, rewriter))
916 convertedOffsetVec.append(numNewChunks, offset);
919 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
921 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
924 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
926 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
927 newOps.push_back(newOp);
929 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
935 struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
936 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
937 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
940 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
941 assert(valueTy &&
"the value type must be vector type!");
943 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
944 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
947 Type elemTy = valueTy.getElementType();
949 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
951 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
958 rewriter, loc, mixedOffsets,
960 offsetsList.push_back(adds);
964 layout = layout.dropInstData();
966 auto newOp = xegpu::LoadMatrixOp::create(
967 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
968 newOps.push_back(newOp);
970 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
976 struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
977 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
978 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
980 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
985 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
986 assert(valueTy &&
"the value type must be vector type!");
988 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
991 getUnrolledTypes(valueTy, *targetShape);
993 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
1000 rewriter, loc, mixedOffsets,
1002 offsetsList.push_back(adds);
1005 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
1006 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
1007 layout.dropInstData());
1019 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
1020 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
1021 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
1022 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
1023 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.