MLIR 23.0.0git
XeGPUSgToWiDistributeExperimental.cpp
Go to the documentation of this file.
1//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
20#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/ValueRange.h"
29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34namespace mlir {
35namespace xegpu {
36#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38} // namespace xegpu
39} // namespace mlir
40
41using namespace mlir;
42
43#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45
46namespace {
47
48/// Casts the given vector value `v` to the expected vector type `expectedTy`.
49static Value castValueTo(ConversionPatternRewriter &rewriter,
50 TypedValue<VectorType> v, VectorType expectedTy) {
51 // If the type matches, simply return the value itself.
52 if (v.getType() == expectedTy)
53 return v;
54 // If only shape differs, use shape cast.
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
58
59 // Else create an unrealized cast.
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
61 expectedTy, ValueRange{v});
62 return newOp.getResult(0);
63}
64
65/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
66/// exactly 1 reduction dimension, it had valid result layout attribute, and
67/// result type can be distributed to lanes using the layout.
68static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
69 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
70 // If no layout, not valid.
71 if (!resLayout || !resLayout.isForSubgroup())
72 return false;
73 // Scalar result (e.g., vector<32xf32> to f32) is valid.
74 if (op.getType().isIntOrFloat())
75 return op.getReductionDims().size() == 1;
76 VectorType resTy = dyn_cast<VectorType>(op.getType());
77 if (!resTy)
78 return false;
79 // Compute the distributed result vector type based on the layout.
80 FailureOr<VectorType> resDistTypeOrFailure =
81 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
82 if (failed(resDistTypeOrFailure))
83 return false;
84 return op.getReductionDims().size() == 1;
85}
86
87/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
88/// is doing its own local reduction. In this case the result layout ensures
89/// that result vector is distributed to lanes, i.e. the result vector type is
90/// different from the distributed result vector type.
91static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
92 // Must be valid MultiDimReductionOp.
93 assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
94 "MultiDimReductionOp");
95 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
96 VectorType resTy = dyn_cast<VectorType>(op.getType());
97 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
98 return resTy != resDistTypeOrFailure.value();
99}
100
101/// Given a vector type and its distributed vector type, return the list of
102/// dimensions that are distributed.
103static SmallVector<int64_t> getDistributedDims(VectorType originalType,
104 VectorType distributedType) {
105 assert(originalType.getRank() == distributedType.getRank() &&
106 "original and distributed vector types must have the same rank");
107 SmallVector<int64_t> distributedDims;
108 for (int64_t i = 0; i < originalType.getRank(); ++i) {
109 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
110 distributedDims.push_back(i);
111 }
112 return distributedDims;
113}
114
115/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
116/// op. This simply drops the layout attribute from the tensor descriptor type.
117struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
118 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
119
120 LogicalResult
121 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override {
123 xegpu::TensorDescType resultType = op.getType();
124 // If no layout, nothing to do.
125 if (!resultType.getLayout())
126 return failure();
127
128 auto newOp = xegpu::CreateNdDescOp::create(
129 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
130 op->getAttrs());
131 rewriter.replaceOp(op, newOp.getResult());
132 return success();
133 }
134};
135
136/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
137/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
138/// original rank.
139struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
140 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
141
142 LogicalResult
143 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
146 // If no layout, nothing to do.
147 if (!layout)
148 return failure();
149 // Check if the layout attached to the tensor descriptor is same as the
150 // anchor layout. Otherwise, this is a conflict.
151 if (op.getTensorDescType().getLayout() != layout)
152 return rewriter.notifyMatchFailure(
153 op, "conflicting layout attributes on tensor descriptor and anchor");
154 auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
155 if (!uArch)
156 return rewriter.notifyMatchFailure(
157 op, "xegpu::LoadNdOp require target attribute attached to "
158 "determine transpose "
159 "requirement");
160 auto supportedWiResultTyOrFailure =
161 xegpu::getDistributedVectorType(op.getTensorDescType());
162 auto expectedWiResultTyOrFailure =
163 xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
164 if (failed(supportedWiResultTyOrFailure))
165 return rewriter.notifyMatchFailure(
166 op, "unable to compute the workitem vector type for LoadNdOp");
167 if (failed(expectedWiResultTyOrFailure))
168 return rewriter.notifyMatchFailure(
169 op,
170 "unable to compute expected workitem vector type from lane layout");
171 auto newOp = xegpu::LoadNdOp::create(
172 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
173 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
174 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
175 op.getL3HintAttr(), /**layout**/ nullptr);
176 // Set the packed attribute if the layout requires it.
177 newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
178 // Set the transpose attribute if the layout requires it.
179 if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
180 newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
181 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
182 expectedWiResultTyOrFailure.value()));
183 return success();
184 }
185};
186
187/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
188/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
189/// incoming value to 1D.
190struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
191 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
192
193 LogicalResult
194 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter) const override {
196 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
197 // If no layout, nothing to do.
198 if (!layout)
199 return failure();
200 // Check if the layout attached to the tensor descriptor and value layout is
201 // same as the anchor layout. Otherwise, this is a conflict.
202 if (op.getTensorDescType().getLayout() != layout)
203 return rewriter.notifyMatchFailure(
204 op, "conflicting layout attributes on tensor descriptor and anchor");
205 auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
206 if (valueLayout != layout)
207 return rewriter.notifyMatchFailure(
208 op, "conflicting layout attributes on value and anchor");
209 auto supportedWiValueTyOrFailure =
210 xegpu::getDistributedVectorType(op.getTensorDescType());
211 if (failed(supportedWiValueTyOrFailure))
212 return rewriter.notifyMatchFailure(
213 op,
214 "unable to compute wi vector type for StoreNdOp value from tensor "
215 "descriptor");
216
217 xegpu::StoreNdOp::create(
218 rewriter, op.getLoc(),
219 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
220 supportedWiValueTyOrFailure.value()),
221 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
222 op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
223 rewriter.eraseOp(op);
224 return success();
225 }
226};
227
228/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
229/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
230/// convert the inputs and output to/from 1D.
231struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
232 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
233
234 LogicalResult
235 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const override {
237 // Check if the op has A, B and CD layouts attached.
238 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
239 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
240 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
241 if (!layoutA || !layoutB || !layoutCd)
242 return failure();
243 auto wiResultTyOrFailure =
244 xegpu::getDistributedVectorType(op.getType(), layoutCd);
245 auto wiATypeOrFailure =
246 xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
247 auto wiBTypeOrFailure =
248 xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
249 auto expectedWiResultTyOrFailure =
250 xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
251 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
252 failed(wiBTypeOrFailure))
253 return rewriter.notifyMatchFailure(
254 op, "failed to calculate supported workitem vector types for DpasOp "
255 "from layouts");
256 if (failed(expectedWiResultTyOrFailure))
257 return rewriter.notifyMatchFailure(
258 op, "unable to compute expected workitem vector type for DpasOp from "
259 "lane layout");
260
261 // Validate bit widths match uArch packed format requirements
262 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
263 if (uArch) {
264 const auto *uArchInstruction =
265 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
268 if (uArchInstruction) {
269 auto wiAType = wiATypeOrFailure.value();
270 auto wiBType = wiBTypeOrFailure.value();
271 // Calculate total packed bit width = element bit width * vector size
272 unsigned aPackedBitWidth =
273 wiAType.getElementTypeBitWidth() * wiAType.getNumElements();
274 unsigned bPackedBitWidth =
275 wiBType.getElementTypeBitWidth() * wiBType.getNumElements();
276 unsigned expectedABitSize = uArchInstruction->getPackedFormatBitSizeA();
277 unsigned expectedBBitSize = uArchInstruction->getPackedFormatBitSizeB();
278
279 if (aPackedBitWidth % expectedABitSize != 0)
280 return rewriter.notifyMatchFailure(
281 op,
282 "A operand packed bit width must be a multiple of uArch packed "
283 "format requirement");
284 if (bPackedBitWidth % expectedBBitSize != 0)
285 return rewriter.notifyMatchFailure(
286 op,
287 "B operand packed bit width must be a multiple of uArch packed "
288 "format requirement");
290 }
292 auto newOp = xegpu::DpasOp::create(
293 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
294 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
295 wiATypeOrFailure.value()),
296 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
297 wiBTypeOrFailure.value()),
298 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
299 wiResultTyOrFailure.value()),
300 /** layoutA**/ nullptr,
301 /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
302 // Explicitly set the new types to enable correct type materializations.
303 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
304 expectedWiResultTyOrFailure.value()));
305 return success();
307};
308
309/// Distributes elementwise ops to workitem-level elementwise ops. This
310/// currently handles elementwise ops with single result only.
311struct SgToWiElementWise : public ConversionPattern {
312 SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
313 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
314
315 LogicalResult
316 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
317 ConversionPatternRewriter &rewriter) const override {
318 // Only match ops with elementwise trait and single result.
320 return failure();
321
322 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
323 if (!resultType)
324 return rewriter.notifyMatchFailure(
325 op, "operation result is not a vector type");
326
327 xegpu::DistributeLayoutAttr layout =
328 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
329 if (!layout || !layout.isForSubgroup())
330 return rewriter.notifyMatchFailure(
331 op, "operation result does not have subgroup distribute layout");
332
333 auto wiShapeOrFailure =
334 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
335
336 if (failed(wiShapeOrFailure))
337 return rewriter.notifyMatchFailure(
338 op, "unable to compute workitem vector type from the layout");
339
340 VectorType newResultType = wiShapeOrFailure.value();
341 OperationState state(op->getLoc(), op->getName());
342 state.addOperands(operands);
343 state.addTypes(newResultType);
344 // Copy all attributes except for DistributeLayoutAttr.
345 for (auto attr : op->getAttrs()) {
346 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
347 state.addAttribute(attr.getName(), attr.getValue());
348 }
349 Operation *newOp = rewriter.create(state);
350
351 rewriter.replaceOp(op, newOp->getResult(0));
352 return success();
353 }
354};
355
356/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
357/// ConstantOp.
358struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
359 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
360
361 LogicalResult
362 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override {
364 auto resultType = dyn_cast<VectorType>(op.getType());
365 if (!resultType)
366 return failure();
367
368 // Only handle dense vector constants
369 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
370 if (!dense)
371 return rewriter.notifyMatchFailure(
372 op, "only dense splat vector constants are supported");
373
374 xegpu::DistributeLayoutAttr layout =
375 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
376 if (!layout || !layout.isForSubgroup())
377 return rewriter.notifyMatchFailure(
378 op, "operation result does not have subgroup distribute layout");
379
380 auto wiShapeOrFailure =
381 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
382
383 if (failed(wiShapeOrFailure))
384 return rewriter.notifyMatchFailure(
385 op, "unable to compute workitem vector type from the layout");
386
387 VectorType newResultType = wiShapeOrFailure.value();
388 auto sclarValue = dense.getSplatValue<Attribute>();
389 auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
390
391 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
392 newDenseAttr);
393 rewriter.replaceOp(op, newOp.getResult());
394 return success();
395 }
396};
397
398/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
399struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
400 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
401
402 LogicalResult
403 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter) const override {
405 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
406 // If no layout, nothing to do.
407 if (!layout)
408 return failure();
409
410 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
411 op.getMixedOffsets(), op.getL1HintAttr(),
412 op.getL2HintAttr(), op.getL3HintAttr(),
413 /**layout**/ nullptr);
414 rewriter.eraseOp(op);
415 return success();
416 }
417};
418
419/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
420///
421/// Example 1 (1D, no chunk size):
422/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
423/// %mask = producer_op : vector<16xi1>
424/// %offset = producer_op : vector<16xindex>
425/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
426/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
427/// Distributed to:
428/// %mask = producer_op : vector<1xi1>
429/// %offset = producer_op : vector<1xindex>
430/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
431/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
432///
433/// Example 2 (2D with chunk size, same mask & offset):
434/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
435/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
436/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
437/// Distributed to:
438/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
439/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
440///
441/// Example 3 (3D with leading unit dims):
442/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
443/// %mask = producer_op : vector<1x1x16xi1>
444/// %offset = producer_op : vector<1x1x16xindex>
445/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
446/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
447/// Distributed to:
448/// %mask = producer_op : vector<1x1x1xi1>
449/// %offset = producer_op : vector<1x1x1xindex>
450/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
451/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
452struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
453 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
454
455 LogicalResult
456 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter) const override {
458 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
459 if (!layout)
460 return failure();
461
462 VectorType origResultTy = op.getValueType();
463 if (!origResultTy)
464 return failure();
465
466 // Check that leading dimensions are unit.
467 int chunkSize = op.getChunkSize().value_or(1);
468 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
469 ArrayRef<int64_t> shape = origResultTy.getShape();
470 if (llvm::any_of(
471 shape.take_front(origResultTy.getRank() - effectiveVecRank),
472 [](int64_t d) { return d != 1; }))
473 return rewriter.notifyMatchFailure(
474 op, "Only unit dimensions allowed for the leading "
475 "dimensions of the load vector!");
476
477 auto distResultTyOrFailure =
478 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
479 if (failed(distResultTyOrFailure))
480 return rewriter.notifyMatchFailure(
481 op,
482 "unable to compute expected workitem vector type from lane layout");
483
484 VectorType distResultTy = distResultTyOrFailure.value();
485 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
486 distResultTy.getElementType());
487
488 // Flatten offsets and mask to 1D to match the 1D result type.
489 Value distOffsets = adaptor.getOffsets();
490 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
491 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
492 distOffsetsTy.getElementType());
493 distOffsets = castValueTo(
494 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
495
496 Value distMask = adaptor.getMask();
497 auto distMaskTy = cast<VectorType>(distMask.getType());
498 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
499 distMaskTy.getElementType());
500 distMask =
501 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
502
503 Value distSource = adaptor.getSource();
504 auto newOp = xegpu::LoadGatherOp::create(
505 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
506 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
507 op.getL3HintAttr(), /*layout=*/nullptr);
508
509 Value result = newOp->getResult(0);
510 if (distResultTy1D != distResultTy)
511 result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
512 distResultTy);
513 rewriter.replaceOp(op, result);
514 return success();
515 }
516};
517
518/// This pattern distributes a subgroup-level vector.reduction op to
519/// workitem-level. This require shuffling the data across the workitems (using
520/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
521/// result.
522struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
523 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
524
525 LogicalResult
526 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter) const override {
528 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
529
530 // If no layout, nothing to do.
531 if (!layout || !layout.isForSubgroup())
532 return failure();
533
534 VectorType srcVecType = op.getSourceVectorType();
535 // Only rank 1 vectors supported.
536 if (srcVecType.getRank() != 1)
537 return rewriter.notifyMatchFailure(
538 op, "Only rank 1 reductions can be distributed.");
539 // Lane layout must have the same rank as the vector.
540 if (layout.getRank() != srcVecType.getRank())
541 return rewriter.notifyMatchFailure(
542 op, "Layout rank does not match vector rank.");
543
544 // Get the subgroup size from the layout.
545 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
546 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
547 if (!uArch)
548 return rewriter.notifyMatchFailure(
549 op, "xegpu::ReductionOp require target attribute attached to "
550 "determine subgroup size");
551
552 // Only subgroup-sized vectors supported.
553 if (sgSize != uArch->getSubgroupSize() ||
554 srcVecType.getShape()[0] % sgSize != 0)
555 return rewriter.notifyMatchFailure(op,
556 "Invalid layout or reduction vector "
557 "dimension must match subgroup size.");
558
559 if (!op.getType().isIntOrFloat())
560 return rewriter.notifyMatchFailure(
561 op, "Reduction distribution currently only supports floats and "
562 "integer types.");
563
564 // Get the distributed vector (per work-item portion).
565 Value laneValVec = adaptor.getVector();
566
567 // Distribute and reduce across work-items in the subgroup.
568 Value fullReduce = xegpu::subgroupReduction(
569 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
570
571 // If there's an accumulator, combine it with the reduced value.
572 if (adaptor.getAcc())
573 fullReduce = vector::makeArithReduction(
574 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
575
576 rewriter.replaceOp(op, fullReduce);
577 return success();
578 }
579};
580
581/// This pattern distributes a subgroup-level vector.multi_reduction op to
582/// workitem-level only if the reduction is lane-local. This means that
583/// reduction dimension is not distributed to lanes and each lane does its own
584/// local reduction.
585struct SgToWiMultiDimReduction
586 : public OpConversionPattern<vector::MultiDimReductionOp> {
587 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
588
589 LogicalResult
590 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter) const override {
592 Value result;
593 ArrayRef<int64_t> reductionDims = op.getReductionDims();
594 assert(reductionDims.size() == 1 &&
595 "Expecting single reduction dimension for subgroup multi "
596 "reduction op");
597 // For rank > 2, ensure leading dimensions are unit.
598 VectorType sourceType = op.getSourceVectorType();
599 int64_t rank = sourceType.getRank();
600 if (rank > 2) {
601 ArrayRef<int64_t> shape = sourceType.getShape();
602 if (llvm::any_of(shape.take_front(rank - 2),
603 [](int64_t d) { return d != 1; }))
604 return rewriter.notifyMatchFailure(
605 op, "only unit leading dimensions are supported for "
606 "multi_reduction with rank > 2");
607 }
608 // Handle scalar result: full reduction of a distributed vector to a
609 // scalar. First do a local vector reduction, then cross-lane shuffles.
610 if (op.getType().isIntOrFloat()) {
611 auto reductionDim = reductionDims[0];
612 VectorType origSourceType = op.getSourceVectorType();
613 int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
614 // Local reduction to scalar, then cross-lane butterfly shuffles.
615 result =
616 xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
617 op.getKind(), reductionDimSize);
618 // Combine with accumulator if present.
619 if (adaptor.getAcc())
620 result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
621 result, adaptor.getAcc());
622 } else if (isReductionLaneLocal(op)) {
623 // For lane-local reduction, lower to a sequence of vector.reduction ops
624 // over 1D slices extracted from the distributed source vector. This is
625 // required so we dont have 2D source vectors at xegpu-linearize.
626 auto reductionDim = reductionDims[0];
628 cast<TypedValue<VectorType>>(adaptor.getSource()),
629 cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
630 reductionDim, op.getLoc(), rewriter);
631 } else {
632 auto reductionDim = reductionDims[0];
633 VectorType sourceType = op.getSourceVectorType();
634 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
636 cast<TypedValue<VectorType>>(adaptor.getSource()),
637 cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
638 reductionDim, reductionDimSize, op.getLoc(), rewriter);
639 }
640 rewriter.replaceOp(op, result);
641 return success();
642 }
643};
644
645/// Helper to compute distributed coordinates for matrix ops.
646/// When not using subgroup_block_io, each workitem computes its own
647/// coordinates based on the layout and lane ID.
648static SmallVector<Value> computeDistributedCoordsForMatrixOp(
649 ConversionPatternRewriter &rewriter, Location loc,
650 xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
651 ValueRange origOffsets) {
652 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
653 /*upperBound=*/mlir::IntegerAttr());
654 auto maybeCoords =
655 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
656 if (failed(maybeCoords))
657 return {};
658 assert(maybeCoords.value().size() == 1 &&
659 "Expected one set of distributed offsets");
661 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
662 getAsOpFoldResult(origOffsets));
663 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
664}
665
666/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
667struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
668 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
669
670 LogicalResult
671 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
672 ConversionPatternRewriter &rewriter) const override {
673 auto layout = op.getLayoutAttr();
674 // If no layout, nothing to do.
675 if (!layout)
676 return failure();
677
678 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
679 if (!sgPayloadTy)
680 return rewriter.notifyMatchFailure(
681 op, "the matrix op payload must be a vector type");
682
683 auto loc = op.getLoc();
684 auto offsets = op.getMixedOffsets();
685 if (offsets.empty())
686 return rewriter.notifyMatchFailure(op, "the load op must have offsets");
687
688 FailureOr<VectorType> distPayloadTyOrFailure =
689 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
690 if (failed(distPayloadTyOrFailure))
691 return rewriter.notifyMatchFailure(
692 op, "Failed to distribute matrix op payload based on layout.");
693
694 SmallVector<Value> offsetsAsValues =
695 vector::getAsValues(rewriter, loc, offsets);
696
697 SmallVector<Value> newCoords = offsetsAsValues;
698 if (!op.getSubgroupBlockIoAttr()) {
699 newCoords = computeDistributedCoordsForMatrixOp(
700 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
701 if (newCoords.empty())
702 return rewriter.notifyMatchFailure(
703 op, "Failed to compute distributed coordinates.");
704 }
705
706 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
707 ShapedType::kDynamic);
708 DenseI64ArrayAttr newConstOffsetsAttr =
709 rewriter.getDenseI64ArrayAttr(newConstOffsets);
710
711 auto newOp = xegpu::LoadMatrixOp::create(
712 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
713 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
714 xegpu::DistributeLayoutAttr{});
715 rewriter.replaceOp(op, newOp.getResult());
716 return success();
717 }
718};
719
720/// Distributes a subgroup-level vector.transpose op to workitem-level.
721struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
722 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
723
724 LogicalResult
725 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter) const override {
727 xegpu::DistributeLayoutAttr sourceLayout =
728 xegpu::getTemporaryLayout(op->getOpOperand(0));
729 xegpu::DistributeLayoutAttr resultLayout =
730 xegpu::getTemporaryLayout(op->getOpResult(0));
731 if (!sourceLayout || !resultLayout)
732 return rewriter.notifyMatchFailure(
733 op, "the source or result vector of the transpose op lacks layout "
734 "attribute");
735 ArrayRef<int64_t> perm = op.getPermutation();
736 // Result layout must be a transpose of source layout.
737 if (!resultLayout.isTransposeOf(sourceLayout, perm,
738 xegpu::LayoutKind::Lane))
739 return rewriter.notifyMatchFailure(
740 op, "the source or result vector layouts must be transposes of "
741 "each other");
742 FailureOr<VectorType> distributedResultTypeOrFailure =
743 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
744 if (failed(distributedResultTypeOrFailure))
745 return rewriter.notifyMatchFailure(
746 op, "Failed to distribute the result vector type in "
747 "vector::Transpose op");
748 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
749 adaptor.getVector(), perm);
750 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
751 distributedResultTypeOrFailure.value()));
752 return success();
753 }
754};
755
756/// Distributes a subgroup-level vector.bitcast op to workitem-level.
757/// Bitcast only impacts the innermost dimension of the source/result vectors.
758struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
759 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
760
761 LogicalResult
762 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
763 ConversionPatternRewriter &rewriter) const override {
764 xegpu::DistributeLayoutAttr resultLayout =
765 xegpu::getTemporaryLayout(op->getOpResult(0));
766 if (!resultLayout)
767 return rewriter.notifyMatchFailure(
768 op, "result vector of the bitcast op lacks layout attribute");
769 FailureOr<VectorType> distributedResultTypeOrFailure =
770 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
771 if (failed(distributedResultTypeOrFailure))
772 return rewriter.notifyMatchFailure(
773 op, "Failed to distribute the result vector type in "
774 "vector::BitCast op");
775 auto newOp = vector::BitCastOp::create(
776 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
777 adaptor.getSource());
778 rewriter.replaceOp(op, newOp.getResult());
779 return success();
780 }
781};
782
783/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
784/// to workitem-level. Uses `computeDistributedCoords()` to obtain the
785/// coordinates each workitem owns, then compares each coordinate against the
786/// original mask bounds using `arith.cmpi slt`. The per-element boolean
787/// results are assembled into the distributed mask vector.
788///
789/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
790/// satisfy `coord[i] < bound[i]`.
791///
792/// Example (1D):
793/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
794/// %mask = vector.create_mask %m0 : vector<16xi1>
795/// For lane k, computeDistributedCoords gives coord = [k], so:
796/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1
797/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
798///
799/// Example (2D):
800/// layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
801/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
802/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
803/// [[r0, c0], [r0, c1]]
804/// For each coord: in_bounds = (r < m0) && (c < m1)
805/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
806template <typename OpType,
807 typename = std::enable_if_t<llvm::is_one_of<
808 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
809struct SgToWiCreateMask : public OpConversionPattern<OpType> {
810 using OpConversionPattern<OpType>::OpConversionPattern;
811
812 LogicalResult
813 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
814 ConversionPatternRewriter &rewriter) const override {
815 xegpu::DistributeLayoutAttr layout =
816 xegpu::getTemporaryLayout(op->getOpResult(0));
817 if (!layout || !layout.isForSubgroup())
818 return rewriter.notifyMatchFailure(
819 op, "operation result does not have subgroup distribute layout");
820
821 VectorType origType = op.getType();
822 FailureOr<VectorType> distTypeOrFailure =
823 getDistVecTypeBasedOnLaneLayout(layout, origType);
824 if (failed(distTypeOrFailure))
825 return rewriter.notifyMatchFailure(
826 op, "unable to compute workitem vector type from the layout");
827
828 VectorType distType = distTypeOrFailure.value();
829 Location loc = op.getLoc();
830
831 // Materialize the original mask bounds as Values.
832 SmallVector<Value> origBounds;
833 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
834 origBounds.append(op.getOperands().begin(), op.getOperands().end());
835 } else {
836 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
837 for (auto dimSize : dimSizes)
838 origBounds.push_back(
839 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
840 }
841
842 ArrayRef<int64_t> origShape = origType.getShape();
843
844 // Use computeDistributedCoords to get the coordinates each WI owns.
845 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
846 /*upperBound=*/mlir::IntegerAttr());
847 auto maybeCoordsVec =
848 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
849 if (failed(maybeCoordsVec))
850 return rewriter.notifyMatchFailure(
851 op, "failed to compute distributed coordinates from layout");
852
853 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
854 int64_t numElements = distType.getNumElements();
855 assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
856 "number of coordinate sets must match number of distributed "
857 "elements");
858
859 // For each element, compare all coordinates against bounds.
860 Value trueVal =
861 arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
862 SmallVector<Value> maskBits;
863 for (auto &coords : coordsVec) {
864 Value inBounds = trueVal;
865 for (size_t i = 0; i < coords.size(); ++i) {
866 Value cmp = arith::CmpIOp::create(
867 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
868 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
869 }
870 maskBits.push_back(inBounds);
871 }
872
873 // Build the distributed mask vector.
874 Value result;
875 if (numElements == 1) {
876 result =
877 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
878 } else {
879 result =
880 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
881 }
882 rewriter.replaceOp(op, result);
883 return success();
884 }
885};
886
887/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
888struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
889 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
890
891 LogicalResult
892 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter) const override {
894 auto layout = op.getLayoutAttr();
895 // If no layout, nothing to do.
896 if (!layout)
897 return failure();
898
899 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
900 if (!sgPayloadTy)
901 return rewriter.notifyMatchFailure(
902 op, "the matrix op payload must be a vector type");
903
904 auto loc = op.getLoc();
905 auto offsets = op.getMixedOffsets();
906 if (offsets.empty())
907 return rewriter.notifyMatchFailure(op, "the store op must have offsets");
908
909 FailureOr<VectorType> distPayloadTyOrFailure =
910 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
911 if (failed(distPayloadTyOrFailure))
912 return rewriter.notifyMatchFailure(
913 op, "Failed to distribute matrix op payload based on layout.");
914
915 SmallVector<Value> offsetsAsValues =
916 vector::getAsValues(rewriter, loc, offsets);
917
918 SmallVector<Value> newCoords = offsetsAsValues;
919 if (!op.getSubgroupBlockIoAttr()) {
920 newCoords = computeDistributedCoordsForMatrixOp(
921 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
922 if (newCoords.empty())
923 return rewriter.notifyMatchFailure(
924 op, "Failed to compute distributed coordinates.");
925 }
926
927 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
928 ShapedType::kDynamic);
929 DenseI64ArrayAttr newConstOffsetsAttr =
930 rewriter.getDenseI64ArrayAttr(newConstOffsets);
931
932 xegpu::StoreMatrixOp::create(
933 rewriter, loc, TypeRange{},
934 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
935 distPayloadTyOrFailure.value()),
936 adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
937 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
938 rewriter.eraseOp(op);
939 return success();
940 }
941};
942
943/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
944/// workitem-level.
945///
946/// Example 1 (1D, no chunk size):
947/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
948/// %mask = producer_op : vector<16xi1>
949/// %offset = producer_op : vector<16xindex>
950/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
951/// memref<256xf16>, vector<16xindex>, vector<16xi1>
952/// Distributed to:
953/// %mask = producer_op : vector<1xi1>
954/// %offset = producer_op : vector<1xindex>
955/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
956/// memref<256xf16>, vector<1xindex>, vector<1xi1>
957///
958/// Example 2 (2D with chunk size, same mask & offset):
959/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
960/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
961/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
962/// Distributed to:
963/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
964/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
965///
966/// Example 3 (3D with leading unit dims):
967/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
968/// %mask = producer_op : vector<1x1x16xi1>
969/// %offset = producer_op : vector<1x1x16xindex>
970/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
971/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
972/// Distributed to:
973/// %mask = producer_op : vector<1x1x1xi1>
974/// %offset = producer_op : vector<1x1x1xindex>
975/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
976/// memref<256xf16>, vector<1xindex>, vector<1xi1>
977struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
978 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
979
980 LogicalResult
981 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
982 ConversionPatternRewriter &rewriter) const override {
983 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
984 if (!layout)
985 return failure();
986
987 VectorType origValueTy = op.getValueType();
988 if (!origValueTy)
989 return failure();
990
991 // Check that all leading dimensions are unit dimensions.
992 int chunkSize = op.getChunkSize().value_or(1);
993 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
994 ArrayRef<int64_t> shape = origValueTy.getShape();
995 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
996 [](int64_t d) { return d != 1; }))
997 return rewriter.notifyMatchFailure(
998 op, "Only unit dimensions allowed for the leading "
999 "dimensions of the store vector!");
1000
1001 auto distValueTyOrFailure =
1002 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
1003 if (failed(distValueTyOrFailure))
1004 return rewriter.notifyMatchFailure(
1005 op,
1006 "unable to compute expected workitem vector type from lane layout");
1007
1008 VectorType distValueTy = distValueTyOrFailure.value();
1009 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
1010 distValueTy.getElementType());
1011
1012 Value distValue = adaptor.getValue();
1013 if (distValue.getType() != distValueTy1D)
1014 distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
1015 distValueTy1D);
1016
1017 // Flatten offsets and mask to 1D to match the 1D value type.
1018 Value distOffsets = adaptor.getOffsets();
1019 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
1020 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
1021 distOffsetsTy.getElementType());
1022 distOffsets = castValueTo(
1023 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
1024
1025 Value distMask = adaptor.getMask();
1026 auto distMaskTy = cast<VectorType>(distMask.getType());
1027 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
1028 distMaskTy.getElementType());
1029 distMask =
1030 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
1031
1032 Value distDest = adaptor.getDest();
1033 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1034 distOffsets, distMask, op.getChunkSizeAttr(),
1035 op.getL1HintAttr(), op.getL2HintAttr(),
1036 op.getL3HintAttr(), /*layout=*/nullptr);
1037 rewriter.eraseOp(op);
1038 return success();
1039 }
1040};
1041
1042/// Distribute a vector::StepOp to workitem-level.
1043/// The layout must have exactly 1 effective lane dimension.
1044/// We completely resolve the vector::StepOp by computing the lane_data-sized
1045/// subranges.
1046struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
1047 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1048
1049 LogicalResult
1050 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1051 ConversionPatternRewriter &rewriter) const override {
1052 xegpu::DistributeLayoutAttr resultLayout =
1053 xegpu::getTemporaryLayout(op->getResult(0));
1054 if (!resultLayout || !resultLayout.isForSubgroup())
1055 return rewriter.notifyMatchFailure(
1056 op, "the result vector of the step op lacks subgroup layout");
1057
1058 auto loc = op.getLoc();
1059 auto stepResultVecTy = op.getResult().getType();
1060 auto wiShapeOrFailure =
1061 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
1062 if (failed(wiShapeOrFailure))
1063 return rewriter.notifyMatchFailure(
1064 op, "unable to compute workitem vector type from the layout");
1065 VectorType newVecTy = wiShapeOrFailure.value();
1066
1067 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1068 /*upperBound=*/mlir::IntegerAttr());
1069 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1070 rewriter, loc, laneId, stepResultVecTy.getShape());
1071 if (failed(laneDataBlockCoords))
1072 return rewriter.notifyMatchFailure(
1073 op, "failed to compute lane data block coordinates");
1074
1075 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1076 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1077 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
1078 newVecTy.getNumElements() / laneDataBlockLength);
1079 SmallVector<Value> stepVals;
1080 // For each lane_data block, reconstruct its sub-range
1081 // from the range of SG-level vector.step.Example: vector.step
1082 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
1083 // vector<16xindex>
1084 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
1085 // The blocks are round-robin distributed, so logical lane id 0
1086 // holds values [0,1, 8,9].
1087 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1088 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1089 stepVals.push_back(laneDataBlockStartCoord);
1090 for (int i = 1; i < laneDataBlockLength; ++i) {
1091 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
1092 stepVals.push_back(arith::AddIOp::create(
1093 rewriter, loc, laneDataBlockStartCoord, offset));
1094 }
1095 }
1096 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
1097 "Expecting the number of step values to match the number of "
1098 "elements in the vector");
1099 auto stepOpVal =
1100 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1101 rewriter.replaceOp(op, stepOpVal);
1102 return success();
1103 }
1104};
1105
1106/// Distributes a subgroup-level vector.extract op to workitem-level. Only
1107/// handles sub-vector extraction (result is VectorType, not scalar).
1108struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
1109 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1110
1111 LogicalResult
1112 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1113 ConversionPatternRewriter &rewriter) const override {
1114 // Only handle vector results (not scalar extraction).
1115 auto resultType = dyn_cast<VectorType>(op.getType());
1116 if (!resultType)
1117 return rewriter.notifyMatchFailure(op, "scalar extract not supported");
1118
1119 xegpu::DistributeLayoutAttr layout =
1120 xegpu::getTemporaryLayout(op->getOpResult(0));
1121 if (!layout || !layout.isForSubgroup())
1122 return failure();
1123
1124 // This implementation assumes distribution only happens on the innermost
1125 // dimension. Verify that lane_layout[0...n-2] are all unit.
1126 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1127 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1128 [](int64_t v) { return v != 1; }))
1129 return rewriter.notifyMatchFailure(
1130 op, "only innermost dimension distribution is supported for "
1131 "vector.extract");
1132
1133 auto newOp = vector::ExtractOp::create(
1134 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1135 rewriter.replaceOp(op, newOp.getResult());
1136 return success();
1137 }
1138};
1139
1140/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
1141struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
1142 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1143
1144 LogicalResult
1145 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const override {
1147 xegpu::DistributeLayoutAttr resultLayout =
1148 xegpu::getTemporaryLayout(op->getOpResult(0));
1149 if (!resultLayout || !resultLayout.isForSubgroup())
1150 return rewriter.notifyMatchFailure(
1151 op, "the result vector of the shape_cast op lacks subgroup layout");
1152
1153 auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
1154 resultLayout, op.getResultVectorType());
1155 if (failed(resultDistTypeOrFailure))
1156 return rewriter.notifyMatchFailure(
1157 op, "failed to get distributed vector type for result");
1158
1159 Value source = adaptor.getSource();
1160 auto newShapeCast = vector::ShapeCastOp::create(
1161 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1162 rewriter.replaceOp(op, newShapeCast);
1163 return success();
1164 }
1165};
1166
1167/// Distributes a subgroup-level vector.extract_strided_slice op to
1168/// workitem-level. If the result is distributed, the offsets and sizes are
1169/// adjusted to match the distributed types.
1170struct SgToWiVectorExtractStridedSlice
1171 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
1172 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1173
1174 LogicalResult
1175 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter) const override {
1177 xegpu::DistributeLayoutAttr resultLayout =
1178 xegpu::getTemporaryLayout(op->getOpResult(0));
1179 if (!resultLayout || !resultLayout.isForSubgroup())
1180 return failure();
1181
1182 VectorType resultType = op.getType();
1183 auto distResultTyOrFailure =
1184 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
1185 if (failed(distResultTyOrFailure))
1186 return rewriter.notifyMatchFailure(
1187 op, "unable to compute distributed vector type from lane layout");
1188 VectorType distResultTy = *distResultTyOrFailure;
1189
1190 SmallVector<int64_t> distributedDims =
1191 getDistributedDims(resultType, distResultTy);
1192
1193 // Collect updated sizes, offsets, strides. Pad to full source rank.
1194 int64_t sourceRank = op.getSourceVectorType().getRank();
1195 SmallVector<Attribute> updatedSizes =
1196 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1197 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1198 op.getOffsets(), [](Attribute attr) { return attr; });
1199 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1200 op.getStrides(), [](Attribute attr) { return attr; });
1201 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1202 updatedSizes.push_back(
1203 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1204 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1205 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1206 }
1207
1208 // If the result is distributed, adjust offsets and sizes in the
1209 // distributed dimension.
1210 if (!distributedDims.empty()) {
1211 if (distributedDims.size() != 1)
1212 return rewriter.notifyMatchFailure(
1213 op, "only single dimension distribution is supported");
1214 int64_t distDim = distributedDims[0];
1215 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1216 if (!uArch)
1217 return rewriter.notifyMatchFailure(
1218 op, "target attribute required to determine subgroup size");
1219 int subgroupSize = uArch->getSubgroupSize();
1220 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1221 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1222 return rewriter.notifyMatchFailure(
1223 op, "source of extract_strided_slice lacks distribution layout");
1224 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1225 if (sourceDistrDimSize % subgroupSize != 0)
1226 return rewriter.notifyMatchFailure(
1227 op, "source size along distributed dim is not a multiple of "
1228 "subgroup size");
1229 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1230 // Only check lane_data for the distributed dimension. Non-distributed
1231 // dimensions may have non-unit lane_data (e.g., packed layouts).
1232 if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
1233 sourceLaneData[distDim] != 1)
1234 return rewriter.notifyMatchFailure(
1235 op, "expecting unit lane data along the distributed dimension");
1236 int64_t distrDimOffset =
1237 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1238 if (distrDimOffset % subgroupSize != 0)
1239 return rewriter.notifyMatchFailure(
1240 op, "offset along distributed dim is not a multiple of "
1241 "subgroup size");
1242 // Adjust sizes and offsets for the distributed dimension.
1243 updatedSizes[distDim] =
1244 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1245 updatedOffsets[distDim] =
1246 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1247 }
1248
1249 auto newOp = vector::ExtractStridedSliceOp::create(
1250 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1251 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1252 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1253 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1254 rewriter.replaceOp(op, newOp.getResult());
1255 return success();
1256 }
1257};
1258
1259/// This pattern distributes a subgroup-level `vector.broadcast` op to
1260/// workitem-level. The pattern supports three cases:
1261///
1262/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1263/// vector must have a slice layout of the result. If the distributed source
1264/// and target vector types are identical, this lowers to a no-op; otherwise,
1265/// it remains a broadcast but operates on distributed vectors.
1266///
1267/// 2) Broadcast a same-rank vector with identical layouts for source and
1268/// target: The source vector must have unit dimensions, and lane_data must
1269/// be unit size for those unit dims. This always lowers to a no-op.
1270///
1271/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
1272/// from scalar to distributed result type.
1273///
1274/// Example 1 (low-rank to high-rank broadcast):
1275/// ```
1276/// %0 = "some_op"() {layout_result_0 =
1277/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
1278/// dims = [0]>} : () -> vector<16xf16>
1279/// %1 = vector.broadcast %0 {layout_result_0 =
1280/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1281/// : vector<16xf16> to vector<16x16xf16>
1282/// ```
1283/// is distributed to:
1284/// ```
1285/// %0 = "some_op"() : () -> vector<1xf16>
1286/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
1287/// ```
1288///
1289/// Example 2 (same-rank broadcast, no-op):
1290/// ```
1291/// %0 = "some_op"() {layout_result_0 =
1292/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1293/// : () -> vector<16x1xf16>
1294/// %1 = vector.broadcast %0 {layout_result_0 =
1295/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1296/// : vector<16x1xf16> to vector<16x16xf16>
1297/// ```
1298/// is distributed to (no-op, source already matches distributed result type):
1299/// ```
1300/// %0 = "some_op"() : () -> vector<16x1xf16>
1301/// // broadcast is eliminated, %0 is used directly
1302/// ```
1303///
1304/// Example 3 (scalar to vector broadcast):
1305/// ```
1306/// %0 = "some_op"() : () -> f16
1307/// %1 = vector.broadcast %0 {layout_result_0 =
1308/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1309/// : f16 to vector<16x16xf16>
1310/// ```
1311/// is distributed to:
1312/// ```
1313/// %0 = "some_op"() : f16
1314/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
1315/// ```
1316struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
1317 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1318
1319 LogicalResult
1320 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1321 ConversionPatternRewriter &rewriter) const override {
1322 xegpu::DistributeLayoutAttr resultLayout =
1323 xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
1324 if (!resultLayout || !resultLayout.isForSubgroup())
1325 return rewriter.notifyMatchFailure(
1326 op, "result does not have subgroup distribute layout");
1327
1328 VectorType destType = op.getResultVectorType();
1329 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1330
1331 xegpu::DistributeLayoutAttr sourceLayout =
1332 xegpu::getTemporaryLayout(op->getOpOperand(0));
1333
1334 if (sourceType) {
1335 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1336 if (rankDiff > 0) {
1337 // Case 1: Low-rank to high-rank broadcast.
1338 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1339 op.emitWarning(
1340 "broadcast source layout must be a slice of result layout");
1341 } else if (rankDiff == 0) {
1342 // Case 2: Same-rank broadcast.
1343 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1344 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1345 broadcastUnitDimsSet.end());
1346 assert(sourceLayout.isEqualTo(
1347 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1348 "The sg_data for unit dimensions should be set as 1");
1349 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1350 }
1351 } else {
1352 // Case 3: Scalar to vector broadcast.
1353 if (sourceLayout)
1354 return rewriter.notifyMatchFailure(
1355 op, "broadcast from scalar must not have a layout attribute");
1356 }
1357
1358 auto destDistType =
1359 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1360 if (failed(destDistType))
1361 return rewriter.notifyMatchFailure(
1362 op, "failed to distribute the result vector type");
1363
1364 Value source = adaptor.getSource();
1365 // If the adapted source already matches the dest dist type, it's a no-op.
1366 if (source.getType() == destDistType.value()) {
1367 rewriter.replaceOp(op, source);
1368 return success();
1369 }
1370
1371 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1372 destDistType.value(), source);
1373 rewriter.replaceOp(op, newOp);
1374 return success();
1375 }
1376};
1377
1378/// Distributes a subgroup-level vector.insert_strided_slice op to
1379/// workitem-level. If the dest is distributed, the offsets are adjusted to
1380/// match the distributed types.
1381struct SgToWiVectorInsertStridedSlice
1382 : public OpConversionPattern<vector::InsertStridedSliceOp> {
1383 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1384
1385 LogicalResult
1386 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1387 ConversionPatternRewriter &rewriter) const override {
1388 xegpu::DistributeLayoutAttr resultLayout =
1389 xegpu::getTemporaryLayout(op->getOpResult(0));
1390 if (!resultLayout || !resultLayout.isForSubgroup())
1391 return failure();
1392
1393 VectorType destType = op.getDestVectorType();
1394 auto distDestTyOrFailure =
1395 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1396 if (failed(distDestTyOrFailure))
1397 return rewriter.notifyMatchFailure(
1398 op, "unable to compute distributed vector type from lane layout");
1399 VectorType distDestTy = *distDestTyOrFailure;
1400
1401 SmallVector<int64_t> destDistributedDims =
1402 getDistributedDims(destType, distDestTy);
1403
1404 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1405 op.getOffsets(), [](Attribute attr) { return attr; });
1406
1407 if (!destDistributedDims.empty()) {
1408 if (destDistributedDims.size() != 1)
1409 return rewriter.notifyMatchFailure(
1410 op, "only single dimension distribution is supported");
1411 int64_t destDistDim = destDistributedDims[0];
1412
1413 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1414 if (!uArch)
1415 return rewriter.notifyMatchFailure(
1416 op, "target attribute required to determine subgroup size");
1417 int subgroupSize = uArch->getSubgroupSize();
1418
1419 VectorType srcType = op.getSourceVectorType();
1420 // The distributed dim must be in the last k (source rank) dims of dest.
1421 int64_t sourceDistDim =
1422 destDistDim - (destType.getRank() - srcType.getRank());
1423 if (sourceDistDim < 0)
1424 return rewriter.notifyMatchFailure(
1425 op, "distributed dimension must be in the last k dims of dest");
1426
1427 auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
1428 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1429 if (!destLayout || !sourceLayout ||
1430 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1431 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1432 return rewriter.notifyMatchFailure(
1433 op, "source or dest of insert_strided_slice lacks distribution "
1434 "layout");
1435
1436 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1437 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1438 // Only check lane_data for the distributed dimension. Non-distributed
1439 // dimensions may have non-unit lane_data (e.g., packed layouts).
1440 if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
1441 destLaneData[destDistDim] != 1) ||
1442 (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
1443 sourceLaneData[sourceDistDim] != 1))
1444 return rewriter.notifyMatchFailure(
1445 op, "expecting unit lane data along the distributed dimension");
1446
1447 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1448 if (srcDistrDimSize % subgroupSize != 0)
1449 return rewriter.notifyMatchFailure(
1450 op, "source distributed dim size is not a multiple of "
1451 "subgroup size");
1452
1453 int64_t destDistrDimOffset =
1454 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1455 if (destDistrDimOffset % subgroupSize != 0)
1456 return rewriter.notifyMatchFailure(
1457 op, "offset along distributed dim is not a multiple of "
1458 "subgroup size");
1459 // Adjust offset for the distributed dimension.
1460 updatedOffsets[destDistDim] =
1461 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1462 }
1463
1464 auto newOp = vector::InsertStridedSliceOp::create(
1465 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1466 adaptor.getDest(),
1467 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1468 rewriter.replaceOp(op, newOp.getResult());
1469 return success();
1470 }
1471};
1472
1473/// Distributes a subgroup-level vector.insert op to workitem-level. Only
1474/// handles sub-vector insertion (value to store is VectorType, not scalar).
1475struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
1476 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1477
1478 LogicalResult
1479 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1480 ConversionPatternRewriter &rewriter) const override {
1481 // Only handle vector value-to-store (not scalar insertion).
1482 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1483 if (!valueType)
1484 return rewriter.notifyMatchFailure(op, "scalar insert not supported");
1485
1486 xegpu::DistributeLayoutAttr layout =
1487 xegpu::getTemporaryLayout(op->getOpResult(0));
1488 if (!layout || !layout.isForSubgroup())
1489 return failure();
1490
1491 // verify that the outer k dimensions (for offsets)
1492 // don't have non-unit lane_layout.
1493 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1494 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1495 [](int64_t v) { return v != 1; }))
1496 return rewriter.notifyMatchFailure(
1497 op, "only innermost dimension distribution is supported for "
1498 "vector.insert");
1499
1500 auto newOp = vector::InsertOp::create(
1501 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1502 op.getMixedPosition());
1503 rewriter.replaceOp(op, newOp.getResult());
1504 return success();
1505 }
1506};
1507
1508/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
1509struct SgToWiConvertLayout
1510 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
1511 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1512
1513 LogicalResult
1514 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1515 ConversionPatternRewriter &rewriter) const override {
1516 auto inputLayout = op.getInputLayoutAttr();
1517 auto targetLayout = op.getTargetLayoutAttr();
1518 Type valType = op.getResult().getType();
1519
1520 if (valType.isIntOrFloat()) {
1521 rewriter.replaceOp(op, op.getSource());
1522 return success();
1523 }
1524
1525 auto resShape = cast<VectorType>(valType).getShape();
1526 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1527 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1528 xegpu::LayoutKind::Lane)) {
1529 return rewriter.notifyMatchFailure(
1530 op, "lowering incompatible convert_layout not yet supported");
1531 }
1532
1533 rewriter.replaceOp(op, adaptor.getSource());
1534 return success();
1535 }
1536};
1537
1538struct XeGPUSgToWiDistributeExperimentalPass
1540 XeGPUSgToWiDistributeExperimentalPass> {
1541 void runOnOperation() override;
1542};
1543
1544} // namespace
1545
1546void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
1547
1548 // Recover temporary operand layouts for usage in patterns.
1549 Operation *root = getOperation();
1550 if (!xegpu::recoverTemporaryLayouts(root)) {
1551 signalPassFailure();
1552 return;
1553 }
1554
1555 // Collect existing UnrealizedConversionCastOps. These must be preserved.
1556 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1557 root->walk(
1558 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1559 // Perform a structural type conversion to convert structural ops to have WI
1560 // types. This will insert UnrealizedConversionCastOps to make the IR
1561 // valid.
1562 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1563 mlir::ValueRange inputs,
1564 mlir::Location loc) -> mlir::Value {
1565 UnrealizedConversionCastOp castOp =
1566 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1567 return castOp.getResult(0);
1568 };
1569 {
1570 ConversionTarget target(getContext());
1571 TypeConverter typeConverter;
1572 RewritePatternSet patterns(&getContext());
1573 typeConverter.addSourceMaterialization(materializeCast);
1574 typeConverter.addTargetMaterialization(materializeCast);
1577 patterns, target);
1579 typeConverter, patterns, target);
1580 target.addLegalOp<UnrealizedConversionCastOp>();
1581 (void)applyPartialConversion(root, target, std::move(patterns));
1582 }
1583 // Structural type conversion can generate some redundant
1584 // UnrealizedConversionCastOps to materialize the SG type from type converted
1585 // WI type. These are redundant at this point and can be eliminated by
1586 // inserting shape casts instead.
1587 // Example:
1588 // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
1589 // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
1590 // This can be replaced with:
1591 // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
1592 OpBuilder builder(root);
1593 root->walk([&](UnrealizedConversionCastOp op) {
1594 // If this op existed before, nothing to do.
1595 if (existingCasts.contains(op))
1596 return;
1597 // number of inputs and outputs must be 1.
1598 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
1599 return;
1600 // Both input and output types must be vector types.
1601 auto singleInput = op.getInputs()[0];
1602 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
1603 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
1604 if (!inputTy || !outputTy)
1605 return;
1606
1607 // Check if the defining op of the input is also an
1608 // UnrealizedConversionCastOp and it has a single user (which is this
1609 // op).
1610 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
1611 if (!definingOp || !definingOp->hasOneUse())
1612 return;
1613 auto inputOfDefiningOp = definingOp.getInputs()[0];
1614 // If the input of the defining op and output type are both vector types
1615 // have same number of elements, insert a shape cast.
1616 auto inputOfDefiningOpTy =
1617 dyn_cast<VectorType>(inputOfDefiningOp.getType());
1618 if (inputOfDefiningOpTy &&
1619 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
1620 builder.setInsertionPoint(op);
1621 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
1622 outputTy, inputOfDefiningOp);
1623 op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
1624 return;
1625 }
1626 });
1627 // At this point, we will have some dead UnrealizedConversionCastOps. Just
1628 // erase them.
1629 bool changed = true;
1630 while (changed) {
1631 changed = false;
1632 root->walk([&](UnrealizedConversionCastOp op) {
1633 // Skip existing casts.
1634 if (existingCasts.contains(op))
1635 return;
1636 if (op.use_empty()) {
1637 op.erase();
1638 changed = true;
1639 }
1640 });
1641 }
1642
1643 xegpu::removeTemporaryLayoutAttrs(getOperation());
1644}
1645
1647 TypeConverter &typeConverter) {
1648 // Any type other than TensorDescType and VectorType are legal as is.
1649 typeConverter.addConversion([](Type type) -> std::optional<Type> {
1650 if (!isa<TensorDescType, VectorType>(type))
1651 return type;
1652 return std::nullopt;
1653 });
1654 // For TensorDescType, drop the layout attribute if any.
1655 typeConverter.addConversion([](TensorDescType type) -> Type {
1656 if (type.getLayoutAttr()) {
1657 return type.dropLayouts();
1658 }
1659 return type;
1660 });
1661 // For VectorType, check if there is a distribute layout attribute on the
1662 // value. If so, convert to the distributed vector type based on the layout.
1663 typeConverter.addConversion([](Value v) -> std::optional<Type> {
1664 auto type = v.getType();
1665 // If value is not vector type, nothing to do.
1666 if (!isa<VectorType>(type))
1667 return std::nullopt;
1668 auto layout = xegpu::getDistributeLayoutAttr(v);
1669 if (!layout || !layout.isForSubgroup())
1670 return type;
1671 // Vector type is distributed based on lane layout.
1672 auto newTyOrFailure =
1673 getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
1674 if (failed(newTyOrFailure))
1675 return type;
1676 return *newTyOrFailure;
1677 });
1678}
1679
1681 TypeConverter &typeConverter, RewritePatternSet &patterns,
1684 // CreateNdDescOp is legal only if its result type has no layout attribute.
1685 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1686 [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
1687 // Any anchor XeGPU op is legal only if it has no anchor layout.
1688 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
1689 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1690 if (!anchorOp)
1691 return true;
1692 return !anchorOp.getAnchorLayout();
1693 });
1694 // Arith constants are legal only if they have no temporary layout attribute.
1695 target.addDynamicallyLegalOp<arith::ConstantOp>(
1696 [=](arith::ConstantOp op) -> bool {
1697 // If the result type is not a vector, it's legal.
1698 if (!isa<VectorType>(op.getResult().getType()))
1699 return true;
1700 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1701 });
1702 // In math and arith dialects, only handle elementwise ops with a single
1703 // result and with a result layout attribute.
1704 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1705 [=](Operation *op) -> std::optional<bool> {
1706 // Only handle elementwise mappable ops
1708 return true;
1709 // Only handle ops with single vector result
1710 if (op->getNumResults() != 1)
1711 return true;
1712
1713 VectorType resultType =
1714 dyn_cast<VectorType>(op->getResult(0).getType());
1715 if (!resultType)
1716 return true;
1717
1718 // Check if all operands are vectors of the same shape
1719 for (Value operand : op->getOperands()) {
1720 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1721 if (!operandType || operandType.getShape() != resultType.getShape()) {
1722 return true;
1723 }
1724 }
1725 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1726 });
1727 // vector::ReductionOp is legal only if its source has no distribute layout
1728 // attribute.
1729 target.addDynamicallyLegalOp<vector::ReductionOp>(
1730 [=](vector::ReductionOp op) -> bool {
1731 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
1732 return !layout;
1733 });
1734 // vector::MultiDimReductionOp op legality.
1735 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1736 [=](vector::MultiDimReductionOp op) -> bool {
1737 return !isValidSubgroupMultiReductionOp(op);
1738 });
1739 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1740 vector::TransposeOp, vector::BitCastOp,
1741 vector::ShapeCastOp, vector::StepOp,
1742 vector::BroadcastOp>([=](Operation *op) -> bool {
1743 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1744 });
1745 target.addDynamicallyLegalOp<vector::ExtractOp>(
1746 [=](vector::ExtractOp op) -> bool {
1747 if (!isa<VectorType>(op.getType()))
1748 return true;
1749 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1750 });
1751 target.addDynamicallyLegalOp<vector::InsertOp>(
1752 [=](vector::InsertOp op) -> bool {
1753 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1754 });
1755 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1756 [=](vector::ExtractStridedSliceOp op) -> bool {
1757 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1758 });
1759 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1760 [=](vector::InsertStridedSliceOp op) -> bool {
1761 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1762 });
1763 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
1764 patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
1765 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
1766 SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
1767 SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
1768 SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
1769 SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
1770 SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
1771 SgToWiVectorShapeCast, SgToWiBroadcast,
1772 SgToWiCreateMask<vector::CreateMaskOp>,
1773 SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
1774 patterns.getContext());
1775}
return success()
b getContext())
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:168