MLIR 22.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
24#include <optional>
25
26namespace mlir {
27namespace xegpu {
28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30} // namespace xegpu
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36
37// Retrieve the RangeAttr if it is specified.
38static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39 Operation *parent = op->getParentOfType<scf::IfOp>();
40 while (parent) {
41 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->getAttr("sg_id_range")))
43 return attr;
44 parent = parent->getParentOfType<scf::IfOp>();
45 }
46 return {};
47}
48
49static std::pair<SmallVector<int64_t>, int>
50getSgShapeAndCount(ArrayRef<int64_t> shape,
51 xegpu::DistributeLayoutAttr layout) {
52 int count = 1;
54 if (layout && layout.isForWorkgroup()) {
55 SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
58 else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
59 sgShape = *maybeDerivedSgData;
60 SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
61 // Clamp distUnit to the original shape to handle cases where data is
62 // shared among subgroups, which may cause distUnit to exceed the original
63 // shape.
64 for (size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(shape[i], distUnit[i]);
66 count = computeProduct(shape) / computeProduct(distUnit);
67 }
68 return std::make_pair(sgShape, count);
69}
70
71/// Utility helper for deriving a list of offsets for each sub-TensorDescs
72/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
73/// associated distribute layout attribute, the shape, subgroup id and the
74/// original offsets of the op
75template <
76 typename OpType,
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
80static LogicalResult
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
83 Location loc = op.getLoc();
84 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
85 // not applicable to ops without offsets operands.
86 if (origOffsets.empty())
87 return failure();
88
89 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
90 xegpu::DistributeLayoutAttr layout;
91 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
92 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
93 layout = op.getLayoutAttr();
94 } else {
95 layout = op.getDescLayoutAttr();
96 }
97
98 // not applicable to ops without workgroup layout attributes
99 if (!layout || !layout.isForWorkgroup())
100 return failure();
101
102 Value sgId =
103 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
104
105 // verify and adjust the sgId if the range specifier is present
106 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
107 if (sgIdRange) {
108 int64_t startOfRange = sgIdRange.getStart().getInt();
109 int64_t endOfRange = sgIdRange.getEnd().getInt();
110 // verify the RangeAttr against the layout attribute
111 if (layout.getNumSubgroups() != endOfRange - startOfRange)
112 return rewriter.notifyMatchFailure(
113 op, "sg_layout size must match the sg_id_range");
114 // adjust the sgId if necessary
115 if (startOfRange > 0) {
116 Value startOfRangeVal =
117 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
118 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
119 }
120 }
121
122 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
123 // descriptors to be accessed, based on the layout information.
124 ArrayRef<int64_t> wgShape = op.getDataShape();
125 auto maybeDescOffsets =
126 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
127 if (failed(maybeDescOffsets))
128 return failure();
129
130 // Compute the final global offsets for each accessed sub-tensor
131 // or sub-memory descriptor.
132 for (const auto &sgOffsets : *maybeDescOffsets) {
134 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
135 offsetsList.push_back(std::move(newOffsets));
136 }
137
138 // callback(offsetsList);
139 return success();
140}
141
142/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
143/// from a workgroup descriptor. It replaces the offsets and sizes with
144/// appropriate values for the subgroup.
145/// It uses round-robin assignment to distribute the work to the subgroups.
146/// Following create_nd_desc operation:,
147/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
148/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
149/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
150/// is converted to 9 subgroup level operations based on the sg_layout &
151/// sg_data:
152/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
153/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
154/// lane_data = [1, 1]>>
155///
156/// The sg_layout and sg_data attributes are dropped after the pass as they are
157/// no longer needed.
158///
159/// 24x24 matrix distribution example:
160/// sg_layout = [4, 4], sg_data = [2, 2]
161/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
162/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
163///
164/// +------------------------+
165/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
166/// |-----+-----+-----|
167/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
168/// |-----+-----+-----|
169/// | 8x8 | 8x8 | 8x8 |
170/// +------------------------+
171///
172/// Each 8x8 tile is further subdivided among subgroups:
173/// +------------------------+
174/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
175/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
176/// | 2x2 2x2 2x2 2x2 |
177/// | 2x2 2x2 2x2 2x2 |
178/// +------------------------+
179///
180/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
181/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
182
183/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
184/// pattern and all the other ops just follow.
185/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
186/// ops in the pass.
187struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
188 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
189
190 LogicalResult
191 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter) const override {
193 SmallVector<SmallVector<OpFoldResult>> offsetsList;
194 if (failed(genOffsetsList(rewriter, op, offsetsList)))
195 return failure();
196
197 MLIRContext *ctx = op.getContext();
198 xegpu::TensorDescType tdescTy = op.getType();
199 ArrayRef<int64_t> wgShape = tdescTy.getShape();
200 Type elemTy = tdescTy.getElementType();
201 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
202 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
203 auto newTdescTy =
204 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
205 layout.dropSgLayoutAndData());
206
207 SmallVector<Value> newOps;
208 for (auto offsets : offsetsList) {
209 auto newOp = xegpu::CreateNdDescOp::create(
210 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
211 op.getMixedSizes(), op.getMixedStrides());
212
213 newOps.push_back(newOp);
214 }
215 rewriter.replaceOpWithMultiple(op, {newOps});
216
217 return success();
218 }
219};
220
221// This pattern transforms the CreateNdDescOp without offsets to create a
222// subgroup descriptor from a workgroup descriptor
223struct WgToSgCreateNdOpNoOffset
224 : public OpConversionPattern<xegpu::CreateNdDescOp> {
225 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
226
227 LogicalResult
228 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const override {
230
231 // Check no offsets are specified.
232 if (!op.getMixedOffsets().empty())
233 return failure();
234
235 Location loc = op.getLoc();
236 MLIRContext *ctx = op.getContext();
237 xegpu::TensorDescType tdescTy = op.getType();
238 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
239 if (!layout || !layout.isForWorkgroup())
240 return failure();
241
242 Type elemTy = tdescTy.getElementType();
243 ArrayRef<int64_t> wgShape = tdescTy.getShape();
244
245 SmallVector<int64_t> sgShape;
246 int count;
247 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
248 xegpu::TensorDescType newTdescTy =
249 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
250 layout.dropSgLayoutAndData());
251
252 SmallVector<Value> newCreateNdOps(count);
253 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
254 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
255 op.getSource(), op.getMixedSizes(),
256 op.getMixedStrides());
257 });
258
259 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
260 return success();
261 }
262};
263
264/// This pattern transforms the LoadNdOp to load subgroup data.
265struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
266 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
267 LogicalResult
268 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter) const override {
270 if (!op.getMixedOffsets().empty())
271 return failure();
272
273 SmallVector<Value> newLoadOps;
274 for (auto src : adaptor.getTensorDesc()) {
275 xegpu::TensorDescType tdescTy =
276 dyn_cast<xegpu::TensorDescType>(src.getType());
277 ArrayRef<int64_t> srcShape = tdescTy.getShape();
278 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
279 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
280 src, op->getAttrs());
281 newLoadOps.push_back(newLoadOp);
282 }
283 rewriter.replaceOpWithMultiple(op, {newLoadOps});
284 return mlir::success();
285 }
286};
287
288/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
289/// It creates a StoreNdOp op to store the updated values to the new subgroup
290/// src tensor descriptors.
291struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
292 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
293 LogicalResult
294 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296 if (!op.getMixedOffsets().empty())
297 return failure();
298
299 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
300 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
301 op.getL2HintAttr(), op.getL3HintAttr());
302
303 rewriter.eraseOp(op);
304 return success();
305 }
306};
307
308// This pattern transforms the LoadNdOp with explicit offsets to load
309// subgroup data.
310struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
311 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
312 LogicalResult
313 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315
316 SmallVector<SmallVector<OpFoldResult>> offsetsList;
317 if (failed(genOffsetsList(rewriter, op, offsetsList)))
318 return failure();
319
320 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
321 if (layout)
322 layout = layout.dropSgLayoutAndData();
323 SmallVector<Value> newOps;
324 for (auto [tdesc, offsets] :
325 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
326 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
327 VectorType newResTy =
328 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
329 auto newOp = xegpu::LoadNdOp::create(
330 rewriter, op.getLoc(), newResTy, tdesc, offsets,
331 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
332 op.getL2HintAttr(), op.getL3HintAttr(), layout);
333 newOps.push_back(newOp);
334 }
335 rewriter.replaceOpWithMultiple(op, {newOps});
336
337 return success();
338 }
339};
340
341// This pattern transforms the StoreNdOp with explicit offsets to store
342// subgroup data.
343struct WgToSgStoreNdOpWithOffset
344 : public OpConversionPattern<xegpu::StoreNdOp> {
345 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
346 LogicalResult
347 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter) const override {
349 SmallVector<SmallVector<OpFoldResult>> offsetsList;
350 if (failed(genOffsetsList(rewriter, op, offsetsList)))
351 return failure();
352
353 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
354 if (layout)
355 layout = layout.dropSgLayoutAndData();
356 for (auto [v, tdesc, offsets] :
357 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
358 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
359 op.getL1HintAttr(), op.getL2HintAttr(),
360 op.getL3HintAttr(), layout);
361 }
362 rewriter.eraseOp(op);
363
364 return success();
365 }
366};
367
368// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
369// subgroup data.
370struct WgToSgPrefetchNdOpWithOffset
371 : public OpConversionPattern<xegpu::PrefetchNdOp> {
372 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
373 LogicalResult
374 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
375 ConversionPatternRewriter &rewriter) const override {
376 SmallVector<SmallVector<OpFoldResult>> offsetsList;
377 if (failed(genOffsetsList(rewriter, op, offsetsList)))
378 return failure();
379
380 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
381 if (layout)
382 layout = layout.dropSgLayoutAndData();
383 for (auto [tdesc, offsets] :
384 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
385 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
386 op.getL1HintAttr(), op.getL2HintAttr(),
387 op.getL3HintAttr(), layout);
388 }
389 rewriter.eraseOp(op);
390
391 return success();
392 }
393};
394
395/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
396/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
397/// offsets of the new subgroup src tensor descriptors.
398struct WgToSgUpdateNdOffsetOp
399 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
400 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
401 LogicalResult
402 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter) const override {
404 llvm::SmallVector<Value> newUpdateTileOffsetOps;
405 for (auto tDesc : adaptor.getTensorDesc()) {
406 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
407 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
408 op.getConstOffsets());
409 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
410 }
411
412 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
413 return success();
414 }
415};
416
417/// This pattern transforms the DpasOp to work at subgroup level.
418struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
419 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
420 LogicalResult
421 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter) const override {
423 Location loc = op.getLoc();
424 VectorType resultTy = op.getResult().getType();
425 if (resultTy.getRank() != 2)
426 return failure();
427
428 auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
429 if (!originalLayout)
430 return failure();
431
432 size_t i = 0;
433 SmallVector<Value> newDpasOps;
434 for (auto aVec : adaptor.getLhs()) {
435 for (auto bVec : adaptor.getRhs()) {
436
437 llvm::SmallVector<Value> operands({aVec, bVec});
438 Value tmpC;
439 if (op.getAcc()) {
440 tmpC = adaptor.getAcc()[i++];
441 operands.push_back(tmpC);
442 }
443
444 ArrayRef<int64_t> aVecShape =
445 llvm::cast<VectorType>(aVec.getType()).getShape();
446 ArrayRef<int64_t> bVecShape =
447 llvm::cast<VectorType>(bVec.getType()).getShape();
448 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
449 resultTy.getElementType());
450 tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
451 xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
452 originalLayout.dropSgLayoutAndData());
453
454 newDpasOps.push_back(tmpC);
455 }
456 }
457 rewriter.replaceOpWithMultiple(op, {newDpasOps});
458 return success();
459 }
460};
461
462/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
463struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
464 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
465 LogicalResult
466 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter) const override {
468
469 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
470 if ((offsetSize != 0) || op.getConstOffsetsAttr())
471 return failure();
472
473 for (auto src : adaptor.getTensorDesc())
474 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src,
475 op->getAttrs());
476 rewriter.eraseOp(op);
477 return success();
478 }
479};
480
481/// This pattern transforms vector.broadcast ops to work at subgroup level.
482struct WgToSgVectorBroadcastOp
483 : public OpConversionPattern<vector::BroadcastOp> {
484 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
485
486 LogicalResult
487 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
488 ConversionPatternRewriter &rewriter) const override {
489
490 VectorType resultType = op.getResult().getType();
491 ArrayRef<int64_t> wgShape = resultType.getShape();
492
493 xegpu::DistributeLayoutAttr layout =
494 xegpu::getDistributeLayoutAttr(op.getResult());
495 if (!layout || !layout.isForWorkgroup())
496 return failure();
497
498 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
499 VectorType newResultType =
500 VectorType::get(sgShape, resultType.getElementType());
501
502 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
503 return failure();
504
505 SmallVector<Value> newBroadcastOps;
506 for (auto operand : adaptor.getOperands().front()) {
507 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
508 newResultType, operand);
509 xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
510 layout.dropSgLayoutAndData());
511
512 newBroadcastOps.push_back(newBroadcast.getResult());
513 }
514 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
515 return success();
516 }
517};
518
519// This pattern transforms elementwise ops to work at subgroup level.
520struct WgToSgElementwiseOp : public ConversionPattern {
521 WgToSgElementwiseOp(MLIRContext *ctx)
522 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
524 LogicalResult
525 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
526 ConversionPatternRewriter &rewriter) const override {
527 // Only match ops with elementwise trait and single result.
529 return failure();
530
531 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
532 assert(resultType && "Expected result to be a VectorType");
533
534 ArrayRef<int64_t> wgShape = resultType.getShape();
535
536 xegpu::DistributeLayoutAttr layout =
538 if (!layout || !layout.isForWorkgroup())
539 return failure();
540
541 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
543 size_t numVariants = operands.empty() ? 0 : operands.front().size();
544
545 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
546 return operandVec.size() != numVariants;
547 }))
548 return failure();
549
551 VectorType newResultType =
552 VectorType::get(sgShape, resultType.getElementType());
553
554 for (size_t i = 0; i < numVariants; ++i) {
556 for (auto &operandVec : operands)
557 opOperands.push_back(operandVec[i]);
558
559 OperationState state(op->getLoc(), op->getName());
560 state.addOperands(opOperands);
561 state.addTypes(newResultType);
562 // Copy all attributes, but update "layout_result_0" to drop
563 // sgLayout/sgData
564 for (auto attr : op->getAttrs()) {
565 if (auto layout =
566 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
567 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
568 !layout.getEffectiveInstDataAsInt().empty())
569 state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
570 } else {
571 state.addAttribute(attr.getName(), attr.getValue());
573 }
574 Operation *newOp = rewriter.create(state);
575 newResults.push_back(newOp->getResult(0));
576 }
577
578 rewriter.replaceOpWithMultiple(op, {newResults});
579 return success();
580 }
581};
582
583// clang-format off
584// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
585// If input_layout and target_layout have identical sg_layout and sg_data,
586// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
587// dropped. For example:
588// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
589// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
590// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
591// becomes:
592// #a = #xegpu.layout<inst_data = [16, 16]>
593// #b = #xegpu.layout<inst_data = [8, 16]>
594// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
595// (vector<16x16xf32> is determined by sg_data = [16, 16])
596//
597// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
598// For example:
599// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
600// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
601// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
602// is lowered to:
603// #a = #xegpu.layout<inst_data = [16, 16]>
604// #b = #xegpu.layout<inst_data = [8, 16]>
605// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
606// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
607// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
608// clang-format on
609struct WgToSgConvertLayoutOp
610 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
611 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
612 LogicalResult
613 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
614 ConversionPatternRewriter &rewriter) const override {
615 // TODO: currently, we only support LayoutAttr
616 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
617 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
618
619 if (!input || !target || !input.isForWorkgroup() ||
620 !target.isForWorkgroup())
621 return rewriter.notifyMatchFailure(
622 op, "Input and target layouts must have subgroup layout");
623
624 DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
625 DenseI32ArrayAttr inputSgData = input.getSgData();
626 DenseI32ArrayAttr inputOrder = input.getOrder();
627 DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
628 DenseI32ArrayAttr targetSgData = target.getSgData();
629 DenseI32ArrayAttr targetOrder = target.getOrder();
630
631 // TODO: currently we only support for optimal case, where input and
632 // output has the same sg_layout and sg_data, so SLM is not involved.
633 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
634 inputOrder != targetOrder)
635 return failure();
636
637 input = input.dropSgLayoutAndData();
638 target = target.dropSgLayoutAndData();
639
640 SmallVector<Value> newOps(adaptor.getSource());
641 if (input && target) {
642 // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
643 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
644 auto newOp = xegpu::ConvertLayoutOp::create(
645 rewriter, op.getLoc(), src.getType(), src, input, target);
646 newOps[i] = newOp;
647 }
648 }
649 rewriter.replaceOpWithMultiple(op, {newOps});
650 return success();
651 }
652};
653
654// Handles UnrealizedConversionCastOp generated during
655// SCFStructuralTypeConversions (step 1). This op may appear as either a
656// target or source materialization for Vector values, e.g.:
657// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
658// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
659// it could be either 1:N or N:1 cast. In both cases, the pattern
660// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
661// for example, the following scf::forOp
662// ```
663// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
664// %n = use(%arg1): vector<128x128xf16>
665// scf.yield %n : vector<128x128xf16>
666// }
667// ```
668// Could be converted to:
669// ```
670// %1 = unrealized_conversion_cast %0
671// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
672// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
673// -> (vector<16x16xf16>, vector<16x16xf16) {
674// %m = unrealized_conversion_cast %arg1, %arg2
675// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
676// %n = use(%m): vector<128x128xf16>
677// %b = unrealized_conversion_cast %n
678// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
679// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
680// }
681// %cast = unrealized_conversion_cast %for:2
682// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
683// ```
684// TODO: remove it when context-aware type converter is ready.
685struct UnrealizedConversionCastOpPattern
686 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
687 using OpConversionPattern<
688 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
689
690 mlir::LogicalResult
691 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
692 ConversionPatternRewriter &rewriter) const override {
693 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
694
695 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
696 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
697
698 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
699 !llvm::all_equal(ValueRange(inputs).getTypes()))
700 return failure();
701
702 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
703 // It is generated by source materialization (e.g., inits to scf forOp).
704 // The input values provided by the adaptor should already be distributed,
705 // and their types should correspond exactly to the result types of the
706 // operation.
707 if (op.getNumOperands() == 1 &&
708 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
709 rewriter.replaceOp(op, inputs);
710 return success();
711 }
712
713 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
714 // It is generated by target materialization (e.g., arguments/results
715 // of scf forOp). All input values must have the same vector type, and
716 // their shape must be evenly divisible by the output vector's shape
717 // (determined by the nature of the workgroup to subgroup distribution).
718 // TODO: it is not safe to do such forward, since such N:1 cast could be
719 // from others.
720 if (op.getNumResults() == 1 &&
721 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
722 rewriter.replaceOpWithMultiple(op, {inputs});
723 return success();
724 }
725
726 return mlir::failure();
727 }
728};
729
730// This pattern distributes arith.constant op into subgroup-level constants
731struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
732 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
733
734 LogicalResult
735 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const override {
737 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
738 auto vecType = dyn_cast<VectorType>(op.getType());
739 if (!vecAttr || !vecType)
740 return failure();
741
742 xegpu::DistributeLayoutAttr layout =
743 xegpu::getDistributeLayoutAttr(op.getResult());
744 if (!layout || !layout.isForWorkgroup())
745 return failure();
746
747 ArrayRef<int64_t> wgShape = vecType.getShape();
748 SmallVector<int64_t> sgShape;
749 int count;
750 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
751
752 auto newType = VectorType::get(sgShape, vecType.getElementType());
753 Location loc = op.getLoc();
754 auto eltType = vecType.getElementType();
755
756 auto setLayout = [&](Value val) {
757 xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
758 layout.dropSgLayoutAndData());
759 };
760
761 if (vecAttr.isSplat()) {
762 // Splat: single value for all subgroups
763 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
764 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
765 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
766 setLayout(cstOp->getResult(0));
767 rewriter.replaceOp(op, cstOp);
768 return success();
769 } else if (sgShape == wgShape) { // if the entire vector is shared by all
770 // subgroups, don't distribute
771 auto newConstOp =
772 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
773 setLayout(newConstOp->getResult(0));
774 rewriter.replaceOp(op, newConstOp);
775 return success();
776 } else {
777 // Non-splat constant
778 // Only supports 1D & 2D
779 // TODO: support other cases that require SLM access
780 if (!eltType.isIndex())
781 return rewriter.notifyMatchFailure(
782 op, "Unsupported element type for non-splat constant op.");
783
784 if (wgShape.size() > 2)
785 return rewriter.notifyMatchFailure(
786 op, "Only 1D & 2D vector constant supported");
787
788 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
789 int64_t rowStride = 0, colStride = 0;
790 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
791 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
792
793 // Compute colStride and rowStride, and check for constant strides.
794 if (cols > 1) {
795 colStride = cast<IntegerAttr>(values[1]).getInt() -
796 cast<IntegerAttr>(values[0]).getInt();
797 }
798 if (rows > 1) {
799 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
800 cast<IntegerAttr>(values[0]).getInt();
801 }
802
803 for (int64_t r = 0; r < rows; ++r) {
804 for (int64_t c = 0; c < cols; ++c) {
805 int64_t idx = r * cols + c;
806 // Check column stride
807 if (c > 0 && cols > 1) {
808 int64_t prevIdx = r * cols + (c - 1);
809 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
810 cast<IntegerAttr>(values[prevIdx]).getInt();
811 if (diff != colStride)
812 return rewriter.notifyMatchFailure(
813 op, "Non-constant column stride in constant op.");
814 }
815 // Check row stride
816 if (r > 0 && rows > 1) {
817 int64_t prevIdx = (r - 1) * cols + c;
818 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
819 cast<IntegerAttr>(values[prevIdx]).getInt();
820 if (diff != rowStride)
821 return rewriter.notifyMatchFailure(
822 op, "Non-constant row stride in constant op.");
823 }
824 }
825 }
826
827 // Create a constant for the base tile.
828 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
829 // For 1D case, extract the first sgShape[0] elements.
830 SmallVector<Attribute> baseTileValues;
831 int baseTileCols = sgShape[sgShape.size() - 1];
832 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
833 for (int64_t r = 0; r < baseTileRows; ++r) {
834 for (int64_t c = 0; c < baseTileCols; ++c) {
835 baseTileValues.push_back(values[r * cols + c]);
836 }
837 }
838
839 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
840 baseTileValues);
841 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
842
843 // Get subgroup id
844 Value sgId =
845 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
846 auto sgOffsets =
847 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
848 if (failed(sgOffsets))
849 return failure();
850
851 SmallVector<Value, 2> strideConsts;
852 strideConsts.push_back(
853 arith::ConstantIndexOp::create(rewriter, loc, colStride));
854 if (rows > 1)
855 strideConsts.insert(
856 strideConsts.begin(),
857 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
858
859 SmallVector<Value> newConstOps;
860 for (auto offsets : *sgOffsets) {
861 // Multiply offset with stride, broadcast it and add to baseConstVec
862 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
863 for (size_t i = 0; i < strideConsts.size(); ++i) {
864 Value mul =
865 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
866 offsets[i], strideConsts[i]);
867 mulOffset = arith::AddIOp::create(
868 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
869 }
870 // Broadcast to baseConstVec size
871 auto bcastOffset = vector::BroadcastOp::create(
872 rewriter, loc, baseConstVec.getType(), mulOffset);
873 auto finalConst =
874 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
875 setLayout(baseConstVec);
876 setLayout(bcastOffset);
877 setLayout(finalConst);
878 newConstOps.push_back(finalConst);
879 }
880 rewriter.replaceOpWithMultiple(op, {newConstOps});
881 return success();
882 }
883 }
884};
885
886// This pattern transforms the LoadGatherOp with explicit offsets to load
887// subgroup data
888struct WgToSgLoadGatherOpWithOffset
889 : public OpConversionPattern<xegpu::LoadGatherOp> {
890 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
891 LogicalResult
892 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter) const override {
894
895 if (!op.getOffsets())
896 return failure();
897
898 Location loc = op.getLoc();
899 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
900 if (!resultType)
901 return failure();
902 ArrayRef<int64_t> wgShape = resultType.getShape();
903
904 xegpu::DistributeLayoutAttr layout =
905 xegpu::getDistributeLayoutAttr(op.getResult());
906 if (!layout || !layout.isForWorkgroup())
907 return failure();
908
909 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
910
911 // The offsets need to be distributed
912 auto offsetsVecType =
913 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
914 auto maskVecType =
915 dyn_cast<VectorType>(adaptor.getMask().front().getType());
916 if (!offsetsVecType || !maskVecType ||
917 offsetsVecType.getShape() != maskVecType.getShape()) {
918 return rewriter.notifyMatchFailure(op,
919 "offsets have not been distributed");
920 }
921
922 SmallVector<Value> newLoadOps;
923 auto chunkSizeAttr =
924 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
925 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
926 for (auto [offsets, mask] :
927 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
928 auto newLayout = layout.dropSgLayoutAndData();
929 auto newLoadOp = xegpu::LoadGatherOp::create(
930 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
931 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
932 newLayout);
933 xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
934 newLoadOps.push_back(newLoadOp);
935 }
936 rewriter.replaceOpWithMultiple(op, {newLoadOps});
937 return success();
938 }
939};
940
941// This pattern transforms the StoreScatterOp with explicit offsets to store
942// subgroup data
943struct WgToSgStoreScatterOpWithOffset
944 : public OpConversionPattern<xegpu::StoreScatterOp> {
945 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
946 LogicalResult
947 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
948 ConversionPatternRewriter &rewriter) const override {
949
950 if (!op.getOffsets())
951 return failure();
952
953 Location loc = op.getLoc();
954 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
955 if (!valueType)
956 return failure();
957
958 xegpu::DistributeLayoutAttr layout =
959 xegpu::getDistributeLayoutAttr(op.getOperand(0));
960 if (!layout || !layout.isForWorkgroup())
961 return failure();
962
963 // The offsets need to be distributed
964 auto offsetsVecType =
965 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
966 auto maskVecType =
967 dyn_cast<VectorType>(adaptor.getMask().front().getType());
968 if (!offsetsVecType || !maskVecType ||
969 offsetsVecType.getShape() != maskVecType.getShape()) {
970 return rewriter.notifyMatchFailure(op,
971 "offsets have not been distributed");
972 }
973
974 auto chunkSizeOpt = op.getChunkSize();
975 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
976 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
977 for (auto [val, offs, mask] : llvm::zip(
978 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
979 auto store = xegpu::StoreScatterOp::create(
980 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
981 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
982 layout.dropSgLayoutAndData());
983 // Update the layout attribute to drop sg_layout and sg_data.
984 for (OpOperand &operand : store->getOpOperands()) {
985 // Skip for operand one (memref)
986 if (operand.getOperandNumber() == 1)
987 continue;
988 xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
989 }
990 }
991 rewriter.eraseOp(op);
992 return success();
993 }
994};
995
996struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
997 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
998 LogicalResult
999 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1000 ConversionPatternRewriter &rewriter) const override {
1001
1002 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1003 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1004 return failure();
1005
1006 ArrayRef<int64_t> wgShape = op.getDataShape();
1007 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1008 assert(valueTy && "the value type must be vector type!");
1009 Type elemTy = valueTy.getElementType();
1010
1011 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1012 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1013 VectorType newResTy = VectorType::get(sgShape, elemTy);
1014 SmallVector<Value> newOps;
1015 for (auto offsets : offsetsList) {
1016 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1017 op.getMemDesc(), offsets,
1018 layout.dropSgLayoutAndData());
1019 newOps.push_back(newOp);
1020 }
1021 rewriter.replaceOpWithMultiple(op, {newOps});
1022
1023 return success();
1024 }
1025};
1026
1027struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1028 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1029 LogicalResult
1030 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter) const override {
1032
1033 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1034 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1035 return failure();
1036
1037 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1038 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1039 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1040 offsets, layout.dropSgLayoutAndData());
1041 rewriter.eraseOp(op);
1042 return success();
1043 }
1044};
1045
1046// This pattern distributes the vector.step ops to work at subgroup level
1047struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1048 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1049 LogicalResult
1050 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1051 ConversionPatternRewriter &rewriter) const override {
1052 xegpu::DistributeLayoutAttr layout =
1053 xegpu::getDistributeLayoutAttr(op.getResult());
1054 if (!layout || !layout.isForWorkgroup())
1055 return failure();
1056
1057 Location loc = op.getLoc();
1058 VectorType type = op.getResult().getType();
1059 auto wgShape = type.getShape();
1060 std::optional<SmallVector<int64_t>> sgShape =
1061 getSgShapeAndCount(wgShape, layout).first;
1062 if (!sgShape)
1063 return failure();
1064
1065 Value sgId =
1066 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1067 auto sgOffsets =
1068 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1069 if (failed(sgOffsets))
1070 return failure();
1071
1072 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1073 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1074 SmallVector<Value> newOps;
1075 for (auto offsets : *sgOffsets) {
1076 // Broadcast the offset scalar to a vector & add to the base steps
1077 auto bcastOffset =
1078 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1079 auto finalSteps =
1080 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1081 xegpu::setDistributeLayoutAttr(steps->getResult(0),
1082 layout.dropSgLayoutAndData());
1083 xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
1084 layout.dropSgLayoutAndData());
1085 xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
1086 layout.dropSgLayoutAndData());
1087 newOps.push_back(finalSteps);
1088 }
1089
1090 rewriter.replaceOpWithMultiple(op, {newOps});
1091 return success();
1092 }
1093};
1094
1095// This pattern transforms vector.shape_cast ops to work at subgroup level.
1096struct WgToSgVectorShapeCastOp
1097 : public OpConversionPattern<vector::ShapeCastOp> {
1098 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1099
1100 LogicalResult
1101 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1102 ConversionPatternRewriter &rewriter) const override {
1103
1104 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1105 if (!resultType)
1106 return failure();
1107
1108 ArrayRef<int64_t> wgShape = resultType.getShape();
1109 xegpu::DistributeLayoutAttr layout =
1110 xegpu::getDistributeLayoutAttr(op.getResult());
1111 if (!layout || !layout.isForWorkgroup())
1112 return failure();
1113
1114 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1115 VectorType newResultType =
1116 VectorType::get(sgShape, resultType.getElementType());
1117
1118 // TODO: Add check for compatible layouts in layout attr.
1119 auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
1120 if (!srcType)
1121 return failure();
1122
1123 // Check that shape_cast only adds/removes unit dimensions,
1124 auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
1125 // Remove all 1s from both shapes and compare the rest.
1126 SmallVector<int64_t> srcNonUnit, dstNonUnit;
1127 for (int64_t d : src)
1128 if (d != 1)
1129 srcNonUnit.push_back(d);
1130 for (int64_t d : dst)
1131 if (d != 1)
1132 dstNonUnit.push_back(d);
1133 return srcNonUnit == dstNonUnit;
1134 };
1135
1136 if (!onlyUnitDims(srcType.getShape(), sgShape))
1137 return failure();
1138
1139 // For rank reducing or increasing shape_cast ops, the lower rank layout
1140 // must be a slice of higher rank layout.
1141 int64_t sourceRank = srcType.getRank();
1142 int64_t resultRank = sgShape.size();
1143 xegpu::DistributeLayoutAttr sourceLayout =
1144 xegpu::getDistributeLayoutAttr(op.getSource());
1145 if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1146 return failure();
1147 if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1148 return failure();
1149
1150 SmallVector<Value> newShapeCastOps;
1151 for (auto src : adaptor.getSource()) {
1152 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1153 newResultType, src);
1154 xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
1155 layout.dropSgLayoutAndData());
1156 newShapeCastOps.push_back(newShapeCast.getResult());
1157 }
1158
1159 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1160 return success();
1161 }
1162};
1163
1164/// Pattern for lowering vector.multi_reduction op to subgroup level.
1165/// Current limitation: the sg_layout in the reduced dimension being 1
1166/// so that reduction is local to subgroup & no cross-subgroup communication is
1167/// needed.
1168/// TODO: Add cases to handle more general situations which require SLM access.
1169struct WgToSgMultiDimReductionOp
1170 : public OpConversionPattern<vector::MultiDimReductionOp> {
1171 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1172
1173 LogicalResult
1174 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1175 ConversionPatternRewriter &rewriter) const override {
1176 VectorType srcType = op.getSourceVectorType();
1177 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1178 if (!dstType)
1179 return failure();
1180
1181 auto srcShape = srcType.getShape();
1182 xegpu::DistributeLayoutAttr layout =
1183 xegpu::getDistributeLayoutAttr(op.getResult());
1184 if (!layout || !layout.isForWorkgroup())
1185 return failure();
1186
1187 auto reductionDims = llvm::to_vector(op.getReductionDims());
1188
1189 SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1190 .getParent()
1191 .getEffectiveSgLayoutAsInt();
1192 SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1193 .getParent()
1194 .getEffectiveSgDataAsInt();
1195
1196 // Check that the sgLayout in the reduced dimension is 1 and
1197 // each sg gets the entire slice to reduce.
1198 for (int64_t dim : reductionDims) {
1199 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1200 return rewriter.notifyMatchFailure(
1201 op,
1202 "sgLayout in each reduced dimension must be 1 and sgData in the "
1203 "reduced dim must match srcShape in that dim");
1204 }
1205
1206 SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1207
1208 VectorType newDstType =
1209 VectorType::get({sgShape}, dstType.getElementType());
1210
1211 SmallVector<Value> newReductions;
1212 for (auto sgSrc : adaptor.getSource()) {
1213 auto newOp = vector::MultiDimReductionOp::create(
1214 rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1215 adaptor.getAcc()[0], op.getReductionDims());
1217 layout.dropSgLayoutAndData());
1218 newReductions.push_back(newOp.getResult());
1219 }
1220
1221 rewriter.replaceOpWithMultiple(op, {newReductions});
1222 return success();
1223 }
1224};
1225
1226// This pattern transforms vector.transpose ops to work at subgroup level.
1227struct WgToSgVectorTransposeOp
1228 : public OpConversionPattern<vector::TransposeOp> {
1229 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1230
1231 LogicalResult
1232 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1233 ConversionPatternRewriter &rewriter) const override {
1234 VectorType resultType = op.getResultVectorType();
1235
1236 ArrayRef<int64_t> wgShape = resultType.getShape();
1237 xegpu::DistributeLayoutAttr layout =
1238 xegpu::getDistributeLayoutAttr(op.getResult());
1239 if (!layout || !layout.isForWorkgroup())
1240 return failure();
1241
1242 xegpu::DistributeLayoutAttr sourceLayout =
1243 xegpu::getDistributeLayoutAttr(op.getVector());
1244 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1245 return failure();
1246
1247 SmallVector<int64_t> sourceSgLayout =
1248 sourceLayout.getEffectiveSgLayoutAsInt();
1249 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1250 DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
1251 DenseI32ArrayAttr resultOrder = layout.getOrder();
1252
1253 if (!sourceOrder || !resultOrder) {
1254 return rewriter.notifyMatchFailure(
1255 op, "Both source and result must have order attributes");
1256 }
1257
1258 ArrayRef<int64_t> permutation = op.getPermutation();
1259 size_t permutationSize = permutation.size();
1260 if (sourceSgLayout.size() != permutationSize ||
1261 resultSgLayout.size() != permutationSize) {
1262 return rewriter.notifyMatchFailure(
1263 op, "Layouts and permutation must have the same rank");
1264 }
1265
1266 // Check that sgLayout, sgData & order are properly transposed for source
1267 // and result
1268 if (!layout.isTransposeOf(sourceLayout, permutation))
1269 return rewriter.notifyMatchFailure(
1270 op, "Result layout is not a valid transpose of source layout "
1271 "according to permutation");
1272
1273 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1274 VectorType newResultType =
1275 VectorType::get(sgShape, resultType.getElementType());
1276 SmallVector<Value> newTransposeOps;
1277 for (auto src : adaptor.getVector()) {
1278 auto newTranspose = vector::TransposeOp::create(
1279 rewriter, op.getLoc(), newResultType, src, permutation);
1280 xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
1281 layout.dropSgLayoutAndData());
1282 newTransposeOps.push_back(newTranspose.getResult());
1283 }
1284
1285 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1286 return success();
1287 }
1288};
1289
1290// Distribute vector mask ops to work at subgroup level.
1291template <typename MaskOpType>
1292struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1293 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1294
1295 LogicalResult matchAndRewrite(
1296 MaskOpType op,
1297 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1298 ConversionPatternRewriter &rewriter) const override {
1299 xegpu::DistributeLayoutAttr layout =
1300 xegpu::getDistributeLayoutAttr(op.getResult());
1301 if (!layout || !layout.isForWorkgroup())
1302 return failure();
1303
1304 Location loc = op.getLoc();
1305 VectorType type = op.getResult().getType();
1306 auto wgShape = type.getShape();
1307
1308 SmallVector<Value> wgMaskDimSizes;
1309 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1310 for (int64_t maskSize : op.getMaskDimSizes()) {
1311 wgMaskDimSizes.push_back(
1312 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1313 }
1314 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1315 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1316 }
1317
1318 Value sgId =
1319 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1320 auto sgOffsets =
1321 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1322 if (failed(sgOffsets))
1323 return failure();
1324
1325 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1326 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1327
1328 // In each dimension, each subgroup computes its local mask size as:
1329 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1330 SmallVector<Value> newCreateMaskOps;
1331 for (auto offsetSet : *sgOffsets) {
1332 SmallVector<Value> maskOperands;
1333
1334 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1335 Value dimSizeVal =
1336 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1337 Value offset = offsetSet[i];
1338 Value adjustedMaskSize =
1339 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1340 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1341 Value nonNegative =
1342 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1343 Value sgMaskSize =
1344 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1345 maskOperands.push_back(sgMaskSize);
1346 }
1347
1348 auto newCreateMaskOp =
1349 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1350 xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
1351 layout.dropSgLayoutAndData());
1352 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1353 }
1354
1355 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1356 return success();
1357 }
1358};
1359
1360using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1361using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1362} // namespace
1363
1364namespace mlir {
1365namespace xegpu {
1367 patterns
1368 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1369 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1370 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1371 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1372 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1373 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1374 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1375 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1376 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1377 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1378 patterns.getContext());
1379}
1380} // namespace xegpu
1381} // namespace mlir
1382
1383namespace {
1384struct XeGPUWgToSgDistributePass
1385 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1386 void runOnOperation() override;
1387};
1388} // namespace
1389
1390void XeGPUWgToSgDistributePass::runOnOperation() {
1391 // Track existing UnrealizedConversionCastOps
1392 SmallVector<Operation *> existingCastOps;
1393 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1394 existingCastOps.push_back(castOp.getOperation());
1395 });
1396
1397 {
1398 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1399 // VectorType operands. This first converts such operands to
1400 // RankedTensorType, propagates the layout attribute into the encoding
1401 // attribute, and finally converts the RankedTensorType to VectorType based
1402 // on the encoding.
1403
1404 TypeConverter converter;
1405 converter.addConversion([&](Type type) -> Type { return type; });
1406 converter.addConversion(
1407 [&](RankedTensorType type,
1408 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1409 Type elemTy = type.getElementType();
1410 ArrayRef<int64_t> shape = type.getShape();
1411
1412 int count;
1413 SmallVector<int64_t> subShape;
1414 std::tie(subShape, count) = getSgShapeAndCount(
1415 shape,
1416 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1417
1418 auto newTy = VectorType::get(subShape, elemTy);
1419 result.append(count, newTy);
1420 return success();
1421 });
1422
1424 converter);
1425 }
1426
1427 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1428 // as well as XeGPU, Arith, and Vector operations.
1429 MLIRContext *ctx = &getContext();
1430 RewritePatternSet patterns(ctx);
1431 ConversionTarget target(*ctx);
1432 TypeConverter converter;
1433 converter.addConversion([&](Type type) -> Type { return type; });
1434 converter.addConversion(
1435 [&](xegpu::TensorDescType type,
1436 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1437 Type elemTy = type.getElementType();
1438 ArrayRef<int64_t> shape = type.getShape();
1439
1440 int count;
1441 SmallVector<int64_t> subShape;
1442 xegpu::LayoutAttr layout = type.getLayoutAttr();
1443 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1444
1445 if (layout)
1446 layout = layout.dropSgLayoutAndData();
1447
1448 auto newTy = xegpu::TensorDescType::get(
1449 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1450 result.append(count, newTy);
1451 return success();
1452 });
1453
1454 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1455 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1456 return createOp.getType();
1457 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1458 return loadOp.getTensorDescType();
1459 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1460 return storeOp.getTensorDescType();
1461 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1462 return updateOp.getType();
1463 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1464 return prefetchOp.getTensorDescType();
1465 return xegpu::TensorDescType();
1466 };
1467
1468 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1469 return !layout || !layout.isForWorkgroup();
1470 };
1471
1472 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1473 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1474 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1475 auto tdescTy = getTensorDescType(op);
1476 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1477 return isLegal(layout);
1478 });
1479
1480 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1481 auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1482 return isLegal(layout);
1483 });
1484
1485 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1486 [=](xegpu::LoadMatrixOp op) -> bool {
1487 return isLegal(op.getLayoutAttr());
1488 });
1489
1490 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1491 [=](xegpu::StoreMatrixOp op) -> bool {
1492 return isLegal(op.getLayoutAttr());
1493 });
1494
1495 target.addDynamicallyLegalOp<arith::ConstantOp>(
1496 [=](arith::ConstantOp op) -> bool {
1497 auto vecType = dyn_cast<VectorType>(op.getType());
1498 if (!vecType)
1499 return true;
1500
1501 auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1502 return isLegal(layout);
1503 });
1504
1505 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1506 vector::TransposeOp, vector::BroadcastOp,
1507 vector::MultiDimReductionOp,
1508 vector::ConstantMaskOp, vector::CreateMaskOp>(
1509 [=](Operation *op) -> bool {
1510 // Check for either a SliceAttr or LayoutAttr on the result.
1511 auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1512 return isLegal(layout);
1513 });
1514
1515 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1516 [=](xegpu::LoadGatherOp op) -> bool {
1517 auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1518 return isLegal(layout);
1519 });
1520
1521 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1522 [=](xegpu::StoreScatterOp op) -> bool {
1523 auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
1524 return isLegal(layout);
1525 });
1526
1527 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1528 [=](xegpu::ConvertLayoutOp op) -> bool {
1529 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1530 });
1531
1532 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1533 [=](Operation *op) -> std::optional<bool> {
1534 // Only handle elementwise mappable ops
1536 return true;
1537
1538 VectorType resultType =
1539 dyn_cast<VectorType>(op->getResult(0).getType());
1540 if (!resultType)
1541 return true;
1542
1543 // Check if all operands are vectors of the same shape
1544 // TODO: Support other types.
1545 for (Value operand : op->getOperands()) {
1546 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1547 if (!operandType || operandType.getShape() != resultType.getShape()) {
1548 return true;
1549 }
1550 }
1551
1552 xegpu::DistributeLayoutAttr layout =
1553 xegpu::getDistributeLayoutAttr(op->getResult(0));
1554 return isLegal(layout);
1555 });
1556
1557 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1558 [=](UnrealizedConversionCastOp op) {
1559 return llvm::is_contained(existingCastOps, op.getOperation());
1560 });
1561
1562 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1563
1565 target);
1567 if (failed(
1568 applyPartialConversion(getOperation(), target, std::move(patterns))))
1569 return signalPassFailure();
1570
1571 // Remove sg_layout and sg_data attributes from the Layout
1572 // attribute for each VectorType result of the operation.
1573 // For Structured Control Flow ops, the layout is simply removed,
1574 // since in 1:N case, the layout for new results are missing.
1575 // Layout propagation pass will activated.
1576 getOperation()->walk([](Operation *op) {
1577 for (OpResult result : op->getOpResults()) {
1578 std::string name = xegpu::getLayoutName(result);
1579 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1580 op->removeAttr(name);
1581 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1582 if (auto newLayout = layout.dropSgLayoutAndData())
1583 op->setAttr(name, newLayout);
1584 }
1585 }
1586 }
1587 });
1588}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_range getOpResults()
Definition Operation.h:420
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.