MLIR 22.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
24#include <optional>
25
26namespace mlir {
27namespace xegpu {
28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30} // namespace xegpu
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36
37// Retrieve the RangeAttr if it is specified.
38static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39 Operation *parent = op->getParentOfType<scf::IfOp>();
40 while (parent) {
41 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->getAttr("sg_id_range")))
43 return attr;
44 parent = parent->getParentOfType<scf::IfOp>();
45 }
46 return {};
47}
48
49static std::pair<SmallVector<int64_t>, int>
50getSgShapeAndCount(ArrayRef<int64_t> shape,
51 xegpu::DistributeLayoutAttr layout) {
52 int count = 1;
54 if (layout && layout.isForWorkgroup()) {
55 SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
58 else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
59 sgShape = *maybeDerivedSgData;
60 SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
61 // Clamp distUnit to the original shape to handle cases where data is
62 // shared among subgroups, which may cause distUnit to exceed the original
63 // shape.
64 for (size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(shape[i], distUnit[i]);
66 count = computeProduct(shape) / computeProduct(distUnit);
67 }
68 return std::make_pair(sgShape, count);
69}
70
71/// Utility helper for deriving a list of offsets for each sub-TensorDescs
72/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
73/// associated distribute layout attribute, the shape, subgroup id and the
74/// original offsets of the op
75template <
76 typename OpType,
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
80static LogicalResult
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
83 Location loc = op.getLoc();
84 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
85 // not applicable to ops without offsets operands.
86 if (origOffsets.empty())
87 return failure();
88
89 // not applicable to ops without workgroup layout attributes
90 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
91 if (!layout || !layout.isForWorkgroup())
92 return failure();
93
94 Value sgId =
95 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
96
97 // verify and adjust the sgId if the range specifier is present
98 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
99 if (sgIdRange) {
100 int64_t startOfRange = sgIdRange.getStart().getInt();
101 int64_t endOfRange = sgIdRange.getEnd().getInt();
102 // verify the RangeAttr against the layout attribute
103 if (layout.getNumSubgroups() != endOfRange - startOfRange)
104 return rewriter.notifyMatchFailure(
105 op, "sg_layout size must match the sg_id_range");
106 // adjust the sgId if necessary
107 if (startOfRange > 0) {
108 Value startOfRangeVal =
109 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
110 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
111 }
112 }
113
114 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
115 // descriptors to be accessed, based on the layout information.
116 ArrayRef<int64_t> wgShape = op.getDataShape();
117 auto maybeDescOffsets =
118 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
119 if (failed(maybeDescOffsets))
120 return failure();
121
122 // Compute the final global offsets for each accessed sub-tensor
123 // or sub-memory descriptor.
124 for (const auto &sgOffsets : *maybeDescOffsets) {
126 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
127 offsetsList.push_back(std::move(newOffsets));
128 }
129
130 // callback(offsetsList);
131 return success();
132}
133
134/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
135/// from a workgroup descriptor. It replaces the offsets and sizes with
136/// appropriate values for the subgroup.
137/// It uses round-robin assignment to distribute the work to the subgroups.
138/// Following create_nd_desc operation:,
139/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
140/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
141/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
142/// is converted to 9 subgroup level operations based on the sg_layout &
143/// sg_data:
144/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
145/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
146/// lane_data = [1, 1]>>
147///
148/// The sg_layout and sg_data attributes are dropped after the pass as they are
149/// no longer needed.
150///
151/// 24x24 matrix distribution example:
152/// sg_layout = [4, 4], sg_data = [2, 2]
153/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
154/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
155///
156/// +------------------------+
157/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
158/// |-----+-----+-----|
159/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
160/// |-----+-----+-----|
161/// | 8x8 | 8x8 | 8x8 |
162/// +------------------------+
163///
164/// Each 8x8 tile is further subdivided among subgroups:
165/// +------------------------+
166/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
167/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
168/// | 2x2 2x2 2x2 2x2 |
169/// | 2x2 2x2 2x2 2x2 |
170/// +------------------------+
171///
172/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
173/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
174
175/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
176/// pattern and all the other ops just follow.
177/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
178/// ops in the pass.
179struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
180 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
181
182 LogicalResult
183 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185 SmallVector<SmallVector<OpFoldResult>> offsetsList;
186 if (failed(genOffsetsList(rewriter, op, offsetsList)))
187 return failure();
188
189 MLIRContext *ctx = op.getContext();
190 xegpu::TensorDescType tdescTy = op.getType();
191 ArrayRef<int64_t> wgShape = tdescTy.getShape();
192 Type elemTy = tdescTy.getElementType();
193 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
194 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
195 auto newTdescTy =
196 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
197 layout.dropSgLayoutAndData());
198
199 SmallVector<Value> newOps;
200 for (auto offsets : offsetsList) {
201 auto newOp = xegpu::CreateNdDescOp::create(
202 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
203 op.getMixedSizes(), op.getMixedStrides());
204
205 newOps.push_back(newOp);
206 }
207 rewriter.replaceOpWithMultiple(op, {newOps});
208
209 return success();
210 }
211};
212
213// This pattern transforms the CreateNdDescOp without offsets to create a
214// subgroup descriptor from a workgroup descriptor
215struct WgToSgCreateNdOpNoOffset
216 : public OpConversionPattern<xegpu::CreateNdDescOp> {
217 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
218
219 LogicalResult
220 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
221 ConversionPatternRewriter &rewriter) const override {
222
223 // Check no offsets are specified.
224 if (!op.getMixedOffsets().empty())
225 return failure();
226
227 Location loc = op.getLoc();
228 MLIRContext *ctx = op.getContext();
229 xegpu::TensorDescType tdescTy = op.getType();
230 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
231 if (!layout || !layout.isForWorkgroup())
232 return failure();
233
234 Type elemTy = tdescTy.getElementType();
235 ArrayRef<int64_t> wgShape = tdescTy.getShape();
236
237 SmallVector<int64_t> sgShape;
238 int count;
239 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
240 xegpu::TensorDescType newTdescTy =
241 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
242 layout.dropSgLayoutAndData());
243
244 SmallVector<Value> newCreateNdOps(count);
245 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
246 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
247 op.getSource(), op.getMixedSizes(),
248 op.getMixedStrides());
249 });
250
251 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
252 return success();
253 }
254};
255
256/// This pattern transforms the LoadNdOp to load subgroup data.
257struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
258 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
259 LogicalResult
260 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262 if (!op.getMixedOffsets().empty())
263 return failure();
264
265 SmallVector<Value> newLoadOps;
266 for (auto src : adaptor.getTensorDesc()) {
267 xegpu::TensorDescType tdescTy =
268 dyn_cast<xegpu::TensorDescType>(src.getType());
269 ArrayRef<int64_t> srcShape = tdescTy.getShape();
270 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
271 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
272 src, op->getAttrs());
273 newLoadOps.push_back(newLoadOp);
274 }
275 rewriter.replaceOpWithMultiple(op, {newLoadOps});
276 return mlir::success();
277 }
278};
279
280/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
281/// It creates a StoreNdOp op to store the updated values to the new subgroup
282/// src tensor descriptors.
283struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
284 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
285 LogicalResult
286 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter) const override {
288 if (!op.getMixedOffsets().empty())
289 return failure();
290
291 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
292 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
293 op.getL2HintAttr(), op.getL3HintAttr());
294
295 rewriter.eraseOp(op);
296 return success();
297 }
298};
299
300// This pattern transforms the LoadNdOp with explicit offsets to load
301// subgroup data.
302struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
303 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
304 LogicalResult
305 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter) const override {
307
308 SmallVector<SmallVector<OpFoldResult>> offsetsList;
309 if (failed(genOffsetsList(rewriter, op, offsetsList)))
310 return failure();
311
312 SmallVector<Value> newOps;
313 for (auto [tdesc, offsets] :
314 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
315 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
316 VectorType newResTy =
317 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
318 auto newOp = xegpu::LoadNdOp::create(
319 rewriter, op.getLoc(), newResTy, tdesc, offsets,
320 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
321 op.getL2HintAttr(), op.getL3HintAttr());
322 newOps.push_back(newOp);
323 }
324 rewriter.replaceOpWithMultiple(op, {newOps});
325
326 return success();
327 }
328};
329
330// This pattern transforms the StoreNdOp with explicit offsets to store
331// subgroup data.
332struct WgToSgStoreNdOpWithOffset
333 : public OpConversionPattern<xegpu::StoreNdOp> {
334 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
335 LogicalResult
336 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
337 ConversionPatternRewriter &rewriter) const override {
338 SmallVector<SmallVector<OpFoldResult>> offsetsList;
339 if (failed(genOffsetsList(rewriter, op, offsetsList)))
340 return failure();
341
342 for (auto [v, tdesc, offsets] :
343 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
344 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
345 op.getL1HintAttr(), op.getL2HintAttr(),
346 op.getL3HintAttr());
347 }
348 rewriter.eraseOp(op);
349
350 return success();
351 }
352};
353
354// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
355// subgroup data.
356struct WgToSgPrefetchNdOpWithOffset
357 : public OpConversionPattern<xegpu::PrefetchNdOp> {
358 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
359 LogicalResult
360 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
361 ConversionPatternRewriter &rewriter) const override {
362 SmallVector<SmallVector<OpFoldResult>> offsetsList;
363 if (failed(genOffsetsList(rewriter, op, offsetsList)))
364 return failure();
365
366 for (auto [tdesc, offsets] :
367 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
368 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
369 op.getL1HintAttr(), op.getL2HintAttr(),
370 op.getL3HintAttr());
371 }
372 rewriter.eraseOp(op);
373
374 return success();
375 }
376};
377
378/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
379/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
380/// offsets of the new subgroup src tensor descriptors.
381struct WgToSgUpdateNdOffsetOp
382 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
383 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
384 LogicalResult
385 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
386 ConversionPatternRewriter &rewriter) const override {
387 llvm::SmallVector<Value> newUpdateTileOffsetOps;
388 for (auto tDesc : adaptor.getTensorDesc()) {
389 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
390 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
391 op.getConstOffsets());
392 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
393 }
394
395 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
396 return success();
397 }
398};
399
400/// This pattern transforms the DpasOp to work at subgroup level.
401struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
402 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
403 LogicalResult
404 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override {
406 Location loc = op.getLoc();
407 VectorType resultTy = op.getResult().getType();
408 if (resultTy.getRank() != 2)
409 return failure();
410
411 auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
412 if (!originalLayout)
413 return failure();
414
415 size_t i = 0;
416 SmallVector<Value> newDpasOps;
417 for (auto aVec : adaptor.getLhs()) {
418 for (auto bVec : adaptor.getRhs()) {
419
420 llvm::SmallVector<Value> operands({aVec, bVec});
421 Value tmpC;
422 if (op.getAcc()) {
423 tmpC = adaptor.getAcc()[i++];
424 operands.push_back(tmpC);
425 }
426
427 ArrayRef<int64_t> aVecShape =
428 llvm::cast<VectorType>(aVec.getType()).getShape();
429 ArrayRef<int64_t> bVecShape =
430 llvm::cast<VectorType>(bVec.getType()).getShape();
431 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
432 resultTy.getElementType());
433 tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
434 xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
435 originalLayout.dropSgLayoutAndData());
436
437 newDpasOps.push_back(tmpC);
438 }
439 }
440 rewriter.replaceOpWithMultiple(op, {newDpasOps});
441 return success();
442 }
443};
444
445/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
446struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
447 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
448 LogicalResult
449 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
450 ConversionPatternRewriter &rewriter) const override {
451
452 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
453 if ((offsetSize != 0) || op.getConstOffsetsAttr())
454 return failure();
455
456 for (auto src : adaptor.getTensorDesc())
457 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src,
458 op->getAttrs());
459 rewriter.eraseOp(op);
460 return success();
461 }
462};
463
464/// This pattern transforms vector.broadcast ops to work at subgroup level.
465struct WgToSgVectorBroadcastOp
466 : public OpConversionPattern<vector::BroadcastOp> {
467 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
468
469 LogicalResult
470 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter) const override {
472
473 VectorType resultType = op.getResult().getType();
474 ArrayRef<int64_t> wgShape = resultType.getShape();
475
476 xegpu::DistributeLayoutAttr layout =
477 xegpu::getDistributeLayoutAttr(op.getResult());
478 if (!layout || !layout.isForWorkgroup())
479 return failure();
480
481 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
482 VectorType newResultType =
483 VectorType::get(sgShape, resultType.getElementType());
484
485 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
486 return failure();
487
488 SmallVector<Value> newBroadcastOps;
489 for (auto operand : adaptor.getOperands().front()) {
490 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
491 newResultType, operand);
492 xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
493 layout.dropSgLayoutAndData());
494
495 newBroadcastOps.push_back(newBroadcast.getResult());
496 }
497 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
498 return success();
499 }
500};
501
502// This pattern transforms elementwise ops to work at subgroup level.
503struct WgToSgElementwiseOp : public ConversionPattern {
504 WgToSgElementwiseOp(MLIRContext *ctx)
505 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
506
507 LogicalResult
508 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
509 ConversionPatternRewriter &rewriter) const override {
510 // Only match ops with elementwise trait and single result.
512 return failure();
513
514 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
515 assert(resultType && "Expected result to be a VectorType");
516
517 ArrayRef<int64_t> wgShape = resultType.getShape();
518
519 xegpu::DistributeLayoutAttr layout =
521 if (!layout || !layout.isForWorkgroup())
522 return failure();
524 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
526 size_t numVariants = operands.empty() ? 0 : operands.front().size();
528 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
529 return operandVec.size() != numVariants;
530 }))
531 return failure();
532
533 SmallVector<Value> newResults;
534 VectorType newResultType =
535 VectorType::get(sgShape, resultType.getElementType());
537 for (size_t i = 0; i < numVariants; ++i) {
538 SmallVector<Value> opOperands;
539 for (auto &operandVec : operands)
540 opOperands.push_back(operandVec[i]);
541
542 OperationState state(op->getLoc(), op->getName());
543 state.addOperands(opOperands);
544 state.addTypes(newResultType);
545 // Copy all attributes, but update "layout_result_0" to drop
546 // sgLayout/sgData
547 for (auto attr : op->getAttrs()) {
548 if (auto layout =
549 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
550 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
551 !layout.getEffectiveInstDataAsInt().empty())
552 state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
553 } else {
554 state.addAttribute(attr.getName(), attr.getValue());
556 }
557 Operation *newOp = rewriter.create(state);
558 newResults.push_back(newOp->getResult(0));
559 }
560
561 rewriter.replaceOpWithMultiple(op, {newResults});
562 return success();
563 }
564};
565
566// clang-format off
567// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
568// If input_layout and target_layout have identical sg_layout and sg_data,
569// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
570// dropped. For example:
571// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
572// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
573// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
574// becomes:
575// #a = #xegpu.layout<inst_data = [16, 16]>
576// #b = #xegpu.layout<inst_data = [8, 16]>
577// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
578// (vector<16x16xf32> is determined by sg_data = [16, 16])
579//
580// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
581// For example:
582// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
583// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
584// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
585// is lowered to:
586// #a = #xegpu.layout<inst_data = [16, 16]>
587// #b = #xegpu.layout<inst_data = [8, 16]>
588// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
589// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
590// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
591// clang-format on
592struct WgToSgConvertLayoutOp
593 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
594 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
595 LogicalResult
596 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
597 ConversionPatternRewriter &rewriter) const override {
598 // TODO: currently, we only support LayoutAttr
599 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
600 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
601
602 if (!input || !target || !input.isForWorkgroup() ||
603 !target.isForWorkgroup())
604 return rewriter.notifyMatchFailure(
605 op, "Input and target layouts must have subgroup layout");
606
607 DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
608 DenseI32ArrayAttr inputSgData = input.getSgData();
609 DenseI32ArrayAttr inputOrder = input.getOrder();
610 DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
611 DenseI32ArrayAttr targetSgData = target.getSgData();
612 DenseI32ArrayAttr targetOrder = target.getOrder();
613
614 // TODO: currently we only support for optimal case, where input and
615 // output has the same sg_layout and sg_data, so SLM is not involved.
616 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
617 inputOrder != targetOrder)
618 return failure();
619
620 input = input.dropSgLayoutAndData();
621 target = target.dropSgLayoutAndData();
622
623 SmallVector<Value> newOps(adaptor.getSource());
624 if (input && target) {
625 // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
626 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
627 auto newOp = xegpu::ConvertLayoutOp::create(
628 rewriter, op.getLoc(), src.getType(), src, input, target);
629 newOps[i] = newOp;
630 }
631 }
632 rewriter.replaceOpWithMultiple(op, {newOps});
633 return success();
634 }
635};
636
637// Handles UnrealizedConversionCastOp generated during
638// SCFStructuralTypeConversions (step 1). This op may appear as either a
639// target or source materialization for Vector values, e.g.:
640// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
641// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
642// it could be either 1:N or N:1 cast. In both cases, the pattern
643// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
644// for example, the following scf::forOp
645// ```
646// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
647// %n = use(%arg1): vector<128x128xf16>
648// scf.yield %n : vector<128x128xf16>
649// }
650// ```
651// Could be converted to:
652// ```
653// %1 = unrealized_conversion_cast %0
654// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
655// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
656// -> (vector<16x16xf16>, vector<16x16xf16) {
657// %m = unrealized_conversion_cast %arg1, %arg2
658// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
659// %n = use(%m): vector<128x128xf16>
660// %b = unrealized_conversion_cast %n
661// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
662// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
663// }
664// %cast = unrealized_conversion_cast %for:2
665// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
666// ```
667// TODO: remove it when context-aware type converter is ready.
668struct UnrealizedConversionCastOpPattern
669 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
670 using OpConversionPattern<
671 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
672
673 mlir::LogicalResult
674 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter) const override {
676 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
677
678 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
679 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
680
681 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
682 !llvm::all_equal(ValueRange(inputs).getTypes()))
683 return failure();
684
685 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
686 // It is generated by source materialization (e.g., inits to scf forOp).
687 // The input values provided by the adaptor should already be distributed,
688 // and their types should correspond exactly to the result types of the
689 // operation.
690 if (op.getNumOperands() == 1 &&
691 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
692 rewriter.replaceOp(op, inputs);
693 return success();
694 }
695
696 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
697 // It is generated by target materialization (e.g., arguments/results
698 // of scf forOp). All input values must have the same vector type, and
699 // their shape must be evenly divisible by the output vector's shape
700 // (determined by the nature of the workgroup to subgroup distribution).
701 // TODO: it is not safe to do such forward, since such N:1 cast could be
702 // from others.
703 if (op.getNumResults() == 1 &&
704 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
705 rewriter.replaceOpWithMultiple(op, {inputs});
706 return success();
707 }
708
709 return mlir::failure();
710 }
711};
712
713// This pattern distributes arith.constant op into subgroup-level constants
714struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
715 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
716
717 LogicalResult
718 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
719 ConversionPatternRewriter &rewriter) const override {
720 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
721 auto vecType = dyn_cast<VectorType>(op.getType());
722 if (!vecAttr || !vecType)
723 return failure();
724
725 xegpu::DistributeLayoutAttr layout =
726 xegpu::getDistributeLayoutAttr(op.getResult());
727 if (!layout || !layout.isForWorkgroup())
728 return failure();
729
730 ArrayRef<int64_t> wgShape = vecType.getShape();
731 SmallVector<int64_t> sgShape;
732 int count;
733 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
734
735 auto newType = VectorType::get(sgShape, vecType.getElementType());
736 Location loc = op.getLoc();
737 auto eltType = vecType.getElementType();
738
739 auto setLayout = [&](Value val) {
740 xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
741 layout.dropSgLayoutAndData());
742 };
743
744 if (vecAttr.isSplat()) {
745 // Splat: single value for all subgroups
746 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
747 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
748 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
749 setLayout(cstOp->getResult(0));
750 rewriter.replaceOp(op, cstOp);
751 return success();
752 } else if (sgShape == wgShape) { // if the entire vector is shared by all
753 // subgroups, don't distribute
754 auto newConstOp =
755 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
756 setLayout(newConstOp->getResult(0));
757 rewriter.replaceOp(op, newConstOp);
758 return success();
759 } else {
760 // Non-splat constant
761 // Only supports 1D & 2D
762 // TODO: support other cases that require SLM access
763 if (!eltType.isIndex())
764 return rewriter.notifyMatchFailure(
765 op, "Unsupported element type for non-splat constant op.");
766
767 if (wgShape.size() > 2)
768 return rewriter.notifyMatchFailure(
769 op, "Only 1D & 2D vector constant supported");
770
771 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
772 int64_t rowStride = 0, colStride = 0;
773 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
774 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
775
776 // Compute colStride and rowStride, and check for constant strides.
777 if (cols > 1) {
778 colStride = cast<IntegerAttr>(values[1]).getInt() -
779 cast<IntegerAttr>(values[0]).getInt();
780 }
781 if (rows > 1) {
782 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
783 cast<IntegerAttr>(values[0]).getInt();
784 }
785
786 for (int64_t r = 0; r < rows; ++r) {
787 for (int64_t c = 0; c < cols; ++c) {
788 int64_t idx = r * cols + c;
789 // Check column stride
790 if (c > 0 && cols > 1) {
791 int64_t prevIdx = r * cols + (c - 1);
792 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
793 cast<IntegerAttr>(values[prevIdx]).getInt();
794 if (diff != colStride)
795 return rewriter.notifyMatchFailure(
796 op, "Non-constant column stride in constant op.");
797 }
798 // Check row stride
799 if (r > 0 && rows > 1) {
800 int64_t prevIdx = (r - 1) * cols + c;
801 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
802 cast<IntegerAttr>(values[prevIdx]).getInt();
803 if (diff != rowStride)
804 return rewriter.notifyMatchFailure(
805 op, "Non-constant row stride in constant op.");
806 }
807 }
808 }
809
810 // Create a constant for the base tile.
811 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
812 // For 1D case, extract the first sgShape[0] elements.
813 SmallVector<Attribute> baseTileValues;
814 int baseTileCols = sgShape[sgShape.size() - 1];
815 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
816 for (int64_t r = 0; r < baseTileRows; ++r) {
817 for (int64_t c = 0; c < baseTileCols; ++c) {
818 baseTileValues.push_back(values[r * cols + c]);
819 }
820 }
821
822 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
823 baseTileValues);
824 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
825
826 // Get subgroup id
827 Value sgId =
828 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
829 auto sgOffsets =
830 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
831 if (failed(sgOffsets))
832 return failure();
833
834 SmallVector<Value, 2> strideConsts;
835 strideConsts.push_back(
836 arith::ConstantIndexOp::create(rewriter, loc, colStride));
837 if (rows > 1)
838 strideConsts.insert(
839 strideConsts.begin(),
840 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
841
842 SmallVector<Value> newConstOps;
843 for (auto offsets : *sgOffsets) {
844 // Multiply offset with stride, broadcast it and add to baseConstVec
845 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
846 for (size_t i = 0; i < strideConsts.size(); ++i) {
847 Value mul =
848 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
849 offsets[i], strideConsts[i]);
850 mulOffset = arith::AddIOp::create(
851 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
852 }
853 // Broadcast to baseConstVec size
854 auto bcastOffset = vector::BroadcastOp::create(
855 rewriter, loc, baseConstVec.getType(), mulOffset);
856 auto finalConst =
857 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
858 setLayout(baseConstVec);
859 setLayout(bcastOffset);
860 setLayout(finalConst);
861 newConstOps.push_back(finalConst);
862 }
863 rewriter.replaceOpWithMultiple(op, {newConstOps});
864 return success();
865 }
866 }
867};
868
869// This pattern transforms the LoadGatherOp with explicit offsets to load
870// subgroup data
871struct WgToSgLoadGatherOpWithOffset
872 : public OpConversionPattern<xegpu::LoadGatherOp> {
873 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
874 LogicalResult
875 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
876 ConversionPatternRewriter &rewriter) const override {
877
878 if (!op.getOffsets())
879 return failure();
880
881 Location loc = op.getLoc();
882 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
883 if (!resultType)
884 return failure();
885 ArrayRef<int64_t> wgShape = resultType.getShape();
886
887 xegpu::DistributeLayoutAttr layout =
888 xegpu::getDistributeLayoutAttr(op.getResult());
889 if (!layout || !layout.isForWorkgroup())
890 return failure();
891
892 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
893
894 // The offsets need to be distributed
895 auto offsetsVecType =
896 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
897 auto maskVecType =
898 dyn_cast<VectorType>(adaptor.getMask().front().getType());
899 if (!offsetsVecType || !maskVecType ||
900 offsetsVecType.getShape() != maskVecType.getShape()) {
901 return rewriter.notifyMatchFailure(op,
902 "offsets have not been distributed");
903 }
904
905 SmallVector<Value> newLoadOps;
906 auto chunkSizeAttr =
907 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
908 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
909 for (auto [offsets, mask] :
910 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
911 auto newLayout = layout.dropSgLayoutAndData();
912 auto newLoadOp = xegpu::LoadGatherOp::create(
913 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
914 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
915 newLayout);
916 xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
917 newLoadOps.push_back(newLoadOp);
918 }
919 rewriter.replaceOpWithMultiple(op, {newLoadOps});
920 return success();
921 }
922};
923
924// This pattern transforms the StoreScatterOp with explicit offsets to store
925// subgroup data
926struct WgToSgStoreScatterOpWithOffset
927 : public OpConversionPattern<xegpu::StoreScatterOp> {
928 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
929 LogicalResult
930 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
931 ConversionPatternRewriter &rewriter) const override {
932
933 if (!op.getOffsets())
934 return failure();
935
936 Location loc = op.getLoc();
937 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
938 if (!valueType)
939 return failure();
940
941 xegpu::DistributeLayoutAttr layout =
942 xegpu::getDistributeLayoutAttr(op.getOperand(0));
943 if (!layout || !layout.isForWorkgroup())
944 return failure();
945
946 // The offsets need to be distributed
947 auto offsetsVecType =
948 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
949 auto maskVecType =
950 dyn_cast<VectorType>(adaptor.getMask().front().getType());
951 if (!offsetsVecType || !maskVecType ||
952 offsetsVecType.getShape() != maskVecType.getShape()) {
953 return rewriter.notifyMatchFailure(op,
954 "offsets have not been distributed");
955 }
956
957 auto chunkSizeOpt = op.getChunkSize();
958 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
959 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
960 for (auto [val, offs, mask] : llvm::zip(
961 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
962 auto store = xegpu::StoreScatterOp::create(
963 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
964 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
965 layout.dropSgLayoutAndData());
966 // Update the layout attribute to drop sg_layout and sg_data.
967 for (OpOperand &operand : store->getOpOperands()) {
968 // Skip for operand one (memref)
969 if (operand.getOperandNumber() == 1)
970 continue;
971 xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
972 }
973 }
974 rewriter.eraseOp(op);
975 return success();
976 }
977};
978
979struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
980 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
981 LogicalResult
982 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
983 ConversionPatternRewriter &rewriter) const override {
984
985 SmallVector<SmallVector<OpFoldResult>> offsetsList;
986 if (failed(genOffsetsList(rewriter, op, offsetsList)))
987 return failure();
988
989 ArrayRef<int64_t> wgShape = op.getDataShape();
990 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
991 assert(valueTy && "the value type must be vector type!");
992 Type elemTy = valueTy.getElementType();
993
994 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
995 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
996 VectorType newResTy = VectorType::get(sgShape, elemTy);
997 SmallVector<Value> newOps;
998 for (auto offsets : offsetsList) {
999 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1000 op.getMemDesc(), offsets,
1001 layout.dropSgLayoutAndData());
1002 newOps.push_back(newOp);
1003 }
1004 rewriter.replaceOpWithMultiple(op, {newOps});
1005
1006 return success();
1007 }
1008};
1009
1010struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1011 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1012 LogicalResult
1013 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1014 ConversionPatternRewriter &rewriter) const override {
1015
1016 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1017 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1018 return failure();
1019
1020 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1021 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1022 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1023 offsets, layout.dropSgLayoutAndData());
1024 rewriter.eraseOp(op);
1025 return success();
1026 }
1027};
1028
1029// This pattern distributes the vector.step ops to work at subgroup level
1030struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1031 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1032 LogicalResult
1033 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1034 ConversionPatternRewriter &rewriter) const override {
1035 xegpu::DistributeLayoutAttr layout =
1036 xegpu::getDistributeLayoutAttr(op.getResult());
1037 if (!layout || !layout.isForWorkgroup())
1038 return failure();
1039
1040 Location loc = op.getLoc();
1041 VectorType type = op.getResult().getType();
1042 auto wgShape = type.getShape();
1043 std::optional<SmallVector<int64_t>> sgShape =
1044 getSgShapeAndCount(wgShape, layout).first;
1045 if (!sgShape)
1046 return failure();
1047
1048 Value sgId =
1049 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1050 auto sgOffsets =
1051 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1052 if (failed(sgOffsets))
1053 return failure();
1054
1055 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1056 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1057 SmallVector<Value> newOps;
1058 for (auto offsets : *sgOffsets) {
1059 // Broadcast the offset scalar to a vector & add to the base steps
1060 auto bcastOffset =
1061 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1062 auto finalSteps =
1063 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1064 xegpu::setDistributeLayoutAttr(steps->getResult(0),
1065 layout.dropSgLayoutAndData());
1066 xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
1067 layout.dropSgLayoutAndData());
1068 xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
1069 layout.dropSgLayoutAndData());
1070 newOps.push_back(finalSteps);
1071 }
1072
1073 rewriter.replaceOpWithMultiple(op, {newOps});
1074 return success();
1075 }
1076};
1077
1078// This pattern transforms vector.shape_cast ops to work at subgroup level.
1079struct WgToSgVectorShapeCastOp
1080 : public OpConversionPattern<vector::ShapeCastOp> {
1081 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1082
1083 LogicalResult
1084 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1085 ConversionPatternRewriter &rewriter) const override {
1086
1087 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1088 if (!resultType)
1089 return failure();
1090
1091 ArrayRef<int64_t> wgShape = resultType.getShape();
1092 xegpu::DistributeLayoutAttr layout =
1093 xegpu::getDistributeLayoutAttr(op.getResult());
1094 if (!layout || !layout.isForWorkgroup())
1095 return failure();
1096
1097 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1098 VectorType newResultType =
1099 VectorType::get(sgShape, resultType.getElementType());
1100
1101 // TODO: Add check for compatible layouts in layout attr.
1102 auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
1103 if (!srcType)
1104 return failure();
1105
1106 // Check that shape_cast only adds/removes unit dimensions,
1107 auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
1108 // Remove all 1s from both shapes and compare the rest.
1109 SmallVector<int64_t> srcNonUnit, dstNonUnit;
1110 for (int64_t d : src)
1111 if (d != 1)
1112 srcNonUnit.push_back(d);
1113 for (int64_t d : dst)
1114 if (d != 1)
1115 dstNonUnit.push_back(d);
1116 return srcNonUnit == dstNonUnit;
1117 };
1118
1119 if (!onlyUnitDims(srcType.getShape(), sgShape))
1120 return failure();
1121
1122 // For rank reducing or increasing shape_cast ops, the lower rank layout
1123 // must be a slice of higher rank layout.
1124 int64_t sourceRank = srcType.getRank();
1125 int64_t resultRank = sgShape.size();
1126 xegpu::DistributeLayoutAttr sourceLayout =
1127 xegpu::getDistributeLayoutAttr(op.getSource());
1128 if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1129 return failure();
1130 if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1131 return failure();
1132
1133 SmallVector<Value> newShapeCastOps;
1134 for (auto src : adaptor.getSource()) {
1135 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1136 newResultType, src);
1137 xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
1138 layout.dropSgLayoutAndData());
1139 newShapeCastOps.push_back(newShapeCast.getResult());
1140 }
1141
1142 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1143 return success();
1144 }
1145};
1146
1147/// Pattern for lowering vector.multi_reduction op to subgroup level.
1148/// Current limitation: the sg_layout in the reduced dimension being 1
1149/// so that reduction is local to subgroup & no cross-subgroup communication is
1150/// needed.
1151/// TODO: Add cases to handle more general situations which require SLM access.
1152struct WgToSgMultiDimReductionOp
1153 : public OpConversionPattern<vector::MultiDimReductionOp> {
1154 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1155
1156 LogicalResult
1157 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter) const override {
1159 VectorType srcType = op.getSourceVectorType();
1160 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1161 if (!dstType)
1162 return failure();
1163
1164 auto srcShape = srcType.getShape();
1165 xegpu::DistributeLayoutAttr layout =
1166 xegpu::getDistributeLayoutAttr(op.getResult());
1167 if (!layout || !layout.isForWorkgroup())
1168 return failure();
1169
1170 auto reductionDims = llvm::to_vector(op.getReductionDims());
1171
1172 SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1173 .getParent()
1174 .getEffectiveSgLayoutAsInt();
1175 SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1176 .getParent()
1177 .getEffectiveSgDataAsInt();
1178
1179 // Check that the sgLayout in the reduced dimension is 1 and
1180 // each sg gets the entire slice to reduce.
1181 for (int64_t dim : reductionDims) {
1182 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1183 return rewriter.notifyMatchFailure(
1184 op,
1185 "sgLayout in each reduced dimension must be 1 and sgData in the "
1186 "reduced dim must match srcShape in that dim");
1187 }
1188
1189 SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1190
1191 VectorType newDstType =
1192 VectorType::get({sgShape}, dstType.getElementType());
1193
1194 SmallVector<Value> newReductions;
1195 for (auto sgSrc : adaptor.getSource()) {
1196 auto newOp = vector::MultiDimReductionOp::create(
1197 rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1198 adaptor.getAcc()[0], op.getReductionDims());
1200 layout.dropSgLayoutAndData());
1201 newReductions.push_back(newOp.getResult());
1202 }
1203
1204 rewriter.replaceOpWithMultiple(op, {newReductions});
1205 return success();
1206 }
1207};
1208
1209// This pattern transforms vector.transpose ops to work at subgroup level.
1210struct WgToSgVectorTransposeOp
1211 : public OpConversionPattern<vector::TransposeOp> {
1212 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1213
1214 LogicalResult
1215 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1216 ConversionPatternRewriter &rewriter) const override {
1217 VectorType resultType = op.getResultVectorType();
1218
1219 ArrayRef<int64_t> wgShape = resultType.getShape();
1220 xegpu::DistributeLayoutAttr layout =
1221 xegpu::getDistributeLayoutAttr(op.getResult());
1222 if (!layout || !layout.isForWorkgroup())
1223 return failure();
1224
1225 xegpu::DistributeLayoutAttr sourceLayout =
1226 xegpu::getDistributeLayoutAttr(op.getVector());
1227 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1228 return failure();
1229
1230 SmallVector<int64_t> sourceSgLayout =
1231 sourceLayout.getEffectiveSgLayoutAsInt();
1232 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1233 DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
1234 DenseI32ArrayAttr resultOrder = layout.getOrder();
1235
1236 if (!sourceOrder || !resultOrder) {
1237 return rewriter.notifyMatchFailure(
1238 op, "Both source and result must have order attributes");
1239 }
1240
1241 ArrayRef<int64_t> permutation = op.getPermutation();
1242 size_t permutationSize = permutation.size();
1243 if (sourceSgLayout.size() != permutationSize ||
1244 resultSgLayout.size() != permutationSize) {
1245 return rewriter.notifyMatchFailure(
1246 op, "Layouts and permutation must have the same rank");
1247 }
1248
1249 // Check that sgLayout, sgData & order are properly transposed for source
1250 // and result
1251 if (!layout.isTransposeOf(sourceLayout, permutation))
1252 return rewriter.notifyMatchFailure(
1253 op, "Result layout is not a valid transpose of source layout "
1254 "according to permutation");
1255
1256 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1257 VectorType newResultType =
1258 VectorType::get(sgShape, resultType.getElementType());
1259 SmallVector<Value> newTransposeOps;
1260 for (auto src : adaptor.getVector()) {
1261 auto newTranspose = vector::TransposeOp::create(
1262 rewriter, op.getLoc(), newResultType, src, permutation);
1263 xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
1264 layout.dropSgLayoutAndData());
1265 newTransposeOps.push_back(newTranspose.getResult());
1266 }
1267
1268 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1269 return success();
1270 }
1271};
1272
1273// This pattern distributes the vector.constant_mask ops to work at subgroup
1274// level.
1275struct WgToSgVectorConstantMaskOp
1276 : public OpConversionPattern<vector::ConstantMaskOp> {
1277 using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1278
1279 LogicalResult
1280 matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1281 ConversionPatternRewriter &rewriter) const override {
1282 xegpu::DistributeLayoutAttr layout =
1283 xegpu::getDistributeLayoutAttr(op.getResult());
1284 if (!layout || !layout.isForWorkgroup())
1285 return failure();
1286
1287 Location loc = op.getLoc();
1288 VectorType type = op.getResult().getType();
1289 auto wgShape = type.getShape();
1290
1291 ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
1292
1293 // Get subgroup ID.
1294 Value sgId =
1295 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1296 auto sgOffsets =
1297 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1298 if (failed(sgOffsets))
1299 return failure();
1300
1301 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1302 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1303
1304 // In each dimension, each subgroup computes its local mask size as:
1305 // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
1306 SmallVector<Value> newCreateMaskOps;
1307 for (auto offsetSet : *sgOffsets) {
1308 SmallVector<Value> maskOperands;
1309
1310 for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
1311 Value wgMaskSizeVal =
1312 arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
1313 Value dimSizeVal =
1314 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1315 Value offset = offsetSet[i];
1316 Value adjustedMaskSize =
1317 arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
1318 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1319 Value nonNegative =
1320 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1321 Value sgMaskSize =
1322 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1323 maskOperands.push_back(sgMaskSize);
1324 }
1325
1326 auto newCreateMaskOp =
1327 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1328 xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
1329 layout.dropSgLayoutAndData());
1330 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1331 }
1332
1333 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1334 return success();
1335 }
1336};
1337
1338} // namespace
1339
1340namespace mlir {
1341namespace xegpu {
1343 patterns
1344 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1345 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1346 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1347 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1348 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1349 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1350 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1351 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1352 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1353 WgToSgVectorConstantMaskOp>(patterns.getContext());
1354}
1355} // namespace xegpu
1356} // namespace mlir
1357
1358namespace {
1359struct XeGPUWgToSgDistributePass
1360 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1361 void runOnOperation() override;
1362};
1363} // namespace
1364
1365void XeGPUWgToSgDistributePass::runOnOperation() {
1366 // Track existing UnrealizedConversionCastOps
1367 SmallVector<Operation *> existingCastOps;
1368 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1369 existingCastOps.push_back(castOp.getOperation());
1370 });
1371
1372 {
1373 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1374 // VectorType operands. This first converts such operands to
1375 // RankedTensorType, propagates the layout attribute into the encoding
1376 // attribute, and finally converts the RankedTensorType to VectorType based
1377 // on the encoding.
1378
1379 TypeConverter converter;
1380 converter.addConversion([&](Type type) -> Type { return type; });
1381 converter.addConversion(
1382 [&](RankedTensorType type,
1383 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1384 Type elemTy = type.getElementType();
1385 ArrayRef<int64_t> shape = type.getShape();
1386
1387 int count;
1388 SmallVector<int64_t> subShape;
1389 std::tie(subShape, count) = getSgShapeAndCount(
1390 shape,
1391 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1392
1393 auto newTy = VectorType::get(subShape, elemTy);
1394 result.append(count, newTy);
1395 return success();
1396 });
1397
1399 converter);
1400 }
1401
1402 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1403 // as well as XeGPU, Arith, and Vector operations.
1404 MLIRContext *ctx = &getContext();
1405 RewritePatternSet patterns(ctx);
1406 ConversionTarget target(*ctx);
1407 TypeConverter converter;
1408 converter.addConversion([&](Type type) -> Type { return type; });
1409 converter.addConversion(
1410 [&](xegpu::TensorDescType type,
1411 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1412 Type elemTy = type.getElementType();
1413 ArrayRef<int64_t> shape = type.getShape();
1414
1415 int count;
1416 SmallVector<int64_t> subShape;
1417 xegpu::LayoutAttr layout = type.getLayoutAttr();
1418 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1419
1420 if (layout)
1421 layout = layout.dropSgLayoutAndData();
1422
1423 auto newTy = xegpu::TensorDescType::get(
1424 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1425 result.append(count, newTy);
1426 return success();
1427 });
1428
1429 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1430 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1431 return createOp.getType();
1432 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1433 return loadOp.getTensorDescType();
1434 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1435 return storeOp.getTensorDescType();
1436 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1437 return updateOp.getType();
1438 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1439 return prefetchOp.getTensorDescType();
1440 return xegpu::TensorDescType();
1441 };
1442
1443 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1444 return !layout || !layout.isForWorkgroup();
1445 };
1446
1447 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1448 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1449 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1450 auto tdescTy = getTensorDescType(op);
1451 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1452 return isLegal(layout);
1453 });
1454
1455 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1456 auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1457 return isLegal(layout);
1458 });
1459
1460 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1461 [=](xegpu::LoadMatrixOp op) -> bool {
1462 return isLegal(op.getLayoutAttr());
1463 });
1464
1465 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1466 [=](xegpu::StoreMatrixOp op) -> bool {
1467 return isLegal(op.getLayoutAttr());
1468 });
1469
1470 target.addDynamicallyLegalOp<arith::ConstantOp>(
1471 [=](arith::ConstantOp op) -> bool {
1472 auto vecType = dyn_cast<VectorType>(op.getType());
1473 if (!vecType)
1474 return true;
1475
1476 auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1477 return isLegal(layout);
1478 });
1479
1480 target.addDynamicallyLegalOp<
1481 vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1482 vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
1483 [=](Operation *op) -> bool {
1484 // Check for either a SliceAttr or LayoutAttr on the result.
1485 auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1486 return isLegal(layout);
1487 });
1488
1489 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1490 [=](xegpu::LoadGatherOp op) -> bool {
1491 auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1492 return isLegal(layout);
1493 });
1494
1495 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1496 [=](xegpu::StoreScatterOp op) -> bool {
1497 auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
1498 return isLegal(layout);
1499 });
1500
1501 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1502 [=](xegpu::ConvertLayoutOp op) -> bool {
1503 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1504 });
1505
1506 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1507 [=](Operation *op) -> std::optional<bool> {
1508 // Only handle elementwise mappable ops
1510 return true;
1511
1512 VectorType resultType =
1513 dyn_cast<VectorType>(op->getResult(0).getType());
1514 if (!resultType)
1515 return true;
1516
1517 // Check if all operands are vectors of the same shape
1518 // TODO: Support other types.
1519 for (Value operand : op->getOperands()) {
1520 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1521 if (!operandType || operandType.getShape() != resultType.getShape()) {
1522 return true;
1523 }
1524 }
1525
1526 xegpu::DistributeLayoutAttr layout =
1527 xegpu::getDistributeLayoutAttr(op->getResult(0));
1528 return isLegal(layout);
1529 });
1530
1531 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1532 [=](UnrealizedConversionCastOp op) {
1533 return llvm::is_contained(existingCastOps, op.getOperation());
1534 });
1535
1536 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1537
1539 target);
1541 if (failed(
1542 applyPartialConversion(getOperation(), target, std::move(patterns))))
1543 return signalPassFailure();
1544
1545 // Remove sg_layout and sg_data attributes from the Layout
1546 // attribute for each VectorType result of the operation.
1547 // For Structured Control Flow ops, the layout is simply removed,
1548 // since in 1:N case, the layout for new results are missing.
1549 // Layout propagation pass will activated.
1550 getOperation()->walk([](Operation *op) {
1551 for (OpResult result : op->getOpResults()) {
1552 std::string name = xegpu::getLayoutName(result);
1553 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1554 op->removeAttr(name);
1555 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1556 if (auto newLayout = layout.dropSgLayoutAndData())
1557 op->setAttr(name, newLayout);
1558 }
1559 }
1560 }
1561 });
1562}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_range getOpResults()
Definition Operation.h:420
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.