MLIR 23.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/DebugLog.h"
24
25namespace mlir {
26namespace xegpu {
27#define GEN_PASS_DEF_XEGPUUNROLL
28#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29} // namespace xegpu
30} // namespace mlir
31
32#define DEBUG_TYPE "xegpu-unroll"
33
34using namespace mlir;
35
36namespace {
37
38template <typename SourceOp>
39struct UnrollPattern : public OpRewritePattern<SourceOp> {
40 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
41 PatternBenefit benefit = 1)
42 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
43
44protected:
45 /// Return the target shape for the given `op`. Return std::nullopt if the
46 /// op shouldn't be or cannot be unrolled.
47 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
48 LDBG() << "Get unroll shape for: " << *op;
49
50 if (options.filterConstraint && failed(options.filterConstraint(op))) {
51 LDBG() << "--no filter constraint -> BAIL";
52 return std::nullopt;
53 }
54
55 assert(options.nativeShape &&
56 "expects the native shape for native shape call back function.");
57 auto nativeShape = options.nativeShape(op);
58 return nativeShape;
59 }
60
61 SmallVector<Type> getUnrolledTypes(ShapedType type,
62 ArrayRef<int64_t> tileShape,
63 bool returnSingleType = false) const {
64 return options.getUnrolledTypes(type, tileShape, returnSingleType);
65 }
66
67 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
68 /// values and unrealized_conversion_cast for TensorDescType values.
69 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
70 Location loc, PatternRewriter &rewriter) const {
71 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
72 auto shape = vecTy.getShape();
73 return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
74 }
75
76 if (isa<xegpu::TensorDescType>(destTy)) {
77 auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
78 rewriter.getUnitAttr());
79 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
80 rewriter.getDenseI64ArrayAttr(blockSize));
81 auto castOp = UnrealizedConversionCastOp::create(
82 rewriter, loc, destTy, srcs,
83 ArrayRef<NamedAttribute>({attr, blkAttr}));
84 return castOp.getResult(0);
85 }
86
87 llvm_unreachable("Unexpected destTy.");
88 return Value();
89 }
90
91 /// Emulate the the pack behavior using extract_strided_slice for VectorType
92 /// values and unrealized_conversion_cast for TensorDescType values.
93 SmallVector<Value> pack(Value src, TypeRange destTypes,
94 ArrayRef<int64_t> blockSize, Location loc,
95 PatternRewriter &rewriter) const {
96 if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
97 return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
98 blockSize);
99 }
100
101 if (isa<xegpu::TensorDescType>(src.getType())) {
102 auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
103 rewriter.getUnitAttr());
104 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
105 rewriter.getDenseI64ArrayAttr(blockSize));
106 auto castOp = UnrealizedConversionCastOp::create(
107 rewriter, loc, destTypes, src,
108 ArrayRef<NamedAttribute>({attr, blkAttr}));
109 return castOp.getResults();
110 }
111
112 llvm_unreachable("Unexpected src type.");
113 return SmallVector<Value>();
114 }
115
116 /// Helper to pack operands for DPAS-like operations with early return if
117 /// no unrolling is needed.
118 SmallVector<Value> packOperandForDpas(Value operand,
119 ArrayRef<int64_t> blockSize,
120 Location loc,
121 PatternRewriter &rewriter) const {
122 auto vecType = cast<VectorType>(operand.getType());
123 std::optional<SmallVector<int64_t>> grids =
124 computeShapeRatio(vecType.getShape(), blockSize);
125 assert(grids && "Expecting grids to be computed.");
126 auto numNewOps = computeProduct(*grids);
127 if (numNewOps == 1)
128 return SmallVector<Value>({operand});
129 VectorType newVecTy =
130 vecType.cloneWith(blockSize, vecType.getElementType());
131 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
132 return pack(operand, convertedTypes, blockSize, loc, rewriter);
133 }
134
135private:
136 const char *const packAttrName = "__xegpu_blocking_pack__";
137 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
138 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
139
141};
142
143// Generic helper function for unrolling operations with offsets.
144//
145// Iterates over tile offsets within the tensor descriptor shape and calls
146// the provided createOp function for each computed offset. This is used by
147// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
148// have explicit offsets that need to be adjusted for each unrolled tile.
149SmallVector<Value> computeUnrolledOffsets(
150 SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
151 ArrayRef<int64_t> targetShape,
152 const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
153 Location loc, PatternRewriter &rewriter) {
154 int64_t rank = tdescTy.getRank();
155 ArrayRef<int64_t> shape = tdescTy.getShape();
156
157 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
158 std::optional<int64_t> maybeInt = getConstantIntValue(a);
159 if (maybeInt) {
160 return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
161 } else {
162 auto aV = llvm::cast<Value>(a);
163 auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
164 return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
165 }
166 };
167
168 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
169 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
170 auto validIdxes =
171 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
172
173 SmallVector<Value> newOps;
174 for (SmallVector<int64_t> offsets :
175 StaticTileOffsetRange(shape, targetShape)) {
176
177 for (auto [idx, oldOff, offset] :
178 llvm::zip(validIdxes, oldOffsets, offsets))
179 mixedOffsets[idx] = addi(oldOff, offset);
180
181 auto newOp = createOp(mixedOffsets);
182 newOps.push_back(newOp);
183 }
184 return newOps;
185}
186
187struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
188 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
189 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
190 PatternRewriter &rewriter) const override {
191 Location loc = op.getLoc();
192 xegpu::TensorDescType tdescTy = op.getType();
193
194 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
195 if (!targetShape)
196 return failure();
197
198 SmallVector<Value> newOps;
199
200 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
201 auto newOp =
202 xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, op.getSource(),
203 op.getMixedSizes(), op.getMixedStrides());
204 newOps.push_back(newOp);
205 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
206 rewriter.replaceOp(op, castOp);
207
208 return success();
209 }
210};
211
212struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
213 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
214 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
215 PatternRewriter &rewriter) const override {
216 Location loc = op.getLoc();
217 xegpu::TensorDescType tdescTy = op.getTensorDescType();
218
219 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
220 if (!targetShape)
221 return failure();
222
223 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
224 if (layout)
225 layout = layout.dropInstData();
226
227 SmallVector<Type> convertedTdescTypes =
228 getUnrolledTypes(tdescTy, *targetShape, /*returnSingleType*/ true);
229
230 SmallVector<Value> convertedTdesc = pack(
231 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
232
233 auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
234 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
235 op.getL1HintAttr(), op.getL2HintAttr(),
236 op.getL3HintAttr(), layout);
237 // return dummy Value to satisfy function's signature
238 return nullptr;
239 };
240
241 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
242 createPrefetch, loc, rewriter);
243
244 rewriter.eraseOp(op);
245 return success();
246 }
247};
248
249struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
250 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
251 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
252 PatternRewriter &rewriter) const override {
253
254 Location loc = op.getLoc();
255 VectorType valueTy = op.getType();
256 xegpu::TensorDescType tdescTy = op.getTensorDescType();
257
258 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
259 if (!targetShape)
260 return failure();
261
262 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
263 if (layout)
264 layout = layout.dropInstData();
265
266 Type elemTy = tdescTy.getElementType();
267 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
268
269 SmallVector<Type> convertedTdescTypes =
270 getUnrolledTypes(tdescTy, *targetShape, /*returnSingleType*/ true);
271
272 SmallVector<Value> convertedTdescs = pack(
273 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
274 SmallVector<Value> newOps;
275
276 auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
277 return xegpu::LoadNdOp::create(
278 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
279 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
280 op.getL2HintAttr(), op.getL3HintAttr(), layout);
281 };
282 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
283 createLoad, loc, rewriter);
284
285 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
286
287 rewriter.replaceOp(op, castOp);
288 return success();
289 }
290};
291
292struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
293 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
294 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
295 PatternRewriter &rewriter) const override {
296 Location loc = op.getLoc();
297 VectorType valueTy = op.getValueType();
298 xegpu::TensorDescType tdescTy = op.getTensorDescType();
299
300 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
301 if (!targetShape)
302 return failure();
303
304 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
305 if (layout)
306 layout = layout.dropInstData();
307
308 SmallVector<Type> convertedValTypes =
309 getUnrolledTypes(valueTy, *targetShape);
310 SmallVector<Type> convertedTdescTypes =
311 getUnrolledTypes(tdescTy, *targetShape, /*returnSingleType*/ true);
312
313 SmallVector<Value> convertedTdescs = pack(
314 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
315
316 SmallVector<Value> convertedValues =
317 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
318
319 size_t valueIndex = 0;
320 auto createStore = [&](SmallVector<OpFoldResult> offsets) {
321 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
322 convertedTdescs[0], offsets, op.getL1HintAttr(),
323 op.getL2HintAttr(), op.getL3HintAttr(), layout);
324 // return dummy Value to satisfy function's signature
325 return nullptr;
326 };
327
328 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
329 createStore, loc, rewriter);
330
331 rewriter.eraseOp(op);
332 return success();
333 }
334};
335
336struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
337 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
338 LogicalResult matchAndRewrite(xegpu::DpasOp op,
339 PatternRewriter &rewriter) const override {
340 Location loc = op.getLoc();
341
342 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
343 if (!targetShape || targetShape->size() != 3)
344 return failure();
345 auto M = (*targetShape)[0];
346 auto K = (*targetShape)[1];
347 auto N = (*targetShape)[2];
348
349 int64_t aBlockSize[2] = {M, K};
350 int64_t bBlockSize[2] = {K, N};
351 int64_t cBlockSize[2] = {M, N};
352
353 auto a = op.getLhs();
354 auto b = op.getRhs();
355 auto c = op.getAcc();
356
357 SmallVector<Value> aVals = packOperandForDpas(a, aBlockSize, loc, rewriter);
358 SmallVector<Value> bVals = packOperandForDpas(b, bBlockSize, loc, rewriter);
359 SmallVector<Value> cVals;
360 if (c)
361 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
362
363 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
364 : SmallVector<ValueRange>({aVals, bVals});
365 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
366 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
367 return failure();
368
369 VectorType resultTy = op.getResult().getType();
370 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
371
372 auto aShape = a.getType().getShape();
373 auto bShape = b.getType().getShape();
374 int64_t mIters = aShape[0] / M;
375 int64_t kIters = aShape[1] / K;
376 int64_t nIters = bShape[1] / N;
377
378 SmallVector<Value> newOps;
379 for (int64_t i = 0; i < mIters; ++i) {
380 for (int64_t j = 0; j < nIters; ++j) {
381 Value tmpC;
382 if (c)
383 tmpC = cVals[i * nIters + j];
384
385 for (int64_t k = 0; k < kIters; ++k) {
386 Value aVec = aVals[i * kIters + k];
387 Value bVec = bVals[k * nIters + j];
388 SmallVector<Value> operands({aVec, bVec});
389 if (tmpC)
390 operands.push_back(tmpC);
391
392 tmpC =
393 xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
394 xegpu::dropInstDataOnAttrs(op->getAttrs()));
395 }
396 newOps.push_back(tmpC);
397 }
398 }
399 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
400 rewriter.replaceOp(op, castOp);
401 return success();
402 }
403};
404
405struct UnrollDpasMxOp : public UnrollPattern<xegpu::DpasMxOp> {
406 using UnrollPattern<xegpu::DpasMxOp>::UnrollPattern;
407 LogicalResult matchAndRewrite(xegpu::DpasMxOp op,
408 PatternRewriter &rewriter) const override {
409 Location loc = op.getLoc();
410
411 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
412 if (!targetShape || targetShape->size() != 4)
413 return failure();
414 auto M = (*targetShape)[0];
415 auto K = (*targetShape)[1];
416 auto N = (*targetShape)[2];
417 auto S = (*targetShape)[3];
418
419 int64_t aBlockSize[2] = {M, K};
420 int64_t bBlockSize[2] = {K, N};
421 int64_t cBlockSize[2] = {M, N};
422 int64_t aScaleBlockSize[2] = {M, S};
423 int64_t bScaleBlockSize[2] = {S, N};
424
425 auto a = op.getA();
426 auto b = op.getB();
427 auto c = op.getAcc();
428 auto ascale = dyn_cast<TypedValue<VectorType>>(op.getScaleA());
429 auto bscale = dyn_cast<TypedValue<VectorType>>(op.getScaleB());
430
431 SmallVector<Value> aVals = packOperandForDpas(a, aBlockSize, loc, rewriter);
432 SmallVector<Value> bVals = packOperandForDpas(b, bBlockSize, loc, rewriter);
433 SmallVector<Value> cVals;
434 if (c)
435 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
436 SmallVector<Value> aScaleVals;
437 if (ascale)
438 aScaleVals = packOperandForDpas(ascale, aScaleBlockSize, loc, rewriter);
439 SmallVector<Value> bScaleVals;
440 if (bscale)
441 bScaleVals = packOperandForDpas(bscale, bScaleBlockSize, loc, rewriter);
442
443 VectorType resultTy = op.getResult().getType();
444 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
445
446 auto aShape = a.getType().getShape();
447 auto bShape = b.getType().getShape();
448 int64_t mIters = aShape[0] / M;
449 int64_t kIters = aShape[1] / K;
450 int64_t nIters = bShape[1] / N;
451
452 SmallVector<Value> newOps;
453 xegpu::DpasMxOp newDpasMxOp;
454 for (int64_t i = 0; i < mIters; ++i) {
455 for (int64_t j = 0; j < nIters; ++j) {
456 Value tmpC;
457 if (c)
458 tmpC = cVals[i * nIters + j];
459
460 for (int64_t k = 0; k < kIters; ++k) {
461 Value aVec = aVals[i * kIters + k];
462 Value bVec = bVals[k * nIters + j];
463 SmallVector<Value> operands({aVec, bVec});
464 if (tmpC)
465 operands.push_back(tmpC);
466 if (ascale)
467 operands.push_back(aScaleVals[i * kIters + k]);
468 if (bscale)
469 operands.push_back(bScaleVals[k * nIters + j]);
470
471 newDpasMxOp = xegpu::DpasMxOp::create(
472 rewriter, loc, vecTy, operands,
473 xegpu::dropInstDataOnAttrs(op->getAttrs()));
474 tmpC = newDpasMxOp.getResult();
475 }
476 newOps.push_back(newDpasMxOp);
477 }
478 }
479 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
480 rewriter.replaceOp(op, castOp);
481 return success();
482 }
483};
484
485/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
486/// load).
487/// It unrolls the offsets and mask operands accordingly, and creates multiple
488/// LoadGatherOp with the unrolled operands.
489struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
490 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
491 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
492 PatternRewriter &rewriter) const override {
493 Location loc = op.getLoc();
494 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
495 Value offsets = op.getOffsets();
496 Value mask = op.getMask();
497
498 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
499 if (!targetShape)
500 return failure();
501
502 SmallVector<int64_t> targetMaskShape(*targetShape);
503 int64_t chunkSize = 1;
504 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
505 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
506 chunkSize = intAttr.getInt();
507 }
508
509 // Unroll mask and offsets with correct shape
510 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
511 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
512 Type elemTy = valueTy.getElementType();
513 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
514
515 SmallVector<Type> convertedMaskTypes;
516 SmallVector<Value> convertedMasks;
517 SmallVector<Type> convertedOffsetTypes;
518 SmallVector<Value> convertedOffsets;
519
520 if (chunkSize > 1) {
521 // For chunked loads, mask and offsets have one less dimension
522 targetMaskShape.pop_back();
523 int64_t blockedChunkSize = targetShape->back();
524 int64_t numNewChunks = chunkSize / blockedChunkSize;
525 chunkSize = blockedChunkSize;
526
527 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
528 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
529
530 SmallVector<Value> convertedMasksBase =
531 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
532 SmallVector<Value> convertedOffsetsBase =
533 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
534
535 for (auto maskVal : convertedMasksBase)
536 convertedMasks.append(numNewChunks, maskVal);
537
538 for (auto [baseOffset, offsetType] :
539 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
540 for (int64_t i = 0; i < numNewChunks; ++i) {
541 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
542 i * blockedChunkSize);
543 Value incVec =
544 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
545 Value offsetVal =
546 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
547 convertedOffsets.push_back(offsetVal);
548 }
549 }
550 } else {
551 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
552 convertedMasks =
553 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
554
555 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
556 convertedOffsets =
557 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
558 }
559
560 auto layout = op.getLayoutAttr();
561 if (layout)
562 layout = layout.dropInstData();
563
564 SmallVector<Value> newOps;
565 for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
566 auto newOp = xegpu::LoadGatherOp::create(
567 rewriter, loc, newValueTy, op.getSource(), o, m,
568 rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
569 op.getL2HintAttr(), op.getL3HintAttr(), layout);
570 newOps.push_back(newOp);
571 }
572
573 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
574 rewriter.replaceOp(op, castOp);
575 return success();
576 }
577};
578
579/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
580/// store).
581/// It unrolls the offsets and mask operands accordingly, and creates multiple
582/// StoreScatterOp with the unrolled operands.
583struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
584 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
585 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
586 PatternRewriter &rewriter) const override {
587 Location loc = op.getLoc();
588 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
589 Value offsets = op.getOffsets();
590 Value mask = op.getMask();
591
592 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
593 if (!targetShape)
594 return failure();
595
596 int64_t chunkSize = 1;
597 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
598 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
599 chunkSize = intAttr.getInt();
600 }
601
602 SmallVector<int64_t> targetMaskShape(*targetShape);
603 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
604 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
605
606 SmallVector<Type> convertedMaskTypes;
607 SmallVector<Value> convertedMasks;
608 SmallVector<Type> convertedOffsetTypes;
609 SmallVector<Value> convertedOffsets;
610
611 if (chunkSize > 1) {
612 targetMaskShape.pop_back();
613 int64_t blockedChunkSize = targetShape->back();
614 int64_t numNewChunks = chunkSize / blockedChunkSize;
615 chunkSize = blockedChunkSize;
616
617 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
618 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
619
620 SmallVector<Value> convertedMasksBase =
621 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
622 SmallVector<Value> convertedOffsetsBase =
623 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
624
625 for (auto maskVal : convertedMasksBase)
626 convertedMasks.append(numNewChunks, maskVal);
627
628 for (auto [baseOffset, offsetType] :
629 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
630 for (int64_t i = 0; i < numNewChunks; ++i) {
631 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
632 i * blockedChunkSize);
633 Value incVec =
634 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
635 Value offsetVal =
636 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
637 convertedOffsets.push_back(offsetVal);
638 }
639 }
640 } else {
641 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
642 convertedMasks =
643 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
644
645 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
646 convertedOffsets =
647 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
648 }
649
650 SmallVector<Type> convertedValTypes =
651 getUnrolledTypes(valueTy, *targetShape);
652 SmallVector<Value> convertedValues =
653 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
654
655 auto layout = op.getLayoutAttr();
656 if (layout)
657 layout = layout.dropInstData();
658
659 for (auto [v, o, m] :
660 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
661 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
662 rewriter.getI64IntegerAttr(chunkSize),
663 op.getL1HintAttr(), op.getL2HintAttr(),
664 op.getL3HintAttr(), layout);
665 }
666
667 rewriter.eraseOp(op);
668 return success();
669 }
670};
671
672struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
673 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
674 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
675 PatternRewriter &rewriter) const override {
676 Location loc = op.getLoc();
677 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
678 assert(valueTy && "the value type must be vector type!");
679
680 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
681 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
682 return failure();
683
684 Type elemTy = valueTy.getElementType();
685 ArrayRef<int64_t> shape = valueTy.getShape();
686 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
687
688 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
689
690 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
692 for (SmallVector<int64_t> offsets :
693 StaticTileOffsetRange(shape, *targetShape)) {
694 auto adds = xegpu::addElementwise(
695 rewriter, loc, mixedOffsets,
696 getAsIndexOpFoldResult(op.getContext(), offsets));
697 offsetsList.push_back(adds);
698 }
699
700 SmallVector<Value> newOps;
701 layout = layout.dropInstData();
702 for (SmallVector<OpFoldResult> offsets : offsetsList) {
703 auto newOp = xegpu::LoadMatrixOp::create(
704 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
705 newOps.push_back(newOp);
706 }
707 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
708 rewriter.replaceOp(op, castOp);
709 return success();
710 }
711};
712
713struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
714 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
715 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
716 PatternRewriter &rewriter) const override {
717 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
718 if (!targetShape)
719 return failure();
720
721 Location loc = op.getLoc();
722 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
723 assert(valueTy && "the value type must be vector type!");
724 ArrayRef<int64_t> shape = valueTy.getShape();
725 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
726
727 SmallVector<Type> convertedValTypes =
728 getUnrolledTypes(valueTy, *targetShape);
729 SmallVector<Value> convertedValues =
730 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
731
732 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
734 for (SmallVector<int64_t> offsets :
735 StaticTileOffsetRange(shape, *targetShape)) {
736 auto adds = xegpu::addElementwise(
737 rewriter, loc, mixedOffsets,
738 getAsIndexOpFoldResult(op.getContext(), offsets));
739 offsetsList.push_back(adds);
740 }
741
742 for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
743 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
744 layout.dropInstData());
745
746 rewriter.eraseOp(op);
747 return success();
748 }
749};
750
751/// UnrollConvertLayoutOp pattern for unrolling xegpu::ConvertLayoutOp
752/// operations. It first check whether the convert layout op has valid layouts
753/// after inst_data stripped. If it does, it will unroll the vector into
754/// multiple smaller vectors according to the target shape, and create multiple
755/// ConvertLayoutOp with the unrolled vectors and the stripped layouts.
756struct UnrollConvertLayoutOp : public UnrollPattern<xegpu::ConvertLayoutOp> {
757 using UnrollPattern<xegpu::ConvertLayoutOp>::UnrollPattern;
758 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
759 PatternRewriter &rewriter) const override {
760 Location loc = op.getLoc();
761 Type valType = op.getType();
762
763 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
764 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
765 if (!inputLayout || !targetLayout)
766 return rewriter.notifyMatchFailure(op, "missing layout attributes.");
767
768 if (valType.isIntOrFloat()) {
769 rewriter.replaceOp(op, op.getSource());
770 assert(!inputLayout.dropInstData() && !targetLayout.dropInstData() &&
771 "unexpected layout attributes for scalar type");
772 return success();
773 }
774
775 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
776 targetLayout.getEffectiveInstDataAsInt().empty())
777 return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
778
779 inputLayout = inputLayout.dropInstData();
780 targetLayout = targetLayout.dropInstData();
781
782 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
783 assert(valueTy && "the value type must be vector type!");
784
785 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
786 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
787 return failure();
788
789 Value newSource = op.getSource();
790 SmallVector<Value> newOps;
791 if (inputLayout && targetLayout) {
792 SmallVector<Type> convertedValTypes =
793 getUnrolledTypes(valueTy, *targetShape);
794 SmallVector<Value> convertedValues =
795 pack(op.getOperand(), convertedValTypes, *targetShape, loc, rewriter);
796 for (auto [v, t] : llvm::zip(convertedValues, convertedValTypes)) {
797 auto newOp = xegpu::ConvertLayoutOp::create(rewriter, loc, t, v,
798 inputLayout, targetLayout);
799 newOps.push_back(newOp);
800 }
801 newSource = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
802 }
803
804 rewriter.replaceOp(op, newSource);
805 return success();
806 }
807};
808
809/// Unrolls vector.multi_reduction by sequentially reducing tiles with
810/// elementwise arith operations first, then a single multi_reduction
811/// per non-reduced tile position. This avoids generating long chains of
812/// multi_reduction ops (as the upstream pattern does) and is more efficient.
813///
814/// Example:
815/// vector.multi_reduction <32x64xf16> to <32xf16> (tile_shape=32, 32)
816/// -- Upstream pattern generates:
817/// %tmp1 = vector.multi_reduction %tile0, %zero_acc <32x32xf16> to <32xf16>
818/// %res = vector.multi_reduction %tmp1, %tile1 <32x32xf16> to <32xf16>
819/// -- This pattern generates:
820/// %tmp1 = arith.reduction %tile0, %tile1 <32x32xf16> -> <32x32xf16> //
821/// elementwise %res = vector.multi_reduction %tmp1, %zero_acc <32x32xf16> to
822/// <32xf16>
823struct UnrollMultiReductionOp
824 : public UnrollPattern<vector::MultiDimReductionOp> {
825 UnrollMultiReductionOp(MLIRContext *context,
827 PatternBenefit benefit = 2)
828 : UnrollPattern<vector::MultiDimReductionOp>(context, options, benefit) {}
829
830 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
831 PatternRewriter &rewriter) const override {
832 VectorType srcTy = reductionOp.getSourceVectorType();
833 ArrayRef<int64_t> srcShape = srcTy.getShape();
834 int64_t srcRank = srcTy.getRank();
835
836 Location loc = reductionOp.getLoc();
837 Value source = reductionOp.getSource();
838 Value acc = reductionOp.getAcc();
839 vector::CombiningKind kind = reductionOp.getKind();
840
841 // Result must be a vector (not scalar).
842 auto resultType = dyn_cast<VectorType>(reductionOp.getDestType());
843 if (!resultType)
844 return failure();
845
846 std::optional<SmallVector<int64_t>> targetShapeOpt =
847 getTargetShape(reductionOp);
848 if (!targetShapeOpt ||
849 static_cast<int64_t>(targetShapeOpt->size()) != srcRank)
850 return failure();
851
852 SmallVector<int64_t> targetShape = *targetShapeOpt;
853
854 // Check divisibility for all dimensions.
855 for (int64_t i = 0; i < srcRank; ++i) {
856 if (srcShape[i] % targetShape[i] != 0)
857 return failure();
858 }
859
860 SmallVector<bool> reductionMask = reductionOp.getReductionMask();
861 // Identify reduced and kept dimensions from the reduction mask.
862 SmallVector<int64_t> reducedDims, keptDims;
863 for (int64_t i = 0; i < srcRank; ++i) {
864 if (reductionMask[i])
865 reducedDims.push_back(i);
866 else
867 keptDims.push_back(i);
868 }
869
870 // Compute the number of tiles along each reduced dimension and their
871 // product
872 SmallVector<int64_t> numReducedTilesPerDim;
873 for (int64_t d : reducedDims)
874 numReducedTilesPerDim.push_back(srcShape[d] / targetShape[d]);
875
876 // Build kept shapes for iterating over non-reduced dimensions.
877 SmallVector<int64_t> keptShape, keptTileShape;
878 for (int64_t d : keptDims) {
879 keptShape.push_back(srcShape[d]);
880 keptTileShape.push_back(targetShape[d]);
881 }
882
883 // Initialize the result vector for assembly.
884 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
885 rewriter.getZeroAttr(resultType));
886
887 // Iterate over all tile positions in the kept dimensions.
888 // Ex: [off0, off1, _ _ off4]
889 // blanks are offsets for the reduced dims, they will be
890 // generated in the inner loop below
891 for (SmallVector<int64_t> keptOffsets :
892 StaticTileOffsetRange(keptShape, keptTileShape)) {
893
894 // Reconstruct full-rank base offsets with 0 for reduced dims.
895 // Ex: [off0, off1, 0, 0, off4]
896 SmallVector<int64_t> baseOffsets(srcRank, 0);
897 for (auto [idx, dim] : llvm::enumerate(keptDims))
898 baseOffsets[dim] = keptOffsets[idx];
899
900 // Generate the full tile indices for the reduced dimensions.
901 // Ex: if reduceDimShapes = [32, 64] and
902 // reducedDimTargetShapes = [16, 16], then reducedTileCoords:
903 // [(0, 0), (0, 1), (0, 2), (0, 3),
904 // (1, 0), (1, 1), (1, 2), (1, 3)]
905 auto reducedTileCoords = StaticTileOffsetRange(
906 numReducedTilesPerDim, SmallVector<int64_t>(reducedDims.size(), 1));
907
908 // Step 1: Fill "blanks" in the offsets for the reduced dimensions
909 // using 'reducedTileCoords' and extract according tiles.
910 // Ex: tiles = [source[off0, off1, off2_red, off3_red, off4], ...]
911 SmallVector<Value> tiles;
912 for (SmallVector<int64_t> reducedTileIdx : reducedTileCoords) {
913 SmallVector<int64_t> offsets(baseOffsets);
914 for (auto [idx, dim] : llvm::enumerate(reducedDims))
915 offsets[dim] = reducedTileIdx[idx] * targetShape[dim];
916 SmallVector<int64_t> strides(srcRank, 1);
917 Value tile = vector::ExtractStridedSliceOp::create(
918 rewriter, loc, source, offsets, targetShape, strides);
919 tiles.push_back(tile);
920 }
921
922 // Step 2: Sequentially reduce tiles using elementwise arith operations.
923 Value reduced = tiles[0];
924 for (size_t i = 1; i < tiles.size(); ++i)
925 reduced =
926 vector::makeArithReduction(rewriter, loc, kind, reduced, tiles[i]);
927
928 // Step 3: Perform a single multi_reduction with the accumulator slice.
929 SmallVector<int64_t> accStrides(keptTileShape.size(), 1);
930 Value accSlice = vector::ExtractStridedSliceOp::create(
931 rewriter, loc, acc, keptOffsets, keptTileShape, accStrides);
932
933 auto newReduction = vector::MultiDimReductionOp::create(
934 rewriter, loc, reduced, accSlice, reductionMask, kind);
935
936 // Step 4: Insert the reduced result into the output vector.
937 SmallVector<int64_t> dstStrides(keptTileShape.size(), 1);
938 result = vector::InsertStridedSliceOp::create(
939 rewriter, loc, newReduction, result, keptOffsets, dstStrides);
940 }
941
942 rewriter.replaceOp(reductionOp, result);
943 return success();
944 }
945};
946
947} // namespace
948
951 patterns
952 .add<UnrollCreateNdOp, UnrollPrefetchNdOp, UnrollLoadNdOp,
953 UnrollStoreNdOp, UnrollDpasOp, UnrollDpasMxOp, UnrollLoadMatrixOp,
954 UnrollStoreMatrixOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
955 UnrollConvertLayoutOp, UnrollMultiReductionOp>(patterns.getContext(),
956 options);
957}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Definition Transforms.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.