MLIR  22.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1 //===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for unrolling XeGPU operations. It follows a
10 // similar concept and design as vector unroll patterns, serving as a complement
11 // to them.
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
21 
22 namespace mlir {
23 namespace xegpu {
24 #define GEN_PASS_DEF_XEGPUUNROLL
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
26 } // namespace xegpu
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "xegpu-unroll"
30 
31 using namespace mlir;
32 
33 namespace {
34 
35 template <typename SourceOp>
36 struct UnrollPattern : public OpRewritePattern<SourceOp> {
37  UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
38  PatternBenefit benefit = 1)
39  : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
40 
41 protected:
42  /// Return the target shape for the given `op`. Return std::nullopt if the
43  /// op shouldn't be or cannot be unrolled.
44  std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
45  LDBG() << "Get unroll shape for: " << *op;
46 
47  if (options.filterConstraint && failed(options.filterConstraint(op))) {
48  LDBG() << "--no filter constraint -> BAIL";
49  return std::nullopt;
50  }
51 
52  assert(options.nativeShape &&
53  "expects the native shape for native shape call back function.");
54  auto nativeShape = options.nativeShape(op);
55  return nativeShape;
56  }
57 
58  SmallVector<Type> getUnrolledTypes(ShapedType type,
59  ArrayRef<int64_t> tileShape) const {
60  return options.getUnrolledTypes(type, tileShape);
61  }
62 
63  /// Emulate the the unpack behavior using insert_strided_slice for VectorType
64  /// values and unrealized_conversion_cast for TensorDescType values.
65  Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
66  Location loc, PatternRewriter &rewriter) const {
67  if (auto vecTy = dyn_cast<VectorType>(destTy)) {
68  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
69  "Expecting blockSize size to match the rank of destTy.");
70  auto shape = vecTy.getShape();
71  return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
72  }
73 
74  if (isa<xegpu::TensorDescType>(destTy)) {
75  auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
76  rewriter.getUnitAttr());
77  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
78  rewriter.getDenseI64ArrayAttr(blockSize));
79  auto castOp = UnrealizedConversionCastOp::create(
80  rewriter, loc, destTy, srcs,
81  ArrayRef<NamedAttribute>({attr, blkAttr}));
82  return castOp.getResult(0);
83  }
84 
85  llvm_unreachable("Unexpected destTy.");
86  return Value();
87  }
88 
89  /// Emulate the the pack behavior using extract_strided_slice for VectorType
90  /// values and unrealized_conversion_cast for TensorDescType values.
91  SmallVector<Value> pack(Value src, TypeRange destTypes,
92  ArrayRef<int64_t> blockSize, Location loc,
93  PatternRewriter &rewriter) const {
94  if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
95  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
96  "Expecting blockSize size to match the rank of src.");
97  return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
98  blockSize);
99  }
100 
101  if (isa<xegpu::TensorDescType>(src.getType())) {
102  auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
103  rewriter.getUnitAttr());
104  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
105  rewriter.getDenseI64ArrayAttr(blockSize));
106  auto castOp = UnrealizedConversionCastOp::create(
107  rewriter, loc, destTypes, src,
108  ArrayRef<NamedAttribute>({attr, blkAttr}));
109  return castOp.getResults();
110  }
111 
112  llvm_unreachable("Unexpected src type.");
113  return SmallVector<Value>();
114  }
115 
116 private:
117  const char *const packAttrName = "__xegpu_blocking_pack__";
118  const char *const unpackAttrName = "__xegpu_blocking_unpack__";
119  const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
120 
122 };
123 
124 struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
125  using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
126  LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
127  PatternRewriter &rewriter) const override {
128  Location loc = op.getLoc();
129  xegpu::TensorDescType tdescTy = op.getType();
130  int64_t rank = tdescTy.getRank();
131  ArrayRef<int64_t> shape = tdescTy.getShape();
132 
133  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
134  if (!targetShape)
135  return failure();
136 
137  auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
138 
139  auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140  std::optional<int64_t> maybeInt = getConstantIntValue(a);
141  if (maybeInt) {
142  return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
143  } else {
144  auto aV = llvm::cast<Value>(a);
145  auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
146  return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
147  }
148  };
149 
150  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
151 
152  // For n-D memrefs where n > rank, we need to handle the last `rank`
153  // dimensions only, and keep the first `n-rank` dimensions as is.
154  SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
155  llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
156  auto validIdxes =
157  llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
158 
159  SmallVector<Value> newOps;
160  for (SmallVector<int64_t> offsets :
161  StaticTileOffsetRange(shape, *targetShape)) {
162 
163  for (auto [idx, oldOff, offset] :
164  llvm::zip(validIdxes, oldOffsets, offsets))
165  mixedOffsets[idx] = addi(oldOff, offset);
166 
167  auto newOp = xegpu::CreateNdDescOp::create(
168  rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
169  op.getMixedSizes(), op.getMixedStrides());
170  newOps.push_back(newOp);
171  }
172  Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
173  rewriter.replaceOp(op, castOp);
174 
175  return success();
176  }
177 };
178 
179 struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
180  using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
181  LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
182  PatternRewriter &rewriter) const override {
183  Location loc = op.getLoc();
184  xegpu::TensorDescType tdescTy = op.getTensorDescType();
185 
186  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
187  if (!targetShape)
188  return failure();
189 
190  SmallVector<Type> convertedTdescTypes =
191  getUnrolledTypes(tdescTy, *targetShape);
192  SmallVector<Value> convertedTdesc = pack(
193  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
194 
195  SmallVector<Value> newOps;
196  for (auto t : convertedTdesc) {
197  auto newOp = xegpu::UpdateNdOffsetOp::create(
198  rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
199  newOps.push_back(newOp);
200  }
201  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
202  rewriter.replaceOp(op, castOp);
203  return success();
204  }
205 };
206 
207 struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
208  using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
209  LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
210  PatternRewriter &rewriter) const override {
211  Location loc = op.getLoc();
212  xegpu::TensorDescType tdescTy = op.getTensorDescType();
213 
214  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
215  if (!targetShape)
216  return failure();
217 
218  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
219  if ((offsetSize != 0) || op.getConstOffsetsAttr())
220  return failure();
221 
222  SmallVector<Type> convertedTdescTypes =
223  getUnrolledTypes(tdescTy, *targetShape);
224  SmallVector<Value> convertedTdesc = pack(
225  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
226 
227  for (auto t : convertedTdesc)
228  xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
229  op->getAttrs());
230 
231  rewriter.eraseOp(op);
232  return success();
233  }
234 };
235 
236 struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
237  using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238  LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
239  PatternRewriter &rewriter) const override {
240 
241  Location loc = op.getLoc();
242  VectorType valueTy = op.getType();
243  xegpu::TensorDescType tdescTy = op.getTensorDescType();
244 
245  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
246  if (!targetShape)
247  return failure();
248 
249  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
250  if ((offsetSize != 0) || op.getConstOffsetsAttr())
251  return failure();
252 
253  Type elemTy = tdescTy.getElementType();
254  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
255 
256  SmallVector<Type> convertedTdescTypes =
257  getUnrolledTypes(tdescTy, *targetShape);
258  SmallVector<Value> convertedTdescs = pack(
259  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
260 
261  SmallVector<Value> newOps;
262  for (auto t : convertedTdescs) {
263  auto newOp =
264  xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
265  newOps.push_back(newOp);
266  }
267 
268  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
269 
270  rewriter.replaceOp(op, castOp);
271  return success();
272  }
273 };
274 
275 struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
276  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
277  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
278  PatternRewriter &rewriter) const override {
279  Location loc = op.getLoc();
280  VectorType valueTy = op.getValueType();
281  xegpu::TensorDescType tdescTy = op.getTensorDescType();
282 
283  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
284  if (!targetShape)
285  return failure();
286 
287  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
288  if ((offsetSize != 0) || op.getConstOffsetsAttr())
289  return failure();
290 
291  SmallVector<Type> convertedValTypes =
292  getUnrolledTypes(valueTy, *targetShape);
293  SmallVector<Type> convertedTdescTypes =
294  getUnrolledTypes(tdescTy, *targetShape);
295 
296  SmallVector<Value> convertedValues =
297  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
298  SmallVector<Value> convertedTdescs = pack(
299  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
300 
301  for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
302  xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
303  op.getL2HintAttr(), op.getL3HintAttr());
304 
305  rewriter.eraseOp(op);
306  return success();
307  }
308 };
309 
310 struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
311  using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
312  LogicalResult matchAndRewrite(xegpu::DpasOp op,
313  PatternRewriter &rewriter) const override {
314  Location loc = op.getLoc();
315 
316  // expecting every operands is a 2D Vector
317  if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
318  auto vecTy = dyn_cast<VectorType>(type);
319  return !vecTy || vecTy.getRank() != 2;
320  }))
321  return failure();
322 
323  // A vector of 3 elements should be returned, representing M, K, N
324  // respectively.
325  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
326  if (!targetShape || targetShape->size() != 3)
327  return failure();
328  auto M = (*targetShape)[0];
329  auto K = (*targetShape)[1];
330  auto N = (*targetShape)[2];
331 
332  int64_t aBlockSize[2] = {M, K};
333  int64_t bBlockSize[2] = {K, N};
334  int64_t cBlockSize[2] = {M, N};
335 
336  auto packWrapper = [&](TypedValue<VectorType> val,
337  ArrayRef<int64_t> blockSize) {
338  VectorType type = val.getType();
339  std::optional<SmallVector<int64_t>> grids =
340  computeShapeRatio(type.getShape(), blockSize);
341  assert(grids && "Expecting grids to be computed.");
342  auto numNewOps = computeProduct(*grids);
343  if (numNewOps == 1)
344  return SmallVector<Value>({val});
345  VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
346  SmallVector<Type> convertedTypes(numNewOps, newVecTy);
347  SmallVector<Value> values =
348  pack(val, convertedTypes, blockSize, loc, rewriter);
349  return values;
350  };
351 
352  auto a = op.getLhs();
353  auto b = op.getRhs();
354  auto c = op.getAcc();
355 
356  auto aShape = a.getType().getShape();
357  auto bShape = b.getType().getShape();
358 
359  SmallVector<Value> aVals, bVals, cVals;
360  aVals = packWrapper(a, aBlockSize);
361  bVals = packWrapper(b, bBlockSize);
362 
363  if (c)
364  cVals = packWrapper(c, cBlockSize);
365 
366  // Skip the operation if every operand has an invalid blocking size (empty)
367  // or if the original shape matches the blocking size (size == 1).
368  auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
369  : SmallVector<ValueRange>({aVals, bVals});
370  if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
371  llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
372  return failure();
373 
374  VectorType resultTy = op.getResult().getType();
375  auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
376 
377  int64_t mIters = aShape[0] / M;
378  int64_t kIters = aShape[1] / K;
379  int64_t nIters = bShape[1] / N;
380 
381  SmallVector<Value> newOps;
382  for (int64_t i = 0; i < mIters; ++i) {
383  for (int64_t j = 0; j < nIters; ++j) {
384  Value tmpC;
385  if (c)
386  tmpC = cVals[i * nIters + j]; // init with acc
387 
388  for (int64_t k = 0; k < kIters; ++k) {
389  Value aVec = aVals[i * kIters + k];
390  Value bVec = bVals[k * nIters + j];
391  SmallVector<Value> operands({aVec, bVec});
392  if (tmpC)
393  operands.push_back(tmpC);
394 
395  tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
396  op->getAttrs());
397  }
398  newOps.push_back(tmpC);
399  }
400  }
401  Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
402  rewriter.replaceOp(op, castOp);
403  return success();
404  }
405 };
406 
407 struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
408  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
409  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
410  PatternRewriter &rewriter) const override {
411  Location loc = op.getLoc();
412  xegpu::TensorDescType tdescTy = op.getType();
413  TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
414  VectorType indiceVecTy = indiceVec.getType();
415 
416  if (!tdescTy.isScattered())
417  return failure();
418 
419  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
420  if (!targetShape)
421  return failure();
422 
423  SmallVector<int64_t> targetIndiceShape(*targetShape);
424  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
425  // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
426  if (originalChunkSize > 1)
427  targetIndiceShape.pop_back();
428 
429  auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
430  SmallVector<Type> convertedIndiceTypes =
431  getUnrolledTypes(indiceVecTy, targetIndiceShape);
432  SmallVector<Value> convertedIndiceVec =
433  pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
434 
435  SmallVector<Value> newOps;
436 
437  // More indices is need when chunkSize > 1. Since a big load from one
438  // address could be break into multiple small loads.
439  if (originalChunkSize > 1) {
440  int64_t blockedChunkSize = targetShape->back();
441  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
442 
443  for (auto [indice, indiceType] :
444  llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
445  for (int64_t i = 0; i < numNewChunks; ++i) {
446  // Compute the offset
447  Value inc = arith::ConstantIndexOp::create(rewriter, loc,
448  i * blockedChunkSize);
449  Value incVec =
450  vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
451  Value offsetIndice =
452  arith::AddIOp::create(rewriter, loc, indice, incVec);
453 
454  auto newOp = xegpu::CreateDescOp::create(
455  rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
456 
457  newOps.push_back(newOp);
458  }
459  }
460  } else {
461  for (auto indice : convertedIndiceVec) {
462  auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
463  op.getSource(), indice);
464  newOps.push_back(newOp);
465  }
466  }
467 
468  Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
469  rewriter.replaceOp(op, castOp);
470 
471  return success();
472  }
473 };
474 
475 struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
476  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
477  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
478  PatternRewriter &rewriter) const override {
479 
480  Location loc = op.getLoc();
481  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
482  xegpu::TensorDescType tdescTy = op.getTensorDescType();
483 
484  // TODO: handle the unstructure source case (!tdesTy)
485  if (!tdescTy || op.getOffsets())
486  return failure();
487 
488  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
489  if (!targetShape)
490  return failure();
491 
492  SmallVector<int64_t> targetMaskShape(*targetShape);
493  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
494 
495  VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
496 
497  Type elemTy = tdescTy.getElementType();
498  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
499 
500  SmallVector<Type> convertedTdescTypes =
501  getUnrolledTypes(tdescTy, *targetShape);
502  SmallVector<Value> convertedTdescs = pack(
503  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
504 
505  SmallVector<Type> convertedMaskTypes;
506  SmallVector<Value> convertedMasks;
507 
508  if (originalChunkSize > 1) {
509  targetMaskShape.pop_back();
510  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
511  int64_t blockedChunkSize = targetShape->back();
512  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
513 
514  // the mask is reused across the chunk_size dimension
515  for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
516  loc, rewriter))
517  convertedMasks.append(numNewChunks, mask);
518 
519  newValueTy = valueTy.cloneWith(*targetShape, elemTy);
520  } else {
521  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
522  convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
523  loc, rewriter);
524  }
525 
526  SmallVector<Value> newOps;
527  for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
528  auto newOp = xegpu::LoadGatherOp::create(
529  rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
530  op.getL2HintAttr(), op.getL3HintAttr());
531  newOps.push_back(newOp);
532  }
533 
534  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
535  rewriter.replaceOp(op, castOp);
536  return success();
537  }
538 };
539 
540 /// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
541 /// load).
542 /// It unrolls the offsets and mask operands accordingly, and creates multiple
543 /// LoadGatherOp with the unrolled operands.
544 struct UnrollLoadGatherOpWithOffset
545  : public UnrollPattern<xegpu::LoadGatherOp> {
546  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
547  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
548  PatternRewriter &rewriter) const override {
549  Location loc = op.getLoc();
550  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
551  Value offsets = op.getOffsets();
552  Value mask = op.getMask();
553 
554  // Only handle the case where offsets are present (scattered load)
555  if (!offsets)
556  return failure();
557 
558  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
559  if (!targetShape)
560  return failure();
561 
562  SmallVector<int64_t> targetMaskShape(*targetShape);
563  int64_t chunkSize = 1;
564  if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
565  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
566  chunkSize = intAttr.getInt();
567  }
568 
569  // Unroll mask and offsets with correct shape
570  VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
571  VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
572  Type elemTy = valueTy.getElementType();
573  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
574 
575  SmallVector<Type> convertedMaskTypes;
576  SmallVector<Value> convertedMasks;
577  SmallVector<Type> convertedOffsetTypes;
578  SmallVector<Value> convertedOffsets;
579 
580  if (chunkSize > 1) {
581  // For chunked loads, mask and offsets have one less dimension
582  targetMaskShape.pop_back();
583  int64_t blockedChunkSize = targetShape->back();
584  int64_t numNewChunks = chunkSize / blockedChunkSize;
585  chunkSize = blockedChunkSize;
586 
587  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
588  convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
589 
590  SmallVector<Value> convertedMasksBase =
591  pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
592  SmallVector<Value> convertedOffsetsBase =
593  pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
594 
595  for (auto maskVal : convertedMasksBase)
596  convertedMasks.append(numNewChunks, maskVal);
597 
598  for (auto [baseOffset, offsetType] :
599  llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
600  for (int64_t i = 0; i < numNewChunks; ++i) {
601  Value inc = arith::ConstantIndexOp::create(rewriter, loc,
602  i * blockedChunkSize);
603  Value incVec =
604  vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
605  Value offsetVal =
606  arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
607  convertedOffsets.push_back(offsetVal);
608  }
609  }
610  } else {
611  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
612  convertedMasks =
613  pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
614 
615  convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
616  convertedOffsets =
617  pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
618  }
619 
620  SmallVector<Value> newOps;
621  for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
622  auto newOp = xegpu::LoadGatherOp::create(
623  rewriter, loc, newValueTy, op.getSource(), o, m,
624  rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
625  op.getL2HintAttr(), op.getL3HintAttr());
626  newOps.push_back(newOp);
627  }
628 
629  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
630  rewriter.replaceOp(op, castOp);
631  return success();
632  }
633 };
634 
635 /// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
636 /// store).
637 /// It unrolls the offsets and mask operands accordingly, and creates multiple
638 /// StoreScatterOp with the unrolled operands.
639 struct UnrollStoreScatterOpWithOffsets
640  : public UnrollPattern<xegpu::StoreScatterOp> {
641  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
642  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
643  PatternRewriter &rewriter) const override {
644  Location loc = op.getLoc();
645  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
646  Value offsets = op.getOffsets();
647  Value mask = op.getMask();
648 
649  // Only handle the case where offsets are present (scattered store)
650  if (!offsets)
651  return failure();
652 
653  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
654  if (!targetShape)
655  return failure();
656 
657  int64_t chunkSize = 1;
658  if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
659  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
660  chunkSize = intAttr.getInt();
661  }
662 
663  SmallVector<int64_t> targetMaskShape(*targetShape);
664  VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
665  VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
666 
667  SmallVector<Type> convertedMaskTypes;
668  SmallVector<Value> convertedMasks;
669  SmallVector<Type> convertedOffsetTypes;
670  SmallVector<Value> convertedOffsets;
671 
672  if (chunkSize > 1) {
673  targetMaskShape.pop_back();
674  int64_t blockedChunkSize = targetShape->back();
675  int64_t numNewChunks = chunkSize / blockedChunkSize;
676  chunkSize = blockedChunkSize;
677 
678  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
679  convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
680 
681  SmallVector<Value> convertedMasksBase =
682  pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
683  SmallVector<Value> convertedOffsetsBase =
684  pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
685 
686  for (auto maskVal : convertedMasksBase)
687  convertedMasks.append(numNewChunks, maskVal);
688 
689  for (auto [baseOffset, offsetType] :
690  llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
691  for (int64_t i = 0; i < numNewChunks; ++i) {
692  Value inc = arith::ConstantIndexOp::create(rewriter, loc,
693  i * blockedChunkSize);
694  Value incVec =
695  vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
696  Value offsetVal =
697  arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
698  convertedOffsets.push_back(offsetVal);
699  }
700  }
701  } else {
702  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
703  convertedMasks =
704  pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
705 
706  convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
707  convertedOffsets =
708  pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
709  }
710 
711  SmallVector<Type> convertedValTypes =
712  getUnrolledTypes(valueTy, *targetShape);
713  SmallVector<Value> convertedValues =
714  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
715 
716  for (auto [v, o, m] :
717  llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
718  xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
719  rewriter.getI64IntegerAttr(chunkSize),
720  op.getL1HintAttr(), op.getL2HintAttr(),
721  op.getL3HintAttr());
722  }
723 
724  rewriter.eraseOp(op);
725  return success();
726  }
727 };
728 
729 struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
730  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
731  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
732  PatternRewriter &rewriter) const override {
733  Location loc = op.getLoc();
734  xegpu::TensorDescType tdescTy = op.getTensorDescType();
735 
736  // TODO: handle the unstructure source case (!tdesTy)
737  if (!tdescTy || op.getOffsets())
738  return failure();
739 
740  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
741  if (!targetShape)
742  return failure();
743 
744  SmallVector<Type> convertedTdescTypes =
745  getUnrolledTypes(tdescTy, *targetShape);
746  SmallVector<Value> convertedTdesc = pack(
747  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
748 
749  for (auto t : convertedTdesc)
750  xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t, op->getAttrs());
751 
752  rewriter.eraseOp(op);
753  return success();
754  }
755 };
756 
757 struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
758  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
759  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
760  PatternRewriter &rewriter) const override {
761 
762  Location loc = op.getLoc();
763  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
764  xegpu::TensorDescType tdescTy = op.getTensorDescType();
765 
766  // TODO: handle the unstructure source case (!tdesTy)
767  if (!tdescTy || op.getOffsets())
768  return failure();
769 
770  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
771  if (!targetShape)
772  return failure();
773 
774  SmallVector<int64_t> targetMaskShape(*targetShape);
775  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
776 
777  VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
778 
779  SmallVector<Type> convertedTdescTypes =
780  getUnrolledTypes(tdescTy, *targetShape);
781  SmallVector<Value> convertedTdescs = pack(
782  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
783 
784  SmallVector<Type> convertedMaskTypes;
785  SmallVector<Value> convertedMasks;
786 
787  if (originalChunkSize > 1) {
788  targetMaskShape.pop_back();
789  int64_t blockedChunkSize = targetShape->back();
790  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
791  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
792 
793  // the mask is reused across the chunk_size dimension
794  for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
795  loc, rewriter))
796  convertedMasks.append(numNewChunks, mask);
797  } else {
798  convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
799  convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
800  loc, rewriter);
801  }
802 
803  SmallVector<Type> convertedValTypes =
804  getUnrolledTypes(valueTy, *targetShape);
805  SmallVector<Value> convertedValues =
806  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
807 
808  for (size_t i = 0; i < convertedValues.size(); ++i) {
809  Value v = convertedValues[i];
810  Value t = convertedTdescs[i];
811  Value m = op.getMask() ? convertedMasks[i] : nullptr;
812  xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
813  op.getL2HintAttr(), op.getL3HintAttr());
814  }
815 
816  rewriter.eraseOp(op);
817  return success();
818  }
819 };
820 
821 struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
822  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
823  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
824  PatternRewriter &rewriter) const override {
825  Location loc = op.getLoc();
826  xegpu::TensorDescType tdescTy = op.getTensorDescType();
827 
828  if (!tdescTy.isScattered())
829  return failure();
830 
831  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
832  if (!targetShape)
833  return failure();
834 
835  SmallVector<Type> convertedTdescTypes =
836  getUnrolledTypes(tdescTy, *targetShape);
837  SmallVector<Value> convertedTdesc = pack(
838  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
839 
840  TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
841  VectorType offsetVecTy = offsetVec.getType();
842  SmallVector<Type> convertedOffsetTypes;
843  SmallVector<Value> convertedOffsetVec;
844  SmallVector<Value> newOps;
845  int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
846  if (originalChunkSize > 1) {
847  auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
848  convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
849 
850  int64_t blockedChunkSize = targetShape->back();
851  int64_t numNewChunks = originalChunkSize / blockedChunkSize;
852  // the offset is reused across the chunk_size dimension
853  for (auto offset : pack(offsetVec, convertedOffsetTypes,
854  targetOffsetShape, loc, rewriter))
855  convertedOffsetVec.append(numNewChunks, offset);
856 
857  } else {
858  convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
859  convertedOffsetVec =
860  pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
861  }
862 
863  for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
864  auto newOp =
865  xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
866  newOps.push_back(newOp);
867  }
868  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
869  rewriter.replaceOp(op, castOp);
870  return success();
871  }
872 };
873 
874 struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
875  using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
876  LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
877  PatternRewriter &rewriter) const override {
878  Location loc = op.getLoc();
879  VectorType valueTy = op.getType();
880  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
881  if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
882  return failure();
883 
884  Type elemTy = valueTy.getElementType();
885  ArrayRef<int64_t> shape = valueTy.getShape();
886  auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
887 
888  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
889 
890  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
892  for (SmallVector<int64_t> offsets :
893  StaticTileOffsetRange(shape, *targetShape)) {
894  auto adds = xegpu::addElementwise(
895  rewriter, loc, mixedOffsets,
896  getAsIndexOpFoldResult(op.getContext(), offsets));
897  offsetsList.push_back(adds);
898  }
899 
900  SmallVector<Value> newOps;
901  layout = layout.dropInstData();
902  for (SmallVector<OpFoldResult> offsets : offsetsList) {
903  auto newOp = xegpu::LoadMatrixOp::create(
904  rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
905  newOps.push_back(newOp);
906  }
907  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
908  rewriter.replaceOp(op, castOp);
909  return success();
910  }
911 };
912 
913 struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
914  using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
915  LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
916  PatternRewriter &rewriter) const override {
917  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
918  if (!targetShape)
919  return failure();
920 
921  Location loc = op.getLoc();
922  VectorType valueTy = op.getData().getType();
923  ArrayRef<int64_t> shape = valueTy.getShape();
924  auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
925 
926  SmallVector<Type> convertedValTypes =
927  getUnrolledTypes(valueTy, *targetShape);
928  SmallVector<Value> convertedValues =
929  pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
930 
931  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
933  for (SmallVector<int64_t> offsets :
934  StaticTileOffsetRange(shape, *targetShape)) {
935  auto adds = xegpu::addElementwise(
936  rewriter, loc, mixedOffsets,
937  getAsIndexOpFoldResult(op.getContext(), offsets));
938  offsetsList.push_back(adds);
939  }
940 
941  for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
942  xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
943  layout.dropInstData());
944 
945  rewriter.eraseOp(op);
946  return success();
947  }
948 };
949 
950 } // namespace
951 
954  patterns
955  .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
956  UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
957  UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
958  UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
959  UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
960  patterns.getContext(), options);
961 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:464
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:567
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
Definition: XeGPUUtils.cpp:260
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Definition: XeGPUUtils.cpp:240
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Definition: XeGPUUtils.cpp:451
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options to control the XeGPU unrolling.
Definition: Transforms.h:27
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.