MLIR 23.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/DebugLog.h"
22
23namespace mlir {
24namespace xegpu {
25#define GEN_PASS_DEF_XEGPUUNROLL
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
27} // namespace xegpu
28} // namespace mlir
29
30#define DEBUG_TYPE "xegpu-unroll"
31
32using namespace mlir;
33
34namespace {
35
36template <typename SourceOp>
37struct UnrollPattern : public OpRewritePattern<SourceOp> {
38 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
39 PatternBenefit benefit = 1)
40 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
41
42protected:
43 /// Return the target shape for the given `op`. Return std::nullopt if the
44 /// op shouldn't be or cannot be unrolled.
45 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
46 LDBG() << "Get unroll shape for: " << *op;
47
48 if (options.filterConstraint && failed(options.filterConstraint(op))) {
49 LDBG() << "--no filter constraint -> BAIL";
50 return std::nullopt;
51 }
52
53 assert(options.nativeShape &&
54 "expects the native shape for native shape call back function.");
55 auto nativeShape = options.nativeShape(op);
56 return nativeShape;
57 }
58
59 SmallVector<Type> getUnrolledTypes(ShapedType type,
60 ArrayRef<int64_t> tileShape,
61 bool returnSingleType = false) const {
62 return options.getUnrolledTypes(type, tileShape, returnSingleType);
63 }
64
65 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
66 /// values and unrealized_conversion_cast for TensorDescType values.
67 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
68 Location loc, PatternRewriter &rewriter) const {
69 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
70 auto shape = vecTy.getShape();
71 return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
72 }
73
74 if (isa<xegpu::TensorDescType>(destTy)) {
75 auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
76 rewriter.getUnitAttr());
77 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
78 rewriter.getDenseI64ArrayAttr(blockSize));
79 auto castOp = UnrealizedConversionCastOp::create(
80 rewriter, loc, destTy, srcs,
81 ArrayRef<NamedAttribute>({attr, blkAttr}));
82 return castOp.getResult(0);
83 }
84
85 llvm_unreachable("Unexpected destTy.");
86 return Value();
87 }
88
89 /// Emulate the the pack behavior using extract_strided_slice for VectorType
90 /// values and unrealized_conversion_cast for TensorDescType values.
91 SmallVector<Value> pack(Value src, TypeRange destTypes,
92 ArrayRef<int64_t> blockSize, Location loc,
93 PatternRewriter &rewriter) const {
94 if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
95 return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
96 blockSize);
97 }
98
99 if (isa<xegpu::TensorDescType>(src.getType())) {
100 auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
101 rewriter.getUnitAttr());
102 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
103 rewriter.getDenseI64ArrayAttr(blockSize));
104 auto castOp = UnrealizedConversionCastOp::create(
105 rewriter, loc, destTypes, src,
106 ArrayRef<NamedAttribute>({attr, blkAttr}));
107 return castOp.getResults();
108 }
109
110 llvm_unreachable("Unexpected src type.");
111 return SmallVector<Value>();
112 }
113
114 /// Helper to pack operands for DPAS-like operations with early return if
115 /// no unrolling is needed.
116 SmallVector<Value> packOperandForDpas(Value operand,
117 ArrayRef<int64_t> blockSize,
118 Location loc,
119 PatternRewriter &rewriter) const {
120 auto vecType = cast<VectorType>(operand.getType());
121 std::optional<SmallVector<int64_t>> grids =
122 computeShapeRatio(vecType.getShape(), blockSize);
123 assert(grids && "Expecting grids to be computed.");
124 auto numNewOps = computeProduct(*grids);
125 if (numNewOps == 1)
126 return SmallVector<Value>({operand});
127 VectorType newVecTy =
128 vecType.cloneWith(blockSize, vecType.getElementType());
129 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
130 return pack(operand, convertedTypes, blockSize, loc, rewriter);
131 }
132
133private:
134 const char *const packAttrName = "__xegpu_blocking_pack__";
135 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
136 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
137
139};
140
141// Generic helper function for unrolling operations with offsets.
142//
143// Iterates over tile offsets within the tensor descriptor shape and calls
144// the provided createOp function for each computed offset. This is used by
145// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
146// have explicit offsets that need to be adjusted for each unrolled tile.
147SmallVector<Value> computeUnrolledOffsets(
148 SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
149 ArrayRef<int64_t> targetShape,
150 const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
151 Location loc, PatternRewriter &rewriter) {
152 int64_t rank = tdescTy.getRank();
153 ArrayRef<int64_t> shape = tdescTy.getShape();
154
155 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
156 std::optional<int64_t> maybeInt = getConstantIntValue(a);
157 if (maybeInt) {
158 return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
159 } else {
160 auto aV = llvm::cast<Value>(a);
161 auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
162 return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
163 }
164 };
165
166 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
167 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
168 auto validIdxes =
169 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
170
171 SmallVector<Value> newOps;
172 for (SmallVector<int64_t> offsets :
173 StaticTileOffsetRange(shape, targetShape)) {
174
175 for (auto [idx, oldOff, offset] :
176 llvm::zip(validIdxes, oldOffsets, offsets))
177 mixedOffsets[idx] = addi(oldOff, offset);
178
179 auto newOp = createOp(mixedOffsets);
180 newOps.push_back(newOp);
181 }
182 return newOps;
183}
184
185struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
186 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
187 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
188 PatternRewriter &rewriter) const override {
189 Location loc = op.getLoc();
190 xegpu::TensorDescType tdescTy = op.getType();
191
192 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
193 if (!targetShape)
194 return failure();
195
196 SmallVector<Value> newOps;
197
198 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
199 auto newOp =
200 xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, op.getSource(),
201 op.getMixedSizes(), op.getMixedStrides());
202 newOps.push_back(newOp);
203 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
204 rewriter.replaceOp(op, castOp);
205
206 return success();
207 }
208};
209
210struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
211 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
212 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
213 PatternRewriter &rewriter) const override {
214 Location loc = op.getLoc();
215 xegpu::TensorDescType tdescTy = op.getTensorDescType();
216
217 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
218 if (!targetShape)
219 return failure();
220
221 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
222 if (layout)
223 layout = layout.dropInstData();
224
225 SmallVector<Type> convertedTdescTypes =
226 getUnrolledTypes(tdescTy, *targetShape, /*returnSingleType*/ true);
227
228 SmallVector<Value> convertedTdesc = pack(
229 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
230
231 auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
232 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
233 op.getL1HintAttr(), op.getL2HintAttr(),
234 op.getL3HintAttr(), layout);
235 // return dummy Value to satisfy function's signature
236 return nullptr;
237 };
238
239 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
240 createPrefetch, loc, rewriter);
241
242 rewriter.eraseOp(op);
243 return success();
244 }
245};
246
247struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
248 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
249 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
250 PatternRewriter &rewriter) const override {
251
252 Location loc = op.getLoc();
253 VectorType valueTy = op.getType();
254 xegpu::TensorDescType tdescTy = op.getTensorDescType();
255
256 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
257 if (!targetShape)
258 return failure();
259
260 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
261 if (layout)
262 layout = layout.dropInstData();
263
264 Type elemTy = tdescTy.getElementType();
265 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
266
267 SmallVector<Type> convertedTdescTypes =
268 getUnrolledTypes(tdescTy, *targetShape, /*returnSingleType*/ true);
269
270 SmallVector<Value> convertedTdescs = pack(
271 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
272 SmallVector<Value> newOps;
273
274 auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
275 return xegpu::LoadNdOp::create(
276 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
277 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
278 op.getL2HintAttr(), op.getL3HintAttr(), layout);
279 };
280 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
281 createLoad, loc, rewriter);
282
283 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
284
285 rewriter.replaceOp(op, castOp);
286 return success();
287 }
288};
289
290struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
291 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
292 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
293 PatternRewriter &rewriter) const override {
294 Location loc = op.getLoc();
295 VectorType valueTy = op.getValueType();
296 xegpu::TensorDescType tdescTy = op.getTensorDescType();
297
298 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
299 if (!targetShape)
300 return failure();
301
302 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
303 if (layout)
304 layout = layout.dropInstData();
305
306 SmallVector<Type> convertedValTypes =
307 getUnrolledTypes(valueTy, *targetShape);
308 SmallVector<Type> convertedTdescTypes =
309 getUnrolledTypes(tdescTy, *targetShape, /*returnSingleType*/ true);
310
311 SmallVector<Value> convertedTdescs = pack(
312 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
313
314 SmallVector<Value> convertedValues =
315 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
316
317 size_t valueIndex = 0;
318 auto createStore = [&](SmallVector<OpFoldResult> offsets) {
319 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
320 convertedTdescs[0], offsets, op.getL1HintAttr(),
321 op.getL2HintAttr(), op.getL3HintAttr(), layout);
322 // return dummy Value to satisfy function's signature
323 return nullptr;
324 };
325
326 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
327 createStore, loc, rewriter);
328
329 rewriter.eraseOp(op);
330 return success();
331 }
332};
333
334struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
335 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
336 LogicalResult matchAndRewrite(xegpu::DpasOp op,
337 PatternRewriter &rewriter) const override {
338 Location loc = op.getLoc();
339
340 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
341 if (!targetShape || targetShape->size() != 3)
342 return failure();
343 auto M = (*targetShape)[0];
344 auto K = (*targetShape)[1];
345 auto N = (*targetShape)[2];
346
347 int64_t aBlockSize[2] = {M, K};
348 int64_t bBlockSize[2] = {K, N};
349 int64_t cBlockSize[2] = {M, N};
350
351 auto a = op.getLhs();
352 auto b = op.getRhs();
353 auto c = op.getAcc();
354
355 SmallVector<Value> aVals = packOperandForDpas(a, aBlockSize, loc, rewriter);
356 SmallVector<Value> bVals = packOperandForDpas(b, bBlockSize, loc, rewriter);
357 SmallVector<Value> cVals;
358 if (c)
359 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
360
361 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
362 : SmallVector<ValueRange>({aVals, bVals});
363 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
364 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
365 return failure();
366
367 VectorType resultTy = op.getResult().getType();
368 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
369
370 auto aShape = a.getType().getShape();
371 auto bShape = b.getType().getShape();
372 int64_t mIters = aShape[0] / M;
373 int64_t kIters = aShape[1] / K;
374 int64_t nIters = bShape[1] / N;
375
376 SmallVector<Value> newOps;
377 for (int64_t i = 0; i < mIters; ++i) {
378 for (int64_t j = 0; j < nIters; ++j) {
379 Value tmpC;
380 if (c)
381 tmpC = cVals[i * nIters + j];
382
383 for (int64_t k = 0; k < kIters; ++k) {
384 Value aVec = aVals[i * kIters + k];
385 Value bVec = bVals[k * nIters + j];
386 SmallVector<Value> operands({aVec, bVec});
387 if (tmpC)
388 operands.push_back(tmpC);
389
390 tmpC =
391 xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
392 xegpu::dropInstDataOnAttrs(op->getAttrs()));
393 }
394 newOps.push_back(tmpC);
395 }
396 }
397 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
398 rewriter.replaceOp(op, castOp);
399 return success();
400 }
401};
402
403struct UnrollDpasMxOp : public UnrollPattern<xegpu::DpasMxOp> {
404 using UnrollPattern<xegpu::DpasMxOp>::UnrollPattern;
405 LogicalResult matchAndRewrite(xegpu::DpasMxOp op,
406 PatternRewriter &rewriter) const override {
407 Location loc = op.getLoc();
408
409 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
410 if (!targetShape || targetShape->size() != 4)
411 return failure();
412 auto M = (*targetShape)[0];
413 auto K = (*targetShape)[1];
414 auto N = (*targetShape)[2];
415 auto S = (*targetShape)[3];
416
417 int64_t aBlockSize[2] = {M, K};
418 int64_t bBlockSize[2] = {K, N};
419 int64_t cBlockSize[2] = {M, N};
420 int64_t aScaleBlockSize[2] = {M, S};
421 int64_t bScaleBlockSize[2] = {S, N};
422
423 auto a = op.getA();
424 auto b = op.getB();
425 auto c = op.getAcc();
426 auto ascale = dyn_cast<TypedValue<VectorType>>(op.getScaleA());
427 auto bscale = dyn_cast<TypedValue<VectorType>>(op.getScaleB());
428
429 SmallVector<Value> aVals = packOperandForDpas(a, aBlockSize, loc, rewriter);
430 SmallVector<Value> bVals = packOperandForDpas(b, bBlockSize, loc, rewriter);
431 SmallVector<Value> cVals;
432 if (c)
433 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
434 SmallVector<Value> aScaleVals;
435 if (ascale)
436 aScaleVals = packOperandForDpas(ascale, aScaleBlockSize, loc, rewriter);
437 SmallVector<Value> bScaleVals;
438 if (bscale)
439 bScaleVals = packOperandForDpas(bscale, bScaleBlockSize, loc, rewriter);
440
441 VectorType resultTy = op.getResult().getType();
442 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
443
444 auto aShape = a.getType().getShape();
445 auto bShape = b.getType().getShape();
446 int64_t mIters = aShape[0] / M;
447 int64_t kIters = aShape[1] / K;
448 int64_t nIters = bShape[1] / N;
449
450 SmallVector<Value> newOps;
451 xegpu::DpasMxOp newDpasMxOp;
452 for (int64_t i = 0; i < mIters; ++i) {
453 for (int64_t j = 0; j < nIters; ++j) {
454 Value tmpC;
455 if (c)
456 tmpC = cVals[i * nIters + j];
457
458 for (int64_t k = 0; k < kIters; ++k) {
459 Value aVec = aVals[i * kIters + k];
460 Value bVec = bVals[k * nIters + j];
461 SmallVector<Value> operands({aVec, bVec});
462 if (tmpC)
463 operands.push_back(tmpC);
464 if (ascale)
465 operands.push_back(aScaleVals[i * kIters + k]);
466 if (bscale)
467 operands.push_back(bScaleVals[k * nIters + j]);
468
469 newDpasMxOp = xegpu::DpasMxOp::create(
470 rewriter, loc, vecTy, operands,
471 xegpu::dropInstDataOnAttrs(op->getAttrs()));
472 tmpC = newDpasMxOp.getResult();
473 }
474 newOps.push_back(newDpasMxOp);
475 }
476 }
477 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
478 rewriter.replaceOp(op, castOp);
479 return success();
480 }
481};
482
483/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
484/// load).
485/// It unrolls the offsets and mask operands accordingly, and creates multiple
486/// LoadGatherOp with the unrolled operands.
487struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
488 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
489 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
490 PatternRewriter &rewriter) const override {
491 Location loc = op.getLoc();
492 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
493 Value offsets = op.getOffsets();
494 Value mask = op.getMask();
495
496 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
497 if (!targetShape)
498 return failure();
499
500 SmallVector<int64_t> targetMaskShape(*targetShape);
501 int64_t chunkSize = 1;
502 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
503 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
504 chunkSize = intAttr.getInt();
505 }
506
507 // Unroll mask and offsets with correct shape
508 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
509 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
510 Type elemTy = valueTy.getElementType();
511 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
512
513 SmallVector<Type> convertedMaskTypes;
514 SmallVector<Value> convertedMasks;
515 SmallVector<Type> convertedOffsetTypes;
516 SmallVector<Value> convertedOffsets;
517
518 if (chunkSize > 1) {
519 // For chunked loads, mask and offsets have one less dimension
520 targetMaskShape.pop_back();
521 int64_t blockedChunkSize = targetShape->back();
522 int64_t numNewChunks = chunkSize / blockedChunkSize;
523 chunkSize = blockedChunkSize;
524
525 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
526 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
527
528 SmallVector<Value> convertedMasksBase =
529 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
530 SmallVector<Value> convertedOffsetsBase =
531 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
532
533 for (auto maskVal : convertedMasksBase)
534 convertedMasks.append(numNewChunks, maskVal);
535
536 for (auto [baseOffset, offsetType] :
537 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
538 for (int64_t i = 0; i < numNewChunks; ++i) {
539 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
540 i * blockedChunkSize);
541 Value incVec =
542 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
543 Value offsetVal =
544 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
545 convertedOffsets.push_back(offsetVal);
546 }
547 }
548 } else {
549 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
550 convertedMasks =
551 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
552
553 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
554 convertedOffsets =
555 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
556 }
557
558 auto layout = op.getLayoutAttr();
559 if (layout)
560 layout = layout.dropInstData();
561
562 SmallVector<Value> newOps;
563 for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
564 auto newOp = xegpu::LoadGatherOp::create(
565 rewriter, loc, newValueTy, op.getSource(), o, m,
566 rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
567 op.getL2HintAttr(), op.getL3HintAttr(), layout);
568 newOps.push_back(newOp);
569 }
570
571 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
572 rewriter.replaceOp(op, castOp);
573 return success();
574 }
575};
576
577/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
578/// store).
579/// It unrolls the offsets and mask operands accordingly, and creates multiple
580/// StoreScatterOp with the unrolled operands.
581struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
582 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
583 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
584 PatternRewriter &rewriter) const override {
585 Location loc = op.getLoc();
586 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
587 Value offsets = op.getOffsets();
588 Value mask = op.getMask();
589
590 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
591 if (!targetShape)
592 return failure();
593
594 int64_t chunkSize = 1;
595 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
596 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
597 chunkSize = intAttr.getInt();
598 }
599
600 SmallVector<int64_t> targetMaskShape(*targetShape);
601 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
602 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
603
604 SmallVector<Type> convertedMaskTypes;
605 SmallVector<Value> convertedMasks;
606 SmallVector<Type> convertedOffsetTypes;
607 SmallVector<Value> convertedOffsets;
608
609 if (chunkSize > 1) {
610 targetMaskShape.pop_back();
611 int64_t blockedChunkSize = targetShape->back();
612 int64_t numNewChunks = chunkSize / blockedChunkSize;
613 chunkSize = blockedChunkSize;
614
615 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
616 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
617
618 SmallVector<Value> convertedMasksBase =
619 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
620 SmallVector<Value> convertedOffsetsBase =
621 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
622
623 for (auto maskVal : convertedMasksBase)
624 convertedMasks.append(numNewChunks, maskVal);
625
626 for (auto [baseOffset, offsetType] :
627 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
628 for (int64_t i = 0; i < numNewChunks; ++i) {
629 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
630 i * blockedChunkSize);
631 Value incVec =
632 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
633 Value offsetVal =
634 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
635 convertedOffsets.push_back(offsetVal);
636 }
637 }
638 } else {
639 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
640 convertedMasks =
641 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
642
643 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
644 convertedOffsets =
645 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
646 }
647
648 SmallVector<Type> convertedValTypes =
649 getUnrolledTypes(valueTy, *targetShape);
650 SmallVector<Value> convertedValues =
651 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
652
653 auto layout = op.getLayoutAttr();
654 if (layout)
655 layout = layout.dropInstData();
656
657 for (auto [v, o, m] :
658 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
659 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
660 rewriter.getI64IntegerAttr(chunkSize),
661 op.getL1HintAttr(), op.getL2HintAttr(),
662 op.getL3HintAttr(), layout);
663 }
664
665 rewriter.eraseOp(op);
666 return success();
667 }
668};
669
670struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
671 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
672 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
673 PatternRewriter &rewriter) const override {
674 Location loc = op.getLoc();
675 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
676 assert(valueTy && "the value type must be vector type!");
677
678 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
679 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
680 return failure();
681
682 Type elemTy = valueTy.getElementType();
683 ArrayRef<int64_t> shape = valueTy.getShape();
684 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
685
686 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
687
688 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
690 for (SmallVector<int64_t> offsets :
691 StaticTileOffsetRange(shape, *targetShape)) {
692 auto adds = xegpu::addElementwise(
693 rewriter, loc, mixedOffsets,
694 getAsIndexOpFoldResult(op.getContext(), offsets));
695 offsetsList.push_back(adds);
696 }
697
698 SmallVector<Value> newOps;
699 layout = layout.dropInstData();
700 for (SmallVector<OpFoldResult> offsets : offsetsList) {
701 auto newOp = xegpu::LoadMatrixOp::create(
702 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
703 newOps.push_back(newOp);
704 }
705 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
706 rewriter.replaceOp(op, castOp);
707 return success();
708 }
709};
710
711struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
712 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
713 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
714 PatternRewriter &rewriter) const override {
715 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
716 if (!targetShape)
717 return failure();
718
719 Location loc = op.getLoc();
720 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
721 assert(valueTy && "the value type must be vector type!");
722 ArrayRef<int64_t> shape = valueTy.getShape();
723 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
724
725 SmallVector<Type> convertedValTypes =
726 getUnrolledTypes(valueTy, *targetShape);
727 SmallVector<Value> convertedValues =
728 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
729
730 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
732 for (SmallVector<int64_t> offsets :
733 StaticTileOffsetRange(shape, *targetShape)) {
734 auto adds = xegpu::addElementwise(
735 rewriter, loc, mixedOffsets,
736 getAsIndexOpFoldResult(op.getContext(), offsets));
737 offsetsList.push_back(adds);
738 }
739
740 for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
741 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
742 layout.dropInstData());
743
744 rewriter.eraseOp(op);
745 return success();
746 }
747};
748
749/// UnrollConvertLayoutOp pattern for unrolling xegpu::ConvertLayoutOp
750/// operations. It first check whether the convert layout op has valid layouts
751/// after inst_data stripped. If it does, it will unroll the vector into
752/// multiple smaller vectors according to the target shape, and create multiple
753/// ConvertLayoutOp with the unrolled vectors and the stripped layouts.
754struct UnrollConvertLayoutOp : public UnrollPattern<xegpu::ConvertLayoutOp> {
755 using UnrollPattern<xegpu::ConvertLayoutOp>::UnrollPattern;
756 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
757 PatternRewriter &rewriter) const override {
758 Location loc = op.getLoc();
759 Type valType = op.getType();
760
761 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
762 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
763 if (!inputLayout || !targetLayout)
764 return rewriter.notifyMatchFailure(op, "missing layout attributes.");
765
766 if (valType.isIntOrFloat()) {
767 rewriter.replaceOp(op, op.getSource());
768 assert(!inputLayout.dropInstData() && !targetLayout.dropInstData() &&
769 "unexpected layout attributes for scalar type");
770 return success();
771 }
772
773 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
774 targetLayout.getEffectiveInstDataAsInt().empty())
775 return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
776
777 inputLayout = inputLayout.dropInstData();
778 targetLayout = targetLayout.dropInstData();
779
780 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
781 assert(valueTy && "the value type must be vector type!");
782
783 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
784 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
785 return failure();
786
787 Value newSource = op.getSource();
788 SmallVector<Value> newOps;
789 if (inputLayout && targetLayout) {
790 SmallVector<Type> convertedValTypes =
791 getUnrolledTypes(valueTy, *targetShape);
792 SmallVector<Value> convertedValues =
793 pack(op.getOperand(), convertedValTypes, *targetShape, loc, rewriter);
794 for (auto [v, t] : llvm::zip(convertedValues, convertedValTypes)) {
795 auto newOp = xegpu::ConvertLayoutOp::create(rewriter, loc, t, v,
796 inputLayout, targetLayout);
797 newOps.push_back(newOp);
798 }
799 newSource = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
800 }
801
802 rewriter.replaceOp(op, newSource);
803 return success();
804 }
805};
806
807} // namespace
808
811 patterns.add<UnrollCreateNdOp, UnrollPrefetchNdOp, UnrollLoadNdOp,
812 UnrollStoreNdOp, UnrollDpasOp, UnrollDpasMxOp,
813 UnrollLoadMatrixOp, UnrollStoreMatrixOp, UnrollLoadGatherOp,
814 UnrollStoreScatterOp, UnrollConvertLayoutOp>(
815 patterns.getContext(), options);
816}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Definition Transforms.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.