22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
28 #define GEN_PASS_DEF_XEGPUUNROLL
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
33 #define DEBUG_TYPE "xegpu-unroll"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
41 template <
typename SourceOp>
52 LDBG(
"Get unroll shape for: " << *op);
54 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
55 LDBG(
"--no filter constraint -> BAIL");
60 "expects the native shape for native shape call back function.");
61 auto nativeShape =
options.nativeShape(op);
67 return options.getUnrolledTypes(type, tileShape);
74 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
75 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
76 "Expecting blockSize size to match the rank of destTy.");
77 auto shape = vecTy.getShape();
81 if (isa<xegpu::TensorDescType>(destTy)) {
86 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(
88 return castOp.getResult(0);
91 llvm_unreachable(
"Unexpected destTy.");
100 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
101 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
102 "Expecting blockSize size to match the rank of src.");
107 if (isa<xegpu::TensorDescType>(src.
getType())) {
112 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(
114 return castOp.getResults();
117 llvm_unreachable(
"Unexpected src type.");
122 const char *
const packAttrName =
"__xegpu_blocking_pack__";
123 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
124 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
129 struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
130 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
131 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
134 xegpu::TensorDescType tdescTy = op.getType();
135 int64_t rank = tdescTy.getRank();
138 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
142 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
147 return rewriter.
create<arith::ConstantIndexOp>(loc, *maybeInt + b);
149 auto aV = llvm::cast<Value>(a);
150 auto bV = rewriter.
create<arith::ConstantIndexOp>(loc, b);
151 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
160 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
162 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
168 for (
auto [idx, oldOff, offset] :
169 llvm::zip(validIdxes, oldOffsets, offsets))
170 mixedOffsets[idx] = addi(oldOff, offset);
172 auto newOp = rewriter.
create<xegpu::CreateNdDescOp>(
173 loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
174 op.getMixedStrides());
175 newOps.push_back(newOp);
177 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
184 struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
185 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
186 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
189 xegpu::TensorDescType tdescTy = op.getTensorDescType();
191 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
196 getUnrolledTypes(tdescTy, *targetShape);
198 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
201 for (
auto t : convertedTdesc) {
202 auto newOp = rewriter.
create<xegpu::UpdateNdOffsetOp>(
203 loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
204 newOps.push_back(newOp);
206 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
212 struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
213 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
214 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
217 xegpu::TensorDescType tdescTy = op.getTensorDescType();
219 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
224 getUnrolledTypes(tdescTy, *targetShape);
226 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
228 for (
auto t : convertedTdesc)
229 rewriter.
create<xegpu::PrefetchNdOp>(loc,
TypeRange(), t, op->getAttrs());
236 struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
237 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
242 VectorType valueTy = op.getType();
243 xegpu::TensorDescType tdescTy = op.getTensorDescType();
245 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
249 Type elemTy = tdescTy.getElementType();
250 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
253 getUnrolledTypes(tdescTy, *targetShape);
255 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
258 for (
auto t : convertedTdescs) {
260 rewriter.
create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
261 newOps.push_back(newOp);
264 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
271 struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
272 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
273 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
276 VectorType valueTy = op.getValueType();
277 xegpu::TensorDescType tdescTy = op.getTensorDescType();
279 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
284 getUnrolledTypes(valueTy, *targetShape);
286 getUnrolledTypes(tdescTy, *targetShape);
289 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
291 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
293 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
294 rewriter.
create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
302 struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
303 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
304 LogicalResult matchAndRewrite(xegpu::DpasOp op,
309 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
310 auto vecTy = dyn_cast<VectorType>(type);
311 return !vecTy || vecTy.getRank() != 2;
317 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
318 if (!targetShape || targetShape->size() != 3)
320 auto M = (*targetShape)[0];
321 auto K = (*targetShape)[1];
322 auto N = (*targetShape)[2];
324 int64_t aBlockSize[2] = {M, K};
325 int64_t bBlockSize[2] = {K, N};
326 int64_t cBlockSize[2] = {M, N};
330 VectorType type = val.getType();
331 std::optional<SmallVector<int64_t>> grids =
333 assert(grids &&
"Expecting grids to be computed.");
337 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
340 pack(val, convertedTypes, blockSize, loc, rewriter);
344 auto a = op.getLhs();
345 auto b = op.getRhs();
346 auto c = op.getAcc();
348 auto aShape = a.getType().getShape();
349 auto bShape = b.getType().getShape();
352 aVals = packWrapper(a, aBlockSize);
353 bVals = packWrapper(b, bBlockSize);
356 cVals = packWrapper(c, cBlockSize);
362 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
363 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
366 VectorType resultTy = op.getResult().getType();
369 int64_t mIters = aShape[0] / M;
370 int64_t kIters = aShape[1] / K;
371 int64_t nIters = bShape[1] / N;
374 for (int64_t i = 0; i < mIters; ++i) {
375 for (int64_t
j = 0;
j < nIters; ++
j) {
378 tmpC = cVals[i * nIters +
j];
380 for (int64_t k = 0; k < kIters; ++k) {
381 Value aVec = aVals[i * kIters + k];
382 Value bVec = bVals[k * nIters +
j];
385 operands.push_back(tmpC);
387 tmpC = rewriter.
create<xegpu::DpasOp>(loc, vecTy, operands,
390 newOps.push_back(tmpC);
393 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
399 struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
400 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
401 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
404 xegpu::TensorDescType tdescTy = op.getType();
406 VectorType indiceVecTy = indiceVec.getType();
408 if (!tdescTy.isScattered())
411 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
416 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
418 if (originalChunkSize > 1)
419 targetIndiceShape.pop_back();
421 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
423 getUnrolledTypes(indiceVecTy, targetIndiceShape);
425 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
431 if (originalChunkSize > 1) {
432 int64_t blockedChunkSize = targetShape->back();
433 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
435 for (
auto [indice, indiceType] :
436 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
437 for (int64_t i = 0; i < numNewChunks; ++i) {
439 Value inc = rewriter.
create<arith::ConstantIndexOp>(
440 loc, i * blockedChunkSize);
441 Value incVec = rewriter.
create<vector::SplatOp>(loc, indiceType, inc);
443 rewriter.
create<arith::AddIOp>(loc, indice, incVec);
445 auto newOp = rewriter.
create<xegpu::CreateDescOp>(
446 loc, newTdescTy, op.getSource(), offsetIndice);
448 newOps.push_back(newOp);
452 for (
auto indice : convertedIndiceVec) {
453 auto newOp = rewriter.
create<xegpu::CreateDescOp>(
454 loc, newTdescTy, op.getSource(), indice);
455 newOps.push_back(newOp);
459 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
466 struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
467 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
468 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
472 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
473 xegpu::TensorDescType tdescTy = op.getTensorDescType();
475 if (!tdescTy.isScattered())
478 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
483 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
485 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
487 Type elemTy = tdescTy.getElementType();
488 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
491 getUnrolledTypes(tdescTy, *targetShape);
493 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
498 if (originalChunkSize > 1) {
499 targetMaskShape.pop_back();
500 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
501 int64_t blockedChunkSize = targetShape->back();
502 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
507 convertedMasks.append(numNewChunks, mask);
509 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
511 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
512 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
517 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
518 auto newOp = rewriter.
create<xegpu::LoadGatherOp>(
519 loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(),
521 newOps.push_back(newOp);
524 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
530 struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
531 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
532 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
535 xegpu::TensorDescType tdescTy = op.getTensorDescType();
537 if (!tdescTy.isScattered())
540 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
545 getUnrolledTypes(tdescTy, *targetShape);
547 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
549 for (
auto t : convertedTdesc)
550 rewriter.
create<xegpu::PrefetchOp>(loc,
TypeRange(), t, op->getAttrs());
557 struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
558 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
559 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
563 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
564 xegpu::TensorDescType tdescTy = op.getTensorDescType();
566 if (!tdescTy.isScattered())
569 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
574 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
576 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
579 getUnrolledTypes(tdescTy, *targetShape);
581 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
586 if (originalChunkSize > 1) {
587 targetMaskShape.pop_back();
588 int64_t blockedChunkSize = targetShape->back();
589 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
590 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
593 for (
auto mask :
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
595 convertedMasks.append(numNewChunks, mask);
597 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
598 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
603 getUnrolledTypes(valueTy, *targetShape);
605 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
607 for (
size_t i = 0; i < convertedValues.size(); ++i) {
608 Value v = convertedValues[i];
609 Value t = convertedTdescs[i];
610 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
611 rewriter.
create<xegpu::StoreScatterOp>(loc, v, t, m, op.getL1HintAttr(),
621 struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
622 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
623 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
626 xegpu::TensorDescType tdescTy = op.getTensorDescType();
628 if (!tdescTy.isScattered())
631 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
636 getUnrolledTypes(tdescTy, *targetShape);
638 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
641 VectorType offsetVecTy = offsetVec.getType();
645 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
646 if (originalChunkSize > 1) {
648 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
650 int64_t blockedChunkSize = targetShape->back();
651 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
653 for (
auto offset :
pack(offsetVec, convertedOffsetTypes,
654 targetOffsetShape, loc, rewriter))
655 convertedOffsetVec.append(numNewChunks, offset);
658 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
660 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
663 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
665 rewriter.
create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
666 newOps.push_back(newOp);
668 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
678 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
679 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
680 UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
681 UnrollPrefetchOp, UnrollUpdateOffsetOp>(
patterns.getContext(),
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.