MLIR  21.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1 //===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for unrolling XeGPU operations. It follows a
10 // similar concept and design as vector unroll patterns, serving as a complement
11 // to them.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include <numeric>
25 
26 namespace mlir {
27 namespace xegpu {
28 #define GEN_PASS_DEF_XEGPUUNROLL
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 } // namespace xegpu
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "xegpu-unroll"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
36 
37 using namespace mlir;
38 
39 namespace {
40 
41 template <typename SourceOp>
42 struct UnrollPattern : public OpRewritePattern<SourceOp> {
43  UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
44  PatternBenefit benefit = 1)
45  : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
46 
47 protected:
48  /// Return the target shape for the given `op`. Return std::nullopt if the
49  /// op shouldn't be or cannot be unrolled.
50  std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
51  LDBG("");
52  LDBG("Get unroll shape for: " << *op);
53 
54  if (options.filterConstraint && failed(options.filterConstraint(op))) {
55  LDBG("--no filter constraint -> BAIL");
56  return std::nullopt;
57  }
58 
59  assert(options.nativeShape &&
60  "expects the native shape for native shape call back function.");
61  auto nativeShape = options.nativeShape(op);
62  return nativeShape;
63  }
64 
65  SmallVector<Type> getUnrolledTypes(ShapedType type,
66  ArrayRef<int64_t> tileShape) const {
67  return options.getUnrolledTypes(type, tileShape);
68  }
69 
70  /// Emulate the the unpack behavior using insert_strided_slice for VectorType
71  /// values and unrealized_conversion_cast for TensorDescType values.
72  Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
73  Location loc, PatternRewriter &rewriter) const {
74  if (auto vecTy = dyn_cast<VectorType>(destTy)) {
75  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
76  "Expecting blockSize size to match the rank of destTy.");
77  auto shape = vecTy.getShape();
78  return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
79  }
80 
81  if (isa<xegpu::TensorDescType>(destTy)) {
82  auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
83  rewriter.getUnitAttr());
84  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
85  rewriter.getDenseI64ArrayAttr(blockSize));
86  auto castOp = rewriter.create<UnrealizedConversionCastOp>(
87  loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr}));
88  return castOp.getResult(0);
89  }
90 
91  llvm_unreachable("Unexpected destTy.");
92  return Value();
93  }
94 
95  /// Emulate the the pack behavior using extract_strided_slice for VectorType
96  /// values and unrealized_conversion_cast for TensorDescType values.
97  SmallVector<Value> pack(Value src, TypeRange destTypes,
98  ArrayRef<int64_t> blockSize, Location loc,
99  PatternRewriter &rewriter) const {
100  if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
101  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
102  "Expecting blockSize size to match the rank of src.");
103  return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
104  blockSize);
105  }
106 
107  if (isa<xegpu::TensorDescType>(src.getType())) {
108  auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
109  rewriter.getUnitAttr());
110  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
111  rewriter.getDenseI64ArrayAttr(blockSize));
112  auto castOp = rewriter.create<UnrealizedConversionCastOp>(
113  loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr}));
114  return castOp.getResults();
115  }
116 
117  llvm_unreachable("Unexpected src type.");
118  return SmallVector<Value>();
119  }
120 
121 private:
122  const char *const packAttrName = "__xegpu_blocking_pack__";
123  const char *const unpackAttrName = "__xegpu_blocking_unpack__";
124  const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
125 
127 };
128 
129 struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
130  using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
131  LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
132  PatternRewriter &rewriter) const override {
133  Location loc = op.getLoc();
134  xegpu::TensorDescType tdescTy = op.getType();
135  int64_t rank = tdescTy.getRank();
136  ArrayRef<int64_t> shape = tdescTy.getShape();
137 
138  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
139  if (!targetShape)
140  return failure();
141 
142  auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
143 
144  auto addi = [&](OpFoldResult a, int64_t b) -> Value {
145  std::optional<int64_t> maybeInt = getConstantIntValue(a);
146  if (maybeInt) {
147  return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b);
148  } else {
149  auto aV = llvm::cast<Value>(a);
150  auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b);
151  return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
152  }
153  };
154 
155  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
156 
157  // For n-D memrefs where n > rank, we need to handle the last `rank`
158  // dimensions only, and keep the first `n-rank` dimensions as is.
159  SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
160  llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
161  auto validIdxes =
162  llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
163 
164  SmallVector<Value> newOps;
165  for (SmallVector<int64_t> offsets :
166  StaticTileOffsetRange(shape, *targetShape)) {
167 
168  for (auto [idx, oldOff, offset] :
169  llvm::zip(validIdxes, oldOffsets, offsets))
170  mixedOffsets[idx] = addi(oldOff, offset);
171 
172  auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
173  loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
174  op.getMixedStrides());
175  newOps.push_back(newOp);
176  }
177  Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
178  rewriter.replaceOp(op, castOp);
179 
180  return success();
181  }
182 };
183 
184 struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
185  using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
186  LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
187  PatternRewriter &rewriter) const override {
188  Location loc = op.getLoc();
189  xegpu::TensorDescType tdescTy = op.getTensorDescType();
190 
191  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
192  if (!targetShape)
193  return failure();
194 
195  SmallVector<Type> convertedTdescTypes =
196  getUnrolledTypes(tdescTy, *targetShape);
197  SmallVector<Value> convertedTdesc = pack(
198  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
199 
200  SmallVector<Value> newOps;
201  for (auto t : convertedTdesc) {
202  auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
203  loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
204  newOps.push_back(newOp);
205  }
206  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
207  rewriter.replaceOp(op, castOp);
208  return success();
209  }
210 };
211 
212 struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
213  using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
214  LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
215  PatternRewriter &rewriter) const override {
216  Location loc = op.getLoc();
217  xegpu::TensorDescType tdescTy = op.getTensorDescType();
218 
219  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
220  if (!targetShape)
221  return failure();
222 
223  SmallVector<Type> convertedTdescTypes =
224  getUnrolledTypes(tdescTy, *targetShape);
225  SmallVector<Value> convertedTdesc = pack(
226  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
227 
228  for (auto t : convertedTdesc)
229  rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs());
230 
231  rewriter.eraseOp(op);
232  return success();
233  }
234 };
235 
236 struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
237  using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238  LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
239  PatternRewriter &rewriter) const override {
240 
241  Location loc = op.getLoc();
242  VectorType valueTy = op.getType();
243  xegpu::TensorDescType tdescTy = op.getTensorDescType();
244 
245  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
246  if (!targetShape)
247  return failure();
248 
249  Type elemTy = tdescTy.getElementType();
250  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
251 
252  SmallVector<Type> convertedTdescTypes =
253  getUnrolledTypes(tdescTy, *targetShape);
254  SmallVector<Value> convertedTdescs = pack(
255  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
256 
257  SmallVector<Value> newOps;
258  for (auto t : convertedTdescs) {
259  auto newOp =
260  rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
261  newOps.push_back(newOp);
262  }
263 
264  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
265 
266  rewriter.replaceOp(op, castOp);
267  return success();
268  }
269 };
270 
271 struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
272  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
273  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
274  PatternRewriter &rewriter) const override {
275  Location loc = op.getLoc();
276  VectorType valueTy = op.getValueType();
277  xegpu::TensorDescType tdescTy = op.getTensorDescType();
278 
279  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
280  if (!targetShape)
281  return failure();
282 
283  SmallVector<Type> convertedValTypes =
284  getUnrolledTypes(valueTy, *targetShape);
285  SmallVector<Type> convertedTdescTypes =
286  getUnrolledTypes(tdescTy, *targetShape);
287 
288  SmallVector<Value> convertedValues =
289  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
290  SmallVector<Value> convertedTdescs = pack(
291  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
292 
293  for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
294  rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
295  op.getL2HintAttr(), op.getL3HintAttr());
296 
297  rewriter.eraseOp(op);
298  return success();
299  }
300 };
301 
302 struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
303  using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
304  LogicalResult matchAndRewrite(xegpu::DpasOp op,
305  PatternRewriter &rewriter) const override {
306  Location loc = op.getLoc();
307 
308  // expecting every operands is a 2D Vector
309  if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
310  auto vecTy = dyn_cast<VectorType>(type);
311  return !vecTy || vecTy.getRank() != 2;
312  }))
313  return failure();
314 
315  // A vector of 3 elements should be returned, representing M, K, N
316  // respectively.
317  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
318  if (!targetShape || targetShape->size() != 3)
319  return failure();
320  auto M = (*targetShape)[0];
321  auto K = (*targetShape)[1];
322  auto N = (*targetShape)[2];
323 
324  int64_t aBlockSize[2] = {M, K};
325  int64_t bBlockSize[2] = {K, N};
326  int64_t cBlockSize[2] = {M, N};
327 
328  auto packWrapper = [&](TypedValue<VectorType> val,
329  ArrayRef<int64_t> blockSize) {
330  VectorType type = val.getType();
331  std::optional<SmallVector<int64_t>> grids =
332  computeShapeRatio(type.getShape(), blockSize);
333  assert(grids && "Expecting grids to be computed.");
334  auto numNewOps = computeProduct(*grids);
335  if (numNewOps == 1)
336  return SmallVector<Value>({val});
337  VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
338  SmallVector<Type> convertedTypes(numNewOps, newVecTy);
339  SmallVector<Value> values =
340  pack(val, convertedTypes, blockSize, loc, rewriter);
341  return values;
342  };
343 
344  auto a = op.getLhs();
345  auto b = op.getRhs();
346  auto c = op.getAcc();
347 
348  auto aShape = a.getType().getShape();
349  auto bShape = b.getType().getShape();
350 
351  SmallVector<Value> aVals, bVals, cVals;
352  aVals = packWrapper(a, aBlockSize);
353  bVals = packWrapper(b, bBlockSize);
354 
355  if (c)
356  cVals = packWrapper(c, cBlockSize);
357 
358  // Skip the operation if every operand has an invalid blocking size (empty)
359  // or if the original shape matches the blocking size (size == 1).
360  auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
361  : SmallVector<ValueRange>({aVals, bVals});
362  if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
363  llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
364  return failure();
365 
366  VectorType resultTy = op.getResult().getType();
367  auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
368 
369  int64_t mIters = aShape[0] / M;
370  int64_t kIters = aShape[1] / K;
371  int64_t nIters = bShape[1] / N;
372 
373  SmallVector<Value> newOps;
374  for (int64_t i = 0; i < mIters; ++i) {
375  for (int64_t j = 0; j < nIters; ++j) {
376  Value tmpC;
377  if (c)
378  tmpC = cVals[i * nIters + j]; // init with acc
379 
380  for (int64_t k = 0; k < kIters; ++k) {
381  Value aVec = aVals[i * kIters + k];
382  Value bVec = bVals[k * nIters + j];
383  SmallVector<Value> operands({aVec, bVec});
384  if (tmpC)
385  operands.push_back(tmpC);
386 
387  tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands,
388  op->getAttrs());
389  }
390  newOps.push_back(tmpC);
391  }
392  }
393  Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
394  rewriter.replaceOp(op, castOp);
395  return success();
396  }
397 };
398 
399 struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
400  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
401  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
402  PatternRewriter &rewriter) const override {
403  Location loc = op.getLoc();
404  xegpu::TensorDescType tdescTy = op.getType();
405  TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
406  VectorType indiceVecTy = indiceVec.getType();
407 
408  if (!tdescTy.isScattered())
409  return failure();
410 
411  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
412  if (!targetShape)
413  return failure();
414 
415  SmallVector<int64_t> targetIndiceShape(*targetShape);
416  int64_t originalChunkSize = tdescTy.getChunkSize();
417  // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
418  if (originalChunkSize > 1)
419  targetIndiceShape.pop_back();
420 
421  auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
422  SmallVector<Type> convertedIndiceTypes =
423  getUnrolledTypes(indiceVecTy, targetIndiceShape);
424  SmallVector<Value> convertedIndiceVec =
425  pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
426 
427  SmallVector<Value> newOps;
428 
429  // More indices is need when chunkSize > 1. Since a big load from one
430  // address could be break into multiple small loads.
431  if (originalChunkSize > 1) {
432  int64_t blockedChunkSize = targetShape->back();
433  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
434 
435  for (auto [indice, indiceType] :
436  llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
437  for (int64_t i = 0; i < numNewChunks; ++i) {
438  // Compute the offset
439  Value inc = rewriter.create<arith::ConstantIndexOp>(
440  loc, i * blockedChunkSize);
441  Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
442  Value offsetIndice =
443  rewriter.create<arith::AddIOp>(loc, indice, incVec);
444 
445  auto newOp = rewriter.create<xegpu::CreateDescOp>(
446  loc, newTdescTy, op.getSource(), offsetIndice);
447 
448  newOps.push_back(newOp);
449  }
450  }
451  } else {
452  for (auto indice : convertedIndiceVec) {
453  auto newOp = rewriter.create<xegpu::CreateDescOp>(
454  loc, newTdescTy, op.getSource(), indice);
455  newOps.push_back(newOp);
456  }
457  }
458 
459  Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
460  rewriter.replaceOp(op, castOp);
461 
462  return success();
463  }
464 };
465 
466 struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
467  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
468  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
469  PatternRewriter &rewriter) const override {
470 
471  Location loc = op.getLoc();
472  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
473  xegpu::TensorDescType tdescTy = op.getTensorDescType();
474 
475  if (!tdescTy.isScattered())
476  return failure();
477 
478  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
479  if (!targetShape)
480  return failure();
481 
482  SmallVector<int64_t> targetMaskShape(*targetShape);
483  int64_t originalChunkSize = tdescTy.getChunkSize();
484 
485  VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
486 
487  Type elemTy = tdescTy.getElementType();
488  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
489 
490  SmallVector<Type> convertedTdescTypes =
491  getUnrolledTypes(tdescTy, *targetShape);
492  SmallVector<Value> convertedTdescs = pack(
493  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
494 
495  SmallVector<Type> convertedMaskTypes;
496  SmallVector<Value> convertedMasks;
497 
498  if (originalChunkSize > 1) {
499  targetMaskShape.pop_back();
500  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
501  SmallVector<Value> convertedMasks1D = pack(
502  op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
503  int64_t blockedChunkSize = targetShape->back();
504  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505 
506  for (auto mask : convertedMasks1D) {
507  for (int64_t i = 0; i < numNewChunks; ++i)
508  convertedMasks.push_back(mask);
509  }
510  // This is to handle the transpose effect when chunkSize > 1.
511  std::swap((*targetShape)[0], (*targetShape)[1]);
512  newValueTy = valueTy.cloneWith(*targetShape, elemTy);
513  } else {
514  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
515  convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
516  loc, rewriter);
517  }
518 
519  SmallVector<Value> newOps;
520  for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
521  auto newOp = rewriter.create<xegpu::LoadGatherOp>(
522  loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
523  op.getL2HintAttr(), op.getL3HintAttr());
524  newOps.push_back(newOp);
525  }
526 
527  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
528  rewriter.replaceOp(op, castOp);
529  return success();
530  }
531 };
532 
533 struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
534  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
535  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
536  PatternRewriter &rewriter) const override {
537  Location loc = op.getLoc();
538  xegpu::TensorDescType tdescTy = op.getTensorDescType();
539 
540  if (!tdescTy.isScattered())
541  return failure();
542 
543  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
544  if (!targetShape)
545  return failure();
546 
547  SmallVector<Type> convertedTdescTypes =
548  getUnrolledTypes(tdescTy, *targetShape);
549  SmallVector<Value> convertedTdesc = pack(
550  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
551 
552  for (auto t : convertedTdesc)
553  rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
554 
555  rewriter.eraseOp(op);
556  return success();
557  }
558 };
559 
560 struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
561  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
562  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
563  PatternRewriter &rewriter) const override {
564 
565  Location loc = op.getLoc();
566  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
567  xegpu::TensorDescType tdescTy = op.getTensorDescType();
568 
569  if (!tdescTy.isScattered())
570  return failure();
571 
572  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
573  if (!targetShape)
574  return failure();
575 
576  SmallVector<int64_t> targetIndiceShape(*targetShape);
577  int64_t originalChunkSize = tdescTy.getChunkSize();
578 
579  VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
580 
581  SmallVector<Type> convertedTdescTypes =
582  getUnrolledTypes(tdescTy, *targetShape);
583  SmallVector<Value> convertedTdescs = pack(
584  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
585 
586  SmallVector<Type> convertedMaskTypes;
587  SmallVector<Value> convertedMasks;
588 
589  if (originalChunkSize > 1) {
590  int64_t blockedChunkSize = targetShape->back();
591  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
592  convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
593  SmallVector<Value> convertedMasks1D = pack(
594  op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
595 
596  for (auto mask : convertedMasks1D) {
597  for (int64_t i = 0; i < numNewChunks; ++i) {
598  convertedMasks.push_back(mask);
599  }
600  }
601  // This is to handle the transpose effect when chunkSize > 1.
602  std::swap((*targetShape)[0], (*targetShape)[1]);
603 
604  } else {
605  convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
606  convertedMasks =
607  pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
608  }
609 
610  SmallVector<Type> convertedValTypes =
611  getUnrolledTypes(valueTy, *targetShape);
612  SmallVector<Value> convertedValues =
613  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
614 
615  for (size_t i = 0; i < convertedValues.size(); ++i) {
616  Value v = convertedValues[i];
617  Value t = convertedTdescs[i];
618  Value m = op.getMask() ? convertedMasks[i] : nullptr;
619  rewriter.create<xegpu::StoreScatterOp>(
620  loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
621  op.getL2HintAttr(), op.getL3HintAttr());
622  }
623 
624  rewriter.eraseOp(op);
625  return success();
626  }
627 };
628 
629 struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
630  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
631  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
632  PatternRewriter &rewriter) const override {
633  Location loc = op.getLoc();
634  xegpu::TensorDescType tdescTy = op.getTensorDescType();
635 
636  if (tdescTy.getRank() > 2)
637  return failure();
638 
639  if (!tdescTy.isScattered())
640  return failure();
641 
642  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
643  if (!targetShape)
644  return failure();
645 
646  SmallVector<Type> convertedTdescTypes =
647  getUnrolledTypes(tdescTy, *targetShape);
648  SmallVector<Value> convertedTdesc = pack(
649  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
650 
651  TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
652  VectorType offsetVecTy = offsetVec.getType();
653  SmallVector<Type> convertedOffsetTypes;
654  SmallVector<Value> convertedOffsetVec;
655  SmallVector<Value> newOps;
656  int64_t originalChunkSize = tdescTy.getChunkSize();
657  if (originalChunkSize > 1) {
658  SmallVector<int64_t> shape1D(targetShape->begin(),
659  targetShape->end() - 1);
660  convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
661  SmallVector<Value> convertedOffsetVec1D =
662  pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
663 
664  int64_t blockedChunkSize = targetShape->back();
665  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
666 
667  for (auto offset : convertedOffsetVec1D) {
668  for (int64_t i = 0; i < numNewChunks; ++i) {
669  convertedOffsetVec.push_back(offset);
670  }
671  }
672 
673  } else {
674  convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
675  convertedOffsetVec =
676  pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
677  }
678 
679  for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
680  auto newOp =
681  rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
682  newOps.push_back(newOp);
683  }
684  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
685  rewriter.replaceOp(op, castOp);
686  return success();
687  }
688 };
689 
690 } // namespace
691 
694  patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
695  UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
696  UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
697  UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
698  options);
699 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
#define LDBG(X)
Definition: XeGPUUnroll.cpp:35
UnitAttr getUnitAttr()
Definition: Builders.cpp:96
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:165
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:481
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
Definition: XeGPUUtils.cpp:208
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Definition: XeGPUUtils.cpp:188
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options to control the XeGPU unrolling.
Definition: Transforms.h:27
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.