19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
24#define GEN_PASS_DEF_XEGPUUNROLL
25#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29#define DEBUG_TYPE "xegpu-unroll"
35template <
typename SourceOp>
45 LDBG() <<
"Get unroll shape for: " << *op;
47 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
48 LDBG() <<
"--no filter constraint -> BAIL";
53 "expects the native shape for native shape call back function.");
54 auto nativeShape =
options.nativeShape(op);
60 bool returnSingleType =
false)
const {
61 return options.getUnrolledTypes(type, tileShape, returnSingleType);
68 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
69 auto shape = vecTy.getShape();
73 if (isa<xegpu::TensorDescType>(destTy)) {
78 auto castOp = UnrealizedConversionCastOp::create(
79 rewriter, loc, destTy, srcs,
81 return castOp.getResult(0);
84 llvm_unreachable(
"Unexpected destTy.");
93 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
98 if (isa<xegpu::TensorDescType>(src.
getType())) {
103 auto castOp = UnrealizedConversionCastOp::create(
104 rewriter, loc, destTypes, src,
106 return castOp.getResults();
109 llvm_unreachable(
"Unexpected src type.");
114 const char *
const packAttrName =
"__xegpu_blocking_pack__";
115 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
116 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
132 int64_t rank = tdescTy.getRank();
140 auto aV = llvm::cast<Value>(a);
142 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
147 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
149 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
155 for (
auto [idx, oldOff, offset] :
156 llvm::zip(validIdxes, oldOffsets, offsets))
157 mixedOffsets[idx] = addi(oldOff, offset);
159 auto newOp = createOp(mixedOffsets);
160 newOps.push_back(newOp);
165struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
166 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
167 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
170 xegpu::TensorDescType tdescTy = op.getType();
172 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
178 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
179 bool hasOffsets = op.getMixedOffsets().size() != 0;
181 auto newOp = xegpu::CreateNdDescOp::create(
182 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
183 op.getMixedStrides());
184 newOps.push_back(newOp);
187 return xegpu::CreateNdDescOp::create(
188 rewriter, loc, newTdescTy, op.getSource(), offsets,
189 op.getMixedSizes(), op.getMixedStrides());
192 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
193 *targetShape, createOp, loc, rewriter);
195 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
202struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
203 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
204 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
207 xegpu::TensorDescType tdescTy = op.getTensorDescType();
209 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
214 getUnrolledTypes(tdescTy, *targetShape);
216 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
219 for (
auto t : convertedTdesc) {
220 auto newOp = xegpu::UpdateNdOffsetOp::create(
221 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
222 newOps.push_back(newOp);
224 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
230struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
231 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
232 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
235 xegpu::TensorDescType tdescTy = op.getTensorDescType();
237 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
241 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
243 layout = layout.dropInstData();
244 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
245 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
248 tdescTy, *targetShape, hasOffsets);
251 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
254 for (
auto t : convertedTdesc)
255 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
259 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
260 op.getL1HintAttr(), op.getL2HintAttr(),
261 op.getL3HintAttr(), layout);
266 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
267 createPrefetch, loc, rewriter);
275struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
276 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
277 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
281 VectorType valueTy = op.getType();
282 xegpu::TensorDescType tdescTy = op.getTensorDescType();
284 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
288 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
290 layout = layout.dropInstData();
291 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
292 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
294 Type elemTy = tdescTy.getElementType();
295 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
298 tdescTy, *targetShape, hasOffsets);
301 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
305 for (
auto t : convertedTdescs) {
306 auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
308 newOps.push_back(newOp);
312 return xegpu::LoadNdOp::create(
313 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
314 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
315 op.getL2HintAttr(), op.getL3HintAttr(), layout);
317 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
318 *targetShape, createLoad, loc, rewriter);
321 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
328struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
329 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
330 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
333 VectorType valueTy = op.getValueType();
334 xegpu::TensorDescType tdescTy = op.getTensorDescType();
336 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
340 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
342 layout = layout.dropInstData();
343 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
344 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
347 getUnrolledTypes(valueTy, *targetShape);
349 tdescTy, *targetShape, hasOffsets);
352 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
355 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
357 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
358 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
359 op.getL2HintAttr(), op.getL3HintAttr());
361 size_t valueIndex = 0;
363 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
364 convertedTdescs[0], offsets,
365 op.getL1HintAttr(), op.getL2HintAttr(),
366 op.getL3HintAttr(), layout);
371 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
372 createStore, loc, rewriter);
380struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
381 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
382 LogicalResult matchAndRewrite(xegpu::DpasOp op,
387 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
388 auto vecTy = dyn_cast<VectorType>(type);
389 return !vecTy || vecTy.getRank() != 2;
395 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
396 if (!targetShape || targetShape->size() != 3)
398 auto M = (*targetShape)[0];
399 auto K = (*targetShape)[1];
400 auto N = (*targetShape)[2];
402 int64_t aBlockSize[2] = {M, K};
403 int64_t bBlockSize[2] = {K, N};
404 int64_t cBlockSize[2] = {M, N};
408 VectorType type = val.getType();
409 std::optional<SmallVector<int64_t>> grids =
411 assert(grids &&
"Expecting grids to be computed.");
415 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
418 pack(val, convertedTypes, blockSize, loc, rewriter);
422 auto a = op.getLhs();
423 auto b = op.getRhs();
424 auto c = op.getAcc();
426 auto aShape = a.getType().getShape();
427 auto bShape =
b.getType().getShape();
430 aVals = packWrapper(a, aBlockSize);
431 bVals = packWrapper(
b, bBlockSize);
434 cVals = packWrapper(c, cBlockSize);
440 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
441 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
444 VectorType resultTy = op.getResult().getType();
445 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
447 int64_t mIters = aShape[0] / M;
448 int64_t kIters = aShape[1] / K;
449 int64_t nIters = bShape[1] / N;
452 for (
int64_t i = 0; i < mIters; ++i) {
456 tmpC = cVals[i * nIters +
j];
458 for (
int64_t k = 0; k < kIters; ++k) {
459 Value aVec = aVals[i * kIters + k];
460 Value bVec = bVals[k * nIters +
j];
463 operands.push_back(tmpC);
465 tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
468 newOps.push_back(tmpC);
471 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
477struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
478 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
479 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
482 xegpu::TensorDescType tdescTy = op.getType();
484 VectorType indiceVecTy = indiceVec.getType();
486 if (!tdescTy.isScattered())
489 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
494 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
496 if (originalChunkSize > 1)
497 targetIndiceShape.pop_back();
499 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
501 getUnrolledTypes(indiceVecTy, targetIndiceShape);
503 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
509 if (originalChunkSize > 1) {
510 int64_t blockedChunkSize = targetShape->back();
511 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
513 for (
auto [indice, indiceType] :
514 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
515 for (
int64_t i = 0; i < numNewChunks; ++i) {
518 i * blockedChunkSize);
520 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
522 arith::AddIOp::create(rewriter, loc, indice, incVec);
524 auto newOp = xegpu::CreateDescOp::create(
525 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
527 newOps.push_back(newOp);
531 for (
auto indice : convertedIndiceVec) {
532 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
533 op.getSource(), indice);
534 newOps.push_back(newOp);
538 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
545struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
546 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
547 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
551 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
552 xegpu::TensorDescType tdescTy = op.getTensorDescType();
555 if (!tdescTy || op.getOffsets())
558 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
563 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
565 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
567 Type elemTy = tdescTy.getElementType();
568 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
571 getUnrolledTypes(tdescTy, *targetShape);
573 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
578 if (originalChunkSize > 1) {
579 targetMaskShape.pop_back();
580 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
581 int64_t blockedChunkSize = targetShape->back();
582 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
585 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
587 convertedMasks.append(numNewChunks, mask);
589 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
591 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
592 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
597 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
598 auto newOp = xegpu::LoadGatherOp::create(
599 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
600 op.getL2HintAttr(), op.getL3HintAttr());
601 newOps.push_back(newOp);
604 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
614struct UnrollLoadGatherOpWithOffset
615 :
public UnrollPattern<xegpu::LoadGatherOp> {
616 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
617 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
620 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
621 Value offsets = op.getOffsets();
622 Value mask = op.getMask();
628 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
634 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
635 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
636 chunkSize = intAttr.getInt();
640 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
641 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
642 Type elemTy = valueTy.getElementType();
643 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
652 targetMaskShape.pop_back();
653 int64_t blockedChunkSize = targetShape->back();
654 int64_t numNewChunks = chunkSize / blockedChunkSize;
655 chunkSize = blockedChunkSize;
657 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
658 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
661 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
663 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
665 for (
auto maskVal : convertedMasksBase)
666 convertedMasks.append(numNewChunks, maskVal);
668 for (
auto [baseOffset, offsetType] :
669 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
670 for (
int64_t i = 0; i < numNewChunks; ++i) {
672 i * blockedChunkSize);
674 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
676 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
677 convertedOffsets.push_back(offsetVal);
681 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
683 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
685 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
687 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
690 auto layout = op.getLayoutAttr();
692 layout = layout.dropInstData();
695 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
696 auto newOp = xegpu::LoadGatherOp::create(
697 rewriter, loc, newValueTy, op.getSource(), o, m,
699 op.getL2HintAttr(), op.getL3HintAttr(), layout);
700 newOps.push_back(newOp);
703 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
713struct UnrollStoreScatterOpWithOffsets
714 :
public UnrollPattern<xegpu::StoreScatterOp> {
715 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
716 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
719 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
720 Value offsets = op.getOffsets();
721 Value mask = op.getMask();
727 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
732 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
733 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
734 chunkSize = intAttr.getInt();
738 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
739 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
747 targetMaskShape.pop_back();
748 int64_t blockedChunkSize = targetShape->back();
749 int64_t numNewChunks = chunkSize / blockedChunkSize;
750 chunkSize = blockedChunkSize;
752 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
753 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
756 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
758 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
760 for (
auto maskVal : convertedMasksBase)
761 convertedMasks.append(numNewChunks, maskVal);
763 for (
auto [baseOffset, offsetType] :
764 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
765 for (
int64_t i = 0; i < numNewChunks; ++i) {
767 i * blockedChunkSize);
769 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
771 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
772 convertedOffsets.push_back(offsetVal);
776 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
778 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
780 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
782 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
786 getUnrolledTypes(valueTy, *targetShape);
788 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
790 auto layout = op.getLayoutAttr();
792 layout = layout.dropInstData();
794 for (
auto [v, o, m] :
795 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
796 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
798 op.getL1HintAttr(), op.getL2HintAttr(),
799 op.getL3HintAttr(), layout);
807struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
808 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
809 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
812 xegpu::TensorDescType tdescTy = op.getTensorDescType();
815 if (!tdescTy || op.getOffsets())
818 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
823 getUnrolledTypes(tdescTy, *targetShape);
825 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
827 for (
auto t : convertedTdesc)
828 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t, op->getAttrs());
835struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
836 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
837 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
841 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
842 xegpu::TensorDescType tdescTy = op.getTensorDescType();
845 if (!tdescTy || op.getOffsets())
848 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
853 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
855 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
858 getUnrolledTypes(tdescTy, *targetShape);
860 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
865 if (originalChunkSize > 1) {
866 targetMaskShape.pop_back();
867 int64_t blockedChunkSize = targetShape->back();
868 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
869 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
872 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
874 convertedMasks.append(numNewChunks, mask);
876 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
877 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
882 getUnrolledTypes(valueTy, *targetShape);
884 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
886 for (
size_t i = 0; i < convertedValues.size(); ++i) {
887 Value v = convertedValues[i];
888 Value t = convertedTdescs[i];
889 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
890 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
891 op.getL2HintAttr(), op.getL3HintAttr());
899struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
900 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
901 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
904 xegpu::TensorDescType tdescTy = op.getTensorDescType();
906 if (!tdescTy.isScattered())
909 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
914 getUnrolledTypes(tdescTy, *targetShape);
916 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
919 VectorType offsetVecTy = offsetVec.getType();
923 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
924 if (originalChunkSize > 1) {
926 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
928 int64_t blockedChunkSize = targetShape->back();
929 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
931 for (
auto offset : pack(offsetVec, convertedOffsetTypes,
932 targetOffsetShape, loc, rewriter))
933 convertedOffsetVec.append(numNewChunks, offset);
936 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
938 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
941 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
943 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
944 newOps.push_back(newOp);
946 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
952struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
953 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
954 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
957 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
958 assert(valueTy &&
"the value type must be vector type!");
960 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
961 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
964 Type elemTy = valueTy.getElementType();
966 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
968 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
975 rewriter, loc, mixedOffsets,
977 offsetsList.push_back(adds);
981 layout = layout.dropInstData();
983 auto newOp = xegpu::LoadMatrixOp::create(
984 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
985 newOps.push_back(newOp);
987 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
993struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
994 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
995 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
997 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
1002 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
1003 assert(valueTy &&
"the value type must be vector type!");
1005 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
1008 getUnrolledTypes(valueTy, *targetShape);
1010 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
1017 rewriter, loc, mixedOffsets,
1019 offsetsList.push_back(adds);
1022 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
1023 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
1024 layout.dropInstData());
1036 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
1037 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
1038 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
1039 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
1040 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.