MLIR 22.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
24#include <optional>
25
26namespace mlir {
27namespace xegpu {
28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30} // namespace xegpu
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36
37// Retrieve the RangeAttr if it is specified.
38static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39 Operation *parent = op->getParentOfType<scf::IfOp>();
40 while (parent) {
41 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->getAttr("sg_id_range")))
43 return attr;
44 parent = parent->getParentOfType<scf::IfOp>();
45 }
46 return {};
47}
48
49static std::pair<SmallVector<int64_t>, int>
50getSgShapeAndCount(ArrayRef<int64_t> shape,
51 xegpu::DistributeLayoutAttr layout) {
52 int count = 1;
54 if (layout && layout.isForWorkgroup()) {
55 SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
58 else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
59 sgShape = *maybeDerivedSgData;
60 SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
61 // Clamp distUnit to the original shape to handle cases where data is
62 // shared among subgroups, which may cause distUnit to exceed the original
63 // shape.
64 for (size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(shape[i], distUnit[i]);
66 count = computeProduct(shape) / computeProduct(distUnit);
67 }
68 return std::make_pair(sgShape, count);
69}
70
71/// Utility helper for deriving a list of offsets for each sub-TensorDescs
72/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
73/// associated distribute layout attribute, the shape, subgroup id and the
74/// original offsets of the op
75template <
76 typename OpType,
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
80static LogicalResult
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
83 Location loc = op.getLoc();
84 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
85 // not applicable to ops without offsets operands.
86 if (origOffsets.empty())
87 return failure();
88
89 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
90 xegpu::DistributeLayoutAttr layout;
91 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
92 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
93 layout = op.getLayoutAttr();
94 } else {
95 layout = op.getDescLayoutAttr();
96 }
97
98 // not applicable to ops without workgroup layout attributes
99 if (!layout || !layout.isForWorkgroup())
100 return failure();
101
102 Value sgId =
103 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
104
105 // verify and adjust the sgId if the range specifier is present
106 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
107 if (sgIdRange) {
108 int64_t startOfRange = sgIdRange.getStart().getInt();
109 int64_t endOfRange = sgIdRange.getEnd().getInt();
110 // verify the RangeAttr against the layout attribute
111 if (layout.getNumSubgroups() != endOfRange - startOfRange)
112 return rewriter.notifyMatchFailure(
113 op, "sg_layout size must match the sg_id_range");
114 // adjust the sgId if necessary
115 if (startOfRange > 0) {
116 Value startOfRangeVal =
117 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
118 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
119 }
120 }
121
122 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
123 // descriptors to be accessed, based on the layout information.
124 ArrayRef<int64_t> wgShape = op.getDataShape();
125 auto maybeDescOffsets =
126 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
127 if (failed(maybeDescOffsets))
128 return failure();
129
130 // Compute the final global offsets for each accessed sub-tensor
131 // or sub-memory descriptor.
132 for (const auto &sgOffsets : *maybeDescOffsets) {
134 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
135 offsetsList.push_back(std::move(newOffsets));
136 }
137
138 // callback(offsetsList);
139 return success();
140}
141
142/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
143/// from a workgroup descriptor. It replaces the offsets and sizes with
144/// appropriate values for the subgroup.
145/// It uses round-robin assignment to distribute the work to the subgroups.
146/// Following create_nd_desc operation:,
147/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
148/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
149/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
150/// is converted to 9 subgroup level operations based on the sg_layout &
151/// sg_data:
152/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
153/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
154/// lane_data = [1, 1]>>
155///
156/// The sg_layout and sg_data attributes are dropped after the pass as they are
157/// no longer needed.
158///
159/// 24x24 matrix distribution example:
160/// sg_layout = [4, 4], sg_data = [2, 2]
161/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
162/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
163///
164/// +------------------------+
165/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
166/// |-----+-----+-----|
167/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
168/// |-----+-----+-----|
169/// | 8x8 | 8x8 | 8x8 |
170/// +------------------------+
171///
172/// Each 8x8 tile is further subdivided among subgroups:
173/// +------------------------+
174/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
175/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
176/// | 2x2 2x2 2x2 2x2 |
177/// | 2x2 2x2 2x2 2x2 |
178/// +------------------------+
179///
180/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
181/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
182
183/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
184/// pattern and all the other ops just follow.
185/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
186/// ops in the pass.
187struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
188 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
189
190 LogicalResult
191 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter) const override {
193 SmallVector<SmallVector<OpFoldResult>> offsetsList;
194 if (failed(genOffsetsList(rewriter, op, offsetsList)))
195 return failure();
196
197 MLIRContext *ctx = op.getContext();
198 xegpu::TensorDescType tdescTy = op.getType();
199 ArrayRef<int64_t> wgShape = tdescTy.getShape();
200 Type elemTy = tdescTy.getElementType();
201 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
202 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
203 auto newTdescTy =
204 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
205 layout.dropSgLayoutAndData());
206
207 SmallVector<Value> newOps;
208 for (auto offsets : offsetsList) {
209 auto newOp = xegpu::CreateNdDescOp::create(
210 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
211 op.getMixedSizes(), op.getMixedStrides());
212
213 newOps.push_back(newOp);
214 }
215 rewriter.replaceOpWithMultiple(op, {newOps});
216
217 return success();
218 }
219};
220
221// This pattern transforms the CreateNdDescOp without offsets to create a
222// subgroup descriptor from a workgroup descriptor
223struct WgToSgCreateNdOpNoOffset
224 : public OpConversionPattern<xegpu::CreateNdDescOp> {
225 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
226
227 LogicalResult
228 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const override {
230
231 // Check no offsets are specified.
232 if (!op.getMixedOffsets().empty())
233 return failure();
234
235 Location loc = op.getLoc();
236 MLIRContext *ctx = op.getContext();
237 xegpu::TensorDescType tdescTy = op.getType();
238 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
239 if (!layout || !layout.isForWorkgroup())
240 return failure();
241
242 Type elemTy = tdescTy.getElementType();
243 ArrayRef<int64_t> wgShape = tdescTy.getShape();
244
245 SmallVector<int64_t> sgShape;
246 int count;
247 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
248 xegpu::TensorDescType newTdescTy =
249 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
250 layout.dropSgLayoutAndData());
251
252 SmallVector<Value> newCreateNdOps(count);
253 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
254 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
255 op.getSource(), op.getMixedSizes(),
256 op.getMixedStrides());
257 });
258
259 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
260 return success();
261 }
262};
263
264/// This pattern transforms the LoadNdOp to load subgroup data.
265struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
266 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
267 LogicalResult
268 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter) const override {
270 if (!op.getMixedOffsets().empty())
271 return failure();
272
273 SmallVector<Value> newLoadOps;
274 for (auto src : adaptor.getTensorDesc()) {
275 xegpu::TensorDescType tdescTy =
276 dyn_cast<xegpu::TensorDescType>(src.getType());
277 ArrayRef<int64_t> srcShape = tdescTy.getShape();
278 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
279 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
280 src, op->getAttrs());
281 newLoadOps.push_back(newLoadOp);
282 }
283 rewriter.replaceOpWithMultiple(op, {newLoadOps});
284 return mlir::success();
285 }
286};
287
288/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
289/// It creates a StoreNdOp op to store the updated values to the new subgroup
290/// src tensor descriptors.
291struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
292 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
293 LogicalResult
294 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296 if (!op.getMixedOffsets().empty())
297 return failure();
298
299 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
300 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
301 op.getL2HintAttr(), op.getL3HintAttr());
302
303 rewriter.eraseOp(op);
304 return success();
305 }
306};
307
308// This pattern transforms the LoadNdOp with explicit offsets to load
309// subgroup data.
310struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
311 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
312 LogicalResult
313 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315
316 SmallVector<SmallVector<OpFoldResult>> offsetsList;
317 if (failed(genOffsetsList(rewriter, op, offsetsList)))
318 return failure();
319
320 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
321 if (layout)
322 layout = layout.dropSgLayoutAndData();
323 SmallVector<Value> newOps;
324 for (auto [tdesc, offsets] :
325 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
326 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
327 VectorType newResTy =
328 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
329 auto newOp = xegpu::LoadNdOp::create(
330 rewriter, op.getLoc(), newResTy, tdesc, offsets,
331 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
332 op.getL2HintAttr(), op.getL3HintAttr(), layout);
333 newOps.push_back(newOp);
334 }
335 rewriter.replaceOpWithMultiple(op, {newOps});
336
337 return success();
338 }
339};
340
341// This pattern transforms the StoreNdOp with explicit offsets to store
342// subgroup data.
343struct WgToSgStoreNdOpWithOffset
344 : public OpConversionPattern<xegpu::StoreNdOp> {
345 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
346 LogicalResult
347 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter) const override {
349 SmallVector<SmallVector<OpFoldResult>> offsetsList;
350 if (failed(genOffsetsList(rewriter, op, offsetsList)))
351 return failure();
352
353 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
354 if (layout)
355 layout = layout.dropSgLayoutAndData();
356 for (auto [v, tdesc, offsets] :
357 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
358 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
359 op.getL1HintAttr(), op.getL2HintAttr(),
360 op.getL3HintAttr(), layout);
361 }
362 rewriter.eraseOp(op);
363
364 return success();
365 }
366};
367
368// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
369// subgroup data.
370struct WgToSgPrefetchNdOpWithOffset
371 : public OpConversionPattern<xegpu::PrefetchNdOp> {
372 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
373 LogicalResult
374 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
375 ConversionPatternRewriter &rewriter) const override {
376 SmallVector<SmallVector<OpFoldResult>> offsetsList;
377 if (failed(genOffsetsList(rewriter, op, offsetsList)))
378 return failure();
379
380 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
381 if (layout)
382 layout = layout.dropSgLayoutAndData();
383 for (auto [tdesc, offsets] :
384 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
385 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
386 op.getL1HintAttr(), op.getL2HintAttr(),
387 op.getL3HintAttr(), layout);
388 }
389 rewriter.eraseOp(op);
390
391 return success();
392 }
393};
394
395/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
396/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
397/// offsets of the new subgroup src tensor descriptors.
398struct WgToSgUpdateNdOffsetOp
399 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
400 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
401 LogicalResult
402 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter) const override {
404 llvm::SmallVector<Value> newUpdateTileOffsetOps;
405 for (auto tDesc : adaptor.getTensorDesc()) {
406 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
407 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
408 op.getConstOffsets());
409 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
410 }
411
412 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
413 return success();
414 }
415};
416
417/// This pattern transforms the DpasOp to work at subgroup level.
418struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
419 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
420 LogicalResult
421 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter) const override {
423 Location loc = op.getLoc();
424 VectorType resultTy = op.getResult().getType();
425 if (resultTy.getRank() != 2)
426 return failure();
427
428 auto layoutCd = op.getLayoutCdAttr();
429 auto layoutA = op.getLayoutAAttr();
430 auto layoutB = op.getLayoutBAttr();
431 if (!layoutCd || !layoutA || !layoutB)
432 return failure();
433 size_t i = 0;
434 SmallVector<Value> newDpasOps;
435 for (auto aVec : adaptor.getLhs()) {
436 for (auto bVec : adaptor.getRhs()) {
437
438 llvm::SmallVector<Value> operands({aVec, bVec});
439 Value tmpC;
440 if (op.getAcc()) {
441 tmpC = adaptor.getAcc()[i++];
442 operands.push_back(tmpC);
443 }
444
445 ArrayRef<int64_t> aVecShape =
446 llvm::cast<VectorType>(aVec.getType()).getShape();
447 ArrayRef<int64_t> bVecShape =
448 llvm::cast<VectorType>(bVec.getType()).getShape();
449 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
450 resultTy.getElementType());
451 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
452 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
453 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
454 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
455
456 newDpasOps.push_back(newDpasOp);
457 }
458 }
459 rewriter.replaceOpWithMultiple(op, {newDpasOps});
460 return success();
461 }
462};
463
464/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
465struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
466 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
467 LogicalResult
468 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override {
470
471 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
472 if ((offsetSize != 0) || op.getConstOffsetsAttr())
473 return failure();
474
475 for (auto src : adaptor.getTensorDesc())
476 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src,
477 op->getAttrs());
478 rewriter.eraseOp(op);
479 return success();
480 }
481};
482
483/// This pattern transforms vector.broadcast ops to work at subgroup level.
484struct WgToSgVectorBroadcastOp
485 : public OpConversionPattern<vector::BroadcastOp> {
486 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
487
488 LogicalResult
489 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter) const override {
491
492 VectorType resultType = op.getResult().getType();
493 ArrayRef<int64_t> wgShape = resultType.getShape();
494
495 xegpu::DistributeLayoutAttr layout =
496 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
497 if (!layout || !layout.isForWorkgroup())
498 return failure();
499
500 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
501 VectorType newResultType =
502 VectorType::get(sgShape, resultType.getElementType());
503
504 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
505 return failure();
506
507 SmallVector<Value> newBroadcastOps;
508 for (auto operand : adaptor.getOperands().front()) {
509 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
510 newResultType, operand);
511 xegpu::setTemporaryLayout(newBroadcast->getResult(0),
512 layout.dropSgLayoutAndData());
513
514 newBroadcastOps.push_back(newBroadcast.getResult());
515 }
516 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
517 return success();
518 }
520
521// This pattern transforms elementwise ops to work at subgroup level.
522struct WgToSgElementwiseOp : public ConversionPattern {
523 WgToSgElementwiseOp(MLIRContext *ctx)
524 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
526 LogicalResult
527 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
528 ConversionPatternRewriter &rewriter) const override {
529 // Only match ops with elementwise trait and single result.
531 return failure();
532
533 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
534 assert(resultType && "Expected result to be a VectorType");
535
536 ArrayRef<int64_t> wgShape = resultType.getShape();
537
538 xegpu::DistributeLayoutAttr layout =
539 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
540 if (!layout || !layout.isForWorkgroup())
541 return failure();
543 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
544
545 size_t numVariants = operands.empty() ? 0 : operands.front().size();
546
547 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
548 return operandVec.size() != numVariants;
549 }))
550 return failure();
551
552 SmallVector<Value> newResults;
553 VectorType newResultType =
554 VectorType::get(sgShape, resultType.getElementType());
556 for (size_t i = 0; i < numVariants; ++i) {
557 SmallVector<Value> opOperands;
558 for (auto &operandVec : operands)
559 opOperands.push_back(operandVec[i]);
560
561 OperationState state(op->getLoc(), op->getName());
562 state.addOperands(opOperands);
563 state.addTypes(newResultType);
564 // Copy all attributes, but update "layout_result_0" to drop
565 // sgLayout/sgData
566 for (auto attr : op->getAttrs()) {
567 if (auto layout =
568 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
569 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
570 !layout.getEffectiveInstDataAsInt().empty())
571 state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
572 } else {
573 state.addAttribute(attr.getName(), attr.getValue());
574 }
575 }
576 Operation *newOp = rewriter.create(state);
577 newResults.push_back(newOp->getResult(0));
579
580 rewriter.replaceOpWithMultiple(op, {newResults});
581 return success();
582 }
583};
584
585// clang-format off
586// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
587// If input_layout and target_layout have identical sg_layout and sg_data,
588// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
589// dropped. For example:
590// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
591// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
592// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
593// becomes:
594// #a = #xegpu.layout<inst_data = [16, 16]>
595// #b = #xegpu.layout<inst_data = [8, 16]>
596// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
597// (vector<16x16xf32> is determined by sg_data = [16, 16])
598//
599// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
600// For example:
601// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
602// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
603// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
604// is lowered to:
605// #a = #xegpu.layout<inst_data = [16, 16]>
606// #b = #xegpu.layout<inst_data = [8, 16]>
607// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
608// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
609// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
610// clang-format on
611struct WgToSgConvertLayoutOp
612 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
613 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
614 LogicalResult
615 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter) const override {
617 // TODO: currently, we only support LayoutAttr
618 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
619 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
620
621 if (!input || !target || !input.isForWorkgroup() ||
622 !target.isForWorkgroup())
623 return rewriter.notifyMatchFailure(
624 op, "Input and target layouts must have subgroup layout");
625
626 DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
627 DenseI32ArrayAttr inputSgData = input.getSgData();
628 DenseI32ArrayAttr inputOrder = input.getOrder();
629 DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
630 DenseI32ArrayAttr targetSgData = target.getSgData();
631 DenseI32ArrayAttr targetOrder = target.getOrder();
632
633 // TODO: currently we only support for optimal case, where input and
634 // output has the same sg_layout and sg_data, so SLM is not involved.
635 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
636 inputOrder != targetOrder)
637 return failure();
638
639 input = input.dropSgLayoutAndData();
640 target = target.dropSgLayoutAndData();
641
642 SmallVector<Value> newOps(adaptor.getSource());
643 if (input && target) {
644 // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
645 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
646 auto newOp = xegpu::ConvertLayoutOp::create(
647 rewriter, op.getLoc(), src.getType(), src, input, target);
648 newOps[i] = newOp;
649 }
650 }
651 rewriter.replaceOpWithMultiple(op, {newOps});
652 return success();
653 }
654};
655
656// Handles UnrealizedConversionCastOp generated during
657// SCFStructuralTypeConversions (step 1). This op may appear as either a
658// target or source materialization for Vector values, e.g.:
659// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
660// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
661// it could be either 1:N or N:1 cast. In both cases, the pattern
662// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
663// for example, the following scf::forOp
664// ```
665// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
666// %n = use(%arg1): vector<128x128xf16>
667// scf.yield %n : vector<128x128xf16>
668// }
669// ```
670// Could be converted to:
671// ```
672// %1 = unrealized_conversion_cast %0
673// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
674// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
675// -> (vector<16x16xf16>, vector<16x16xf16) {
676// %m = unrealized_conversion_cast %arg1, %arg2
677// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
678// %n = use(%m): vector<128x128xf16>
679// %b = unrealized_conversion_cast %n
680// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
681// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
682// }
683// %cast = unrealized_conversion_cast %for:2
684// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
685// ```
686// TODO: remove it when context-aware type converter is ready.
687struct UnrealizedConversionCastOpPattern
688 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
689 using OpConversionPattern<
690 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
691
692 mlir::LogicalResult
693 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
694 ConversionPatternRewriter &rewriter) const override {
695 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
696
697 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
698 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
699
700 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
701 !llvm::all_equal(ValueRange(inputs).getTypes()))
702 return failure();
703
704 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
705 // It is generated by source materialization (e.g., inits to scf forOp).
706 // The input values provided by the adaptor should already be distributed,
707 // and their types should correspond exactly to the result types of the
708 // operation.
709 if (op.getNumOperands() == 1 &&
710 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
711 rewriter.replaceOp(op, inputs);
712 return success();
713 }
714
715 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
716 // It is generated by target materialization (e.g., arguments/results
717 // of scf forOp). All input values must have the same vector type, and
718 // their shape must be evenly divisible by the output vector's shape
719 // (determined by the nature of the workgroup to subgroup distribution).
720 // TODO: it is not safe to do such forward, since such N:1 cast could be
721 // from others.
722 if (op.getNumResults() == 1 &&
723 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
724 rewriter.replaceOpWithMultiple(op, {inputs});
725 return success();
726 }
727
728 return mlir::failure();
729 }
730};
731
732// This pattern distributes arith.constant op into subgroup-level constants
733struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
734 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
735
736 LogicalResult
737 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
738 ConversionPatternRewriter &rewriter) const override {
739 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
740 auto vecType = dyn_cast<VectorType>(op.getType());
741 if (!vecAttr || !vecType)
742 return failure();
743
744 xegpu::DistributeLayoutAttr layout =
745 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
746 if (!layout || !layout.isForWorkgroup())
747 return failure();
748
749 ArrayRef<int64_t> wgShape = vecType.getShape();
750 SmallVector<int64_t> sgShape;
751 int count;
752 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
753
754 auto newType = VectorType::get(sgShape, vecType.getElementType());
755 Location loc = op.getLoc();
756 auto eltType = vecType.getElementType();
757
758 auto setLayout = [&](Value val) {
759 xegpu::setTemporaryLayout(llvm::dyn_cast<OpResult>(val),
760 layout.dropSgLayoutAndData());
761 };
762
763 if (vecAttr.isSplat()) {
764 // Splat: single value for all subgroups
765 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
766 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
767 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
768 setLayout(cstOp->getResult(0));
769 rewriter.replaceOp(op, cstOp);
770 return success();
771 } else if (sgShape == wgShape) { // if the entire vector is shared by all
772 // subgroups, don't distribute
773 auto newConstOp =
774 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
775 setLayout(newConstOp->getResult(0));
776 rewriter.replaceOp(op, newConstOp);
777 return success();
778 } else {
779 // Non-splat constant
780 // Only supports 1D & 2D
781 // TODO: support other cases that require SLM access
782 if (!eltType.isIndex())
783 return rewriter.notifyMatchFailure(
784 op, "Unsupported element type for non-splat constant op.");
785
786 if (wgShape.size() > 2)
787 return rewriter.notifyMatchFailure(
788 op, "Only 1D & 2D vector constant supported");
789
790 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
791 int64_t rowStride = 0, colStride = 0;
792 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
793 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
794
795 // Compute colStride and rowStride, and check for constant strides.
796 if (cols > 1) {
797 colStride = cast<IntegerAttr>(values[1]).getInt() -
798 cast<IntegerAttr>(values[0]).getInt();
799 }
800 if (rows > 1) {
801 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
802 cast<IntegerAttr>(values[0]).getInt();
803 }
804
805 for (int64_t r = 0; r < rows; ++r) {
806 for (int64_t c = 0; c < cols; ++c) {
807 int64_t idx = r * cols + c;
808 // Check column stride
809 if (c > 0 && cols > 1) {
810 int64_t prevIdx = r * cols + (c - 1);
811 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
812 cast<IntegerAttr>(values[prevIdx]).getInt();
813 if (diff != colStride)
814 return rewriter.notifyMatchFailure(
815 op, "Non-constant column stride in constant op.");
816 }
817 // Check row stride
818 if (r > 0 && rows > 1) {
819 int64_t prevIdx = (r - 1) * cols + c;
820 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
821 cast<IntegerAttr>(values[prevIdx]).getInt();
822 if (diff != rowStride)
823 return rewriter.notifyMatchFailure(
824 op, "Non-constant row stride in constant op.");
825 }
826 }
827 }
828
829 // Create a constant for the base tile.
830 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
831 // For 1D case, extract the first sgShape[0] elements.
832 SmallVector<Attribute> baseTileValues;
833 int baseTileCols = sgShape[sgShape.size() - 1];
834 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
835 for (int64_t r = 0; r < baseTileRows; ++r) {
836 for (int64_t c = 0; c < baseTileCols; ++c) {
837 baseTileValues.push_back(values[r * cols + c]);
838 }
839 }
840
841 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
842 baseTileValues);
843 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
844
845 // Get subgroup id
846 Value sgId =
847 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
848 auto sgOffsets =
849 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
850 if (failed(sgOffsets))
851 return failure();
852
853 SmallVector<Value, 2> strideConsts;
854 strideConsts.push_back(
855 arith::ConstantIndexOp::create(rewriter, loc, colStride));
856 if (rows > 1)
857 strideConsts.insert(
858 strideConsts.begin(),
859 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
860
861 SmallVector<Value> newConstOps;
862 for (auto offsets : *sgOffsets) {
863 // Multiply offset with stride, broadcast it and add to baseConstVec
864 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
865 for (size_t i = 0; i < strideConsts.size(); ++i) {
866 Value mul =
867 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
868 offsets[i], strideConsts[i]);
869 mulOffset = arith::AddIOp::create(
870 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
871 }
872 // Broadcast to baseConstVec size
873 auto bcastOffset = vector::BroadcastOp::create(
874 rewriter, loc, baseConstVec.getType(), mulOffset);
875 auto finalConst =
876 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
877 setLayout(baseConstVec);
878 setLayout(bcastOffset);
879 setLayout(finalConst);
880 newConstOps.push_back(finalConst);
881 }
882 rewriter.replaceOpWithMultiple(op, {newConstOps});
883 return success();
884 }
885 }
886};
887
888// This pattern transforms the LoadGatherOp with explicit offsets to load
889// subgroup data
890struct WgToSgLoadGatherOpWithOffset
891 : public OpConversionPattern<xegpu::LoadGatherOp> {
892 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
893 LogicalResult
894 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
895 ConversionPatternRewriter &rewriter) const override {
896
897 if (!op.getOffsets())
898 return failure();
899
900 Location loc = op.getLoc();
901 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
902 if (!resultType)
903 return failure();
904 ArrayRef<int64_t> wgShape = resultType.getShape();
905
906 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
907
908 if (!layout || !layout.isForWorkgroup())
909 return failure();
910
911 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
912
913 // The offsets need to be distributed
914 auto offsetsVecType =
915 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
916 auto maskVecType =
917 dyn_cast<VectorType>(adaptor.getMask().front().getType());
918 if (!offsetsVecType || !maskVecType ||
919 offsetsVecType.getShape() != maskVecType.getShape()) {
920 return rewriter.notifyMatchFailure(op,
921 "offsets have not been distributed");
922 }
923
924 SmallVector<Value> newLoadOps;
925 auto chunkSizeAttr =
926 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
927 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
928 for (auto [offsets, mask] :
929 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
930 auto newLayout = layout.dropSgLayoutAndData();
931 auto newLoadOp = xegpu::LoadGatherOp::create(
932 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
933 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
934 newLayout);
935 newLoadOp.setAnchorLayout(newLayout);
936 newLoadOps.push_back(newLoadOp);
937 }
938 rewriter.replaceOpWithMultiple(op, {newLoadOps});
939 return success();
940 }
941};
942
943// This pattern transforms the StoreScatterOp with explicit offsets to store
944// subgroup data
945struct WgToSgStoreScatterOpWithOffset
946 : public OpConversionPattern<xegpu::StoreScatterOp> {
947 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
948 LogicalResult
949 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter) const override {
951
952 if (!op.getOffsets())
953 return failure();
954
955 Location loc = op.getLoc();
956 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
957 if (!valueType)
958 return failure();
959
960 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
961
962 if (!layout || !layout.isForWorkgroup())
963 return failure();
964
965 // The offsets need to be distributed
966 auto offsetsVecType =
967 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
968 auto maskVecType =
969 dyn_cast<VectorType>(adaptor.getMask().front().getType());
970 if (!offsetsVecType || !maskVecType ||
971 offsetsVecType.getShape() != maskVecType.getShape()) {
972 return rewriter.notifyMatchFailure(op,
973 "offsets have not been distributed");
974 }
975
976 auto chunkSizeOpt = op.getChunkSize();
977 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
978 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
979 for (auto [val, offs, mask] : llvm::zip(
980 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
981 auto store = xegpu::StoreScatterOp::create(
982 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
983 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
984 layout.dropSgLayoutAndData());
985 // Update the layout attribute to drop sg_layout and sg_data.
986 for (OpOperand &operand : store->getOpOperands()) {
987 // Skip for operand one (memref)
988 if (operand.getOperandNumber() == 1)
989 continue;
990 xegpu::setTemporaryLayout(operand, layout.dropSgLayoutAndData());
991 }
992 }
993 rewriter.eraseOp(op);
994 return success();
995 }
996};
997
998struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
999 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1000 LogicalResult
1001 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1002 ConversionPatternRewriter &rewriter) const override {
1003
1004 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1005 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1006 return failure();
1007
1008 ArrayRef<int64_t> wgShape = op.getDataShape();
1009 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1010 assert(valueTy && "the value type must be vector type!");
1011 Type elemTy = valueTy.getElementType();
1012
1013 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1014 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1015 VectorType newResTy = VectorType::get(sgShape, elemTy);
1016 SmallVector<Value> newOps;
1017 for (auto offsets : offsetsList) {
1018 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1019 op.getMemDesc(), offsets,
1020 layout.dropSgLayoutAndData());
1021 newOps.push_back(newOp);
1022 }
1023 rewriter.replaceOpWithMultiple(op, {newOps});
1024
1025 return success();
1026 }
1027};
1028
1029struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1030 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1031 LogicalResult
1032 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1033 ConversionPatternRewriter &rewriter) const override {
1034
1035 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1036 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1037 return failure();
1038
1039 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1040 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1041 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1042 offsets, layout.dropSgLayoutAndData());
1043 rewriter.eraseOp(op);
1044 return success();
1045 }
1046};
1047
1048// This pattern distributes the vector.step ops to work at subgroup level
1049struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1050 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1051 LogicalResult
1052 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1053 ConversionPatternRewriter &rewriter) const override {
1054 xegpu::DistributeLayoutAttr layout =
1055 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1056 if (!layout || !layout.isForWorkgroup())
1057 return failure();
1058
1059 Location loc = op.getLoc();
1060 VectorType type = op.getResult().getType();
1061 auto wgShape = type.getShape();
1062 std::optional<SmallVector<int64_t>> sgShape =
1063 getSgShapeAndCount(wgShape, layout).first;
1064 if (!sgShape)
1065 return failure();
1066
1067 Value sgId =
1068 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1069 auto sgOffsets =
1070 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1071 if (failed(sgOffsets))
1072 return failure();
1073
1074 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1075 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1076 SmallVector<Value> newOps;
1077 for (auto offsets : *sgOffsets) {
1078 // Broadcast the offset scalar to a vector & add to the base steps
1079 auto bcastOffset =
1080 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1081 auto finalSteps =
1082 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1083 xegpu::setTemporaryLayout(steps->getResult(0),
1084 layout.dropSgLayoutAndData());
1085 xegpu::setTemporaryLayout(bcastOffset->getResult(0),
1086 layout.dropSgLayoutAndData());
1087 xegpu::setTemporaryLayout(finalSteps->getResult(0),
1088 layout.dropSgLayoutAndData());
1089 newOps.push_back(finalSteps);
1090 }
1091
1092 rewriter.replaceOpWithMultiple(op, {newOps});
1093 return success();
1094 }
1095};
1096
1097// This pattern transforms vector.shape_cast ops to work at subgroup level.
1098struct WgToSgVectorShapeCastOp
1099 : public OpConversionPattern<vector::ShapeCastOp> {
1100 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1101
1102 LogicalResult
1103 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter) const override {
1105
1106 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1107 if (!resultType)
1108 return failure();
1109
1110 ArrayRef<int64_t> wgShape = resultType.getShape();
1111 xegpu::DistributeLayoutAttr layout =
1112 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1113 if (!layout || !layout.isForWorkgroup())
1114 return failure();
1115
1116 // Check that srcShape and destShape, if they differ, only differ by
1117 // expand of unit dimensions.
1118 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1119 if (!srcType)
1120 return failure();
1121
1122 ArrayRef<int64_t> srcShape = srcType.getShape();
1123 llvm::SetVector<int64_t> expandedUnitDims;
1124
1125 // Check if shapes only differ by expanding unit dimensions (like
1126 // expand_dims)
1127 auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
1128 ArrayRef<int64_t> dst) -> bool {
1129 // All unit dimensions in dst that don't appear in src are the expanded
1130 // unit dimensions
1131 size_t srcIdx = 0;
1132 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
1133 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
1134 srcIdx++;
1135 else if (dst[dstIdx] == 1)
1136 expandedUnitDims.insert(dstIdx);
1137 else
1138 return false;
1139 return srcIdx == src.size();
1140 };
1141
1142 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1143
1144 if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
1145 xegpu::DistributeLayoutAttr sourceLayout =
1146 xegpu::getDistributeLayoutAttr(op.getSource());
1147
1148 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1149 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1150 return isa<vector::BroadcastOp>(user);
1151 });
1152 };
1153
1154 if (!usedByBroadcastOp(op))
1155 return rewriter.notifyMatchFailure(
1156 op, "ShapeCast ops that expand unit dimensions and are used by "
1157 "non-broadcast operations are not supported.");
1158
1159 if (!sourceLayout.isSliceOf(layout))
1160 return rewriter.notifyMatchFailure(
1161 op, "The ShapeCast op only expands dimensions, the result layout "
1162 "must be a slice of the input layout, or vice versa.");
1163 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1164 layoutToDistribute =
1165 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1166 }
1167
1168 SmallVector<int64_t> sgShape =
1169 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1170 VectorType newResultType =
1171 VectorType::get(sgShape, resultType.getElementType());
1172
1173 SmallVector<Value> newShapeCastOps;
1174 for (auto src : adaptor.getSource()) {
1175 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1176 newResultType, src);
1177 xegpu::setTemporaryLayout(newShapeCast->getResult(0),
1178 layout.dropSgLayoutAndData());
1179 newShapeCastOps.push_back(newShapeCast.getResult());
1180 }
1181
1182 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1183 return success();
1184 }
1185};
1186
1187/// Pattern for lowering vector.multi_reduction op to subgroup level.
1188/// Current limitation: the sg_layout in the reduced dimension being 1
1189/// so that reduction is local to subgroup & no cross-subgroup communication is
1190/// needed.
1191/// TODO: Add cases to handle more general situations which require SLM access.
1192struct WgToSgMultiDimReductionOp
1193 : public OpConversionPattern<vector::MultiDimReductionOp> {
1194 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1195
1196 LogicalResult
1197 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1198 ConversionPatternRewriter &rewriter) const override {
1199 VectorType srcType = op.getSourceVectorType();
1200 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1201 if (!dstType)
1202 return failure();
1203
1204 auto srcShape = srcType.getShape();
1205 xegpu::DistributeLayoutAttr layout =
1206 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1207 if (!layout || !layout.isForWorkgroup())
1208 return failure();
1209
1210 auto reductionDims = llvm::to_vector(op.getReductionDims());
1211
1212 SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1213 .getParent()
1214 .getEffectiveSgLayoutAsInt();
1215 SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1216 .getParent()
1217 .getEffectiveSgDataAsInt();
1218
1219 // Check that the sgLayout in the reduced dimension is 1 and
1220 // each sg gets the entire slice to reduce.
1221 for (int64_t dim : reductionDims) {
1222 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1223 return rewriter.notifyMatchFailure(
1224 op,
1225 "sgLayout in each reduced dimension must be 1 and sgData in the "
1226 "reduced dim must match srcShape in that dim");
1227 }
1228
1229 SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1230
1231 VectorType newDstType =
1232 VectorType::get({sgShape}, dstType.getElementType());
1233
1234 SmallVector<Value> newReductions;
1235 for (auto sgSrc : adaptor.getSource()) {
1236 auto newOp = vector::MultiDimReductionOp::create(
1237 rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1238 adaptor.getAcc()[0], op.getReductionDims());
1240 layout.dropSgLayoutAndData());
1241 newReductions.push_back(newOp.getResult());
1242 }
1243
1244 rewriter.replaceOpWithMultiple(op, {newReductions});
1245 return success();
1246 }
1247};
1248
1249// This pattern transforms vector.transpose ops to work at subgroup level.
1250struct WgToSgVectorTransposeOp
1251 : public OpConversionPattern<vector::TransposeOp> {
1252 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1253
1254 LogicalResult
1255 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter) const override {
1257 VectorType resultType = op.getResultVectorType();
1258
1259 ArrayRef<int64_t> wgShape = resultType.getShape();
1260 xegpu::DistributeLayoutAttr layout =
1261 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1262 if (!layout || !layout.isForWorkgroup())
1263 return failure();
1264 // TODO-LayoutRefactor: handle the case using getTemporaryLayout
1265 xegpu::DistributeLayoutAttr sourceLayout =
1266 xegpu::getDistributeLayoutAttr(op.getVector());
1267 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1268 return failure();
1269
1270 SmallVector<int64_t> sourceSgLayout =
1271 sourceLayout.getEffectiveSgLayoutAsInt();
1272 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1273 DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
1274 DenseI32ArrayAttr resultOrder = layout.getOrder();
1275
1276 if (!sourceOrder || !resultOrder) {
1277 return rewriter.notifyMatchFailure(
1278 op, "Both source and result must have order attributes");
1279 }
1280
1281 ArrayRef<int64_t> permutation = op.getPermutation();
1282 size_t permutationSize = permutation.size();
1283 if (sourceSgLayout.size() != permutationSize ||
1284 resultSgLayout.size() != permutationSize) {
1285 return rewriter.notifyMatchFailure(
1286 op, "Layouts and permutation must have the same rank");
1287 }
1288
1289 // Check that sgLayout, sgData & order are properly transposed for source
1290 // and result
1291 if (!layout.isTransposeOf(sourceLayout, permutation))
1292 return rewriter.notifyMatchFailure(
1293 op, "Result layout is not a valid transpose of source layout "
1294 "according to permutation");
1295
1296 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1297 VectorType newResultType =
1298 VectorType::get(sgShape, resultType.getElementType());
1299 SmallVector<Value> newTransposeOps;
1300 for (auto src : adaptor.getVector()) {
1301 auto newTranspose = vector::TransposeOp::create(
1302 rewriter, op.getLoc(), newResultType, src, permutation);
1303 xegpu::setTemporaryLayout(newTranspose->getResult(0),
1304 layout.dropSgLayoutAndData());
1305 newTransposeOps.push_back(newTranspose.getResult());
1306 }
1307
1308 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1309 return success();
1310 }
1311};
1312
1313// Distribute vector mask ops to work at subgroup level.
1314template <typename MaskOpType>
1315struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1316 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1317
1318 LogicalResult matchAndRewrite(
1319 MaskOpType op,
1320 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1321 ConversionPatternRewriter &rewriter) const override {
1322 xegpu::DistributeLayoutAttr layout =
1323 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1324 if (!layout || !layout.isForWorkgroup())
1325 return failure();
1326
1327 Location loc = op.getLoc();
1328 VectorType type = op.getResult().getType();
1329 auto wgShape = type.getShape();
1330
1331 SmallVector<Value> wgMaskDimSizes;
1332 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1333 for (int64_t maskSize : op.getMaskDimSizes()) {
1334 wgMaskDimSizes.push_back(
1335 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1336 }
1337 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1338 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1339 }
1340
1341 Value sgId =
1342 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1343 auto sgOffsets =
1344 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1345 if (failed(sgOffsets))
1346 return failure();
1347
1348 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1349 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1350
1351 // In each dimension, each subgroup computes its local mask size as:
1352 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1353 SmallVector<Value> newCreateMaskOps;
1354 for (auto offsetSet : *sgOffsets) {
1355 SmallVector<Value> maskOperands;
1356
1357 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1358 Value dimSizeVal =
1359 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1360 Value offset = offsetSet[i];
1361 Value adjustedMaskSize =
1362 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1363 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1364 Value nonNegative =
1365 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1366 Value sgMaskSize =
1367 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1368 maskOperands.push_back(sgMaskSize);
1369 }
1370
1371 auto newCreateMaskOp =
1372 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1373 xegpu::setTemporaryLayout(newCreateMaskOp->getResult(0),
1374 layout.dropSgLayoutAndData());
1375 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1376 }
1377
1378 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1379 return success();
1380 }
1381};
1382
1383using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1384using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1385} // namespace
1386
1387namespace mlir {
1388namespace xegpu {
1390 patterns
1391 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1392 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1393 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1394 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1395 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1396 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1397 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1398 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1399 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1400 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1401 patterns.getContext());
1402}
1403} // namespace xegpu
1404} // namespace mlir
1405
1406namespace {
1407struct XeGPUWgToSgDistributePass
1408 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1409 void runOnOperation() override;
1410};
1411} // namespace
1412
1413void XeGPUWgToSgDistributePass::runOnOperation() {
1414
1415 // TODO-LayoutRefactor: unify the local propagation for layout preprocessing
1416 // Operation *op = getOperation();
1417 // if (!xegpu::recoverTemporaryLayouts(op)) {
1418 // signalPassFailure();
1419 // return;
1420 // }
1421
1422 // Track existing UnrealizedConversionCastOps
1423 SmallVector<Operation *> existingCastOps;
1424 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1425 existingCastOps.push_back(castOp.getOperation());
1426 });
1427
1428 {
1429 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1430 // VectorType operands. This first converts such operands to
1431 // RankedTensorType, propagates the layout attribute into the encoding
1432 // attribute, and finally converts the RankedTensorType to VectorType based
1433 // on the encoding.
1434
1435 TypeConverter converter;
1436 converter.addConversion([&](Type type) -> Type { return type; });
1437 converter.addConversion(
1438 [&](RankedTensorType type,
1439 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1440 Type elemTy = type.getElementType();
1441 ArrayRef<int64_t> shape = type.getShape();
1442
1443 int count;
1444 SmallVector<int64_t> subShape;
1445 std::tie(subShape, count) = getSgShapeAndCount(
1446 shape,
1447 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1448
1449 auto newTy = VectorType::get(subShape, elemTy);
1450 result.append(count, newTy);
1451 return success();
1452 });
1453
1455 converter);
1456 }
1457
1458 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1459 // as well as XeGPU, Arith, and Vector operations.
1460 MLIRContext *ctx = &getContext();
1461 RewritePatternSet patterns(ctx);
1462 ConversionTarget target(*ctx);
1463 TypeConverter converter;
1464 converter.addConversion([&](Type type) -> Type { return type; });
1465 converter.addConversion(
1466 [&](xegpu::TensorDescType type,
1467 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1468 Type elemTy = type.getElementType();
1469 ArrayRef<int64_t> shape = type.getShape();
1470
1471 int count;
1472 SmallVector<int64_t> subShape;
1473 xegpu::LayoutAttr layout = type.getLayoutAttr();
1474 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1475
1476 if (layout)
1477 layout = layout.dropSgLayoutAndData();
1478
1479 auto newTy = xegpu::TensorDescType::get(
1480 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1481 result.append(count, newTy);
1482 return success();
1483 });
1484
1485 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1486 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1487 return createOp.getType();
1488 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1489 return loadOp.getTensorDescType();
1490 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1491 return storeOp.getTensorDescType();
1492 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1493 return updateOp.getType();
1494 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1495 return prefetchOp.getTensorDescType();
1496 return xegpu::TensorDescType();
1497 };
1498
1499 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1500 return !layout || !layout.isForWorkgroup();
1501 };
1502
1503 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1504 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1505 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1506 auto tdescTy = getTensorDescType(op);
1507 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1508 return isLegal(layout);
1509 });
1510
1511 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1512 auto layout = op.getLayoutCdAttr();
1513 return isLegal(layout);
1514 });
1515
1516 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1517 [=](xegpu::LoadMatrixOp op) -> bool {
1518 return isLegal(op.getLayoutAttr());
1519 });
1520
1521 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1522 [=](xegpu::StoreMatrixOp op) -> bool {
1523 return isLegal(op.getLayoutAttr());
1524 });
1525
1526 target.addDynamicallyLegalOp<arith::ConstantOp>(
1527 [=](arith::ConstantOp op) -> bool {
1528 auto vecType = dyn_cast<VectorType>(op.getType());
1529 if (!vecType)
1530 return true;
1531
1532 auto layout =
1533 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1534 return isLegal(layout);
1535 });
1536
1537 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1538 vector::TransposeOp, vector::BroadcastOp,
1539 vector::MultiDimReductionOp,
1540 vector::ConstantMaskOp, vector::CreateMaskOp>(
1541 [=](Operation *op) -> bool {
1542 // Check for either a SliceAttr or LayoutAttr on the result.
1543 auto layout =
1544 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1545 return isLegal(layout);
1546 });
1547
1548 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1549 [=](xegpu::LoadGatherOp op) -> bool {
1550 auto layout = op.getLayoutAttr();
1551 return isLegal(layout);
1552 });
1553
1554 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1555 [=](xegpu::StoreScatterOp op) -> bool {
1556 auto layout = op.getLayoutAttr();
1557 return isLegal(layout);
1558 });
1559
1560 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1561 [=](xegpu::ConvertLayoutOp op) -> bool {
1562 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1563 });
1564
1565 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1566 [=](Operation *op) -> std::optional<bool> {
1567 // Only handle elementwise mappable ops
1569 return true;
1570
1571 VectorType resultType =
1572 dyn_cast<VectorType>(op->getResult(0).getType());
1573 if (!resultType)
1574 return true;
1575
1576 // Check if all operands are vectors of the same shape
1577 // TODO: Support other types.
1578 for (Value operand : op->getOperands()) {
1579 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1580 if (!operandType || operandType.getShape() != resultType.getShape()) {
1581 return true;
1582 }
1583 }
1584
1585 xegpu::DistributeLayoutAttr layout =
1586 xegpu::getTemporaryLayout(op->getResult(0));
1587 return isLegal(layout);
1588 });
1589
1590 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1591 [=](UnrealizedConversionCastOp op) {
1592 return llvm::is_contained(existingCastOps, op.getOperation());
1593 });
1594
1595 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1596
1598 target);
1600 if (failed(
1601 applyPartialConversion(getOperation(), target, std::move(patterns))))
1602 return signalPassFailure();
1603
1604 // Remove sg_layout and sg_data attributes from the Layout
1605 // attribute for each VectorType result of the operation.
1606 // For Structured Control Flow ops, the layout is simply removed,
1607 // since in 1:N case, the layout for new results are missing.
1608 // Layout propagation pass will activated.
1609 getOperation()->walk([](Operation *op) {
1610 for (OpResult result : op->getOpResults()) {
1611 std::string name = xegpu::getTemporaryLayoutName(result);
1612 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1613 op->removeAttr(name);
1614 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1615 if (auto newLayout = layout.dropSgLayoutAndData())
1616 op->setAttr(name, newLayout);
1617 }
1618 }
1619 }
1620 });
1621}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_range getOpResults()
Definition Operation.h:420
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)