19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
24 #define GEN_PASS_DEF_XEGPUUNROLL
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "xegpu-unroll"
35 template <
typename SourceOp>
45 LDBG() <<
"Get unroll shape for: " << *op;
47 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
48 LDBG() <<
"--no filter constraint -> BAIL";
53 "expects the native shape for native shape call back function.");
54 auto nativeShape =
options.nativeShape(op);
60 return options.getUnrolledTypes(type, tileShape);
67 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
68 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
69 "Expecting blockSize size to match the rank of destTy.");
70 auto shape = vecTy.getShape();
74 if (isa<xegpu::TensorDescType>(destTy)) {
79 auto castOp = UnrealizedConversionCastOp::create(
80 rewriter, loc, destTy, srcs,
82 return castOp.getResult(0);
85 llvm_unreachable(
"Unexpected destTy.");
94 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
95 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
96 "Expecting blockSize size to match the rank of src.");
101 if (isa<xegpu::TensorDescType>(src.
getType())) {
106 auto castOp = UnrealizedConversionCastOp::create(
107 rewriter, loc, destTypes, src,
109 return castOp.getResults();
112 llvm_unreachable(
"Unexpected src type.");
117 const char *
const packAttrName =
"__xegpu_blocking_pack__";
118 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
119 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
124 struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
125 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
126 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
129 xegpu::TensorDescType tdescTy = op.getType();
130 int64_t rank = tdescTy.getRank();
133 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
137 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
144 auto aV = llvm::cast<Value>(a);
146 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
155 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
157 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
163 for (
auto [idx, oldOff, offset] :
164 llvm::zip(validIdxes, oldOffsets, offsets))
165 mixedOffsets[idx] = addi(oldOff, offset);
167 auto newOp = xegpu::CreateNdDescOp::create(
168 rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
169 op.getMixedSizes(), op.getMixedStrides());
170 newOps.push_back(newOp);
172 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
179 struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
180 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
181 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
184 xegpu::TensorDescType tdescTy = op.getTensorDescType();
186 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
191 getUnrolledTypes(tdescTy, *targetShape);
193 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
196 for (
auto t : convertedTdesc) {
197 auto newOp = xegpu::UpdateNdOffsetOp::create(
198 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
199 newOps.push_back(newOp);
201 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
207 struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
208 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
209 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
212 xegpu::TensorDescType tdescTy = op.getTensorDescType();
214 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
218 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
219 if ((offsetSize != 0) || op.getConstOffsetsAttr())
223 getUnrolledTypes(tdescTy, *targetShape);
225 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
227 for (
auto t : convertedTdesc)
228 xegpu::PrefetchNdOp::create(rewriter, loc,
TypeRange(), t,
236 struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
237 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
242 VectorType valueTy = op.getType();
243 xegpu::TensorDescType tdescTy = op.getTensorDescType();
245 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
249 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
250 if ((offsetSize != 0) || op.getConstOffsetsAttr())
253 Type elemTy = tdescTy.getElementType();
254 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
257 getUnrolledTypes(tdescTy, *targetShape);
259 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
262 for (
auto t : convertedTdescs) {
264 xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
265 newOps.push_back(newOp);
268 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
275 struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
276 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
277 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
280 VectorType valueTy = op.getValueType();
281 xegpu::TensorDescType tdescTy = op.getTensorDescType();
283 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
287 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
288 if ((offsetSize != 0) || op.getConstOffsetsAttr())
292 getUnrolledTypes(valueTy, *targetShape);
294 getUnrolledTypes(tdescTy, *targetShape);
297 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
299 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
301 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
302 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
303 op.getL2HintAttr(), op.getL3HintAttr());
310 struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
311 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
312 LogicalResult matchAndRewrite(xegpu::DpasOp op,
317 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
318 auto vecTy = dyn_cast<VectorType>(type);
319 return !vecTy || vecTy.getRank() != 2;
325 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
326 if (!targetShape || targetShape->size() != 3)
328 auto M = (*targetShape)[0];
329 auto K = (*targetShape)[1];
330 auto N = (*targetShape)[2];
332 int64_t aBlockSize[2] = {M, K};
333 int64_t bBlockSize[2] = {K, N};
334 int64_t cBlockSize[2] = {M, N};
338 VectorType type = val.getType();
339 std::optional<SmallVector<int64_t>> grids =
341 assert(grids &&
"Expecting grids to be computed.");
345 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
348 pack(val, convertedTypes, blockSize, loc, rewriter);
352 auto a = op.getLhs();
353 auto b = op.getRhs();
354 auto c = op.getAcc();
356 auto aShape = a.getType().getShape();
357 auto bShape = b.getType().getShape();
360 aVals = packWrapper(a, aBlockSize);
361 bVals = packWrapper(b, bBlockSize);
364 cVals = packWrapper(c, cBlockSize);
370 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
371 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
374 VectorType resultTy = op.getResult().getType();
377 int64_t mIters = aShape[0] / M;
378 int64_t kIters = aShape[1] / K;
379 int64_t nIters = bShape[1] / N;
382 for (int64_t i = 0; i < mIters; ++i) {
383 for (int64_t
j = 0;
j < nIters; ++
j) {
386 tmpC = cVals[i * nIters +
j];
388 for (int64_t k = 0; k < kIters; ++k) {
389 Value aVec = aVals[i * kIters + k];
390 Value bVec = bVals[k * nIters +
j];
393 operands.push_back(tmpC);
395 tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
398 newOps.push_back(tmpC);
401 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
407 struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
408 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
409 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
412 xegpu::TensorDescType tdescTy = op.getType();
414 VectorType indiceVecTy = indiceVec.getType();
416 if (!tdescTy.isScattered())
419 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
424 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
426 if (originalChunkSize > 1)
427 targetIndiceShape.pop_back();
429 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
431 getUnrolledTypes(indiceVecTy, targetIndiceShape);
433 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
439 if (originalChunkSize > 1) {
440 int64_t blockedChunkSize = targetShape->back();
441 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
443 for (
auto [indice, indiceType] :
444 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
445 for (int64_t i = 0; i < numNewChunks; ++i) {
448 i * blockedChunkSize);
450 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
452 arith::AddIOp::create(rewriter, loc, indice, incVec);
454 auto newOp = xegpu::CreateDescOp::create(
455 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
457 newOps.push_back(newOp);
461 for (
auto indice : convertedIndiceVec) {
462 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
463 op.getSource(), indice);
464 newOps.push_back(newOp);
468 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
475 struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
476 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
477 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
481 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
482 xegpu::TensorDescType tdescTy = op.getTensorDescType();
485 if (!tdescTy || op.getOffsets())
488 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
493 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
495 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
497 Type elemTy = tdescTy.getElementType();
498 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
501 getUnrolledTypes(tdescTy, *targetShape);
503 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
508 if (originalChunkSize > 1) {
509 targetMaskShape.pop_back();
510 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
511 int64_t blockedChunkSize = targetShape->back();
512 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
515 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
517 convertedMasks.append(numNewChunks, mask);
519 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
521 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
522 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
527 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
528 auto newOp = xegpu::LoadGatherOp::create(
529 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
530 op.getL2HintAttr(), op.getL3HintAttr());
531 newOps.push_back(newOp);
534 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
540 struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
541 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
542 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
545 xegpu::TensorDescType tdescTy = op.getTensorDescType();
548 if (!tdescTy || op.getOffsets())
551 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
556 getUnrolledTypes(tdescTy, *targetShape);
558 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
560 for (
auto t : convertedTdesc)
561 xegpu::PrefetchOp::create(rewriter, loc,
TypeRange(), t, op->getAttrs());
568 struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
569 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
570 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
574 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
575 xegpu::TensorDescType tdescTy = op.getTensorDescType();
578 if (!tdescTy || op.getOffsets())
581 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
586 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
588 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
591 getUnrolledTypes(tdescTy, *targetShape);
593 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
598 if (originalChunkSize > 1) {
599 targetMaskShape.pop_back();
600 int64_t blockedChunkSize = targetShape->back();
601 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
602 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
605 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
607 convertedMasks.append(numNewChunks, mask);
609 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
610 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
615 getUnrolledTypes(valueTy, *targetShape);
617 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
619 for (
size_t i = 0; i < convertedValues.size(); ++i) {
620 Value v = convertedValues[i];
621 Value t = convertedTdescs[i];
622 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
623 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
624 op.getL2HintAttr(), op.getL3HintAttr());
632 struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
633 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
634 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
637 xegpu::TensorDescType tdescTy = op.getTensorDescType();
639 if (!tdescTy.isScattered())
642 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
647 getUnrolledTypes(tdescTy, *targetShape);
649 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
652 VectorType offsetVecTy = offsetVec.getType();
656 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
657 if (originalChunkSize > 1) {
659 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
661 int64_t blockedChunkSize = targetShape->back();
662 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
664 for (
auto offset :
pack(offsetVec, convertedOffsetTypes,
665 targetOffsetShape, loc, rewriter))
666 convertedOffsetVec.append(numNewChunks, offset);
669 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
671 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
674 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
676 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
677 newOps.push_back(newOp);
679 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
689 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
690 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
691 UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
692 UnrollPrefetchOp, UnrollUpdateOffsetOp>(
patterns.getContext(),
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.