MLIR 23.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/DebugLog.h"
22
23namespace mlir {
24namespace xegpu {
25#define GEN_PASS_DEF_XEGPUUNROLL
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
27} // namespace xegpu
28} // namespace mlir
29
30#define DEBUG_TYPE "xegpu-unroll"
31
32using namespace mlir;
33
34namespace {
35
36template <typename SourceOp>
37struct UnrollPattern : public OpRewritePattern<SourceOp> {
38 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
39 PatternBenefit benefit = 1)
40 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
41
42protected:
43 /// Return the target shape for the given `op`. Return std::nullopt if the
44 /// op shouldn't be or cannot be unrolled.
45 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
46 LDBG() << "Get unroll shape for: " << *op;
47
48 if (options.filterConstraint && failed(options.filterConstraint(op))) {
49 LDBG() << "--no filter constraint -> BAIL";
50 return std::nullopt;
51 }
52
53 assert(options.nativeShape &&
54 "expects the native shape for native shape call back function.");
55 auto nativeShape = options.nativeShape(op);
56 return nativeShape;
57 }
58
59 SmallVector<Type> getUnrolledTypes(ShapedType type,
60 ArrayRef<int64_t> tileShape,
61 bool returnSingleType = false) const {
62 return options.getUnrolledTypes(type, tileShape, returnSingleType);
63 }
64
65 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
66 /// values and unrealized_conversion_cast for TensorDescType values.
67 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
68 Location loc, PatternRewriter &rewriter) const {
69 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
70 auto shape = vecTy.getShape();
71 return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
72 }
73
74 if (isa<xegpu::TensorDescType>(destTy)) {
75 auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
76 rewriter.getUnitAttr());
77 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
78 rewriter.getDenseI64ArrayAttr(blockSize));
79 auto castOp = UnrealizedConversionCastOp::create(
80 rewriter, loc, destTy, srcs,
81 ArrayRef<NamedAttribute>({attr, blkAttr}));
82 return castOp.getResult(0);
83 }
84
85 llvm_unreachable("Unexpected destTy.");
86 return Value();
87 }
88
89 /// Emulate the the pack behavior using extract_strided_slice for VectorType
90 /// values and unrealized_conversion_cast for TensorDescType values.
91 SmallVector<Value> pack(Value src, TypeRange destTypes,
92 ArrayRef<int64_t> blockSize, Location loc,
93 PatternRewriter &rewriter) const {
94 if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
95 return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
96 blockSize);
97 }
98
99 if (isa<xegpu::TensorDescType>(src.getType())) {
100 auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
101 rewriter.getUnitAttr());
102 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
103 rewriter.getDenseI64ArrayAttr(blockSize));
104 auto castOp = UnrealizedConversionCastOp::create(
105 rewriter, loc, destTypes, src,
106 ArrayRef<NamedAttribute>({attr, blkAttr}));
107 return castOp.getResults();
108 }
109
110 llvm_unreachable("Unexpected src type.");
111 return SmallVector<Value>();
112 }
113
114private:
115 const char *const packAttrName = "__xegpu_blocking_pack__";
116 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
117 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
118
120};
121
122// Generic helper function for unrolling operations with offsets.
123//
124// Iterates over tile offsets within the tensor descriptor shape and calls
125// the provided createOp function for each computed offset. This is used by
126// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
127// have explicit offsets that need to be adjusted for each unrolled tile.
128SmallVector<Value> computeUnrolledOffsets(
129 SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
130 ArrayRef<int64_t> targetShape,
131 const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
132 Location loc, PatternRewriter &rewriter) {
133 int64_t rank = tdescTy.getRank();
134 ArrayRef<int64_t> shape = tdescTy.getShape();
135
136 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
137 std::optional<int64_t> maybeInt = getConstantIntValue(a);
138 if (maybeInt) {
139 return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
140 } else {
141 auto aV = llvm::cast<Value>(a);
142 auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
143 return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
144 }
145 };
146
147 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
148 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
149 auto validIdxes =
150 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
151
152 SmallVector<Value> newOps;
153 for (SmallVector<int64_t> offsets :
154 StaticTileOffsetRange(shape, targetShape)) {
155
156 for (auto [idx, oldOff, offset] :
157 llvm::zip(validIdxes, oldOffsets, offsets))
158 mixedOffsets[idx] = addi(oldOff, offset);
159
160 auto newOp = createOp(mixedOffsets);
161 newOps.push_back(newOp);
162 }
163 return newOps;
164}
165
166struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
167 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
168 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
169 PatternRewriter &rewriter) const override {
170 Location loc = op.getLoc();
171 xegpu::TensorDescType tdescTy = op.getType();
172
173 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
174 if (!targetShape)
175 return failure();
176
177 SmallVector<Value> newOps;
178
179 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
180 bool hasOffsets = op.getMixedOffsets().size() != 0;
181 if (!hasOffsets) {
182 auto newOp = xegpu::CreateNdDescOp::create(
183 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
184 op.getMixedStrides());
185 newOps.push_back(newOp);
186 } else {
187 auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
188 return xegpu::CreateNdDescOp::create(
189 rewriter, loc, newTdescTy, op.getSource(), offsets,
190 op.getMixedSizes(), op.getMixedStrides());
191 };
192
193 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
194 *targetShape, createOp, loc, rewriter);
195 }
196 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
197 rewriter.replaceOp(op, castOp);
198
199 return success();
200 }
201};
202
203struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
204 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
205 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
206 PatternRewriter &rewriter) const override {
207 Location loc = op.getLoc();
208 xegpu::TensorDescType tdescTy = op.getTensorDescType();
209
210 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
211 if (!targetShape)
212 return failure();
213
214 SmallVector<Type> convertedTdescTypes =
215 getUnrolledTypes(tdescTy, *targetShape);
216 SmallVector<Value> convertedTdesc = pack(
217 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
218
219 SmallVector<Value> newOps;
220 for (auto t : convertedTdesc) {
221 auto newOp = xegpu::UpdateNdOffsetOp::create(
222 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
223 newOps.push_back(newOp);
224 }
225 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
226 rewriter.replaceOp(op, castOp);
227 return success();
228 }
229};
230
231struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
232 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
233 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
234 PatternRewriter &rewriter) const override {
235 Location loc = op.getLoc();
236 xegpu::TensorDescType tdescTy = op.getTensorDescType();
237
238 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
239 if (!targetShape)
240 return failure();
241
242 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
243 if (layout)
244 layout = layout.dropInstData();
245 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
246 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
247
248 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
249 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
250
251 SmallVector<Value> convertedTdesc = pack(
252 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
253
254 if (!hasOffsets) {
255 for (auto t : convertedTdesc)
256 xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
257 xegpu::dropInstDataOnAttrs(op->getAttrs()));
258 } else {
259 auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
260 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
261 op.getL1HintAttr(), op.getL2HintAttr(),
262 op.getL3HintAttr(), layout);
263 // return dummy Value to satisfy function's signature
264 return nullptr;
265 };
266
267 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
268 createPrefetch, loc, rewriter);
269 }
270
271 rewriter.eraseOp(op);
272 return success();
273 }
274};
275
276struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
277 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
278 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
279 PatternRewriter &rewriter) const override {
280
281 Location loc = op.getLoc();
282 VectorType valueTy = op.getType();
283 xegpu::TensorDescType tdescTy = op.getTensorDescType();
284
285 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
286 if (!targetShape)
287 return failure();
288
289 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
290 if (layout)
291 layout = layout.dropInstData();
292 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
293 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
294
295 Type elemTy = tdescTy.getElementType();
296 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
297
298 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
299 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
300
301 SmallVector<Value> convertedTdescs = pack(
302 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
303 SmallVector<Value> newOps;
304
305 if (!hasOffsets) {
306 for (auto t : convertedTdescs) {
307 auto newOp =
308 xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
309 xegpu::dropInstDataOnAttrs(op->getAttrs()));
310 newOps.push_back(newOp);
311 }
312 } else {
313 auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
314 return xegpu::LoadNdOp::create(
315 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
316 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
317 op.getL2HintAttr(), op.getL3HintAttr(), layout);
318 };
319 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
320 *targetShape, createLoad, loc, rewriter);
321 }
322
323 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
324
325 rewriter.replaceOp(op, castOp);
326 return success();
327 }
328};
329
330struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
331 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
332 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
333 PatternRewriter &rewriter) const override {
334 Location loc = op.getLoc();
335 VectorType valueTy = op.getValueType();
336 xegpu::TensorDescType tdescTy = op.getTensorDescType();
337
338 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
339 if (!targetShape)
340 return failure();
341
342 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
343 if (layout)
344 layout = layout.dropInstData();
345 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
346 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
347
348 SmallVector<Type> convertedValTypes =
349 getUnrolledTypes(valueTy, *targetShape);
350 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
351 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
352
353 SmallVector<Value> convertedTdescs = pack(
354 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
355
356 SmallVector<Value> convertedValues =
357 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
358 if (!hasOffsets) {
359 for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
360 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
361 op.getL2HintAttr(), op.getL3HintAttr());
362 } else {
363 size_t valueIndex = 0;
364 auto createStore = [&](SmallVector<OpFoldResult> offsets) {
365 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
366 convertedTdescs[0], offsets,
367 op.getL1HintAttr(), op.getL2HintAttr(),
368 op.getL3HintAttr(), layout);
369 // return dummy Value to satisfy function's signature
370 return nullptr;
371 };
372
373 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
374 createStore, loc, rewriter);
375 }
376
377 rewriter.eraseOp(op);
378 return success();
379 }
380};
381
382struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
383 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
384 LogicalResult matchAndRewrite(xegpu::DpasOp op,
385 PatternRewriter &rewriter) const override {
386 Location loc = op.getLoc();
387
388 // expecting every operands is a 2D Vector
389 if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
390 auto vecTy = dyn_cast<VectorType>(type);
391 return !vecTy || vecTy.getRank() != 2;
392 }))
393 return failure();
394
395 // A vector of 3 elements should be returned, representing M, K, N
396 // respectively.
397 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
398 if (!targetShape || targetShape->size() != 3)
399 return failure();
400 auto M = (*targetShape)[0];
401 auto K = (*targetShape)[1];
402 auto N = (*targetShape)[2];
403
404 int64_t aBlockSize[2] = {M, K};
405 int64_t bBlockSize[2] = {K, N};
406 int64_t cBlockSize[2] = {M, N};
407
408 auto packWrapper = [&](TypedValue<VectorType> val,
409 ArrayRef<int64_t> blockSize) {
410 VectorType type = val.getType();
411 std::optional<SmallVector<int64_t>> grids =
412 computeShapeRatio(type.getShape(), blockSize);
413 assert(grids && "Expecting grids to be computed.");
414 auto numNewOps = computeProduct(*grids);
415 if (numNewOps == 1)
416 return SmallVector<Value>({val});
417 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
418 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
419 SmallVector<Value> values =
420 pack(val, convertedTypes, blockSize, loc, rewriter);
421 return values;
422 };
423
424 auto a = op.getLhs();
425 auto b = op.getRhs();
426 auto c = op.getAcc();
427
428 auto aShape = a.getType().getShape();
429 auto bShape = b.getType().getShape();
430
431 SmallVector<Value> aVals, bVals, cVals;
432 aVals = packWrapper(a, aBlockSize);
433 bVals = packWrapper(b, bBlockSize);
434
435 if (c)
436 cVals = packWrapper(c, cBlockSize);
437
438 // Skip the operation if every operand has an invalid blocking size (empty)
439 // or if the original shape matches the blocking size (size == 1).
440 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
441 : SmallVector<ValueRange>({aVals, bVals});
442 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
443 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
444 return failure();
445
446 VectorType resultTy = op.getResult().getType();
447 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
448
449 int64_t mIters = aShape[0] / M;
450 int64_t kIters = aShape[1] / K;
451 int64_t nIters = bShape[1] / N;
452
453 SmallVector<Value> newOps;
454 for (int64_t i = 0; i < mIters; ++i) {
455 for (int64_t j = 0; j < nIters; ++j) {
456 Value tmpC;
457 if (c)
458 tmpC = cVals[i * nIters + j]; // init with acc
459
460 for (int64_t k = 0; k < kIters; ++k) {
461 Value aVec = aVals[i * kIters + k];
462 Value bVec = bVals[k * nIters + j];
463 SmallVector<Value> operands({aVec, bVec});
464 if (tmpC)
465 operands.push_back(tmpC);
466
467 tmpC =
468 xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
469 xegpu::dropInstDataOnAttrs(op->getAttrs()));
470 }
471 newOps.push_back(tmpC);
472 }
473 }
474 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
475 rewriter.replaceOp(op, castOp);
476 return success();
477 }
478};
479
480struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
481 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
482 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
483 PatternRewriter &rewriter) const override {
484
485 Location loc = op.getLoc();
486 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
487 xegpu::TensorDescType tdescTy = op.getTensorDescType();
488
489 // TODO: handle the unstructure source case (!tdesTy)
490 if (!tdescTy || op.getOffsets())
491 return failure();
492
493 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
494 if (!targetShape)
495 return failure();
496
497 SmallVector<int64_t> targetMaskShape(*targetShape);
498 int originalChunkSize = op.getChunkSize().value_or(1);
499
500 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
501
502 Type elemTy = tdescTy.getElementType();
503 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
504
505 SmallVector<Type> convertedTdescTypes =
506 getUnrolledTypes(tdescTy, *targetShape);
507 SmallVector<Value> convertedTdescs = pack(
508 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
509
510 SmallVector<Type> convertedMaskTypes;
511 SmallVector<Value> convertedMasks;
512
513 if (originalChunkSize > 1) {
514 targetMaskShape.pop_back();
515 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
516 int64_t blockedChunkSize = targetShape->back();
517 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
518
519 // the mask is reused across the chunk_size dimension
520 for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
521 loc, rewriter))
522 convertedMasks.append(numNewChunks, mask);
523
524 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
525 } else {
526 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
527 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
528 loc, rewriter);
529 }
530
531 SmallVector<Value> newOps;
532 for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
533 auto newOp = xegpu::LoadGatherOp::create(
534 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
535 op.getL2HintAttr(), op.getL3HintAttr());
536 newOps.push_back(newOp);
537 }
538
539 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
540 rewriter.replaceOp(op, castOp);
541 return success();
542 }
543};
544
545/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
546/// load).
547/// It unrolls the offsets and mask operands accordingly, and creates multiple
548/// LoadGatherOp with the unrolled operands.
549struct UnrollLoadGatherOpWithOffset
550 : public UnrollPattern<xegpu::LoadGatherOp> {
551 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
552 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
553 PatternRewriter &rewriter) const override {
554 Location loc = op.getLoc();
555 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
556 Value offsets = op.getOffsets();
557 Value mask = op.getMask();
558
559 // Only handle the case where offsets are present (scattered load)
560 if (!offsets)
561 return failure();
562
563 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
564 if (!targetShape)
565 return failure();
566
567 SmallVector<int64_t> targetMaskShape(*targetShape);
568 int64_t chunkSize = 1;
569 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
570 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
571 chunkSize = intAttr.getInt();
572 }
573
574 // Unroll mask and offsets with correct shape
575 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
576 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
577 Type elemTy = valueTy.getElementType();
578 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
579
580 SmallVector<Type> convertedMaskTypes;
581 SmallVector<Value> convertedMasks;
582 SmallVector<Type> convertedOffsetTypes;
583 SmallVector<Value> convertedOffsets;
584
585 if (chunkSize > 1) {
586 // For chunked loads, mask and offsets have one less dimension
587 targetMaskShape.pop_back();
588 int64_t blockedChunkSize = targetShape->back();
589 int64_t numNewChunks = chunkSize / blockedChunkSize;
590 chunkSize = blockedChunkSize;
591
592 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
593 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
594
595 SmallVector<Value> convertedMasksBase =
596 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
597 SmallVector<Value> convertedOffsetsBase =
598 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
599
600 for (auto maskVal : convertedMasksBase)
601 convertedMasks.append(numNewChunks, maskVal);
602
603 for (auto [baseOffset, offsetType] :
604 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
605 for (int64_t i = 0; i < numNewChunks; ++i) {
606 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
607 i * blockedChunkSize);
608 Value incVec =
609 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
610 Value offsetVal =
611 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
612 convertedOffsets.push_back(offsetVal);
613 }
614 }
615 } else {
616 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
617 convertedMasks =
618 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
619
620 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
621 convertedOffsets =
622 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
623 }
624
625 auto layout = op.getLayoutAttr();
626 if (layout)
627 layout = layout.dropInstData();
628
629 SmallVector<Value> newOps;
630 for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
631 auto newOp = xegpu::LoadGatherOp::create(
632 rewriter, loc, newValueTy, op.getSource(), o, m,
633 rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
634 op.getL2HintAttr(), op.getL3HintAttr(), layout);
635 newOps.push_back(newOp);
636 }
637
638 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
639 rewriter.replaceOp(op, castOp);
640 return success();
641 }
642};
643
644/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
645/// store).
646/// It unrolls the offsets and mask operands accordingly, and creates multiple
647/// StoreScatterOp with the unrolled operands.
648struct UnrollStoreScatterOpWithOffsets
649 : public UnrollPattern<xegpu::StoreScatterOp> {
650 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
651 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
652 PatternRewriter &rewriter) const override {
653 Location loc = op.getLoc();
654 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
655 Value offsets = op.getOffsets();
656 Value mask = op.getMask();
657
658 // Only handle the case where offsets are present (scattered store)
659 if (!offsets)
660 return failure();
661
662 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
663 if (!targetShape)
664 return failure();
665
666 int64_t chunkSize = 1;
667 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
668 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
669 chunkSize = intAttr.getInt();
670 }
671
672 SmallVector<int64_t> targetMaskShape(*targetShape);
673 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
674 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
675
676 SmallVector<Type> convertedMaskTypes;
677 SmallVector<Value> convertedMasks;
678 SmallVector<Type> convertedOffsetTypes;
679 SmallVector<Value> convertedOffsets;
680
681 if (chunkSize > 1) {
682 targetMaskShape.pop_back();
683 int64_t blockedChunkSize = targetShape->back();
684 int64_t numNewChunks = chunkSize / blockedChunkSize;
685 chunkSize = blockedChunkSize;
686
687 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
688 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
689
690 SmallVector<Value> convertedMasksBase =
691 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
692 SmallVector<Value> convertedOffsetsBase =
693 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
694
695 for (auto maskVal : convertedMasksBase)
696 convertedMasks.append(numNewChunks, maskVal);
697
698 for (auto [baseOffset, offsetType] :
699 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
700 for (int64_t i = 0; i < numNewChunks; ++i) {
701 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
702 i * blockedChunkSize);
703 Value incVec =
704 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
705 Value offsetVal =
706 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
707 convertedOffsets.push_back(offsetVal);
708 }
709 }
710 } else {
711 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
712 convertedMasks =
713 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
714
715 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
716 convertedOffsets =
717 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
718 }
719
720 SmallVector<Type> convertedValTypes =
721 getUnrolledTypes(valueTy, *targetShape);
722 SmallVector<Value> convertedValues =
723 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
724
725 auto layout = op.getLayoutAttr();
726 if (layout)
727 layout = layout.dropInstData();
728
729 for (auto [v, o, m] :
730 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
731 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
732 rewriter.getI64IntegerAttr(chunkSize),
733 op.getL1HintAttr(), op.getL2HintAttr(),
734 op.getL3HintAttr(), layout);
735 }
736
737 rewriter.eraseOp(op);
738 return success();
739 }
740};
741
742struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
743 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
744 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
745 PatternRewriter &rewriter) const override {
746 Location loc = op.getLoc();
747 xegpu::TensorDescType tdescTy = op.getTensorDescType();
748
749 // TODO: handle the unstructure source case (!tdesTy)
750 if (!tdescTy || op.getOffsets())
751 return failure();
752
753 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
754 if (!targetShape)
755 return failure();
756
757 SmallVector<Type> convertedTdescTypes =
758 getUnrolledTypes(tdescTy, *targetShape);
759 SmallVector<Value> convertedTdesc = pack(
760 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
761
762 for (auto t : convertedTdesc)
763 xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t,
764 xegpu::dropInstDataOnAttrs(op->getAttrs()));
765
766 rewriter.eraseOp(op);
767 return success();
768 }
769};
770
771struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
772 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
773 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
774 PatternRewriter &rewriter) const override {
775
776 Location loc = op.getLoc();
777 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
778 xegpu::TensorDescType tdescTy = op.getTensorDescType();
779
780 // TODO: handle the unstructure source case (!tdesTy)
781 if (!tdescTy || op.getOffsets())
782 return failure();
783
784 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
785 if (!targetShape)
786 return failure();
787
788 SmallVector<int64_t> targetMaskShape(*targetShape);
789 int originalChunkSize = op.getChunkSize().value_or(1);
790
791 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
792
793 SmallVector<Type> convertedTdescTypes =
794 getUnrolledTypes(tdescTy, *targetShape);
795 SmallVector<Value> convertedTdescs = pack(
796 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
797
798 SmallVector<Type> convertedMaskTypes;
799 SmallVector<Value> convertedMasks;
800
801 if (originalChunkSize > 1) {
802 targetMaskShape.pop_back();
803 int64_t blockedChunkSize = targetShape->back();
804 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
805 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
806
807 // the mask is reused across the chunk_size dimension
808 for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
809 loc, rewriter))
810 convertedMasks.append(numNewChunks, mask);
811 } else {
812 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
813 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
814 loc, rewriter);
815 }
816
817 SmallVector<Type> convertedValTypes =
818 getUnrolledTypes(valueTy, *targetShape);
819 SmallVector<Value> convertedValues =
820 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
821
822 for (size_t i = 0; i < convertedValues.size(); ++i) {
823 Value v = convertedValues[i];
824 Value t = convertedTdescs[i];
825 Value m = op.getMask() ? convertedMasks[i] : nullptr;
826 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
827 op.getL2HintAttr(), op.getL3HintAttr());
828 }
829
830 rewriter.eraseOp(op);
831 return success();
832 }
833};
834
835struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
836 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
837 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
838 PatternRewriter &rewriter) const override {
839 Location loc = op.getLoc();
840 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
841 assert(valueTy && "the value type must be vector type!");
842
843 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
844 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
845 return failure();
846
847 Type elemTy = valueTy.getElementType();
848 ArrayRef<int64_t> shape = valueTy.getShape();
849 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
850
851 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
852
853 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
855 for (SmallVector<int64_t> offsets :
856 StaticTileOffsetRange(shape, *targetShape)) {
857 auto adds = xegpu::addElementwise(
858 rewriter, loc, mixedOffsets,
859 getAsIndexOpFoldResult(op.getContext(), offsets));
860 offsetsList.push_back(adds);
861 }
862
863 SmallVector<Value> newOps;
864 layout = layout.dropInstData();
865 for (SmallVector<OpFoldResult> offsets : offsetsList) {
866 auto newOp = xegpu::LoadMatrixOp::create(
867 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
868 newOps.push_back(newOp);
869 }
870 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
871 rewriter.replaceOp(op, castOp);
872 return success();
873 }
874};
875
876struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
877 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
878 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
879 PatternRewriter &rewriter) const override {
880 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
881 if (!targetShape)
882 return failure();
883
884 Location loc = op.getLoc();
885 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
886 assert(valueTy && "the value type must be vector type!");
887 ArrayRef<int64_t> shape = valueTy.getShape();
888 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
889
890 SmallVector<Type> convertedValTypes =
891 getUnrolledTypes(valueTy, *targetShape);
892 SmallVector<Value> convertedValues =
893 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
894
895 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
897 for (SmallVector<int64_t> offsets :
898 StaticTileOffsetRange(shape, *targetShape)) {
899 auto adds = xegpu::addElementwise(
900 rewriter, loc, mixedOffsets,
901 getAsIndexOpFoldResult(op.getContext(), offsets));
902 offsetsList.push_back(adds);
903 }
904
905 for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
906 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
907 layout.dropInstData());
908
909 rewriter.eraseOp(op);
910 return success();
911 }
912};
913
914/// UnrollConvertLayoutOp pattern for unrolling xegpu::ConvertLayoutOp
915/// operations. It first check whether the convert layout op has valid layouts
916/// after inst_data stripped. If it does, it will unroll the vector into
917/// multiple smaller vectors according to the target shape, and create multiple
918/// ConvertLayoutOp with the unrolled vectors and the stripped layouts.
919struct UnrollConvertLayoutOp : public UnrollPattern<xegpu::ConvertLayoutOp> {
920 using UnrollPattern<xegpu::ConvertLayoutOp>::UnrollPattern;
921 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
922 PatternRewriter &rewriter) const override {
923 Location loc = op.getLoc();
924 Type valType = op.getType();
925
926 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
927 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
928 if (!inputLayout || !targetLayout)
929 return rewriter.notifyMatchFailure(op, "missing layout attributes.");
930
931 if (valType.isIntOrFloat()) {
932 rewriter.replaceOp(op, op.getSource());
933 assert(!inputLayout.dropInstData() && !targetLayout.dropInstData() &&
934 "unexpected layout attributes for scalar type");
935 return success();
936 }
937
938 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
939 targetLayout.getEffectiveInstDataAsInt().empty())
940 return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
941
942 inputLayout = inputLayout.dropInstData();
943 targetLayout = targetLayout.dropInstData();
944
945 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
946 assert(valueTy && "the value type must be vector type!");
947
948 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
949 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
950 return failure();
951
952 Value newSource = op.getSource();
953 SmallVector<Value> newOps;
954 if (inputLayout && targetLayout) {
955 SmallVector<Type> convertedValTypes =
956 getUnrolledTypes(valueTy, *targetShape);
957 SmallVector<Value> convertedValues =
958 pack(op.getOperand(), convertedValTypes, *targetShape, loc, rewriter);
959 for (auto [v, t] : llvm::zip(convertedValues, convertedValTypes)) {
960 auto newOp = xegpu::ConvertLayoutOp::create(rewriter, loc, t, v,
961 inputLayout, targetLayout);
962 newOps.push_back(newOp);
963 }
964 newSource = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
965 }
966
967 rewriter.replaceOp(op, newSource);
968 return success();
969 }
970};
971
972} // namespace
973
976 patterns
977 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
978 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollLoadGatherOp,
979 UnrollStoreScatterOp, UnrollPrefetchOp, UnrollLoadMatrixOp,
980 UnrollStoreMatrixOp, UnrollLoadGatherOpWithOffset,
981 UnrollStoreScatterOpWithOffsets, UnrollConvertLayoutOp>(
982 patterns.getContext(), options);
983}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Definition Transforms.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.