MLIR 23.0.0git
XeGPUSgToWiDistributeExperimental.cpp
Go to the documentation of this file.
1//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
20#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/ValueRange.h"
29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34namespace mlir {
35namespace xegpu {
36#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38} // namespace xegpu
39} // namespace mlir
40
41using namespace mlir;
42
43#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45
46namespace {
47
48/// Casts the given vector value `v` to the expected vector type `expectedTy`.
49static Value castValueTo(ConversionPatternRewriter &rewriter,
50 TypedValue<VectorType> v, VectorType expectedTy) {
51 // If the type matches, simply return the value itself.
52 if (v.getType() == expectedTy)
53 return v;
54 // If only shape differs, use shape cast.
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
58
59 // Else create an unrealized cast.
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
61 expectedTy, ValueRange{v});
62 return newOp.getResult(0);
63}
64
65/// Checks if all XeGPU anchor ops and vector results have valid layouts.
66static LogicalResult verifyLayouts(Operation *root) {
67 auto walkResult = root->walk([&](Operation *nestedOp) -> WalkResult {
68 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
69 auto layout = anchorOp.getAnchorLayout();
70 if (!layout) {
71 nestedOp->emitError("expected anchor layout attribute on operation");
72 return WalkResult::interrupt();
73 }
74 return WalkResult::advance();
75 }
76 // For each vector result, check if the op contains a result layout
77 // attribute.
78 for (OpResult result : nestedOp->getResults()) {
79 if (isa<VectorType>(result.getType())) {
81 if (!layout) {
82 nestedOp->emitError(
83 "expected result layout attribute on vector result");
84 return WalkResult::interrupt();
85 }
86 }
87 }
88 return WalkResult::advance();
89 });
90 return walkResult.wasInterrupted() ? failure() : success();
91}
92
93/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
94/// exactly 1 reduction dimension, it had valid result layout attribute, and
95/// result type can be distributed to lanes using the layout.
96static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
97 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
98 // If no layout, not valid.
99 if (!resLayout || !resLayout.isForSubgroup())
100 return false;
101 VectorType resTy = dyn_cast<VectorType>(op.getType());
102 if (!resTy)
103 return false;
104 // Compute the distributed result vector type based on the layout.
105 FailureOr<VectorType> resDistTypeOrFailure =
106 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
107 if (failed(resDistTypeOrFailure))
108 return false;
109 return op.getReductionDims().size() == 1;
110}
111
112/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
113/// is doing its own local reduction. In this case the result layout ensures
114/// that result vector is distributed to lanes, i.e. the result vector type is
115/// different from the distributed result vector type.
116static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
117 // Must be valid MultiDimReductionOp.
118 assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
119 "MultiDimReductionOp");
120 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
121 VectorType resTy = dyn_cast<VectorType>(op.getType());
122 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
123 return resTy != resDistTypeOrFailure.value();
124}
125
126/// Given a vector type and its distributed vector type, return the list of
127/// dimensions that are distributed.
128static SmallVector<int64_t> getDistributedDims(VectorType originalType,
129 VectorType distributedType) {
130 assert(originalType.getRank() == distributedType.getRank() &&
131 "original and distributed vector types must have the same rank");
132 SmallVector<int64_t> distributedDims;
133 for (int64_t i = 0; i < originalType.getRank(); ++i) {
134 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
135 distributedDims.push_back(i);
136 }
137 return distributedDims;
138}
139
140/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
141/// op. This simply drops the layout attribute from the tensor descriptor type.
142struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
143 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
144
145 LogicalResult
146 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 xegpu::TensorDescType resultType = op.getType();
149 // If no layout, nothing to do.
150 if (!resultType.getLayout())
151 return failure();
152
153 auto newOp = xegpu::CreateNdDescOp::create(
154 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
155 op->getAttrs());
156 rewriter.replaceOp(op, newOp.getResult());
157 return success();
158 }
159};
160
161/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
162/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
163/// original rank.
164struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
165 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
166
167 LogicalResult
168 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
169 ConversionPatternRewriter &rewriter) const override {
170 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
171 // If no layout, nothing to do.
172 if (!layout)
173 return failure();
174 // Check if the layout attached to the tensor descriptor is same as the
175 // anchor layout. Otherwise, this is a conflict.
176 if (op.getTensorDescType().getLayout() != layout)
177 return rewriter.notifyMatchFailure(
178 op, "conflicting layout attributes on tensor descriptor and anchor");
179 auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
180 if (!uArch)
181 return rewriter.notifyMatchFailure(
182 op, "xegpu::LoadNdOp require target attribute attached to "
183 "determine transpose "
184 "requirement");
185 auto supportedWiResultTyOrFailure =
186 xegpu::getDistributedVectorType(op.getTensorDescType());
187 auto expectedWiResultTyOrFailure =
188 xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
189 if (failed(supportedWiResultTyOrFailure))
190 return rewriter.notifyMatchFailure(
191 op, "unable to compute the workitem vector type for LoadNdOp");
192 if (failed(expectedWiResultTyOrFailure))
193 return rewriter.notifyMatchFailure(
194 op,
195 "unable to compute expected workitem vector type from lane layout");
196 auto newOp = xegpu::LoadNdOp::create(
197 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
198 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
199 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
200 op.getL3HintAttr(), /**layout**/ nullptr);
201 // Set the packed attribute if the layout requires it.
202 newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
203 // Set the transpose attribute if the layout requires it.
204 if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
205 newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
206 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
207 expectedWiResultTyOrFailure.value()));
208 return success();
209 }
210};
211
212/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
213/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
214/// incoming value to 1D.
215struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
216 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
217
218 LogicalResult
219 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter) const override {
221 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
222 // If no layout, nothing to do.
223 if (!layout)
224 return failure();
225 // Check if the layout attached to the tensor descriptor and value layout is
226 // same as the anchor layout. Otherwise, this is a conflict.
227 if (op.getTensorDescType().getLayout() != layout)
228 return rewriter.notifyMatchFailure(
229 op, "conflicting layout attributes on tensor descriptor and anchor");
230 auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
231 if (valueLayout != layout)
232 return rewriter.notifyMatchFailure(
233 op, "conflicting layout attributes on value and anchor");
234 auto supportedWiValueTyOrFailure =
235 xegpu::getDistributedVectorType(op.getTensorDescType());
236 if (failed(supportedWiValueTyOrFailure))
237 return rewriter.notifyMatchFailure(
238 op,
239 "unable to compute wi vector type for StoreNdOp value from tensor "
240 "descriptor");
241
242 xegpu::StoreNdOp::create(
243 rewriter, op.getLoc(),
244 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
245 supportedWiValueTyOrFailure.value()),
246 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
247 op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
248 rewriter.eraseOp(op);
249 return success();
250 }
251};
252
253/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
254/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
255/// convert the inputs and output to/from 1D.
256struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
257 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
258
259 LogicalResult
260 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262 // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
263 // Check if the op has A, B and CD layouts attached.
264 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
265 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
266 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
267 if (!layoutA || !layoutB || !layoutCd)
268 return failure();
269 // llvm::errs() << "tryning to calculate wi types for dpas op\n";
270 auto wiResultTyOrFailure =
271 xegpu::getDistributedVectorType(op.getType(), layoutCd);
272 auto wiATypeOrFailure =
273 xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
274 auto wiBTypeOrFailure =
275 xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
276 auto expectedWiResultTyOrFailure =
277 xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
278 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
279 failed(wiBTypeOrFailure))
280 return rewriter.notifyMatchFailure(
281 op, "failed to calculate supported workitem vector types for DpasOp "
282 "from layouts");
283 if (failed(expectedWiResultTyOrFailure))
284 return rewriter.notifyMatchFailure(
285 op, "unable to compute expected workitem vector type for DpasOp from "
286 "lane layout");
287 auto newOp = xegpu::DpasOp::create(
288 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
289 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
290 wiATypeOrFailure.value()),
291 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
292 wiBTypeOrFailure.value()),
293 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
294 wiResultTyOrFailure.value()),
295 /** layoutA**/ nullptr,
296 /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
297 // Explicitly set the new types to enable correct type materializations.
298 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
299 expectedWiResultTyOrFailure.value()));
300 return success();
301 }
302};
303
304/// Distributes elementwise ops to workitem-level elementwise ops. This
305/// currently handles elementwise ops with single result only.
306struct SgToWiElementWise : public ConversionPattern {
307 SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
308 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
309
310 LogicalResult
311 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
312 ConversionPatternRewriter &rewriter) const override {
313 // Only match ops with elementwise trait and single result.
315 return failure();
316
317 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
318 if (!resultType)
319 return rewriter.notifyMatchFailure(
320 op, "operation result is not a vector type");
321
322 xegpu::DistributeLayoutAttr layout =
323 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
324 if (!layout || !layout.isForSubgroup())
325 return rewriter.notifyMatchFailure(
326 op, "operation result does not have subgroup distribute layout");
327
328 auto wiShapeOrFailure =
329 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
330
331 if (failed(wiShapeOrFailure))
332 return rewriter.notifyMatchFailure(
333 op, "unable to compute workitem vector type from the layout");
334
335 VectorType newResultType = wiShapeOrFailure.value();
336 OperationState state(op->getLoc(), op->getName());
337 state.addOperands(operands);
338 state.addTypes(newResultType);
339 // Copy all attributes except for DistributeLayoutAttr.
340 for (auto attr : op->getAttrs()) {
341 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
342 state.addAttribute(attr.getName(), attr.getValue());
343 }
344 Operation *newOp = rewriter.create(state);
345
346 rewriter.replaceOp(op, newOp->getResult(0));
347 return success();
348 }
349};
350
351/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
352/// ConstantOp.
353struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
354 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
355
356 LogicalResult
357 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 auto resultType = dyn_cast<VectorType>(op.getType());
360 if (!resultType)
361 return failure();
362
363 // Only handle dense vector constants
364 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
365 if (!dense)
366 return rewriter.notifyMatchFailure(
367 op, "only dense splat vector constants are supported");
368
369 xegpu::DistributeLayoutAttr layout =
370 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
371 if (!layout || !layout.isForSubgroup())
372 return rewriter.notifyMatchFailure(
373 op, "operation result does not have subgroup distribute layout");
374
375 auto wiShapeOrFailure =
376 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
377
378 if (failed(wiShapeOrFailure))
379 return rewriter.notifyMatchFailure(
380 op, "unable to compute workitem vector type from the layout");
381
382 VectorType newResultType = wiShapeOrFailure.value();
383 auto sclarValue = dense.getSplatValue<Attribute>();
384 auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
385
386 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
387 newDenseAttr);
388 rewriter.replaceOp(op, newOp.getResult());
389 return success();
390 }
391};
392
393/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
394struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
395 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
396
397 LogicalResult
398 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter) const override {
400 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
401 // If no layout, nothing to do.
402 if (!layout)
403 return failure();
404
405 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
406 op.getMixedOffsets(), op.getL1HintAttr(),
407 op.getL2HintAttr(), op.getL3HintAttr(),
408 /**layout**/ nullptr);
409 rewriter.eraseOp(op);
410 return success();
411 }
412};
413
414/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
415///
416/// Example 1 (1D, no chunk size):
417/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
418/// %mask = producer_op : vector<16xi1>
419/// %offset = producer_op : vector<16xindex>
420/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
421/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
422/// Distributed to:
423/// %mask = producer_op : vector<1xi1>
424/// %offset = producer_op : vector<1xindex>
425/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
426/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
427///
428/// Example 2 (2D with chunk size, same mask & offset):
429/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
430/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
431/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
432/// Distributed to:
433/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
434/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
435///
436/// Example 3 (3D with leading unit dims):
437/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
438/// %mask = producer_op : vector<1x1x16xi1>
439/// %offset = producer_op : vector<1x1x16xindex>
440/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
441/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
442/// Distributed to:
443/// %mask = producer_op : vector<1x1x1xi1>
444/// %offset = producer_op : vector<1x1x1xindex>
445/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
446/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
447struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
448 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
449
450 LogicalResult
451 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
452 ConversionPatternRewriter &rewriter) const override {
453 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
454 if (!layout)
455 return failure();
456
457 VectorType origResultTy = op.getValueType();
458 if (!origResultTy)
459 return failure();
460
461 // Check that leading dimensions are unit.
462 int chunkSize = op.getChunkSize().value_or(1);
463 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
464 ArrayRef<int64_t> shape = origResultTy.getShape();
465 if (llvm::any_of(
466 shape.take_front(origResultTy.getRank() - effectiveVecRank),
467 [](int64_t d) { return d != 1; }))
468 return rewriter.notifyMatchFailure(
469 op, "Only unit dimensions allowed for the leading "
470 "dimensions of the load vector!");
471
472 auto distResultTyOrFailure =
473 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
474 if (failed(distResultTyOrFailure))
475 return rewriter.notifyMatchFailure(
476 op,
477 "unable to compute expected workitem vector type from lane layout");
478
479 VectorType distResultTy = distResultTyOrFailure.value();
480 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
481 distResultTy.getElementType());
482
483 // Flatten offsets and mask to 1D to match the 1D result type.
484 Value distOffsets = adaptor.getOffsets();
485 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
486 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
487 distOffsetsTy.getElementType());
488 distOffsets = castValueTo(
489 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
490
491 Value distMask = adaptor.getMask();
492 auto distMaskTy = cast<VectorType>(distMask.getType());
493 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
494 distMaskTy.getElementType());
495 distMask =
496 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
497
498 Value distSource = adaptor.getSource();
499 auto newOp = xegpu::LoadGatherOp::create(
500 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
501 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
502 op.getL3HintAttr(), /*layout=*/nullptr);
503
504 Value result = newOp->getResult(0);
505 if (distResultTy1D != distResultTy)
506 result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
507 distResultTy);
508 rewriter.replaceOp(op, result);
509 return success();
510 }
511};
512
513/// This pattern distributes a subgroup-level vector.reduction op to
514/// workitem-level. This require shuffling the data across the workitems (using
515/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
516/// result.
517struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
518 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
519
520 LogicalResult
521 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
522 ConversionPatternRewriter &rewriter) const override {
523 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
524
525 // If no layout, nothing to do.
526 if (!layout || !layout.isForSubgroup())
527 return failure();
528
529 VectorType srcVecType = op.getSourceVectorType();
530 // Only rank 1 vectors supported.
531 if (srcVecType.getRank() != 1)
532 return rewriter.notifyMatchFailure(
533 op, "Only rank 1 reductions can be distributed.");
534 // Lane layout must have the same rank as the vector.
535 if (layout.getRank() != srcVecType.getRank())
536 return rewriter.notifyMatchFailure(
537 op, "Layout rank does not match vector rank.");
538
539 // Get the subgroup size from the layout.
540 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
541 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
542 if (!uArch)
543 return rewriter.notifyMatchFailure(
544 op, "xegpu::ReductionOp require target attribute attached to "
545 "determine subgroup size");
546
547 // Only subgroup-sized vectors supported.
548 if (sgSize != uArch->getSubgroupSize() ||
549 srcVecType.getShape()[0] % sgSize != 0)
550 return rewriter.notifyMatchFailure(op,
551 "Invalid layout or reduction vector "
552 "dimension must match subgroup size.");
553
554 if (!op.getType().isIntOrFloat())
555 return rewriter.notifyMatchFailure(
556 op, "Reduction distribution currently only supports floats and "
557 "integer types.");
558
559 // Get the distributed vector (per work-item portion).
560 Value laneValVec = adaptor.getVector();
561
562 // Distribute and reduce across work-items in the subgroup.
563 Value fullReduce = xegpu::subgroupReduction(
564 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
565
566 // If there's an accumulator, combine it with the reduced value.
567 if (adaptor.getAcc())
568 fullReduce = vector::makeArithReduction(
569 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
570
571 rewriter.replaceOp(op, fullReduce);
572 return success();
573 }
574};
575
576/// This pattern distributes a subgroup-level vector.multi_reduction op to
577/// workitem-level only if the reduction is lane-local. This means that
578/// reduction dimension is not distributed to lanes and each lane does its own
579/// local reduction.
580struct SgToWiMultiDimReduction
581 : public OpConversionPattern<vector::MultiDimReductionOp> {
582 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
583
584 LogicalResult
585 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const override {
588 ArrayRef<int64_t> reductionDims = op.getReductionDims();
589 assert(reductionDims.size() == 1 &&
590 "Expecting single reduction dimension for subgroup multi "
591 "reduction op");
592 // For rank > 2, ensure leading dimensions are unit.
593 VectorType sourceType = op.getSourceVectorType();
594 int64_t rank = sourceType.getRank();
595 if (rank > 2) {
596 ArrayRef<int64_t> shape = sourceType.getShape();
597 if (llvm::any_of(shape.take_front(rank - 2),
598 [](int64_t d) { return d != 1; }))
599 return rewriter.notifyMatchFailure(
600 op, "only unit leading dimensions are supported for "
601 "multi_reduction with rank > 2");
602 }
603 if (isReductionLaneLocal(op)) {
604 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
605 VectorType resVecTy = dyn_cast<VectorType>(op.getType());
606 auto resDistVecTyOrFailure =
607 getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
608 // For lane local reduction, simply create a new MultiDimReductionOp using
609 // adaptor operands and the new result type.
610 result = vector::MultiDimReductionOp::create(
611 rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
612 adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
613 } else {
614 auto reductionDim = reductionDims[0];
615 VectorType sourceType = op.getSourceVectorType();
616 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
618 cast<TypedValue<VectorType>>(adaptor.getSource()),
619 cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
620 reductionDim, reductionDimSize, op.getLoc(), rewriter);
621 }
622 rewriter.replaceOp(op, result);
623 return success();
624 }
625};
626
627/// Helper to compute distributed coordinates for matrix ops.
628/// When not using subgroup_block_io, each workitem computes its own
629/// coordinates based on the layout and lane ID.
630static SmallVector<Value> computeDistributedCoordsForMatrixOp(
631 ConversionPatternRewriter &rewriter, Location loc,
632 xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
633 ValueRange origOffsets) {
634 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
635 /*upperBound=*/mlir::IntegerAttr());
636 auto maybeCoords =
637 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
638 if (failed(maybeCoords))
639 return {};
640 assert(maybeCoords.value().size() == 1 &&
641 "Expected one set of distributed offsets");
643 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
644 getAsOpFoldResult(origOffsets));
645 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
646}
647
648/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
649struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
650 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
651
652 LogicalResult
653 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
654 ConversionPatternRewriter &rewriter) const override {
655 auto layout = op.getLayoutAttr();
656 // If no layout, nothing to do.
657 if (!layout)
658 return failure();
659
660 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
661 if (!sgPayloadTy)
662 return rewriter.notifyMatchFailure(
663 op, "the matrix op payload must be a vector type");
664
665 auto loc = op.getLoc();
666 auto offsets = op.getMixedOffsets();
667 if (offsets.empty())
668 return rewriter.notifyMatchFailure(op, "the load op must have offsets");
669
670 FailureOr<VectorType> distPayloadTyOrFailure =
671 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
672 if (failed(distPayloadTyOrFailure))
673 return rewriter.notifyMatchFailure(
674 op, "Failed to distribute matrix op payload based on layout.");
675
676 SmallVector<Value> offsetsAsValues =
677 vector::getAsValues(rewriter, loc, offsets);
678
679 SmallVector<Value> newCoords = offsetsAsValues;
680 if (!op.getSubgroupBlockIoAttr()) {
681 newCoords = computeDistributedCoordsForMatrixOp(
682 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
683 if (newCoords.empty())
684 return rewriter.notifyMatchFailure(
685 op, "Failed to compute distributed coordinates.");
686 }
687
688 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
689 ShapedType::kDynamic);
690 DenseI64ArrayAttr newConstOffsetsAttr =
691 rewriter.getDenseI64ArrayAttr(newConstOffsets);
692
693 auto newOp = xegpu::LoadMatrixOp::create(
694 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
695 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
696 xegpu::DistributeLayoutAttr{});
697 rewriter.replaceOp(op, newOp.getResult());
698 return success();
699 }
700};
701
702/// Distributes a subgroup-level vector.transpose op to workitem-level.
703struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
704 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
705
706 LogicalResult
707 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const override {
709 xegpu::DistributeLayoutAttr sourceLayout =
710 xegpu::getTemporaryLayout(op->getOpOperand(0));
711 xegpu::DistributeLayoutAttr resultLayout =
712 xegpu::getTemporaryLayout(op->getOpResult(0));
713 if (!sourceLayout || !resultLayout)
714 return rewriter.notifyMatchFailure(
715 op, "the source or result vector of the transpose op lacks layout "
716 "attribute");
717 ArrayRef<int64_t> perm = op.getPermutation();
718 // Result layout must be a transpose of source layout.
719 if (!resultLayout.isTransposeOf(sourceLayout, perm,
721 return rewriter.notifyMatchFailure(
722 op, "the source or result vector layouts must be transposes of "
723 "each other");
724 FailureOr<VectorType> distributedResultTypeOrFailure =
725 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
726 if (failed(distributedResultTypeOrFailure))
727 return rewriter.notifyMatchFailure(
728 op, "Failed to distribute the result vector type in "
729 "vector::Transpose op");
730 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
731 adaptor.getVector(), perm);
732 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
733 distributedResultTypeOrFailure.value()));
734 return success();
735 }
736};
737
738/// Distributes a subgroup-level vector.bitcast op to workitem-level.
739/// Bitcast only impacts the innermost dimension of the source/result vectors.
740struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
741 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
742
743 LogicalResult
744 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter) const override {
746 xegpu::DistributeLayoutAttr resultLayout =
747 xegpu::getTemporaryLayout(op->getOpResult(0));
748 if (!resultLayout)
749 return rewriter.notifyMatchFailure(
750 op, "result vector of the bitcast op lacks layout attribute");
751 FailureOr<VectorType> distributedResultTypeOrFailure =
752 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
753 if (failed(distributedResultTypeOrFailure))
754 return rewriter.notifyMatchFailure(
755 op, "Failed to distribute the result vector type in "
756 "vector::BitCast op");
757 auto newOp = vector::BitCastOp::create(
758 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
759 adaptor.getSource());
760 rewriter.replaceOp(op, newOp.getResult());
761 return success();
762 }
763};
764
765/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
766/// to workitem-level. Uses `computeDistributedCoords()` to obtain the
767/// coordinates each workitem owns, then compares each coordinate against the
768/// original mask bounds using `arith.cmpi slt`. The per-element boolean
769/// results are assembled into the distributed mask vector.
770///
771/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
772/// satisfy `coord[i] < bound[i]`.
773///
774/// Example (1D):
775/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
776/// %mask = vector.create_mask %m0 : vector<16xi1>
777/// For lane k, computeDistributedCoords gives coord = [k], so:
778/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1
779/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
780///
781/// Example (2D):
782/// layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
783/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
784/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
785/// [[r0, c0], [r0, c1]]
786/// For each coord: in_bounds = (r < m0) && (c < m1)
787/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
788template <typename OpType,
789 typename = std::enable_if_t<llvm::is_one_of<
790 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
791struct SgToWiCreateMask : public OpConversionPattern<OpType> {
792 using OpConversionPattern<OpType>::OpConversionPattern;
793
794 LogicalResult
795 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
796 ConversionPatternRewriter &rewriter) const override {
797 xegpu::DistributeLayoutAttr layout =
798 xegpu::getTemporaryLayout(op->getOpResult(0));
799 if (!layout || !layout.isForSubgroup())
800 return rewriter.notifyMatchFailure(
801 op, "operation result does not have subgroup distribute layout");
802
803 VectorType origType = op.getType();
804 FailureOr<VectorType> distTypeOrFailure =
805 getDistVecTypeBasedOnLaneLayout(layout, origType);
806 if (failed(distTypeOrFailure))
807 return rewriter.notifyMatchFailure(
808 op, "unable to compute workitem vector type from the layout");
809
810 VectorType distType = distTypeOrFailure.value();
811 Location loc = op.getLoc();
812
813 // Materialize the original mask bounds as Values.
814 SmallVector<Value> origBounds;
815 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
816 origBounds.append(op.getOperands().begin(), op.getOperands().end());
817 } else {
818 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
819 for (auto dimSize : dimSizes)
820 origBounds.push_back(
821 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
822 }
823
824 ArrayRef<int64_t> origShape = origType.getShape();
825
826 // Use computeDistributedCoords to get the coordinates each WI owns.
827 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
828 /*upperBound=*/mlir::IntegerAttr());
829 auto maybeCoordsVec =
830 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
831 if (failed(maybeCoordsVec))
832 return rewriter.notifyMatchFailure(
833 op, "failed to compute distributed coordinates from layout");
834
835 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
836 int64_t numElements = distType.getNumElements();
837 assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
838 "number of coordinate sets must match number of distributed "
839 "elements");
840
841 // For each element, compare all coordinates against bounds.
842 Value trueVal =
843 arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
844 SmallVector<Value> maskBits;
845 for (auto &coords : coordsVec) {
846 Value inBounds = trueVal;
847 for (size_t i = 0; i < coords.size(); ++i) {
848 Value cmp = arith::CmpIOp::create(
849 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
850 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
851 }
852 maskBits.push_back(inBounds);
853 }
854
855 // Build the distributed mask vector.
857 if (numElements == 1) {
858 result =
859 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
860 } else {
861 result =
862 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
863 }
864 rewriter.replaceOp(op, result);
865 return success();
866 }
867};
868
869/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
870struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
871 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
872
873 LogicalResult
874 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
875 ConversionPatternRewriter &rewriter) const override {
876 auto layout = op.getLayoutAttr();
877 // If no layout, nothing to do.
878 if (!layout)
879 return failure();
880
881 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
882 if (!sgPayloadTy)
883 return rewriter.notifyMatchFailure(
884 op, "the matrix op payload must be a vector type");
885
886 auto loc = op.getLoc();
887 auto offsets = op.getMixedOffsets();
888 if (offsets.empty())
889 return rewriter.notifyMatchFailure(op, "the store op must have offsets");
890
891 FailureOr<VectorType> distPayloadTyOrFailure =
892 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
893 if (failed(distPayloadTyOrFailure))
894 return rewriter.notifyMatchFailure(
895 op, "Failed to distribute matrix op payload based on layout.");
896
897 SmallVector<Value> offsetsAsValues =
898 vector::getAsValues(rewriter, loc, offsets);
899
900 SmallVector<Value> newCoords = offsetsAsValues;
901 if (!op.getSubgroupBlockIoAttr()) {
902 newCoords = computeDistributedCoordsForMatrixOp(
903 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
904 if (newCoords.empty())
905 return rewriter.notifyMatchFailure(
906 op, "Failed to compute distributed coordinates.");
907 }
908
909 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
910 ShapedType::kDynamic);
911 DenseI64ArrayAttr newConstOffsetsAttr =
912 rewriter.getDenseI64ArrayAttr(newConstOffsets);
913
914 xegpu::StoreMatrixOp::create(
915 rewriter, loc, TypeRange{},
916 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
917 distPayloadTyOrFailure.value()),
918 adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
919 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
920 rewriter.eraseOp(op);
921 return success();
922 }
923};
924
925/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
926/// workitem-level.
927///
928/// Example 1 (1D, no chunk size):
929/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
930/// %mask = producer_op : vector<16xi1>
931/// %offset = producer_op : vector<16xindex>
932/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
933/// memref<256xf16>, vector<16xindex>, vector<16xi1>
934/// Distributed to:
935/// %mask = producer_op : vector<1xi1>
936/// %offset = producer_op : vector<1xindex>
937/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
938/// memref<256xf16>, vector<1xindex>, vector<1xi1>
939///
940/// Example 2 (2D with chunk size, same mask & offset):
941/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
942/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
943/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
944/// Distributed to:
945/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
946/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
947///
948/// Example 3 (3D with leading unit dims):
949/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
950/// %mask = producer_op : vector<1x1x16xi1>
951/// %offset = producer_op : vector<1x1x16xindex>
952/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
953/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
954/// Distributed to:
955/// %mask = producer_op : vector<1x1x1xi1>
956/// %offset = producer_op : vector<1x1x1xindex>
957/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
958/// memref<256xf16>, vector<1xindex>, vector<1xi1>
959struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
960 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
961
962 LogicalResult
963 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
964 ConversionPatternRewriter &rewriter) const override {
965 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
966 if (!layout)
967 return failure();
968
969 VectorType origValueTy = op.getValueType();
970 if (!origValueTy)
971 return failure();
972
973 // Check that all leading dimensions are unit dimensions.
974 int chunkSize = op.getChunkSize().value_or(1);
975 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
976 ArrayRef<int64_t> shape = origValueTy.getShape();
977 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
978 [](int64_t d) { return d != 1; }))
979 return rewriter.notifyMatchFailure(
980 op, "Only unit dimensions allowed for the leading "
981 "dimensions of the store vector!");
982
983 auto distValueTyOrFailure =
984 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
985 if (failed(distValueTyOrFailure))
986 return rewriter.notifyMatchFailure(
987 op,
988 "unable to compute expected workitem vector type from lane layout");
989
990 VectorType distValueTy = distValueTyOrFailure.value();
991 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
992 distValueTy.getElementType());
993
994 Value distValue = adaptor.getValue();
995 if (distValue.getType() != distValueTy1D)
996 distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
997 distValueTy1D);
998
999 // Flatten offsets and mask to 1D to match the 1D value type.
1000 Value distOffsets = adaptor.getOffsets();
1001 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
1002 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
1003 distOffsetsTy.getElementType());
1004 distOffsets = castValueTo(
1005 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
1006
1007 Value distMask = adaptor.getMask();
1008 auto distMaskTy = cast<VectorType>(distMask.getType());
1009 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
1010 distMaskTy.getElementType());
1011 distMask =
1012 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
1013
1014 Value distDest = adaptor.getDest();
1015 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1016 distOffsets, distMask, op.getChunkSizeAttr(),
1017 op.getL1HintAttr(), op.getL2HintAttr(),
1018 op.getL3HintAttr(), /*layout=*/nullptr);
1019 rewriter.eraseOp(op);
1020 return success();
1021 }
1022};
1023
1024/// Distribute a vector::StepOp to workitem-level.
1025/// The layout must have exactly 1 effective lane dimension.
1026/// We completely resolve the vector::StepOp by computing the lane_data-sized
1027/// subranges.
1028struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
1029 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1030
1031 LogicalResult
1032 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1033 ConversionPatternRewriter &rewriter) const override {
1034 xegpu::DistributeLayoutAttr resultLayout =
1035 xegpu::getTemporaryLayout(op->getResult(0));
1036 if (!resultLayout || !resultLayout.isForSubgroup())
1037 return rewriter.notifyMatchFailure(
1038 op, "the result vector of the step op lacks subgroup layout");
1039
1040 auto loc = op.getLoc();
1041 auto stepResultVecTy = op.getResult().getType();
1042 auto wiShapeOrFailure =
1043 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
1044 if (failed(wiShapeOrFailure))
1045 return rewriter.notifyMatchFailure(
1046 op, "unable to compute workitem vector type from the layout");
1047 VectorType newVecTy = wiShapeOrFailure.value();
1048
1049 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1050 /*upperBound=*/mlir::IntegerAttr());
1051 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1052 rewriter, loc, laneId, stepResultVecTy.getShape());
1053 if (failed(laneDataBlockCoords))
1054 return rewriter.notifyMatchFailure(
1055 op, "failed to compute lane data block coordinates");
1056
1057 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1058 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1059 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
1060 newVecTy.getNumElements() / laneDataBlockLength);
1061 SmallVector<Value> stepVals;
1062 // For each lane_data block, reconstruct its sub-range
1063 // from the range of SG-level vector.step.Example: vector.step
1064 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
1065 // vector<16xindex>
1066 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
1067 // The blocks are round-robin distributed, so logical lane id 0
1068 // holds values [0,1, 8,9].
1069 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1070 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1071 stepVals.push_back(laneDataBlockStartCoord);
1072 for (int i = 1; i < laneDataBlockLength; ++i) {
1073 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
1074 stepVals.push_back(arith::AddIOp::create(
1075 rewriter, loc, laneDataBlockStartCoord, offset));
1076 }
1077 }
1078 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
1079 "Expecting the number of step values to match the number of "
1080 "elements in the vector");
1081 auto stepOpVal =
1082 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1083 rewriter.replaceOp(op, stepOpVal);
1084 return success();
1085 }
1086};
1087
1088/// Distributes a subgroup-level vector.extract op to workitem-level. Only
1089/// handles sub-vector extraction (result is VectorType, not scalar).
1090struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
1091 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1092
1093 LogicalResult
1094 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1095 ConversionPatternRewriter &rewriter) const override {
1096 // Only handle vector results (not scalar extraction).
1097 auto resultType = dyn_cast<VectorType>(op.getType());
1098 if (!resultType)
1099 return rewriter.notifyMatchFailure(op, "scalar extract not supported");
1100
1101 xegpu::DistributeLayoutAttr layout =
1102 xegpu::getTemporaryLayout(op->getOpResult(0));
1103 if (!layout || !layout.isForSubgroup())
1104 return failure();
1105
1106 // This implementation assumes distribution only happens on the innermost
1107 // dimension. Verify that lane_layout[0...n-2] are all unit.
1108 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1109 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1110 [](int64_t v) { return v != 1; }))
1111 return rewriter.notifyMatchFailure(
1112 op, "only innermost dimension distribution is supported for "
1113 "vector.extract");
1114
1115 auto newOp = vector::ExtractOp::create(
1116 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1117 rewriter.replaceOp(op, newOp.getResult());
1118 return success();
1119 }
1120};
1121
1122/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
1123struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
1124 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1125
1126 LogicalResult
1127 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1128 ConversionPatternRewriter &rewriter) const override {
1129 xegpu::DistributeLayoutAttr resultLayout =
1130 xegpu::getTemporaryLayout(op->getOpResult(0));
1131 if (!resultLayout || !resultLayout.isForSubgroup())
1132 return rewriter.notifyMatchFailure(
1133 op, "the result vector of the shape_cast op lacks subgroup layout");
1134
1135 auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
1136 resultLayout, op.getResultVectorType());
1137 if (failed(resultDistTypeOrFailure))
1138 return rewriter.notifyMatchFailure(
1139 op, "failed to get distributed vector type for result");
1140
1141 Value source = adaptor.getSource();
1142 auto newShapeCast = vector::ShapeCastOp::create(
1143 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1144 rewriter.replaceOp(op, newShapeCast);
1145 return success();
1146 }
1147};
1148
1149/// Distributes a subgroup-level vector.extract_strided_slice op to
1150/// workitem-level. If the result is distributed, the offsets and sizes are
1151/// adjusted to match the distributed types.
1152struct SgToWiVectorExtractStridedSlice
1153 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
1154 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1155
1156 LogicalResult
1157 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter) const override {
1159 xegpu::DistributeLayoutAttr resultLayout =
1160 xegpu::getTemporaryLayout(op->getOpResult(0));
1161 if (!resultLayout || !resultLayout.isForSubgroup())
1162 return failure();
1163
1164 VectorType resultType = op.getType();
1165 auto distResultTyOrFailure =
1166 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
1167 if (failed(distResultTyOrFailure))
1168 return rewriter.notifyMatchFailure(
1169 op, "unable to compute distributed vector type from lane layout");
1170 VectorType distResultTy = *distResultTyOrFailure;
1171
1172 SmallVector<int64_t> distributedDims =
1173 getDistributedDims(resultType, distResultTy);
1174
1175 // Collect updated sizes, offsets, strides. Pad to full source rank.
1176 int64_t sourceRank = op.getSourceVectorType().getRank();
1177 SmallVector<Attribute> updatedSizes =
1178 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1179 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1180 op.getOffsets(), [](Attribute attr) { return attr; });
1181 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1182 op.getStrides(), [](Attribute attr) { return attr; });
1183 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1184 updatedSizes.push_back(
1185 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1186 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1187 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1188 }
1189
1190 // If the result is distributed, adjust offsets and sizes in the
1191 // distributed dimension.
1192 if (!distributedDims.empty()) {
1193 if (distributedDims.size() != 1)
1194 return rewriter.notifyMatchFailure(
1195 op, "only single dimension distribution is supported");
1196 int64_t distDim = distributedDims[0];
1197 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1198 if (!uArch)
1199 return rewriter.notifyMatchFailure(
1200 op, "target attribute required to determine subgroup size");
1201 int subgroupSize = uArch->getSubgroupSize();
1202 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1203 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1204 return rewriter.notifyMatchFailure(
1205 op, "source of extract_strided_slice lacks distribution layout");
1206 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1207 if (sourceDistrDimSize % subgroupSize != 0)
1208 return rewriter.notifyMatchFailure(
1209 op, "source size along distributed dim is not a multiple of "
1210 "subgroup size");
1211 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1212 // Only check lane_data for the distributed dimension. Non-distributed
1213 // dimensions may have non-unit lane_data (e.g., packed layouts).
1214 if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
1215 sourceLaneData[distDim] != 1)
1216 return rewriter.notifyMatchFailure(
1217 op, "expecting unit lane data along the distributed dimension");
1218 int64_t distrDimOffset =
1219 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1220 if (distrDimOffset % subgroupSize != 0)
1221 return rewriter.notifyMatchFailure(
1222 op, "offset along distributed dim is not a multiple of "
1223 "subgroup size");
1224 // Adjust sizes and offsets for the distributed dimension.
1225 updatedSizes[distDim] =
1226 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1227 updatedOffsets[distDim] =
1228 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1229 }
1230
1231 auto newOp = vector::ExtractStridedSliceOp::create(
1232 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1233 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1234 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1235 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1236 rewriter.replaceOp(op, newOp.getResult());
1237 return success();
1238 }
1239};
1240
1241/// This pattern distributes a subgroup-level `vector.broadcast` op to
1242/// workitem-level. The pattern supports three cases:
1243///
1244/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1245/// vector must have a slice layout of the result. If the distributed source
1246/// and target vector types are identical, this lowers to a no-op; otherwise,
1247/// it remains a broadcast but operates on distributed vectors.
1248///
1249/// 2) Broadcast a same-rank vector with identical layouts for source and
1250/// target: The source vector must have unit dimensions, and lane_data must
1251/// be unit size for those unit dims. This always lowers to a no-op.
1252///
1253/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
1254/// from scalar to distributed result type.
1255///
1256/// Example 1 (low-rank to high-rank broadcast):
1257/// ```
1258/// %0 = "some_op"() {layout_result_0 =
1259/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
1260/// dims = [0]>} : () -> vector<16xf16>
1261/// %1 = vector.broadcast %0 {layout_result_0 =
1262/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1263/// : vector<16xf16> to vector<16x16xf16>
1264/// ```
1265/// is distributed to:
1266/// ```
1267/// %0 = "some_op"() : () -> vector<1xf16>
1268/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
1269/// ```
1270///
1271/// Example 2 (same-rank broadcast, no-op):
1272/// ```
1273/// %0 = "some_op"() {layout_result_0 =
1274/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1275/// : () -> vector<16x1xf16>
1276/// %1 = vector.broadcast %0 {layout_result_0 =
1277/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1278/// : vector<16x1xf16> to vector<16x16xf16>
1279/// ```
1280/// is distributed to (no-op, source already matches distributed result type):
1281/// ```
1282/// %0 = "some_op"() : () -> vector<16x1xf16>
1283/// // broadcast is eliminated, %0 is used directly
1284/// ```
1285///
1286/// Example 3 (scalar to vector broadcast):
1287/// ```
1288/// %0 = "some_op"() : () -> f16
1289/// %1 = vector.broadcast %0 {layout_result_0 =
1290/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1291/// : f16 to vector<16x16xf16>
1292/// ```
1293/// is distributed to:
1294/// ```
1295/// %0 = "some_op"() : f16
1296/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
1297/// ```
1298struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
1299 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1300
1301 LogicalResult
1302 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1303 ConversionPatternRewriter &rewriter) const override {
1304 xegpu::DistributeLayoutAttr resultLayout =
1305 xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
1306 if (!resultLayout || !resultLayout.isForSubgroup())
1307 return rewriter.notifyMatchFailure(
1308 op, "result does not have subgroup distribute layout");
1309
1310 VectorType destType = op.getResultVectorType();
1311 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1312
1313 xegpu::DistributeLayoutAttr sourceLayout =
1314 xegpu::getTemporaryLayout(op->getOpOperand(0));
1315
1316 if (sourceType) {
1317 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1318 if (rankDiff > 0) {
1319 // Case 1: Low-rank to high-rank broadcast.
1320 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1321 op.emitWarning(
1322 "broadcast source layout must be a slice of result layout");
1323 } else if (rankDiff == 0) {
1324 // Case 2: Same-rank broadcast.
1325 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1326 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1327 broadcastUnitDimsSet.end());
1328 assert(sourceLayout.isEqualTo(
1329 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1330 "The sg_data for unit dimensions should be set as 1");
1331 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1332 }
1333 } else {
1334 // Case 3: Scalar to vector broadcast.
1335 if (sourceLayout)
1336 return rewriter.notifyMatchFailure(
1337 op, "broadcast from scalar must not have a layout attribute");
1338 }
1339
1340 auto destDistType =
1341 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1342 if (failed(destDistType))
1343 return rewriter.notifyMatchFailure(
1344 op, "failed to distribute the result vector type");
1345
1346 Value source = adaptor.getSource();
1347 // If the adapted source already matches the dest dist type, it's a no-op.
1348 if (source.getType() == destDistType.value()) {
1349 rewriter.replaceOp(op, source);
1350 return success();
1351 }
1352
1353 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1354 destDistType.value(), source);
1355 rewriter.replaceOp(op, newOp);
1356 return success();
1357 }
1358};
1359
1360/// Distributes a subgroup-level vector.insert_strided_slice op to
1361/// workitem-level. If the dest is distributed, the offsets are adjusted to
1362/// match the distributed types.
1363struct SgToWiVectorInsertStridedSlice
1364 : public OpConversionPattern<vector::InsertStridedSliceOp> {
1365 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1366
1367 LogicalResult
1368 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1369 ConversionPatternRewriter &rewriter) const override {
1370 xegpu::DistributeLayoutAttr resultLayout =
1371 xegpu::getTemporaryLayout(op->getOpResult(0));
1372 if (!resultLayout || !resultLayout.isForSubgroup())
1373 return failure();
1374
1375 VectorType destType = op.getDestVectorType();
1376 auto distDestTyOrFailure =
1377 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1378 if (failed(distDestTyOrFailure))
1379 return rewriter.notifyMatchFailure(
1380 op, "unable to compute distributed vector type from lane layout");
1381 VectorType distDestTy = *distDestTyOrFailure;
1382
1383 SmallVector<int64_t> destDistributedDims =
1384 getDistributedDims(destType, distDestTy);
1385
1386 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1387 op.getOffsets(), [](Attribute attr) { return attr; });
1388
1389 if (!destDistributedDims.empty()) {
1390 if (destDistributedDims.size() != 1)
1391 return rewriter.notifyMatchFailure(
1392 op, "only single dimension distribution is supported");
1393 int64_t destDistDim = destDistributedDims[0];
1394
1395 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1396 if (!uArch)
1397 return rewriter.notifyMatchFailure(
1398 op, "target attribute required to determine subgroup size");
1399 int subgroupSize = uArch->getSubgroupSize();
1400
1401 VectorType srcType = op.getSourceVectorType();
1402 // The distributed dim must be in the last k (source rank) dims of dest.
1403 int64_t sourceDistDim =
1404 destDistDim - (destType.getRank() - srcType.getRank());
1405 if (sourceDistDim < 0)
1406 return rewriter.notifyMatchFailure(
1407 op, "distributed dimension must be in the last k dims of dest");
1408
1409 auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
1410 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1411 if (!destLayout || !sourceLayout ||
1412 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1413 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1414 return rewriter.notifyMatchFailure(
1415 op, "source or dest of insert_strided_slice lacks distribution "
1416 "layout");
1417
1418 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1419 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1420 // Only check lane_data for the distributed dimension. Non-distributed
1421 // dimensions may have non-unit lane_data (e.g., packed layouts).
1422 if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
1423 destLaneData[destDistDim] != 1) ||
1424 (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
1425 sourceLaneData[sourceDistDim] != 1))
1426 return rewriter.notifyMatchFailure(
1427 op, "expecting unit lane data along the distributed dimension");
1428
1429 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1430 if (srcDistrDimSize % subgroupSize != 0)
1431 return rewriter.notifyMatchFailure(
1432 op, "source distributed dim size is not a multiple of "
1433 "subgroup size");
1434
1435 int64_t destDistrDimOffset =
1436 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1437 if (destDistrDimOffset % subgroupSize != 0)
1438 return rewriter.notifyMatchFailure(
1439 op, "offset along distributed dim is not a multiple of "
1440 "subgroup size");
1441 // Adjust offset for the distributed dimension.
1442 updatedOffsets[destDistDim] =
1443 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1444 }
1445
1446 auto newOp = vector::InsertStridedSliceOp::create(
1447 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1448 adaptor.getDest(),
1449 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1450 rewriter.replaceOp(op, newOp.getResult());
1451 return success();
1452 }
1453};
1454
1455/// Distributes a subgroup-level vector.insert op to workitem-level. Only
1456/// handles sub-vector insertion (value to store is VectorType, not scalar).
1457struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
1458 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1459
1460 LogicalResult
1461 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter) const override {
1463 // Only handle vector value-to-store (not scalar insertion).
1464 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1465 if (!valueType)
1466 return rewriter.notifyMatchFailure(op, "scalar insert not supported");
1467
1468 xegpu::DistributeLayoutAttr layout =
1469 xegpu::getTemporaryLayout(op->getOpResult(0));
1470 if (!layout || !layout.isForSubgroup())
1471 return failure();
1472
1473 // verify that the outer k dimensions (for offsets)
1474 // don't have non-unit lane_layout.
1475 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1476 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1477 [](int64_t v) { return v != 1; }))
1478 return rewriter.notifyMatchFailure(
1479 op, "only innermost dimension distribution is supported for "
1480 "vector.insert");
1481
1482 auto newOp = vector::InsertOp::create(
1483 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1484 op.getMixedPosition());
1485 rewriter.replaceOp(op, newOp.getResult());
1486 return success();
1487 }
1488};
1489
1490/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
1491struct SgToWiConvertLayout
1492 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
1493 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1494
1495 LogicalResult
1496 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1497 ConversionPatternRewriter &rewriter) const override {
1498 auto inputLayout = op.getInputLayoutAttr();
1499 auto targetLayout = op.getTargetLayoutAttr();
1500 auto resShape = cast<VectorType>(op.getResult().getType()).getShape();
1501 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1502
1503 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1505 return rewriter.notifyMatchFailure(
1506 op, "lowering incompatible convert_layout not yet supported");
1507 }
1508 rewriter.replaceOp(op, adaptor.getSource());
1509 return success();
1510 }
1511};
1512
1513struct XeGPUSgToWiDistributeExperimentalPass
1514 : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
1515 XeGPUSgToWiDistributeExperimentalPass> {
1516 void runOnOperation() override;
1517};
1518
1519} // namespace
1520
1521void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
1522
1523 // Recover temporary operand layouts for usage in patterns.
1524 Operation *root = getOperation();
1525 if (!xegpu::recoverTemporaryLayouts(root)) {
1526 signalPassFailure();
1527 return;
1528 }
1529
1530 // Verify if all XeGPU anchor ops and vector ops have result layouts.
1531 // TODO: This can be removed once the full layout refactoring is done.
1532 if (failed(verifyLayouts(root))) {
1533 LLVM_DEBUG(DBGS() << "XeGPUSgToWiDistributeExperimentalPass: layout "
1534 "verification failed\n");
1535 signalPassFailure();
1536 return;
1537 }
1538 // Collect existing UnrealizedConversionCastOps. These must be preserved.
1539 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1540 root->walk(
1541 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1542 // Perform a structural type conversion to convert structural ops to have WI
1543 // types. This will insert UnrealizedConversionCastOps to make the IR
1544 // valid.
1545 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1546 mlir::ValueRange inputs,
1547 mlir::Location loc) -> mlir::Value {
1548 UnrealizedConversionCastOp castOp =
1549 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1550 return castOp.getResult(0);
1551 };
1552 {
1553 ConversionTarget target(getContext());
1554 TypeConverter typeConverter;
1555 RewritePatternSet patterns(&getContext());
1556 typeConverter.addSourceMaterialization(materializeCast);
1557 typeConverter.addTargetMaterialization(materializeCast);
1560 patterns, target);
1562 typeConverter, patterns, target);
1563 target.addLegalOp<UnrealizedConversionCastOp>();
1564 (void)applyPartialConversion(root, target, std::move(patterns));
1565 }
1566 // Structural type conversion can generate some redundant
1567 // UnrealizedConversionCastOps to materialize the SG type from type converted
1568 // WI type. These are redundant at this point and can be eliminated by
1569 // inserting shape casts instead.
1570 // Example:
1571 // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
1572 // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
1573 // This can be replaced with:
1574 // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
1575 OpBuilder builder(root);
1576 root->walk([&](UnrealizedConversionCastOp op) {
1577 // If this op existed before, nothing to do.
1578 if (existingCasts.contains(op))
1579 return;
1580 // number of inputs and outputs must be 1.
1581 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
1582 return;
1583 // Both input and output types must be vector types.
1584 auto singleInput = op.getInputs()[0];
1585 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
1586 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
1587 if (!inputTy || !outputTy)
1588 return;
1589
1590 // Check if the defining op of the input is also an
1591 // UnrealizedConversionCastOp and it has a single user (which is this
1592 // op).
1593 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
1594 if (!definingOp || !definingOp->hasOneUse())
1595 return;
1596 auto inputOfDefiningOp = definingOp.getInputs()[0];
1597 // If the input of the defining op and output type are both vector types
1598 // have same number of elements, insert a shape cast.
1599 auto inputOfDefiningOpTy =
1600 dyn_cast<VectorType>(inputOfDefiningOp.getType());
1601 if (inputOfDefiningOpTy &&
1602 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
1603 builder.setInsertionPoint(op);
1604 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
1605 outputTy, inputOfDefiningOp);
1606 op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
1607 return;
1608 }
1609 });
1610 // At this point, we will have some dead UnrealizedConversionCastOps. Just
1611 // erase them.
1612 bool changed = true;
1613 while (changed) {
1614 changed = false;
1615 root->walk([&](UnrealizedConversionCastOp op) {
1616 // Skip existing casts.
1617 if (existingCasts.contains(op))
1618 return;
1619 if (op.use_empty()) {
1620 op.erase();
1621 changed = true;
1622 }
1623 });
1624 }
1625}
1626
1628 TypeConverter &typeConverter) {
1629 // Any type other than TensorDescType and VectorType are legal as is.
1630 typeConverter.addConversion([](Type type) -> std::optional<Type> {
1631 if (!isa<TensorDescType, VectorType>(type))
1632 return type;
1633 return std::nullopt;
1634 });
1635 // For TensorDescType, drop the layout attribute if any.
1636 typeConverter.addConversion([](TensorDescType type) -> Type {
1637 if (type.getLayoutAttr()) {
1638 return type.dropLayouts();
1639 }
1640 return type;
1641 });
1642 // For VectorType, check if there is a distribute layout attribute on the
1643 // value. If so, convert to the distributed vector type based on the layout.
1644 typeConverter.addConversion([](Value v) -> std::optional<Type> {
1645 auto type = v.getType();
1646 // If value is not vector type, nothing to do.
1647 if (!isa<VectorType>(type))
1648 return std::nullopt;
1649 auto layout = xegpu::getDistributeLayoutAttr(v);
1650 if (!layout || !layout.isForSubgroup())
1651 return type;
1652 // Vector type is distributed based on lane layout.
1653 auto newTyOrFailure =
1654 getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
1655 if (failed(newTyOrFailure))
1656 return type;
1657 return *newTyOrFailure;
1658 });
1659}
1660
1662 TypeConverter &typeConverter, RewritePatternSet &patterns,
1665 // CreateNdDescOp is legal only if its result type has no layout attribute.
1666 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1667 [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
1668 // Any anchor XeGPU op is legal only if it has no anchor layout.
1669 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
1670 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1671 if (!anchorOp)
1672 return true;
1673 return !anchorOp.getAnchorLayout();
1674 });
1675 // Arith constants are legal only if they have no temporary layout attribute.
1676 target.addDynamicallyLegalOp<arith::ConstantOp>(
1677 [=](arith::ConstantOp op) -> bool {
1678 // If the result type is not a vector, it's legal.
1679 if (!isa<VectorType>(op.getResult().getType()))
1680 return true;
1681 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1682 });
1683 // In math and arith dialects, only handle elementwise ops with a single
1684 // result and with a result layout attribute.
1685 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1686 [=](Operation *op) -> std::optional<bool> {
1687 // Only handle elementwise mappable ops
1689 return true;
1690 // Only handle ops with single vector result
1691 if (op->getNumResults() != 1)
1692 return true;
1693
1694 VectorType resultType =
1695 dyn_cast<VectorType>(op->getResult(0).getType());
1696 if (!resultType)
1697 return true;
1698
1699 // Check if all operands are vectors of the same shape
1700 for (Value operand : op->getOperands()) {
1701 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1702 if (!operandType || operandType.getShape() != resultType.getShape()) {
1703 return true;
1704 }
1705 }
1706 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1707 });
1708 // vector::ReductionOp is legal only if its source has no distribute layout
1709 // attribute.
1710 target.addDynamicallyLegalOp<vector::ReductionOp>(
1711 [=](vector::ReductionOp op) -> bool {
1712 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
1713 return !layout;
1714 });
1715 // vector::MultiDimReductionOp op legality.
1716 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1717 [=](vector::MultiDimReductionOp op) -> bool {
1718 return !isValidSubgroupMultiReductionOp(op);
1719 });
1720 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1721 vector::TransposeOp, vector::BitCastOp,
1722 vector::ShapeCastOp, vector::StepOp,
1723 vector::BroadcastOp>([=](Operation *op) -> bool {
1724 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1725 });
1726 target.addDynamicallyLegalOp<vector::ExtractOp>(
1727 [=](vector::ExtractOp op) -> bool {
1728 if (!isa<VectorType>(op.getType()))
1729 return true;
1730 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1731 });
1732 target.addDynamicallyLegalOp<vector::InsertOp>(
1733 [=](vector::InsertOp op) -> bool {
1734 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1735 });
1736 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1737 [=](vector::ExtractStridedSliceOp op) -> bool {
1738 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1739 });
1740 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1741 [=](vector::InsertStridedSliceOp op) -> bool {
1742 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1743 });
1744 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
1745 patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
1746 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
1747 SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
1748 SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
1749 SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
1750 SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
1751 SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
1752 SgToWiVectorShapeCast, SgToWiBroadcast,
1753 SgToWiCreateMask<vector::CreateMaskOp>,
1754 SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
1755 patterns.getContext());
1756}
return success()
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
virtual int getSubgroupSize() const =0