MLIR 23.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/DebugLog.h"
25
26namespace mlir {
27namespace xegpu {
28#define GEN_PASS_DEF_XEGPUUNROLL
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30} // namespace xegpu
31} // namespace mlir
32
33#define DEBUG_TYPE "xegpu-unroll"
34
35using namespace mlir;
36
37namespace {
38
39// Forward declaration for use inside UnrollPattern below.
41unrollByTile(SmallVector<OpFoldResult> mixedOffsets,
42 xegpu::TensorDescType tdescTy, ArrayRef<int64_t> targetShape,
43 const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
44 Location loc, PatternRewriter &rewriter);
45
46template <typename SourceOp>
47struct UnrollPattern : public OpRewritePattern<SourceOp> {
48 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
49 PatternBenefit benefit = 1)
50 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
51
52protected:
53 /// Return the target shape for the given `op`. Return std::nullopt if the
54 /// op shouldn't be or cannot be unrolled.
55 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
56 LDBG() << "Get unroll shape for: " << *op;
57
58 if (options.filterConstraint && failed(options.filterConstraint(op))) {
59 LDBG() << "--no filter constraint -> BAIL";
60 return std::nullopt;
61 }
62
63 assert(options.nativeShape &&
64 "expects the native shape for native shape call back function.");
65 auto nativeShape = options.nativeShape(op);
66 return nativeShape;
67 }
68
69 SmallVector<Type> getUnrolledTypes(ShapedType type,
70 ArrayRef<int64_t> tileShape) const {
71 return options.getUnrolledTypes(type, tileShape);
72 }
73
74 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
75 /// values and unrealized_conversion_cast for TensorDescType values.
76 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
77 Location loc, PatternRewriter &rewriter) const {
78 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
79 auto shape = vecTy.getShape();
80 return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
81 }
82
83 if (isa<xegpu::TensorDescType>(destTy)) {
84 auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
85 rewriter.getUnitAttr());
86 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
87 rewriter.getDenseI64ArrayAttr(blockSize));
88 auto castOp = UnrealizedConversionCastOp::create(
89 rewriter, loc, destTy, srcs,
90 ArrayRef<NamedAttribute>({attr, blkAttr}));
91 return castOp.getResult(0);
92 }
93
94 llvm_unreachable("Unexpected destTy.");
95 return Value();
96 }
97
98 /// Emulate the the pack behavior using extract_strided_slice for VectorType
99 /// values and unrealized_conversion_cast for TensorDescType values.
100 SmallVector<Value> pack(Value src, TypeRange destTypes,
101 ArrayRef<int64_t> blockSize, Location loc,
102 PatternRewriter &rewriter) const {
103 if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
104 return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
105 blockSize);
106 }
107
108 if (isa<xegpu::TensorDescType>(src.getType())) {
109 auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
110 rewriter.getUnitAttr());
111 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
112 rewriter.getDenseI64ArrayAttr(blockSize));
113 auto castOp = UnrealizedConversionCastOp::create(
114 rewriter, loc, destTypes, src,
115 ArrayRef<NamedAttribute>({attr, blkAttr}));
116 return castOp.getResults();
117 }
118
119 llvm_unreachable("Unexpected src type.");
120 return SmallVector<Value>();
121 }
122
123 /// Helper for the rank > 2 case shared by Load/Store/PrefetchNd unroll
124 /// patterns. The matching CreateNdDesc unroll pattern produces one tdesc
125 /// per batch tile (the batch offset is baked into its base pointer via
126 /// memref.subview), so here we only need to iterate the inner 2D offsets
127 /// for each batch tdesc.
128 ///
129 /// Packs `srcTdesc` into one tdesc per batch tile, then iterates the inner
130 /// 2D tile offsets for each batch tdesc and invokes `createOp` with
131 /// (batchTdesc, fullOffsets), where fullOffsets is `batchRank` zeros
132 /// followed by the inner offsets. Returns the values produced by createOp,
133 /// flattened across (batch, inner) iteration order.
134 SmallVector<Value> unrollNdBatch(
135 Value srcTdesc, xegpu::TensorDescType tdescTy,
136 ArrayRef<int64_t> targetShape, ArrayRef<OpFoldResult> mixedOffsets,
137 int64_t batchRank,
139 Location loc, PatternRewriter &rewriter) const {
140 ArrayRef<int64_t> shape = tdescTy.getShape();
141 SmallVector<int64_t> innerShape(shape.begin() + batchRank, shape.end());
142 SmallVector<int64_t> innerTarget(targetShape.begin() + batchRank,
143 targetShape.end());
144
145 SmallVector<Type> batchTdescTypes = getUnrolledTypes(tdescTy, targetShape);
146 SmallVector<Value> batchTdescs =
147 pack(srcTdesc, batchTdescTypes, targetShape, loc, rewriter);
148
149 auto innerTdescTy = xegpu::TensorDescType::get(
150 tdescTy.getContext(), innerShape, tdescTy.getElementType(),
151 tdescTy.getEncoding(), /*layout=*/nullptr);
152
153 SmallVector<OpFoldResult> innerOffsets(mixedOffsets.begin() + batchRank,
154 mixedOffsets.end());
155
156 SmallVector<Value> newOps;
157 for (Value batchTdesc : batchTdescs) {
158 auto wrappedCreate = [&](SmallVector<OpFoldResult> offsets) -> Value {
159 SmallVector<OpFoldResult> fullOffsets(batchRank,
160 rewriter.getIndexAttr(0));
161 fullOffsets.append(offsets.begin(), offsets.end());
162 return createOp(batchTdesc, fullOffsets);
163 };
164 auto perBatch = unrollByTile(innerOffsets, innerTdescTy, innerTarget,
165 wrappedCreate, loc, rewriter);
166 newOps.append(perBatch.begin(), perBatch.end());
167 }
168 return newOps;
169 }
170
171 /// Helper to pack operands for DPAS-like operations with early return if
172 /// no unrolling is needed.
173 SmallVector<Value> packOperandForDpas(Value operand,
174 ArrayRef<int64_t> blockSize,
175 Location loc,
176 PatternRewriter &rewriter) const {
177 auto vecType = cast<VectorType>(operand.getType());
178 std::optional<SmallVector<int64_t>> grids =
179 computeShapeRatio(vecType.getShape(), blockSize);
180 assert(grids && "Expecting grids to be computed.");
181 auto numNewOps = computeProduct(*grids);
182 if (numNewOps == 1)
183 return SmallVector<Value>({operand});
184 VectorType newVecTy =
185 vecType.cloneWith(blockSize, vecType.getElementType());
186 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
187 return pack(operand, convertedTypes, blockSize, loc, rewriter);
188 }
189
190private:
191 const char *const packAttrName = "__xegpu_blocking_pack__";
192 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
193 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
194
196};
197
198// Walks tile offsets within the tensor descriptor shape and emits one op per
199// tile by calling `createOp` with the per-tile offsets. Used by LoadNd,
200// StoreNd, CreateNdDesc, and PrefetchNd unrollers, which all need to adjust
201// their explicit offsets for each unrolled tile.
203unrollByTile(SmallVector<OpFoldResult> mixedOffsets,
204 xegpu::TensorDescType tdescTy, ArrayRef<int64_t> targetShape,
205 const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
206 Location loc, PatternRewriter &rewriter) {
207 int64_t rank = tdescTy.getRank();
208 ArrayRef<int64_t> shape = tdescTy.getShape();
209
210 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
211 std::optional<int64_t> maybeInt = getConstantIntValue(a);
212 if (maybeInt) {
213 return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
214 } else {
215 auto aV = llvm::cast<Value>(a);
216 auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
217 return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
218 }
219 };
220
221 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
222 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
223 auto validIdxes =
224 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
225
226 SmallVector<Value> newOps;
227 for (SmallVector<int64_t> offsets :
228 StaticTileOffsetRange(shape, targetShape)) {
229
230 for (auto [idx, oldOff, offset] :
231 llvm::zip(validIdxes, oldOffsets, offsets))
232 mixedOffsets[idx] = addi(oldOff, offset);
233
234 auto newOp = createOp(mixedOffsets);
235 newOps.push_back(newOp);
236 }
237 return newOps;
238}
239
240struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
241 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
242 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
243 PatternRewriter &rewriter) const override {
244 Location loc = op.getLoc();
245 xegpu::TensorDescType tdescTy = op.getType();
246
247 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
248 if (!targetShape)
249 return failure();
250
251 int64_t rank = tdescTy.getRank();
252 int64_t batchRank = rank - 2;
253
254 // For rank <= 2 or non-memref source: existing single-tdesc behavior.
255 if (batchRank <= 0 || !isa<MemRefType>(op.getSourceType())) {
256 SmallVector<Value> newOps;
257 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
258 auto newOp = xegpu::CreateNdDescOp::create(
259 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
260 op.getMixedStrides());
261 newOps.push_back(newOp);
262 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
263 rewriter.replaceOp(op, castOp);
264 return success();
265 }
266
267 // For rank > 2 with memref source: create one tdesc per batch tile via
268 // memref.subview. Each subview slices the batch dimensions, so the
269 // resulting tdesc has the batch offset baked into its base pointer.
270 // The inner dimensions remain full-size for reuse across multiple
271 // load/store operations with different offsets.
272 ArrayRef<int64_t> shape = tdescTy.getShape();
273 SmallVector<int64_t> batchBlockSize(targetShape->begin(),
274 targetShape->begin() + batchRank);
275 batchBlockSize.append(shape.begin() + batchRank, shape.end());
276
277 auto newTdescTy =
278 cast<xegpu::TensorDescType>(getUnrolledTypes(tdescTy, *targetShape)[0]);
279
280 SmallVector<Value> newOps;
281 for (SmallVector<int64_t> batchOffsets :
282 StaticTileOffsetRange(shape, batchBlockSize)) {
283 // Build memref.subview operands. The subview slices contiguously along
284 // each batch dimension (no gaps), so the subview's element stride is 1
285 // for every dim. This is unrelated to the source memref's strides, which
286 // describe the layout of the original buffer and are propagated by the
287 // SubViewOp builder onto the resulting memref type.
288 SmallVector<OpFoldResult> subviewOffsets;
289 for (int64_t off : batchOffsets)
290 subviewOffsets.push_back(rewriter.getIndexAttr(off));
291
292 SmallVector<OpFoldResult> subviewSizes;
293 for (int64_t d : batchBlockSize)
294 subviewSizes.push_back(rewriter.getIndexAttr(d));
295
296 SmallVector<OpFoldResult> subviewStrides(rank, rewriter.getIndexAttr(1));
297
298 auto subview = memref::SubViewOp::create(rewriter, loc, op.getSource(),
299 subviewOffsets, subviewSizes,
300 subviewStrides);
301
302 auto newOp = xegpu::CreateNdDescOp::create(
303 rewriter, loc, newTdescTy,
304 cast<TypedValue<MemRefType>>(subview.getResult()));
305 newOps.push_back(newOp);
306 }
307
308 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
309 rewriter.replaceOp(op, castOp);
310 return success();
311 }
312};
313
314struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
315 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
316 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
317 PatternRewriter &rewriter) const override {
318 Location loc = op.getLoc();
319 xegpu::TensorDescType tdescTy = op.getTensorDescType();
320
321 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
322 if (!targetShape)
323 return failure();
324
325 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
326 if (layout)
327 layout = layout.dropInstData();
328
329 int64_t rank = tdescTy.getRank();
330 int64_t batchRank = rank - 2;
331
332 if (batchRank <= 0) {
333 SmallVector<Type> convertedTdescTypes =
334 getUnrolledTypes(tdescTy, *targetShape);
335 SmallVector<Value> convertedTdesc = pack(
336 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
337
338 auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
339 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
340 op.getL1HintAttr(), op.getL2HintAttr(),
341 op.getL3HintAttr(), layout);
342 return nullptr;
343 };
344 unrollByTile(op.getMixedOffsets(), tdescTy, *targetShape, createPrefetch,
345 loc, rewriter);
346 } else {
347 // Rank > 2: batch tdescs cover [batchTarget..., innerShape...].
348 // Each batch tdesc is reused for multiple inner prefetches via offsets.
349 auto createPrefetch =
350 [&](Value tdesc, SmallVector<OpFoldResult> fullOffsets) -> Value {
351 xegpu::PrefetchNdOp::create(rewriter, loc, tdesc, fullOffsets,
352 op.getL1HintAttr(), op.getL2HintAttr(),
353 op.getL3HintAttr(), layout);
354 return nullptr;
355 };
356 this->unrollNdBatch(op.getTensorDesc(), tdescTy, *targetShape,
357 op.getMixedOffsets(), batchRank, createPrefetch, loc,
358 rewriter);
359 }
360
361 rewriter.eraseOp(op);
362 return success();
363 }
364};
365
366struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
367 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
368 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
369 PatternRewriter &rewriter) const override {
370
371 Location loc = op.getLoc();
372 VectorType valueTy = op.getType();
373 xegpu::TensorDescType tdescTy = op.getTensorDescType();
374
375 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
376 if (!targetShape)
377 return failure();
378
379 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
380 if (layout)
381 layout = layout.dropInstData();
382
383 Type elemTy = tdescTy.getElementType();
384 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
385
386 int64_t rank = tdescTy.getRank();
387 int64_t batchRank = rank - 2;
388 SmallVector<Value> newOps;
389
390 if (batchRank <= 0) {
391 // Rank <= 2: original behavior with single tdesc.
392 SmallVector<Type> convertedTdescTypes =
393 getUnrolledTypes(tdescTy, *targetShape);
394 SmallVector<Value> convertedTdescs = pack(
395 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
396
397 auto createLoad = [&](SmallVector<OpFoldResult> offsets) -> Value {
398 return xegpu::LoadNdOp::create(
399 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
400 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
401 op.getL2HintAttr(), op.getL3HintAttr(), layout);
402 };
403 newOps = unrollByTile(op.getMixedOffsets(), tdescTy, *targetShape,
404 createLoad, loc, rewriter);
405 } else {
406 // Rank > 2: batch tdescs cover [batchTarget..., innerShape...].
407 // Each batch tdesc is reused for multiple inner loads via offsets.
408 auto createLoad = [&](Value tdesc,
409 SmallVector<OpFoldResult> fullOffsets) -> Value {
410 return xegpu::LoadNdOp::create(
411 rewriter, loc, newValueTy, tdesc, fullOffsets, op.getPackedAttr(),
412 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
413 op.getL3HintAttr(), layout);
414 };
415 newOps = this->unrollNdBatch(op.getTensorDesc(), tdescTy, *targetShape,
416 op.getMixedOffsets(), batchRank, createLoad,
417 loc, rewriter);
418 }
419
420 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
421 rewriter.replaceOp(op, castOp);
422 return success();
423 }
424};
425
426struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
427 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
428 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
429 PatternRewriter &rewriter) const override {
430 Location loc = op.getLoc();
431 VectorType valueTy = op.getValueType();
432 xegpu::TensorDescType tdescTy = op.getTensorDescType();
433
434 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
435 if (!targetShape)
436 return failure();
437
438 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
439 if (layout)
440 layout = layout.dropInstData();
441
442 SmallVector<Type> convertedValTypes =
443 getUnrolledTypes(valueTy, *targetShape);
444
445 SmallVector<Value> convertedValues =
446 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
447
448 int64_t rank = tdescTy.getRank();
449 int64_t batchRank = rank - 2;
450 size_t valueIndex = 0;
451
452 if (batchRank <= 0) {
453 SmallVector<Type> convertedTdescTypes =
454 getUnrolledTypes(tdescTy, *targetShape);
455 SmallVector<Value> convertedTdescs = pack(
456 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
457
458 auto createStore = [&](SmallVector<OpFoldResult> offsets) {
459 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
460 convertedTdescs[0], offsets,
461 op.getL1HintAttr(), op.getL2HintAttr(),
462 op.getL3HintAttr(), layout);
463 return (Value) nullptr;
464 };
465 unrollByTile(op.getMixedOffsets(), tdescTy, *targetShape, createStore,
466 loc, rewriter);
467 } else {
468 // Rank > 2: batch tdescs cover [batchTarget..., innerShape...].
469 // Each batch tdesc is reused for multiple inner stores via offsets.
470 // valueIndex advances across (batch, inner) iterations in the same
471 // order unrollNdBatch invokes the callback, so it stays in sync with
472 // the pre-packed convertedValues.
473 auto createStore = [&](Value tdesc,
474 SmallVector<OpFoldResult> fullOffsets) -> Value {
475 xegpu::StoreNdOp::create(
476 rewriter, loc, convertedValues[valueIndex++], tdesc, fullOffsets,
477 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout);
478 return nullptr;
479 };
480 this->unrollNdBatch(op.getTensorDesc(), tdescTy, *targetShape,
481 op.getMixedOffsets(), batchRank, createStore, loc,
482 rewriter);
483 }
484
485 rewriter.eraseOp(op);
486 return success();
487 }
488};
489
490struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
491 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
492 LogicalResult matchAndRewrite(xegpu::DpasOp op,
493 PatternRewriter &rewriter) const override {
494 Location loc = op.getLoc();
495
496 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
497 if (!targetShape || targetShape->size() < 3)
498 return failure();
499
500 // targetShape is [batch..., M, K, N]
501 int64_t tsRank = targetShape->size();
502 auto M = (*targetShape)[tsRank - 3];
503 auto K = (*targetShape)[tsRank - 2];
504 auto N = (*targetShape)[tsRank - 1];
505 ArrayRef<int64_t> batchDims(targetShape->data(), tsRank - 3);
506
507 // Build block sizes including batch dimensions.
508 SmallVector<int64_t> aBlockSize(batchDims);
509 aBlockSize.push_back(M);
510 aBlockSize.push_back(K);
511 SmallVector<int64_t> bBlockSize(batchDims);
512 bBlockSize.push_back(K);
513 bBlockSize.push_back(N);
514 SmallVector<int64_t> cBlockSize(batchDims);
515 cBlockSize.push_back(M);
516 cBlockSize.push_back(N);
517
518 auto a = op.getLhs();
519 auto b = op.getRhs();
520 auto c = op.getAcc();
521
522 SmallVector<Value> aVals = packOperandForDpas(a, aBlockSize, loc, rewriter);
523 SmallVector<Value> bVals = packOperandForDpas(b, bBlockSize, loc, rewriter);
524 SmallVector<Value> cVals;
525 if (c)
526 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
527
528 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
529 : SmallVector<ValueRange>({aVals, bVals});
530 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
531 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
532 return failure();
533
534 VectorType resultTy = op.getResult().getType();
535 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
536
537 auto aShape = a.getType().getShape();
538 auto bShape = b.getType().getShape();
539
540 // Compute iteration counts. Batch dims only iterate over M and N (not
541 // K-reduction), so compute batch iterations from the C block size.
542 int64_t batchRank = batchDims.size();
543 int64_t mIters = aShape[batchRank] / M;
544 int64_t kIters = aShape[batchRank + 1] / K;
545 int64_t nIters = bShape[batchRank + 1] / N;
546
547 // Compute batch iterations (product of batch dim ratios).
548 int64_t batchIters = 1;
549 for (int64_t d = 0; d < batchRank; ++d)
550 batchIters *= aShape[d] / batchDims[d];
551
552 SmallVector<Value> newOps;
553 for (int64_t batch = 0; batch < batchIters; ++batch) {
554 for (int64_t i = 0; i < mIters; ++i) {
555 for (int64_t j = 0; j < nIters; ++j) {
556 Value tmpC;
557 if (c)
558 tmpC = cVals[batch * (mIters * nIters) + i * nIters + j];
559
560 for (int64_t k = 0; k < kIters; ++k) {
561 Value aVec = aVals[batch * (mIters * kIters) + i * kIters + k];
562 Value bVec = bVals[batch * (kIters * nIters) + k * nIters + j];
563 SmallVector<Value> operands({aVec, bVec});
564 if (tmpC)
565 operands.push_back(tmpC);
566
567 tmpC = xegpu::DpasOp::create(
568 rewriter, loc, vecTy, operands,
569 xegpu::dropInstDataOnAttrs(op->getAttrs()));
570 }
571 newOps.push_back(tmpC);
572 }
573 }
574 }
575 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
576 rewriter.replaceOp(op, castOp);
577 return success();
578 }
579};
580
581struct UnrollDpasMxOp : public UnrollPattern<xegpu::DpasMxOp> {
582 using UnrollPattern<xegpu::DpasMxOp>::UnrollPattern;
583 LogicalResult matchAndRewrite(xegpu::DpasMxOp op,
584 PatternRewriter &rewriter) const override {
585 Location loc = op.getLoc();
586
587 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
588 if (!targetShape || targetShape->size() < 4)
589 return failure();
590
591 // targetShape is [batch..., M, K, N, S]
592 int64_t tsRank = targetShape->size();
593 auto M = (*targetShape)[tsRank - 4];
594 auto K = (*targetShape)[tsRank - 3];
595 auto N = (*targetShape)[tsRank - 2];
596 auto S = (*targetShape)[tsRank - 1];
597 ArrayRef<int64_t> batchDims(targetShape->data(), tsRank - 4);
598
599 SmallVector<int64_t> aBlockSize(batchDims);
600 aBlockSize.push_back(M);
601 aBlockSize.push_back(K);
602 SmallVector<int64_t> bBlockSize(batchDims);
603 bBlockSize.push_back(K);
604 bBlockSize.push_back(N);
605 SmallVector<int64_t> cBlockSize(batchDims);
606 cBlockSize.push_back(M);
607 cBlockSize.push_back(N);
608 SmallVector<int64_t> aScaleBlockSize(batchDims);
609 aScaleBlockSize.push_back(M);
610 aScaleBlockSize.push_back(S);
611 SmallVector<int64_t> bScaleBlockSize(batchDims);
612 bScaleBlockSize.push_back(S);
613 bScaleBlockSize.push_back(N);
614
615 auto a = op.getA();
616 auto b = op.getB();
617 auto c = op.getAcc();
618 auto ascale = dyn_cast<TypedValue<VectorType>>(op.getScaleA());
619 auto bscale = dyn_cast<TypedValue<VectorType>>(op.getScaleB());
620
621 SmallVector<Value> aVals = packOperandForDpas(a, aBlockSize, loc, rewriter);
622 SmallVector<Value> bVals = packOperandForDpas(b, bBlockSize, loc, rewriter);
623 SmallVector<Value> cVals;
624 if (c)
625 cVals = packOperandForDpas(c, cBlockSize, loc, rewriter);
626 SmallVector<Value> aScaleVals;
627 if (ascale)
628 aScaleVals = packOperandForDpas(ascale, aScaleBlockSize, loc, rewriter);
629 SmallVector<Value> bScaleVals;
630 if (bscale)
631 bScaleVals = packOperandForDpas(bscale, bScaleBlockSize, loc, rewriter);
632
633 VectorType resultTy = op.getResult().getType();
634 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
635
636 auto aShape = a.getType().getShape();
637 auto bShape = b.getType().getShape();
638 int64_t batchRank = batchDims.size();
639 int64_t mIters = aShape[batchRank] / M;
640 int64_t kIters = aShape[batchRank + 1] / K;
641 int64_t nIters = bShape[batchRank + 1] / N;
642
643 int64_t batchIters = 1;
644 for (int64_t d = 0; d < batchRank; ++d)
645 batchIters *= aShape[d] / batchDims[d];
646
647 SmallVector<Value> newOps;
648 xegpu::DpasMxOp newDpasMxOp;
649 for (int64_t batch = 0; batch < batchIters; ++batch) {
650 for (int64_t i = 0; i < mIters; ++i) {
651 for (int64_t j = 0; j < nIters; ++j) {
652 Value tmpC;
653 if (c)
654 tmpC = cVals[batch * (mIters * nIters) + i * nIters + j];
655
656 for (int64_t k = 0; k < kIters; ++k) {
657 Value aVec = aVals[batch * (mIters * kIters) + i * kIters + k];
658 Value bVec = bVals[batch * (kIters * nIters) + k * nIters + j];
659 SmallVector<Value> operands({aVec, bVec});
660 if (tmpC)
661 operands.push_back(tmpC);
662 if (ascale)
663 operands.push_back(
664 aScaleVals[batch * (mIters * kIters) + i * kIters + k]);
665 if (bscale)
666 operands.push_back(
667 bScaleVals[batch * (kIters * nIters) + k * nIters + j]);
668
669 newDpasMxOp = xegpu::DpasMxOp::create(
670 rewriter, loc, vecTy, operands,
671 xegpu::dropInstDataOnAttrs(op->getAttrs()));
672 tmpC = newDpasMxOp.getResult();
673 }
674 newOps.push_back(newDpasMxOp);
675 }
676 }
677 }
678 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
679 rewriter.replaceOp(op, castOp);
680 return success();
681 }
682};
683
684/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
685/// load).
686/// It unrolls the offsets and mask operands accordingly, and creates multiple
687/// LoadGatherOp with the unrolled operands.
688struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
689 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
690 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
691 PatternRewriter &rewriter) const override {
692 Location loc = op.getLoc();
693 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
694 Value offsets = op.getOffsets();
695 Value mask = op.getMask();
696
697 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
698 if (!targetShape)
699 return failure();
700
701 SmallVector<int64_t> targetMaskShape(*targetShape);
702 int64_t chunkSize = 1;
703 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
704 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
705 chunkSize = intAttr.getInt();
706 }
707
708 // Unroll mask and offsets with correct shape
709 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
710 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
711 Type elemTy = valueTy.getElementType();
712 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
713
714 SmallVector<Type> convertedMaskTypes;
715 SmallVector<Value> convertedMasks;
716 SmallVector<Type> convertedOffsetTypes;
717 SmallVector<Value> convertedOffsets;
718
719 if (chunkSize > 1) {
720 // For chunked loads, mask and offsets have one less dimension
721 targetMaskShape.pop_back();
722 int64_t blockedChunkSize = targetShape->back();
723 int64_t numNewChunks = chunkSize / blockedChunkSize;
724 chunkSize = blockedChunkSize;
725
726 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
727 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
728
729 SmallVector<Value> convertedMasksBase =
730 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
731 SmallVector<Value> convertedOffsetsBase =
732 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
733
734 for (auto maskVal : convertedMasksBase)
735 convertedMasks.append(numNewChunks, maskVal);
736
737 for (auto [baseOffset, offsetType] :
738 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
739 for (int64_t i = 0; i < numNewChunks; ++i) {
740 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
741 i * blockedChunkSize);
742 Value incVec =
743 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
744 Value offsetVal =
745 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
746 convertedOffsets.push_back(offsetVal);
747 }
748 }
749 } else {
750 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
751 convertedMasks =
752 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
753
754 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
755 convertedOffsets =
756 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
757 }
758
759 auto layout = op.getLayoutAttr();
760 if (layout)
761 layout = layout.dropInstData();
762
763 SmallVector<Value> newOps;
764 for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
765 auto newOp = xegpu::LoadGatherOp::create(
766 rewriter, loc, newValueTy, op.getSource(), o, m,
767 rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
768 op.getL2HintAttr(), op.getL3HintAttr(), layout);
769 newOps.push_back(newOp);
770 }
771
772 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
773 rewriter.replaceOp(op, castOp);
774 return success();
775 }
776};
777
778/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
779/// store).
780/// It unrolls the offsets and mask operands accordingly, and creates multiple
781/// StoreScatterOp with the unrolled operands.
782struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
783 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
784 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
785 PatternRewriter &rewriter) const override {
786 Location loc = op.getLoc();
787 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
788 Value offsets = op.getOffsets();
789 Value mask = op.getMask();
790
791 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
792 if (!targetShape)
793 return failure();
794
795 int64_t chunkSize = 1;
796 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
797 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
798 chunkSize = intAttr.getInt();
799 }
800
801 SmallVector<int64_t> targetMaskShape(*targetShape);
802 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
803 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
804
805 SmallVector<Type> convertedMaskTypes;
806 SmallVector<Value> convertedMasks;
807 SmallVector<Type> convertedOffsetTypes;
808 SmallVector<Value> convertedOffsets;
809
810 if (chunkSize > 1) {
811 targetMaskShape.pop_back();
812 int64_t blockedChunkSize = targetShape->back();
813 int64_t numNewChunks = chunkSize / blockedChunkSize;
814 chunkSize = blockedChunkSize;
815
816 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
817 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
818
819 SmallVector<Value> convertedMasksBase =
820 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
821 SmallVector<Value> convertedOffsetsBase =
822 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
823
824 for (auto maskVal : convertedMasksBase)
825 convertedMasks.append(numNewChunks, maskVal);
826
827 for (auto [baseOffset, offsetType] :
828 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
829 for (int64_t i = 0; i < numNewChunks; ++i) {
830 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
831 i * blockedChunkSize);
832 Value incVec =
833 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
834 Value offsetVal =
835 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
836 convertedOffsets.push_back(offsetVal);
837 }
838 }
839 } else {
840 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
841 convertedMasks =
842 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
843
844 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
845 convertedOffsets =
846 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
847 }
848
849 SmallVector<Type> convertedValTypes =
850 getUnrolledTypes(valueTy, *targetShape);
851 SmallVector<Value> convertedValues =
852 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
853
854 auto layout = op.getLayoutAttr();
855 if (layout)
856 layout = layout.dropInstData();
857
858 for (auto [v, o, m] :
859 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
860 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
861 rewriter.getI64IntegerAttr(chunkSize),
862 op.getL1HintAttr(), op.getL2HintAttr(),
863 op.getL3HintAttr(), layout);
864 }
865
866 rewriter.eraseOp(op);
867 return success();
868 }
869};
870
871struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
872 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
873 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
874 PatternRewriter &rewriter) const override {
875 Location loc = op.getLoc();
876 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
877 assert(valueTy && "the value type must be vector type!");
878
879 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
880 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
881 return failure();
882
883 Type elemTy = valueTy.getElementType();
884 ArrayRef<int64_t> shape = valueTy.getShape();
885 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
886
887 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
888
889 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
891 for (SmallVector<int64_t> offsets :
892 StaticTileOffsetRange(shape, *targetShape)) {
893 auto adds = xegpu::addElementwise(
894 rewriter, loc, mixedOffsets,
895 getAsIndexOpFoldResult(op.getContext(), offsets));
896 offsetsList.push_back(adds);
897 }
898
899 SmallVector<Value> newOps;
900 if (layout)
901 layout = layout.dropInstData();
902 for (SmallVector<OpFoldResult> offsets : offsetsList) {
903 auto newOp = xegpu::LoadMatrixOp::create(
904 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
905 newOps.push_back(newOp);
906 }
907 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
908 rewriter.replaceOp(op, castOp);
909 return success();
910 }
911};
912
913struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
914 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
915 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
916 PatternRewriter &rewriter) const override {
917 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
918 if (!targetShape)
919 return failure();
920
921 Location loc = op.getLoc();
922 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
923 assert(valueTy && "the value type must be vector type!");
924 ArrayRef<int64_t> shape = valueTy.getShape();
925 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
926 if (layout)
927 layout = layout.dropInstData();
928
929 SmallVector<Type> convertedValTypes =
930 getUnrolledTypes(valueTy, *targetShape);
931 SmallVector<Value> convertedValues =
932 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
933
934 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
936 for (SmallVector<int64_t> offsets :
937 StaticTileOffsetRange(shape, *targetShape)) {
938 auto adds = xegpu::addElementwise(
939 rewriter, loc, mixedOffsets,
940 getAsIndexOpFoldResult(op.getContext(), offsets));
941 offsetsList.push_back(adds);
942 }
943
944 for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
945 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
946 layout);
947
948 rewriter.eraseOp(op);
949 return success();
950 }
951};
952
953/// UnrollConvertLayoutOp pattern for unrolling xegpu::ConvertLayoutOp
954/// operations. It first check whether the convert layout op has valid layouts
955/// after inst_data stripped. If it does, it will unroll the vector into
956/// multiple smaller vectors according to the target shape, and create multiple
957/// ConvertLayoutOp with the unrolled vectors and the stripped layouts.
958struct UnrollConvertLayoutOp : public UnrollPattern<xegpu::ConvertLayoutOp> {
959 using UnrollPattern<xegpu::ConvertLayoutOp>::UnrollPattern;
960 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
961 PatternRewriter &rewriter) const override {
962 Location loc = op.getLoc();
963 Type valType = op.getType();
964
965 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
966 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
967 if (!inputLayout || !targetLayout)
968 return rewriter.notifyMatchFailure(op, "missing layout attributes.");
969
970 if (valType.isIntOrFloat()) {
971 rewriter.replaceOp(op, op.getSource());
972 assert(!inputLayout.dropInstData() && !targetLayout.dropInstData() &&
973 "unexpected layout attributes for scalar type");
974 return success();
975 }
976
977 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
978 targetLayout.getEffectiveInstDataAsInt().empty())
979 return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
980
981 inputLayout = inputLayout.dropInstData();
982 targetLayout = targetLayout.dropInstData();
983
984 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
985 assert(valueTy && "the value type must be vector type!");
986
987 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
988 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
989 return failure();
990
991 Value newSource = op.getSource();
992 SmallVector<Value> newOps;
993 if (inputLayout && targetLayout) {
994 SmallVector<Type> convertedValTypes =
995 getUnrolledTypes(valueTy, *targetShape);
996 SmallVector<Value> convertedValues =
997 pack(op.getOperand(), convertedValTypes, *targetShape, loc, rewriter);
998 for (auto [v, t] : llvm::zip(convertedValues, convertedValTypes)) {
999 auto newOp = xegpu::ConvertLayoutOp::create(rewriter, loc, t, v,
1000 inputLayout, targetLayout);
1001 newOps.push_back(newOp);
1002 }
1003 newSource = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
1004 }
1005
1006 rewriter.replaceOp(op, newSource);
1007 return success();
1008 }
1009};
1010
1011/// Unrolls vector.multi_reduction by sequentially reducing tiles with
1012/// elementwise arith operations first, then a single multi_reduction
1013/// per non-reduced tile position. This avoids generating long chains of
1014/// multi_reduction ops (as the upstream pattern does) and is more efficient.
1015///
1016/// Example:
1017/// vector.multi_reduction <32x64xf16> to <32xf16> (tile_shape=32, 32)
1018/// -- Upstream pattern generates:
1019/// %tmp1 = vector.multi_reduction %tile0, %zero_acc <32x32xf16> to <32xf16>
1020/// %res = vector.multi_reduction %tmp1, %tile1 <32x32xf16> to <32xf16>
1021/// -- This pattern generates:
1022/// %tmp1 = arith.reduction %tile0, %tile1 <32x32xf16> -> <32x32xf16> //
1023/// elementwise %res = vector.multi_reduction %tmp1, %zero_acc <32x32xf16> to
1024/// <32xf16>
1025struct UnrollMultiReductionOp
1026 : public UnrollPattern<vector::MultiDimReductionOp> {
1027 UnrollMultiReductionOp(MLIRContext *context,
1029 PatternBenefit benefit = 2)
1030 : UnrollPattern<vector::MultiDimReductionOp>(context, options, benefit) {}
1031
1032 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
1033 PatternRewriter &rewriter) const override {
1034 VectorType srcTy = reductionOp.getSourceVectorType();
1035 ArrayRef<int64_t> srcShape = srcTy.getShape();
1036 int64_t srcRank = srcTy.getRank();
1037
1038 Location loc = reductionOp.getLoc();
1039 Value source = reductionOp.getSource();
1040 Value acc = reductionOp.getAcc();
1041 vector::CombiningKind kind = reductionOp.getKind();
1042
1043 // Result must be a vector (not scalar).
1044 auto resultType = dyn_cast<VectorType>(reductionOp.getDestType());
1045 if (!resultType)
1046 return failure();
1047
1048 std::optional<SmallVector<int64_t>> targetShapeOpt =
1049 getTargetShape(reductionOp);
1050 if (!targetShapeOpt ||
1051 static_cast<int64_t>(targetShapeOpt->size()) != srcRank)
1052 return failure();
1053
1054 SmallVector<int64_t> targetShape = *targetShapeOpt;
1055
1056 // Check divisibility for all dimensions.
1057 for (int64_t i = 0; i < srcRank; ++i) {
1058 if (srcShape[i] % targetShape[i] != 0)
1059 return failure();
1060 }
1061
1062 SmallVector<bool> reductionMask = reductionOp.getReductionMask();
1063 // Identify reduced and kept dimensions from the reduction mask.
1064 SmallVector<int64_t> reducedDims, keptDims;
1065 for (int64_t i = 0; i < srcRank; ++i) {
1066 if (reductionMask[i])
1067 reducedDims.push_back(i);
1068 else
1069 keptDims.push_back(i);
1070 }
1071
1072 // Compute the number of tiles along each reduced dimension and their
1073 // product
1074 SmallVector<int64_t> numReducedTilesPerDim;
1075 for (int64_t d : reducedDims)
1076 numReducedTilesPerDim.push_back(srcShape[d] / targetShape[d]);
1077
1078 // Build kept shapes for iterating over non-reduced dimensions.
1079 SmallVector<int64_t> keptShape, keptTileShape;
1080 for (int64_t d : keptDims) {
1081 keptShape.push_back(srcShape[d]);
1082 keptTileShape.push_back(targetShape[d]);
1083 }
1084
1085 // Initialize the result vector for assembly.
1086 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1087 rewriter.getZeroAttr(resultType));
1088
1089 // Iterate over all tile positions in the kept dimensions.
1090 // Ex: [off0, off1, _ _ off4]
1091 // blanks are offsets for the reduced dims, they will be
1092 // generated in the inner loop below
1093 for (SmallVector<int64_t> keptOffsets :
1094 StaticTileOffsetRange(keptShape, keptTileShape)) {
1095
1096 // Reconstruct full-rank base offsets with 0 for reduced dims.
1097 // Ex: [off0, off1, 0, 0, off4]
1098 SmallVector<int64_t> baseOffsets(srcRank, 0);
1099 for (auto [idx, dim] : llvm::enumerate(keptDims))
1100 baseOffsets[dim] = keptOffsets[idx];
1101
1102 // Generate the full tile indices for the reduced dimensions.
1103 // Ex: if reduceDimShapes = [32, 64] and
1104 // reducedDimTargetShapes = [16, 16], then reducedTileCoords:
1105 // [(0, 0), (0, 1), (0, 2), (0, 3),
1106 // (1, 0), (1, 1), (1, 2), (1, 3)]
1107 auto reducedTileCoords = StaticTileOffsetRange(
1108 numReducedTilesPerDim, SmallVector<int64_t>(reducedDims.size(), 1));
1109
1110 // Step 1: Fill "blanks" in the offsets for the reduced dimensions
1111 // using 'reducedTileCoords' and extract according tiles.
1112 // Ex: tiles = [source[off0, off1, off2_red, off3_red, off4], ...]
1113 SmallVector<Value> tiles;
1114 for (SmallVector<int64_t> reducedTileIdx : reducedTileCoords) {
1115 SmallVector<int64_t> offsets(baseOffsets);
1116 for (auto [idx, dim] : llvm::enumerate(reducedDims))
1117 offsets[dim] = reducedTileIdx[idx] * targetShape[dim];
1118 SmallVector<int64_t> strides(srcRank, 1);
1119 Value tile = vector::ExtractStridedSliceOp::create(
1120 rewriter, loc, source, offsets, targetShape, strides);
1121 tiles.push_back(tile);
1122 }
1123
1124 // Step 2: Sequentially reduce tiles using elementwise arith operations.
1125 Value reduced = tiles[0];
1126 for (size_t i = 1; i < tiles.size(); ++i)
1127 reduced =
1128 vector::makeArithReduction(rewriter, loc, kind, reduced, tiles[i]);
1129
1130 // Step 3: Perform a single multi_reduction with the accumulator slice.
1131 SmallVector<int64_t> accStrides(keptTileShape.size(), 1);
1132 Value accSlice = vector::ExtractStridedSliceOp::create(
1133 rewriter, loc, acc, keptOffsets, keptTileShape, accStrides);
1134
1135 auto newReduction = vector::MultiDimReductionOp::create(
1136 rewriter, loc, reduced, accSlice, reductionMask, kind);
1137
1138 // Step 4: Insert the reduced result into the output vector.
1139 SmallVector<int64_t> dstStrides(keptTileShape.size(), 1);
1140 result = vector::InsertStridedSliceOp::create(
1141 rewriter, loc, newReduction, result, keptOffsets, dstStrides);
1142 }
1143
1144 rewriter.replaceOp(reductionOp, result);
1145 return success();
1146 }
1147};
1148
1149} // namespace
1150
1153 patterns
1154 .add<UnrollCreateNdOp, UnrollPrefetchNdOp, UnrollLoadNdOp,
1155 UnrollStoreNdOp, UnrollDpasOp, UnrollDpasMxOp, UnrollLoadMatrixOp,
1156 UnrollStoreMatrixOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
1157 UnrollConvertLayoutOp, UnrollMultiReductionOp>(patterns.getContext(),
1158 options);
1159}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Definition Transforms.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.