20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/DebugLog.h"
25#define GEN_PASS_DEF_XEGPUUNROLL
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30#define DEBUG_TYPE "xegpu-unroll"
36template <
typename SourceOp>
46 LDBG() <<
"Get unroll shape for: " << *op;
48 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
49 LDBG() <<
"--no filter constraint -> BAIL";
54 "expects the native shape for native shape call back function.");
55 auto nativeShape =
options.nativeShape(op);
61 bool returnSingleType =
false)
const {
62 return options.getUnrolledTypes(type, tileShape, returnSingleType);
69 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
70 auto shape = vecTy.getShape();
74 if (isa<xegpu::TensorDescType>(destTy)) {
79 auto castOp = UnrealizedConversionCastOp::create(
80 rewriter, loc, destTy, srcs,
82 return castOp.getResult(0);
85 llvm_unreachable(
"Unexpected destTy.");
94 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
99 if (isa<xegpu::TensorDescType>(src.
getType())) {
104 auto castOp = UnrealizedConversionCastOp::create(
105 rewriter, loc, destTypes, src,
107 return castOp.getResults();
110 llvm_unreachable(
"Unexpected src type.");
115 const char *
const packAttrName =
"__xegpu_blocking_pack__";
116 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
117 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
133 int64_t rank = tdescTy.getRank();
141 auto aV = llvm::cast<Value>(a);
143 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
148 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
150 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
156 for (
auto [idx, oldOff, offset] :
157 llvm::zip(validIdxes, oldOffsets, offsets))
158 mixedOffsets[idx] = addi(oldOff, offset);
160 auto newOp = createOp(mixedOffsets);
161 newOps.push_back(newOp);
166struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
167 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
168 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
171 xegpu::TensorDescType tdescTy = op.getType();
173 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
179 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
180 bool hasOffsets = op.getMixedOffsets().size() != 0;
182 auto newOp = xegpu::CreateNdDescOp::create(
183 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
184 op.getMixedStrides());
185 newOps.push_back(newOp);
188 return xegpu::CreateNdDescOp::create(
189 rewriter, loc, newTdescTy, op.getSource(), offsets,
190 op.getMixedSizes(), op.getMixedStrides());
193 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
194 *targetShape, createOp, loc, rewriter);
196 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
203struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
204 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
205 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
208 xegpu::TensorDescType tdescTy = op.getTensorDescType();
210 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
215 getUnrolledTypes(tdescTy, *targetShape);
217 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
220 for (
auto t : convertedTdesc) {
221 auto newOp = xegpu::UpdateNdOffsetOp::create(
222 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
223 newOps.push_back(newOp);
225 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
231struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
232 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
233 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
236 xegpu::TensorDescType tdescTy = op.getTensorDescType();
238 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
242 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
244 layout = layout.dropInstData();
245 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
246 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
249 tdescTy, *targetShape, hasOffsets);
252 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
255 for (
auto t : convertedTdesc)
256 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
260 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
261 op.getL1HintAttr(), op.getL2HintAttr(),
262 op.getL3HintAttr(), layout);
267 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
268 createPrefetch, loc, rewriter);
276struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
277 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
278 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
282 VectorType valueTy = op.getType();
283 xegpu::TensorDescType tdescTy = op.getTensorDescType();
285 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
289 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
291 layout = layout.dropInstData();
292 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
293 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
295 Type elemTy = tdescTy.getElementType();
296 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
299 tdescTy, *targetShape, hasOffsets);
302 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
306 for (
auto t : convertedTdescs) {
308 xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
310 newOps.push_back(newOp);
314 return xegpu::LoadNdOp::create(
315 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
316 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
317 op.getL2HintAttr(), op.getL3HintAttr(), layout);
319 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
320 *targetShape, createLoad, loc, rewriter);
323 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
330struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
331 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
332 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
335 VectorType valueTy = op.getValueType();
336 xegpu::TensorDescType tdescTy = op.getTensorDescType();
338 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
342 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
344 layout = layout.dropInstData();
345 int64_t offsetSize =
static_cast<int64_t>(op.getOffsets().size());
346 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
349 getUnrolledTypes(valueTy, *targetShape);
351 tdescTy, *targetShape, hasOffsets);
354 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
357 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
359 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
360 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
361 op.getL2HintAttr(), op.getL3HintAttr());
363 size_t valueIndex = 0;
365 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
366 convertedTdescs[0], offsets,
367 op.getL1HintAttr(), op.getL2HintAttr(),
368 op.getL3HintAttr(), layout);
373 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
374 createStore, loc, rewriter);
382struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
383 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
384 LogicalResult matchAndRewrite(xegpu::DpasOp op,
389 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
390 auto vecTy = dyn_cast<VectorType>(type);
391 return !vecTy || vecTy.getRank() != 2;
397 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
398 if (!targetShape || targetShape->size() != 3)
400 auto M = (*targetShape)[0];
401 auto K = (*targetShape)[1];
402 auto N = (*targetShape)[2];
404 int64_t aBlockSize[2] = {M, K};
405 int64_t bBlockSize[2] = {K, N};
406 int64_t cBlockSize[2] = {M, N};
410 VectorType type = val.getType();
411 std::optional<SmallVector<int64_t>> grids =
413 assert(grids &&
"Expecting grids to be computed.");
417 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
420 pack(val, convertedTypes, blockSize, loc, rewriter);
424 auto a = op.getLhs();
425 auto b = op.getRhs();
426 auto c = op.getAcc();
428 auto aShape = a.getType().getShape();
429 auto bShape =
b.getType().getShape();
432 aVals = packWrapper(a, aBlockSize);
433 bVals = packWrapper(
b, bBlockSize);
436 cVals = packWrapper(c, cBlockSize);
442 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
443 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
446 VectorType resultTy = op.getResult().getType();
447 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
449 int64_t mIters = aShape[0] / M;
450 int64_t kIters = aShape[1] / K;
451 int64_t nIters = bShape[1] / N;
454 for (
int64_t i = 0; i < mIters; ++i) {
458 tmpC = cVals[i * nIters +
j];
460 for (
int64_t k = 0; k < kIters; ++k) {
461 Value aVec = aVals[i * kIters + k];
462 Value bVec = bVals[k * nIters +
j];
465 operands.push_back(tmpC);
468 xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
471 newOps.push_back(tmpC);
474 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
480struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
481 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
482 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
486 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
487 xegpu::TensorDescType tdescTy = op.getTensorDescType();
490 if (!tdescTy || op.getOffsets())
493 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
498 int originalChunkSize = op.getChunkSize().value_or(1);
500 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
502 Type elemTy = tdescTy.getElementType();
503 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
506 getUnrolledTypes(tdescTy, *targetShape);
508 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
513 if (originalChunkSize > 1) {
514 targetMaskShape.pop_back();
515 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
516 int64_t blockedChunkSize = targetShape->back();
517 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
520 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
522 convertedMasks.append(numNewChunks, mask);
524 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
526 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
527 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
532 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
533 auto newOp = xegpu::LoadGatherOp::create(
534 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
535 op.getL2HintAttr(), op.getL3HintAttr());
536 newOps.push_back(newOp);
539 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
549struct UnrollLoadGatherOpWithOffset
550 :
public UnrollPattern<xegpu::LoadGatherOp> {
551 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
552 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
555 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
556 Value offsets = op.getOffsets();
557 Value mask = op.getMask();
563 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
569 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
570 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
571 chunkSize = intAttr.getInt();
575 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
576 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
577 Type elemTy = valueTy.getElementType();
578 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
587 targetMaskShape.pop_back();
588 int64_t blockedChunkSize = targetShape->back();
589 int64_t numNewChunks = chunkSize / blockedChunkSize;
590 chunkSize = blockedChunkSize;
592 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
593 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
596 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
598 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
600 for (
auto maskVal : convertedMasksBase)
601 convertedMasks.append(numNewChunks, maskVal);
603 for (
auto [baseOffset, offsetType] :
604 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
605 for (
int64_t i = 0; i < numNewChunks; ++i) {
607 i * blockedChunkSize);
609 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
611 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
612 convertedOffsets.push_back(offsetVal);
616 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
618 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
620 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
622 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
625 auto layout = op.getLayoutAttr();
627 layout = layout.dropInstData();
630 for (
auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
631 auto newOp = xegpu::LoadGatherOp::create(
632 rewriter, loc, newValueTy, op.getSource(), o, m,
634 op.getL2HintAttr(), op.getL3HintAttr(), layout);
635 newOps.push_back(newOp);
638 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
648struct UnrollStoreScatterOpWithOffsets
649 :
public UnrollPattern<xegpu::StoreScatterOp> {
650 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
651 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
654 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
655 Value offsets = op.getOffsets();
656 Value mask = op.getMask();
662 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
667 if (
auto chunkSizeAttr = op->getAttr(
"chunk_size")) {
668 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
669 chunkSize = intAttr.getInt();
673 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.
getType());
674 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.
getType());
682 targetMaskShape.pop_back();
683 int64_t blockedChunkSize = targetShape->back();
684 int64_t numNewChunks = chunkSize / blockedChunkSize;
685 chunkSize = blockedChunkSize;
687 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
688 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
691 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
693 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
695 for (
auto maskVal : convertedMasksBase)
696 convertedMasks.append(numNewChunks, maskVal);
698 for (
auto [baseOffset, offsetType] :
699 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
700 for (
int64_t i = 0; i < numNewChunks; ++i) {
702 i * blockedChunkSize);
704 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
706 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
707 convertedOffsets.push_back(offsetVal);
711 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
713 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
715 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
717 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
721 getUnrolledTypes(valueTy, *targetShape);
723 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
725 auto layout = op.getLayoutAttr();
727 layout = layout.dropInstData();
729 for (
auto [v, o, m] :
730 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
731 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
733 op.getL1HintAttr(), op.getL2HintAttr(),
734 op.getL3HintAttr(), layout);
742struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
743 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
744 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
747 xegpu::TensorDescType tdescTy = op.getTensorDescType();
750 if (!tdescTy || op.getOffsets())
753 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
758 getUnrolledTypes(tdescTy, *targetShape);
760 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
762 for (
auto t : convertedTdesc)
763 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t,
771struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
772 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
773 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
777 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
778 xegpu::TensorDescType tdescTy = op.getTensorDescType();
781 if (!tdescTy || op.getOffsets())
784 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
789 int originalChunkSize = op.getChunkSize().value_or(1);
791 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
794 getUnrolledTypes(tdescTy, *targetShape);
796 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
801 if (originalChunkSize > 1) {
802 targetMaskShape.pop_back();
803 int64_t blockedChunkSize = targetShape->back();
804 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
805 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
808 for (
auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
810 convertedMasks.append(numNewChunks, mask);
812 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
813 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
818 getUnrolledTypes(valueTy, *targetShape);
820 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
822 for (
size_t i = 0; i < convertedValues.size(); ++i) {
823 Value v = convertedValues[i];
824 Value t = convertedTdescs[i];
825 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
826 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
827 op.getL2HintAttr(), op.getL3HintAttr());
835struct UnrollLoadMatrixOp :
public UnrollPattern<xegpu::LoadMatrixOp> {
836 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
837 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
840 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
841 assert(valueTy &&
"the value type must be vector type!");
843 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
844 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
847 Type elemTy = valueTy.getElementType();
849 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
851 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
858 rewriter, loc, mixedOffsets,
860 offsetsList.push_back(adds);
864 layout = layout.dropInstData();
866 auto newOp = xegpu::LoadMatrixOp::create(
867 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
868 newOps.push_back(newOp);
870 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
876struct UnrollStoreMatrixOp :
public UnrollPattern<xegpu::StoreMatrixOp> {
877 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
878 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
880 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
885 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
886 assert(valueTy &&
"the value type must be vector type!");
888 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
891 getUnrolledTypes(valueTy, *targetShape);
893 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
900 rewriter, loc, mixedOffsets,
902 offsetsList.push_back(adds);
905 for (
auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
906 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
907 layout.dropInstData());
919struct UnrollConvertLayoutOp :
public UnrollPattern<xegpu::ConvertLayoutOp> {
920 using UnrollPattern<xegpu::ConvertLayoutOp>::UnrollPattern;
921 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
924 Type valType = op.getType();
926 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
927 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
928 if (!inputLayout || !targetLayout)
933 assert(!inputLayout.dropInstData() && !targetLayout.dropInstData() &&
934 "unexpected layout attributes for scalar type");
938 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
939 targetLayout.getEffectiveInstDataAsInt().empty())
942 inputLayout = inputLayout.dropInstData();
943 targetLayout = targetLayout.dropInstData();
945 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
946 assert(valueTy &&
"the value type must be vector type!");
948 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
949 if (!targetShape || targetShape->size() != (
size_t)valueTy.getRank())
952 Value newSource = op.getSource();
954 if (inputLayout && targetLayout) {
956 getUnrolledTypes(valueTy, *targetShape);
958 pack(op.getOperand(), convertedValTypes, *targetShape, loc, rewriter);
959 for (
auto [v, t] : llvm::zip(convertedValues, convertedValTypes)) {
960 auto newOp = xegpu::ConvertLayoutOp::create(rewriter, loc, t, v,
961 inputLayout, targetLayout);
962 newOps.push_back(newOp);
964 newSource = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
977 .
add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
978 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollLoadGatherOp,
979 UnrollStoreScatterOp, UnrollPrefetchOp, UnrollLoadMatrixOp,
980 UnrollStoreMatrixOp, UnrollLoadGatherOpWithOffset,
981 UnrollStoreScatterOpWithOffsets, UnrollConvertLayoutOp>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.