22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
28 #define GEN_PASS_DEF_XEGPUUNROLL
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
33 #define DEBUG_TYPE "xegpu-unroll"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
41 template <
typename SourceOp>
52 LDBG(
"Get unroll shape for: " << *op);
54 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
55 LDBG(
"--no filter constraint -> BAIL");
60 "expects the native shape for native shape call back function.");
61 auto nativeShape =
options.nativeShape(op);
67 return options.getUnrolledTypes(type, tileShape);
74 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
75 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
76 "Expecting blockSize size to match the rank of destTy.");
77 auto shape = vecTy.getShape();
81 if (isa<xegpu::TensorDescType>(destTy)) {
86 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(
88 return castOp.getResult(0);
91 llvm_unreachable(
"Unexpected destTy.");
100 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
101 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
102 "Expecting blockSize size to match the rank of src.");
107 if (isa<xegpu::TensorDescType>(src.
getType())) {
112 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(
114 return castOp.getResults();
117 llvm_unreachable(
"Unexpected src type.");
122 const char *
const packAttrName =
"__xegpu_blocking_pack__";
123 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
124 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
129 struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
130 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
131 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
134 xegpu::TensorDescType tdescTy = op.getType();
135 int64_t rank = tdescTy.getRank();
138 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
142 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
147 return rewriter.
create<arith::ConstantIndexOp>(loc, *maybeInt + b);
149 auto aV = llvm::cast<Value>(a);
150 auto bV = rewriter.
create<arith::ConstantIndexOp>(loc, b);
151 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
160 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
162 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
168 for (
auto [idx, oldOff, offset] :
169 llvm::zip(validIdxes, oldOffsets, offsets))
170 mixedOffsets[idx] = addi(oldOff, offset);
172 auto newOp = rewriter.
create<xegpu::CreateNdDescOp>(
173 loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
174 op.getMixedStrides());
175 newOps.push_back(newOp);
177 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
184 struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
185 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
186 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
189 xegpu::TensorDescType tdescTy = op.getTensorDescType();
191 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
196 getUnrolledTypes(tdescTy, *targetShape);
198 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
201 for (
auto t : convertedTdesc) {
202 auto newOp = rewriter.
create<xegpu::UpdateNdOffsetOp>(
203 loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
204 newOps.push_back(newOp);
206 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
212 struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
213 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
214 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
217 xegpu::TensorDescType tdescTy = op.getTensorDescType();
219 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
224 getUnrolledTypes(tdescTy, *targetShape);
226 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
228 for (
auto t : convertedTdesc)
229 rewriter.
create<xegpu::PrefetchNdOp>(loc,
TypeRange(), t, op->getAttrs());
236 struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
237 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
242 VectorType valueTy = op.getType();
243 xegpu::TensorDescType tdescTy = op.getTensorDescType();
245 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
249 Type elemTy = tdescTy.getElementType();
250 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
253 getUnrolledTypes(tdescTy, *targetShape);
255 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
258 for (
auto t : convertedTdescs) {
260 rewriter.
create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
261 newOps.push_back(newOp);
264 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
271 struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
272 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
273 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
276 VectorType valueTy = op.getValueType();
277 xegpu::TensorDescType tdescTy = op.getTensorDescType();
279 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
284 getUnrolledTypes(valueTy, *targetShape);
286 getUnrolledTypes(tdescTy, *targetShape);
289 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
291 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
293 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
294 rewriter.
create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
302 struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
303 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
304 LogicalResult matchAndRewrite(xegpu::DpasOp op,
309 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
310 auto vecTy = dyn_cast<VectorType>(type);
311 return !vecTy || vecTy.getRank() != 2;
317 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
318 if (!targetShape || targetShape->size() != 3)
320 auto M = (*targetShape)[0];
321 auto K = (*targetShape)[1];
322 auto N = (*targetShape)[2];
324 int64_t aBlockSize[2] = {M, K};
325 int64_t bBlockSize[2] = {K, N};
326 int64_t cBlockSize[2] = {M, N};
330 VectorType type = val.getType();
331 std::optional<SmallVector<int64_t>> grids =
333 assert(grids &&
"Expecting grids to be computed.");
337 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
340 pack(val, convertedTypes, blockSize, loc, rewriter);
344 auto a = op.getLhs();
345 auto b = op.getRhs();
346 auto c = op.getAcc();
348 auto aShape = a.getType().getShape();
349 auto bShape = b.getType().getShape();
352 aVals = packWrapper(a, aBlockSize);
353 bVals = packWrapper(b, bBlockSize);
356 cVals = packWrapper(c, cBlockSize);
362 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
363 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
366 VectorType resultTy = op.getResult().getType();
369 int64_t mIters = aShape[0] / M;
370 int64_t kIters = aShape[1] / K;
371 int64_t nIters = bShape[1] / N;
374 for (int64_t i = 0; i < mIters; ++i) {
375 for (int64_t
j = 0;
j < nIters; ++
j) {
378 tmpC = cVals[i * nIters +
j];
380 for (int64_t k = 0; k < kIters; ++k) {
381 Value aVec = aVals[i * kIters + k];
382 Value bVec = bVals[k * nIters +
j];
385 operands.push_back(tmpC);
387 tmpC = rewriter.
create<xegpu::DpasOp>(loc, vecTy, operands,
390 newOps.push_back(tmpC);
393 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
399 struct UnrollCreateDescOp :
public UnrollPattern<xegpu::CreateDescOp> {
400 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
401 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
404 xegpu::TensorDescType tdescTy = op.getType();
406 VectorType indiceVecTy = indiceVec.getType();
408 if (!tdescTy.isScattered())
411 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
416 int64_t originalChunkSize = tdescTy.getChunkSize();
418 if (originalChunkSize > 1)
419 targetIndiceShape.pop_back();
421 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
423 getUnrolledTypes(indiceVecTy, targetIndiceShape);
425 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
431 if (originalChunkSize > 1) {
432 int64_t blockedChunkSize = targetShape->back();
433 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
435 for (
auto [indice, indiceType] :
436 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
437 for (int64_t i = 0; i < numNewChunks; ++i) {
439 Value inc = rewriter.
create<arith::ConstantIndexOp>(
440 loc, i * blockedChunkSize);
441 Value incVec = rewriter.
create<vector::SplatOp>(loc, indiceType, inc);
443 rewriter.
create<arith::AddIOp>(loc, indice, incVec);
445 auto newOp = rewriter.
create<xegpu::CreateDescOp>(
446 loc, newTdescTy, op.getSource(), offsetIndice);
448 newOps.push_back(newOp);
452 for (
auto indice : convertedIndiceVec) {
453 auto newOp = rewriter.
create<xegpu::CreateDescOp>(
454 loc, newTdescTy, op.getSource(), indice);
455 newOps.push_back(newOp);
459 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
466 struct UnrollLoadGatherOp :
public UnrollPattern<xegpu::LoadGatherOp> {
467 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
468 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
472 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
473 xegpu::TensorDescType tdescTy = op.getTensorDescType();
475 if (!tdescTy.isScattered())
478 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
483 int64_t originalChunkSize = tdescTy.getChunkSize();
485 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
487 Type elemTy = tdescTy.getElementType();
488 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
491 getUnrolledTypes(tdescTy, *targetShape);
493 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
498 if (originalChunkSize > 1) {
499 targetMaskShape.pop_back();
500 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
502 op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
503 int64_t blockedChunkSize = targetShape->back();
504 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
506 for (
auto mask : convertedMasks1D) {
507 for (int64_t i = 0; i < numNewChunks; ++i)
508 convertedMasks.push_back(mask);
511 std::swap((*targetShape)[0], (*targetShape)[1]);
512 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
514 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
515 convertedMasks =
pack(op.getMask(), convertedMaskTypes, targetMaskShape,
520 for (
auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
521 auto newOp = rewriter.
create<xegpu::LoadGatherOp>(
522 loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
523 op.getL2HintAttr(), op.getL3HintAttr());
524 newOps.push_back(newOp);
527 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
533 struct UnrollPrefetchOp :
public UnrollPattern<xegpu::PrefetchOp> {
534 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
535 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
538 xegpu::TensorDescType tdescTy = op.getTensorDescType();
540 if (!tdescTy.isScattered())
543 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
548 getUnrolledTypes(tdescTy, *targetShape);
550 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
552 for (
auto t : convertedTdesc)
553 rewriter.
create<xegpu::PrefetchOp>(loc,
TypeRange(), t, op->getAttrs());
560 struct UnrollStoreScatterOp :
public UnrollPattern<xegpu::StoreScatterOp> {
561 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
562 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
566 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
567 xegpu::TensorDescType tdescTy = op.getTensorDescType();
569 if (!tdescTy.isScattered())
572 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
577 int64_t originalChunkSize = tdescTy.getChunkSize();
579 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
582 getUnrolledTypes(tdescTy, *targetShape);
584 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
589 if (originalChunkSize > 1) {
590 int64_t blockedChunkSize = targetShape->back();
591 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
592 convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
594 op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
596 for (
auto mask : convertedMasks1D) {
597 for (int64_t i = 0; i < numNewChunks; ++i) {
598 convertedMasks.push_back(mask);
602 std::swap((*targetShape)[0], (*targetShape)[1]);
605 convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
607 pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
611 getUnrolledTypes(valueTy, *targetShape);
613 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
615 for (
size_t i = 0; i < convertedValues.size(); ++i) {
616 Value v = convertedValues[i];
617 Value t = convertedTdescs[i];
618 Value m = op.getMask() ? convertedMasks[i] :
nullptr;
619 rewriter.
create<xegpu::StoreScatterOp>(
620 loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
621 op.getL2HintAttr(), op.getL3HintAttr());
629 struct UnrollUpdateOffsetOp :
public UnrollPattern<xegpu::UpdateOffsetOp> {
630 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
631 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
634 xegpu::TensorDescType tdescTy = op.getTensorDescType();
636 if (tdescTy.getRank() > 2)
639 if (!tdescTy.isScattered())
642 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
647 getUnrolledTypes(tdescTy, *targetShape);
649 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
652 VectorType offsetVecTy = offsetVec.getType();
656 int64_t originalChunkSize = tdescTy.getChunkSize();
657 if (originalChunkSize > 1) {
659 targetShape->end() - 1);
660 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
662 pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
664 int64_t blockedChunkSize = targetShape->back();
665 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
667 for (
auto offset : convertedOffsetVec1D) {
668 for (int64_t i = 0; i < numNewChunks; ++i) {
669 convertedOffsetVec.push_back(offset);
674 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
676 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
679 for (
auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
681 rewriter.
create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
682 newOps.push_back(newOp);
684 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
694 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
695 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
696 UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
697 UnrollPrefetchOp, UnrollUpdateOffsetOp>(
patterns.getContext(),
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.