MLIR  22.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1 //===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for unrolling XeGPU operations. It follows a
10 // similar concept and design as vector unroll patterns, serving as a complement
11 // to them.
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
21 
22 namespace mlir {
23 namespace xegpu {
24 #define GEN_PASS_DEF_XEGPUUNROLL
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
26 } // namespace xegpu
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "xegpu-unroll"
30 
31 using namespace mlir;
32 
33 namespace {
34 
35 template <typename SourceOp>
36 struct UnrollPattern : public OpRewritePattern<SourceOp> {
37  UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
38  PatternBenefit benefit = 1)
39  : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
40 
41 protected:
42  /// Return the target shape for the given `op`. Return std::nullopt if the
43  /// op shouldn't be or cannot be unrolled.
44  std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
45  LDBG() << "Get unroll shape for: " << *op;
46 
47  if (options.filterConstraint && failed(options.filterConstraint(op))) {
48  LDBG() << "--no filter constraint -> BAIL";
49  return std::nullopt;
50  }
51 
52  assert(options.nativeShape &&
53  "expects the native shape for native shape call back function.");
54  auto nativeShape = options.nativeShape(op);
55  return nativeShape;
56  }
57 
58  SmallVector<Type> getUnrolledTypes(ShapedType type,
59  ArrayRef<int64_t> tileShape) const {
60  return options.getUnrolledTypes(type, tileShape);
61  }
62 
63  /// Emulate the the unpack behavior using insert_strided_slice for VectorType
64  /// values and unrealized_conversion_cast for TensorDescType values.
65  Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
66  Location loc, PatternRewriter &rewriter) const {
67  if (auto vecTy = dyn_cast<VectorType>(destTy)) {
68  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
69  "Expecting blockSize size to match the rank of destTy.");
70  auto shape = vecTy.getShape();
71  return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
72  }
73 
74  if (isa<xegpu::TensorDescType>(destTy)) {
75  auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
76  rewriter.getUnitAttr());
77  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
78  rewriter.getDenseI64ArrayAttr(blockSize));
79  auto castOp = UnrealizedConversionCastOp::create(
80  rewriter, loc, destTy, srcs,
81  ArrayRef<NamedAttribute>({attr, blkAttr}));
82  return castOp.getResult(0);
83  }
84 
85  llvm_unreachable("Unexpected destTy.");
86  return Value();
87  }
88 
89  /// Emulate the the pack behavior using extract_strided_slice for VectorType
90  /// values and unrealized_conversion_cast for TensorDescType values.
91  SmallVector<Value> pack(Value src, TypeRange destTypes,
92  ArrayRef<int64_t> blockSize, Location loc,
93  PatternRewriter &rewriter) const {
94  if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
95  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
96  "Expecting blockSize size to match the rank of src.");
97  return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
98  blockSize);
99  }
100 
101  if (isa<xegpu::TensorDescType>(src.getType())) {
102  auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
103  rewriter.getUnitAttr());
104  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
105  rewriter.getDenseI64ArrayAttr(blockSize));
106  auto castOp = UnrealizedConversionCastOp::create(
107  rewriter, loc, destTypes, src,
108  ArrayRef<NamedAttribute>({attr, blkAttr}));
109  return castOp.getResults();
110  }
111 
112  llvm_unreachable("Unexpected src type.");
113  return SmallVector<Value>();
114  }
115 
116 private:
117  const char *const packAttrName = "__xegpu_blocking_pack__";
118  const char *const unpackAttrName = "__xegpu_blocking_unpack__";
119  const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
120 
122 };
123 
124 struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
125  using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
126  LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
127  PatternRewriter &rewriter) const override {
128  Location loc = op.getLoc();
129  xegpu::TensorDescType tdescTy = op.getType();
130  int64_t rank = tdescTy.getRank();
131  ArrayRef<int64_t> shape = tdescTy.getShape();
132 
133  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
134  if (!targetShape)
135  return failure();
136 
137  auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
138 
139  auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140  std::optional<int64_t> maybeInt = getConstantIntValue(a);
141  if (maybeInt) {
142  return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
143  } else {
144  auto aV = llvm::cast<Value>(a);
145  auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
146  return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
147  }
148  };
149 
150  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
151 
152  // For n-D memrefs where n > rank, we need to handle the last `rank`
153  // dimensions only, and keep the first `n-rank` dimensions as is.
154  SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
155  llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
156  auto validIdxes =
157  llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
158 
159  SmallVector<Value> newOps;
160  for (SmallVector<int64_t> offsets :
161  StaticTileOffsetRange(shape, *targetShape)) {
162 
163  for (auto [idx, oldOff, offset] :
164  llvm::zip(validIdxes, oldOffsets, offsets))
165  mixedOffsets[idx] = addi(oldOff, offset);
166 
167  auto newOp = xegpu::CreateNdDescOp::create(
168  rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
169  op.getMixedSizes(), op.getMixedStrides());
170  newOps.push_back(newOp);
171  }
172  Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
173  rewriter.replaceOp(op, castOp);
174 
175  return success();
176  }
177 };
178 
179 struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
180  using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
181  LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
182  PatternRewriter &rewriter) const override {
183  Location loc = op.getLoc();
184  xegpu::TensorDescType tdescTy = op.getTensorDescType();
185 
186  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
187  if (!targetShape)
188  return failure();
189 
190  SmallVector<Type> convertedTdescTypes =
191  getUnrolledTypes(tdescTy, *targetShape);
192  SmallVector<Value> convertedTdesc = pack(
193  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
194 
195  SmallVector<Value> newOps;
196  for (auto t : convertedTdesc) {
197  auto newOp = xegpu::UpdateNdOffsetOp::create(
198  rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
199  newOps.push_back(newOp);
200  }
201  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
202  rewriter.replaceOp(op, castOp);
203  return success();
204  }
205 };
206 
207 struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
208  using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
209  LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
210  PatternRewriter &rewriter) const override {
211  Location loc = op.getLoc();
212  xegpu::TensorDescType tdescTy = op.getTensorDescType();
213 
214  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
215  if (!targetShape)
216  return failure();
217 
218  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
219  if ((offsetSize != 0) || op.getConstOffsetsAttr())
220  return failure();
221 
222  SmallVector<Type> convertedTdescTypes =
223  getUnrolledTypes(tdescTy, *targetShape);
224  SmallVector<Value> convertedTdesc = pack(
225  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
226 
227  for (auto t : convertedTdesc)
228  xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
229  op->getAttrs());
230 
231  rewriter.eraseOp(op);
232  return success();
233  }
234 };
235 
236 struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
237  using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238  LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
239  PatternRewriter &rewriter) const override {
240 
241  Location loc = op.getLoc();
242  VectorType valueTy = op.getType();
243  xegpu::TensorDescType tdescTy = op.getTensorDescType();
244 
245  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
246  if (!targetShape)
247  return failure();
248 
249  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
250  if ((offsetSize != 0) || op.getConstOffsetsAttr())
251  return failure();
252 
253  Type elemTy = tdescTy.getElementType();
254  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
255 
256  SmallVector<Type> convertedTdescTypes =
257  getUnrolledTypes(tdescTy, *targetShape);
258  SmallVector<Value> convertedTdescs = pack(
259  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
260 
261  SmallVector<Value> newOps;
262  for (auto t : convertedTdescs) {
263  auto newOp =
264  xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
265  newOps.push_back(newOp);
266  }
267 
268  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
269 
270  rewriter.replaceOp(op, castOp);
271  return success();
272  }
273 };
274 
275 struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
276  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
277  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
278  PatternRewriter &rewriter) const override {
279  Location loc = op.getLoc();
280  VectorType valueTy = op.getValueType();
281  xegpu::TensorDescType tdescTy = op.getTensorDescType();
282 
283  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
284  if (!targetShape)
285  return failure();
286 
287  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
288  if ((offsetSize != 0) || op.getConstOffsetsAttr())
289  return failure();
290 
291  SmallVector<Type> convertedValTypes =
292  getUnrolledTypes(valueTy, *targetShape);
293  SmallVector<Type> convertedTdescTypes =
294  getUnrolledTypes(tdescTy, *targetShape);
295 
296  SmallVector<Value> convertedValues =
297  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
298  SmallVector<Value> convertedTdescs = pack(
299  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
300 
301  for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
302  xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
303  op.getL2HintAttr(), op.getL3HintAttr());
304 
305  rewriter.eraseOp(op);
306  return success();
307  }
308 };
309 
310 struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
311  using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
312  LogicalResult matchAndRewrite(xegpu::DpasOp op,
313  PatternRewriter &rewriter) const override {
314  Location loc = op.getLoc();
315 
316  // expecting every operands is a 2D Vector
317  if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
318  auto vecTy = dyn_cast<VectorType>(type);
319  return !vecTy || vecTy.getRank() != 2;
320  }))
321  return failure();
322 
323  // A vector of 3 elements should be returned, representing M, K, N
324  // respectively.
325  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
326  if (!targetShape || targetShape->size() != 3)
327  return failure();
328  auto M = (*targetShape)[0];
329  auto K = (*targetShape)[1];
330  auto N = (*targetShape)[2];
331 
332  int64_t aBlockSize[2] = {M, K};
333  int64_t bBlockSize[2] = {K, N};
334  int64_t cBlockSize[2] = {M, N};
335 
336  auto packWrapper = [&](TypedValue<VectorType> val,
337  ArrayRef<int64_t> blockSize) {
338  VectorType type = val.getType();
339  std::optional<SmallVector<int64_t>> grids =
340  computeShapeRatio(type.getShape(), blockSize);
341  assert(grids && "Expecting grids to be computed.");
342  auto numNewOps = computeProduct(*grids);
343  if (numNewOps == 1)
344  return SmallVector<Value>({val});
345  VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
346  SmallVector<Type> convertedTypes(numNewOps, newVecTy);
347  SmallVector<Value> values =
348  pack(val, convertedTypes, blockSize, loc, rewriter);
349  return values;
350  };
351 
352  auto a = op.getLhs();
353  auto b = op.getRhs();
354  auto c = op.getAcc();
355 
356  auto aShape = a.getType().getShape();
357  auto bShape = b.getType().getShape();
358 
359  SmallVector<Value> aVals, bVals, cVals;
360  aVals = packWrapper(a, aBlockSize);
361  bVals = packWrapper(b, bBlockSize);
362 
363  if (c)
364  cVals = packWrapper(c, cBlockSize);
365 
366  // Skip the operation if every operand has an invalid blocking size (empty)
367  // or if the original shape matches the blocking size (size == 1).
368  auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
369  : SmallVector<ValueRange>({aVals, bVals});
370  if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
371  llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
372  return failure();
373 
374  VectorType resultTy = op.getResult().getType();
375  auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
376 
377  int64_t mIters = aShape[0] / M;
378  int64_t kIters = aShape[1] / K;
379  int64_t nIters = bShape[1] / N;
380 
381  SmallVector<Value> newOps;
382  for (int64_t i = 0; i < mIters; ++i) {
383  for (int64_t j = 0; j < nIters; ++j) {
384  Value tmpC;
385  if (c)
386  tmpC = cVals[i * nIters + j]; // init with acc
387 
388  for (int64_t k = 0; k < kIters; ++k) {
389  Value aVec = aVals[i * kIters + k];
390  Value bVec = bVals[k * nIters + j];
391  SmallVector<Value> operands({aVec, bVec});
392  if (tmpC)
393  operands.push_back(tmpC);
394 
395  tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
396  op->getAttrs());
397  }
398  newOps.push_back(tmpC);
399  }
400  }
401  Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
402  rewriter.replaceOp(op, castOp);
403  return success();
404  }
405 };
406 
407 struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
408  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
409  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
410  PatternRewriter &rewriter) const override {
411  Location loc = op.getLoc();
412  xegpu::TensorDescType tdescTy = op.getType();
413  TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
414  VectorType indiceVecTy = indiceVec.getType();
415 
416  if (!tdescTy.isScattered())
417  return failure();
418 
419  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
420  if (!targetShape)
421  return failure();
422 
423  SmallVector<int64_t> targetIndiceShape(*targetShape);
424  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
425  // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
426  if (originalChunkSize > 1)
427  targetIndiceShape.pop_back();
428 
429  auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
430  SmallVector<Type> convertedIndiceTypes =
431  getUnrolledTypes(indiceVecTy, targetIndiceShape);
432  SmallVector<Value> convertedIndiceVec =
433  pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
434 
435  SmallVector<Value> newOps;
436 
437  // More indices is need when chunkSize > 1. Since a big load from one
438  // address could be break into multiple small loads.
439  if (originalChunkSize > 1) {
440  int64_t blockedChunkSize = targetShape->back();
441  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
442 
443  for (auto [indice, indiceType] :
444  llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
445  for (int64_t i = 0; i < numNewChunks; ++i) {
446  // Compute the offset
447  Value inc = arith::ConstantIndexOp::create(rewriter, loc,
448  i * blockedChunkSize);
449  Value incVec =
450  vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
451  Value offsetIndice =
452  arith::AddIOp::create(rewriter, loc, indice, incVec);
453 
454  auto newOp = xegpu::CreateDescOp::create(
455  rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
456 
457  newOps.push_back(newOp);
458  }
459  }
460  } else {
461  for (auto indice : convertedIndiceVec) {
462  auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
463  op.getSource(), indice);
464  newOps.push_back(newOp);
465  }
466  }
467 
468  Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
469  rewriter.replaceOp(op, castOp);
470 
471  return success();
472  }
473 };
474 
475 struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
476  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
477  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
478  PatternRewriter &rewriter) const override {
479 
480  Location loc = op.getLoc();
481  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
482  xegpu::TensorDescType tdescTy = op.getTensorDescType();
483 
484  // TODO: handle the unstructure source case (!tdesTy)
485  if (!tdescTy || op.getOffsets())
486  return failure();
487 
488  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
489  if (!targetShape)
490  return failure();
491 
492  SmallVector<int64_t> targetMaskShape(*targetShape);
493  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
494 
495  VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
496 
497  Type elemTy = tdescTy.getElementType();
498  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
499 
500  SmallVector<Type> convertedTdescTypes =
501  getUnrolledTypes(tdescTy, *targetShape);
502  SmallVector<Value> convertedTdescs = pack(
503  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
504 
505  SmallVector<Type> convertedMaskTypes;
506  SmallVector<Value> convertedMasks;
507 
508  if (originalChunkSize > 1) {
509  targetMaskShape.pop_back();
510  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
511  int64_t blockedChunkSize = targetShape->back();
512  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
513 
514  // the mask is reused across the chunk_size dimension
515  for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
516  loc, rewriter))
517  convertedMasks.append(numNewChunks, mask);
518 
519  newValueTy = valueTy.cloneWith(*targetShape, elemTy);
520  } else {
521  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
522  convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
523  loc, rewriter);
524  }
525 
526  SmallVector<Value> newOps;
527  for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
528  auto newOp = xegpu::LoadGatherOp::create(
529  rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
530  op.getL2HintAttr(), op.getL3HintAttr());
531  newOps.push_back(newOp);
532  }
533 
534  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
535  rewriter.replaceOp(op, castOp);
536  return success();
537  }
538 };
539 
540 struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
541  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
542  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
543  PatternRewriter &rewriter) const override {
544  Location loc = op.getLoc();
545  xegpu::TensorDescType tdescTy = op.getTensorDescType();
546 
547  // TODO: handle the unstructure source case (!tdesTy)
548  if (!tdescTy || op.getOffsets())
549  return failure();
550 
551  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
552  if (!targetShape)
553  return failure();
554 
555  SmallVector<Type> convertedTdescTypes =
556  getUnrolledTypes(tdescTy, *targetShape);
557  SmallVector<Value> convertedTdesc = pack(
558  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
559 
560  for (auto t : convertedTdesc)
561  xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t, op->getAttrs());
562 
563  rewriter.eraseOp(op);
564  return success();
565  }
566 };
567 
568 struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
569  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
570  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
571  PatternRewriter &rewriter) const override {
572 
573  Location loc = op.getLoc();
574  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
575  xegpu::TensorDescType tdescTy = op.getTensorDescType();
576 
577  // TODO: handle the unstructure source case (!tdesTy)
578  if (!tdescTy || op.getOffsets())
579  return failure();
580 
581  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
582  if (!targetShape)
583  return failure();
584 
585  SmallVector<int64_t> targetMaskShape(*targetShape);
586  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
587 
588  VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
589 
590  SmallVector<Type> convertedTdescTypes =
591  getUnrolledTypes(tdescTy, *targetShape);
592  SmallVector<Value> convertedTdescs = pack(
593  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
594 
595  SmallVector<Type> convertedMaskTypes;
596  SmallVector<Value> convertedMasks;
597 
598  if (originalChunkSize > 1) {
599  targetMaskShape.pop_back();
600  int64_t blockedChunkSize = targetShape->back();
601  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
602  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
603 
604  // the mask is reused across the chunk_size dimension
605  for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
606  loc, rewriter))
607  convertedMasks.append(numNewChunks, mask);
608  } else {
609  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
610  convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
611  loc, rewriter);
612  }
613 
614  SmallVector<Type> convertedValTypes =
615  getUnrolledTypes(valueTy, *targetShape);
616  SmallVector<Value> convertedValues =
617  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
618 
619  for (size_t i = 0; i < convertedValues.size(); ++i) {
620  Value v = convertedValues[i];
621  Value t = convertedTdescs[i];
622  Value m = op.getMask() ? convertedMasks[i] : nullptr;
623  xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
624  op.getL2HintAttr(), op.getL3HintAttr());
625  }
626 
627  rewriter.eraseOp(op);
628  return success();
629  }
630 };
631 
632 struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
633  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
634  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
635  PatternRewriter &rewriter) const override {
636  Location loc = op.getLoc();
637  xegpu::TensorDescType tdescTy = op.getTensorDescType();
638 
639  if (!tdescTy.isScattered())
640  return failure();
641 
642  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
643  if (!targetShape)
644  return failure();
645 
646  SmallVector<Type> convertedTdescTypes =
647  getUnrolledTypes(tdescTy, *targetShape);
648  SmallVector<Value> convertedTdesc = pack(
649  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
650 
651  TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
652  VectorType offsetVecTy = offsetVec.getType();
653  SmallVector<Type> convertedOffsetTypes;
654  SmallVector<Value> convertedOffsetVec;
655  SmallVector<Value> newOps;
656  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
657  if (originalChunkSize > 1) {
658  auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
659  convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
660 
661  int64_t blockedChunkSize = targetShape->back();
662  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
663  // the offset is reused across the chunk_size dimension
664  for (auto offset : pack(offsetVec, convertedOffsetTypes,
665  targetOffsetShape, loc, rewriter))
666  convertedOffsetVec.append(numNewChunks, offset);
667 
668  } else {
669  convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
670  convertedOffsetVec =
671  pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
672  }
673 
674  for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
675  auto newOp =
676  xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
677  newOps.push_back(newOp);
678  }
679  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
680  rewriter.replaceOp(op, castOp);
681  return success();
682  }
683 };
684 
685 } // namespace
686 
689  patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
690  UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
691  UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
692  UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
693  options);
694 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:477
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
Definition: XeGPUUtils.cpp:237
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Definition: XeGPUUtils.cpp:217
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options to control the XeGPU unrolling.
Definition: Transforms.h:27
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.