23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/DebugLog.h"
28#define GEN_PASS_DEF_XEGPUUNROLL
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
33#define DEBUG_TYPE "xegpu-unroll"
46template <
typename SourceOp>
56 LDBG() <<
"Get unroll shape for: " << *op;
58 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
59 LDBG() <<
"--no filter constraint -> BAIL";
64 "expects the native shape for native shape call back function.");
65 auto nativeShape =
options.nativeShape(op);
71 return options.getUnrolledTypes(type, tileShape);
78 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
79 auto shape = vecTy.getShape();
83 if (isa<xegpu::TensorDescType>(destTy)) {
88 auto castOp = UnrealizedConversionCastOp::create(
89 rewriter, loc, destTy, srcs,
91 return castOp.getResult(0);
94 llvm_unreachable(
"Unexpected destTy.");
103 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
108 if (isa<xegpu::TensorDescType>(src.
getType())) {
113 auto castOp = UnrealizedConversionCastOp::create(
114 rewriter, loc, destTypes, src,
116 return castOp.getResults();
119 llvm_unreachable(
"Unexpected src type.");
135 Value srcTdesc, xegpu::TensorDescType tdescTy,
147 pack(srcTdesc, batchTdescTypes, targetShape, loc, rewriter);
149 auto innerTdescTy = xegpu::TensorDescType::get(
150 tdescTy.getContext(), innerShape, tdescTy.getElementType(),
151 tdescTy.getEncoding(),
nullptr);
157 for (
Value batchTdesc : batchTdescs) {
161 fullOffsets.append(offsets.begin(), offsets.end());
162 return createOp(batchTdesc, fullOffsets);
164 auto perBatch = unrollByTile(innerOffsets, innerTdescTy, innerTarget,
165 wrappedCreate, loc, rewriter);
166 newOps.append(perBatch.begin(), perBatch.end());
177 auto vecType = cast<VectorType>(operand.
getType());
178 std::optional<SmallVector<int64_t>> grids =
180 assert(grids &&
"Expecting grids to be computed.");
184 VectorType newVecTy =
185 vecType.cloneWith(blockSize, vecType.getElementType());
187 return pack(operand, convertedTypes, blockSize, loc, rewriter);
191 const char *
const packAttrName =
"__xegpu_blocking_pack__";
192 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
193 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
207 int64_t rank = tdescTy.getRank();
215 auto aV = llvm::cast<Value>(a);
217 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
222 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
224 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
230 for (
auto [idx, oldOff, offset] :
231 llvm::zip(validIdxes, oldOffsets, offsets))
232 mixedOffsets[idx] = addi(oldOff, offset);
234 auto newOp = createOp(mixedOffsets);
235 newOps.push_back(newOp);
240struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
241 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
242 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
245 xegpu::TensorDescType tdescTy = op.getType();
247 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
251 int64_t rank = tdescTy.getRank();
255 if (batchRank <= 0 || !isa<MemRefType>(op.getSourceType())) {
257 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
258 auto newOp = xegpu::CreateNdDescOp::create(
259 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
260 op.getMixedStrides());
261 newOps.push_back(newOp);
262 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
274 targetShape->begin() + batchRank);
275 batchBlockSize.append(
shape.begin() + batchRank,
shape.end());
278 cast<xegpu::TensorDescType>(getUnrolledTypes(tdescTy, *targetShape)[0]);
289 for (
int64_t off : batchOffsets)
293 for (
int64_t d : batchBlockSize)
298 auto subview = memref::SubViewOp::create(rewriter, loc, op.getSource(),
299 subviewOffsets, subviewSizes,
302 auto newOp = xegpu::CreateNdDescOp::create(
303 rewriter, loc, newTdescTy,
305 newOps.push_back(newOp);
308 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
314struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
315 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
316 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
319 xegpu::TensorDescType tdescTy = op.getTensorDescType();
321 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
325 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
327 layout = layout.dropInstData();
329 int64_t rank = tdescTy.getRank();
332 if (batchRank <= 0) {
334 getUnrolledTypes(tdescTy, *targetShape);
336 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
339 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
340 op.getL1HintAttr(), op.getL2HintAttr(),
341 op.getL3HintAttr(), layout);
344 unrollByTile(op.getMixedOffsets(), tdescTy, *targetShape, createPrefetch,
349 auto createPrefetch =
351 xegpu::PrefetchNdOp::create(rewriter, loc, tdesc, fullOffsets,
352 op.getL1HintAttr(), op.getL2HintAttr(),
353 op.getL3HintAttr(), layout);
356 this->unrollNdBatch(op.getTensorDesc(), tdescTy, *targetShape,
357 op.getMixedOffsets(), batchRank, createPrefetch, loc,
366struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
367 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
368 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
372 VectorType valueTy = op.getType();
373 xegpu::TensorDescType tdescTy = op.getTensorDescType();
375 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
379 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
381 layout = layout.dropInstData();
383 Type elemTy = tdescTy.getElementType();
384 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
386 int64_t rank = tdescTy.getRank();
390 if (batchRank <= 0) {
393 getUnrolledTypes(tdescTy, *targetShape);
395 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
398 return xegpu::LoadNdOp::create(
399 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
400 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
401 op.getL2HintAttr(), op.getL3HintAttr(), layout);
403 newOps = unrollByTile(op.getMixedOffsets(), tdescTy, *targetShape,
404 createLoad, loc, rewriter);
408 auto createLoad = [&](
Value tdesc,
410 return xegpu::LoadNdOp::create(
411 rewriter, loc, newValueTy, tdesc, fullOffsets, op.getPackedAttr(),
412 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
413 op.getL3HintAttr(), layout);
415 newOps = this->unrollNdBatch(op.getTensorDesc(), tdescTy, *targetShape,
416 op.getMixedOffsets(), batchRank, createLoad,
420 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
426struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
427 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
428 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
431 VectorType valueTy = op.getValueType();
432 xegpu::TensorDescType tdescTy = op.getTensorDescType();
434 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
438 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
440 layout = layout.dropInstData();
443 getUnrolledTypes(valueTy, *targetShape);
446 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
448 int64_t rank = tdescTy.getRank();
450 size_t valueIndex = 0;
452 if (batchRank <= 0) {
454 getUnrolledTypes(tdescTy, *targetShape);
456 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
459 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
460 convertedTdescs[0], offsets,
461 op.getL1HintAttr(), op.getL2HintAttr(),
462 op.getL3HintAttr(), layout);
463 return (
Value)
nullptr;
465 unrollByTile(op.getMixedOffsets(), tdescTy, *targetShape, createStore,
473 auto createStore = [&](
Value tdesc,
475 xegpu::StoreNdOp::create(
476 rewriter, loc, convertedValues[valueIndex++], tdesc, fullOffsets,
477 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout);
480 this->unrollNdBatch(op.getTensorDesc(), tdescTy, *targetShape,
481 op.getMixedOffsets(), batchRank, createStore, loc,
490struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
491 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
492 LogicalResult matchAndRewrite(xegpu::DpasOp op,
496 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
497 if (!targetShape || targetShape->size() < 3)
501 int64_t tsRank = targetShape->size();
502 auto M = (*targetShape)[tsRank - 3];
503 auto K = (*targetShape)[tsRank - 2];
504 auto N = (*targetShape)[tsRank - 1];
509 aBlockSize.push_back(M);
510 aBlockSize.push_back(K);
512 bBlockSize.push_back(K);
513 bBlockSize.push_back(N);
515 cBlockSize.push_back(M);
516 cBlockSize.push_back(N);
518 auto a = op.getLhs();
519 auto b = op.getRhs();
520 auto c = op.getAcc();
526 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
530 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
531 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
534 VectorType resultTy = op.getResult().getType();
535 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
537 auto aShape = a.getType().getShape();
538 auto bShape =
b.getType().getShape();
542 int64_t batchRank = batchDims.size();
543 int64_t mIters = aShape[batchRank] / M;
544 int64_t kIters = aShape[batchRank + 1] / K;
545 int64_t nIters = bShape[batchRank + 1] / N;
549 for (
int64_t d = 0; d < batchRank; ++d)
550 batchIters *= aShape[d] / batchDims[d];
553 for (
int64_t batch = 0; batch < batchIters; ++batch) {
554 for (
int64_t i = 0; i < mIters; ++i) {
558 tmpC = cVals[batch * (mIters * nIters) + i * nIters +
j];
560 for (
int64_t k = 0; k < kIters; ++k) {
561 Value aVec = aVals[batch * (mIters * kIters) + i * kIters + k];
562 Value bVec = bVals[batch * (kIters * nIters) + k * nIters +
j];
565 operands.push_back(tmpC);
567 tmpC = xegpu::DpasOp::create(
568 rewriter, loc, vecTy, operands,
571 newOps.push_back(tmpC);
575 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
581struct UnrollDpasMxOp :
public UnrollPattern<xegpu::DpasMxOp> {
582 using UnrollPattern<xegpu::DpasMxOp>::UnrollPattern;
583 LogicalResult matchAndRewrite(xegpu::DpasMxOp op,
587 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
588 if (!targetShape || targetShape->size() < 4)
592 int64_t tsRank = targetShape->size();
593 auto M = (*targetShape)[tsRank - 4];
594 auto K = (*targetShape)[tsRank - 3];
595 auto N = (*targetShape)[tsRank - 2];
596 auto S = (*targetShape)[tsRank - 1];
600 aBlockSize.push_back(M);
601 aBlockSize.push_back(K);
603 bBlockSize.push_back(K);
604 bBlockSize.push_back(N);
606 cBlockSize.push_back(M);
607 cBlockSize.push_back(N);
609 aScaleBlockSize.push_back(M);
610 aScaleBlockSize.push_back(S);
612 bScaleBlockSize.push_back(S);
613 bScaleBlockSize.push_back(N);
617 auto c = op.getAcc();
618 auto ascale = dyn_cast<TypedValue<VectorType>>(op.getScaleA());
619 auto bscale = dyn_cast<TypedValue<VectorType>>(op.getScaleB());
625 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
628 aScaleVals = packOperandForDpas(ascale, aScaleBlockSize, loc, rewriter);
631 bScaleVals = packOperandForDpas(bscale, bScaleBlockSize, loc, rewriter);
633 VectorType resultTy = op.getResult().getType();
634 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
636 auto aShape = a.getType().getShape();
637 auto bShape =
b.getType().getShape();
638 int64_t batchRank = batchDims.size();
639 int64_t mIters = aShape[batchRank] / M;
640 int64_t kIters = aShape[batchRank + 1] / K;
641 int64_t nIters = bShape[batchRank + 1] / N;
644 for (
int64_t d = 0; d < batchRank; ++d)
645 batchIters *= aShape[d] / batchDims[d];
648 xegpu::DpasMxOp newDpasMxOp;
649 for (
int64_t batch = 0; batch < batchIters; ++batch) {
650 for (
int64_t i = 0; i < mIters; ++i) {
654 tmpC = cVals[batch * (mIters * nIters) + i * nIters +
j];
656 for (
int64_t k = 0; k < kIters; ++k) {
657 Value aVec = aVals[batch * (mIters * kIters) + i * kIters + k];
658 Value bVec = bVals[batch * (kIters * nIters) + k * nIters +
j];
661 operands.push_back(tmpC);
664 aScaleVals[batch * (mIters * kIters) + i * kIters + k]);
667 bScaleVals[batch * (kIters * nIters) + k * nIters +
j]);
669 newDpasMxOp = xegpu::DpasMxOp::create(
670 rewriter, loc, vecTy, operands,
672 tmpC = newDpasMxOp.getResult();
674 newOps.push_back(newDpasMxOp);
678 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
688struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
689 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
690 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
693 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
694 Value offsets = op.getOffsets();
695 Value mask = op.getMask();
697 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
703 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
704 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
705 chunkSize = intAttr.getInt();
709 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
710 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
711 Type elemTy = valueTy.getElementType();
712 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
721 targetMaskShape.pop_back();
722 int64_t blockedChunkSize = targetShape->back();
723 int64_t numNewChunks = chunkSize / blockedChunkSize;
724 chunkSize = blockedChunkSize;
726 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
727 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
730 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
732 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
734 for (
auto maskVal : convertedMasksBase)
735 convertedMasks.append(numNewChunks, maskVal);
737 for (
auto [baseOffset, offsetType] :
738 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
739 for (
int64_t i = 0; i < numNewChunks; ++i) {
741 i * blockedChunkSize);
743 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
745 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
746 convertedOffsets.push_back(offsetVal);
750 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
752 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
754 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
756 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
759 auto layout = op.getLayoutAttr();
761 layout = layout.dropInstData();
764 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
765 auto newOp = xegpu::LoadGatherOp::create(
766 rewriter, loc, newValueTy, op.getSource(), o, m,
768 op.getL2HintAttr(), op.getL3HintAttr(), layout);
769 newOps.push_back(newOp);
772 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
782struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
783 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
784 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
787 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
788 Value offsets = op.getOffsets();
789 Value mask = op.getMask();
791 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
796 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
797 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
798 chunkSize = intAttr.getInt();
802 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
803 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
811 targetMaskShape.pop_back();
812 int64_t blockedChunkSize = targetShape->back();
813 int64_t numNewChunks = chunkSize / blockedChunkSize;
814 chunkSize = blockedChunkSize;
816 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
817 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
820 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
822 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
824 for (
auto maskVal : convertedMasksBase)
825 convertedMasks.append(numNewChunks, maskVal);
827 for (
auto [baseOffset, offsetType] :
828 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
829 for (
int64_t i = 0; i < numNewChunks; ++i) {
831 i * blockedChunkSize);
833 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
835 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
836 convertedOffsets.push_back(offsetVal);
840 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
842 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
844 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
846 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
850 getUnrolledTypes(valueTy, *targetShape);
852 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
854 auto layout = op.getLayoutAttr();
856 layout = layout.dropInstData();
858 for (
auto [v, o, m] :
859 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
860 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
862 op.getL1HintAttr(), op.getL2HintAttr(),
863 op.getL3HintAttr(), layout);
871struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
872 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
873 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
876 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
877 assert(valueTy &&
"the value type must be vector type!");
879 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
880 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
883 Type elemTy = valueTy.getElementType();
885 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
887 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
894 rewriter, loc, mixedOffsets,
896 offsetsList.push_back(adds);
901 layout = layout.dropInstData();
903 auto newOp = xegpu::LoadMatrixOp::create(
904 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
905 newOps.push_back(newOp);
907 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
913struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
914 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
915 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
917 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
922 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
923 assert(valueTy &&
"the value type must be vector type!");
925 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
927 layout = layout.dropInstData();
930 getUnrolledTypes(valueTy, *targetShape);
932 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
939 rewriter, loc, mixedOffsets,
941 offsetsList.push_back(adds);
944 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
945 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
958struct UnrollConvertLayoutOp :
public UnrollPattern<xegpu::ConvertLayoutOp> {
959 using UnrollPattern<xegpu::ConvertLayoutOp>::UnrollPattern;
960 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
963 Type valType = op.getType();
965 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
966 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
967 if (!inputLayout || !targetLayout)
972 assert(!inputLayout.dropInstData() && !targetLayout.dropInstData() &&
973 "unexpected layout attributes for scalar type");
977 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
978 targetLayout.getEffectiveInstDataAsInt().empty())
981 inputLayout = inputLayout.dropInstData();
982 targetLayout = targetLayout.dropInstData();
984 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
985 assert(valueTy &&
"the value type must be vector type!");
987 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
988 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
991 Value newSource = op.getSource();
993 if (inputLayout && targetLayout) {
995 getUnrolledTypes(valueTy, *targetShape);
997 pack(op.getOperand(), convertedValTypes, *targetShape, loc, rewriter);
998 for (
auto [v, t] : llvm::zip(convertedValues, convertedValTypes)) {
999 auto newOp = xegpu::ConvertLayoutOp::create(rewriter, loc, t, v,
1000 inputLayout, targetLayout);
1001 newOps.push_back(newOp);
1003 newSource = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
1025struct UnrollMultiReductionOp
1026 :
public UnrollPattern<vector::MultiDimReductionOp> {
1030 : UnrollPattern<vector::MultiDimReductionOp>(context,
options, benefit) {}
1032 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
1034 VectorType srcTy = reductionOp.getSourceVectorType();
1036 int64_t srcRank = srcTy.getRank();
1038 Location loc = reductionOp.getLoc();
1039 Value source = reductionOp.getSource();
1041 vector::CombiningKind kind = reductionOp.getKind();
1044 auto resultType = dyn_cast<VectorType>(reductionOp.getDestType());
1048 std::optional<SmallVector<int64_t>> targetShapeOpt =
1050 if (!targetShapeOpt ||
1051 static_cast<int64_t>(targetShapeOpt->size()) != srcRank)
1057 for (
int64_t i = 0; i < srcRank; ++i) {
1058 if (srcShape[i] % targetShape[i] != 0)
1065 for (
int64_t i = 0; i < srcRank; ++i) {
1066 if (reductionMask[i])
1067 reducedDims.push_back(i);
1069 keptDims.push_back(i);
1076 numReducedTilesPerDim.push_back(srcShape[d] / targetShape[d]);
1081 keptShape.push_back(srcShape[d]);
1082 keptTileShape.push_back(targetShape[d]);
1086 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1099 for (
auto [idx, dim] : llvm::enumerate(keptDims))
1100 baseOffsets[dim] = keptOffsets[idx];
1116 for (
auto [idx, dim] : llvm::enumerate(reducedDims))
1117 offsets[dim] = reducedTileIdx[idx] * targetShape[dim];
1119 Value tile = vector::ExtractStridedSliceOp::create(
1120 rewriter, loc, source, offsets, targetShape, strides);
1121 tiles.push_back(
tile);
1125 Value reduced = tiles[0];
1126 for (
size_t i = 1; i < tiles.size(); ++i)
1132 Value accSlice = vector::ExtractStridedSliceOp::create(
1133 rewriter, loc,
acc, keptOffsets, keptTileShape, accStrides);
1135 auto newReduction = vector::MultiDimReductionOp::create(
1136 rewriter, loc, reduced, accSlice, reductionMask, kind);
1140 result = vector::InsertStridedSliceOp::create(
1141 rewriter, loc, newReduction,
result, keptOffsets, dstStrides);
1154 .
add<UnrollCreateNdOp, UnrollPrefetchNdOp, UnrollLoadNdOp,
1155 UnrollStoreNdOp, UnrollDpasOp, UnrollDpasMxOp, UnrollLoadMatrixOp,
1156 UnrollStoreMatrixOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
1157 UnrollConvertLayoutOp, UnrollMultiReductionOp>(patterns.
getContext(),
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.