MLIR 23.0.0git
XeGPUSgToWiDistributeExperimental.cpp
Go to the documentation of this file.
1//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
20#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/ValueRange.h"
29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34namespace mlir {
35namespace xegpu {
36#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38} // namespace xegpu
39} // namespace mlir
40
41using namespace mlir;
42
43#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45
46namespace {
47
48/// Casts the given vector value `v` to the expected vector type `expectedTy`.
49static Value castValueTo(ConversionPatternRewriter &rewriter,
50 TypedValue<VectorType> v, VectorType expectedTy) {
51 // If the type matches, simply return the value itself.
52 if (v.getType() == expectedTy)
53 return v;
54 // If only shape differs, use shape cast.
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
58
59 // Else create an unrealized cast.
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
61 expectedTy, ValueRange{v});
62 return newOp.getResult(0);
63}
64
65/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
66/// exactly 1 reduction dimension, it had valid result layout attribute, and
67/// result type can be distributed to lanes using the layout.
68static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
69 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
70 // If no layout, not valid.
71 if (!resLayout || !resLayout.isForSubgroup())
72 return false;
73 // Scalar result (e.g., vector<32xf32> to f32) is valid.
74 if (op.getType().isIntOrFloat())
75 return op.getReductionDims().size() == 1;
76 VectorType resTy = dyn_cast<VectorType>(op.getType());
77 if (!resTy)
78 return false;
79 // Compute the distributed result vector type based on the layout.
80 FailureOr<VectorType> resDistTypeOrFailure =
81 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
82 if (failed(resDistTypeOrFailure))
83 return false;
84 return op.getReductionDims().size() == 1;
85}
86
87/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
88/// is doing its own local reduction. In this case the result layout ensures
89/// that result vector is distributed to lanes, i.e. the result vector type is
90/// different from the distributed result vector type.
91static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
92 // Must be valid MultiDimReductionOp.
93 assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
94 "MultiDimReductionOp");
95 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
96 VectorType resTy = dyn_cast<VectorType>(op.getType());
97 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
98 return resTy != resDistTypeOrFailure.value();
99}
100
101/// Given a vector type and its distributed vector type, return the list of
102/// dimensions that are distributed.
103static SmallVector<int64_t> getDistributedDims(VectorType originalType,
104 VectorType distributedType) {
105 assert(originalType.getRank() == distributedType.getRank() &&
106 "original and distributed vector types must have the same rank");
107 SmallVector<int64_t> distributedDims;
108 for (int64_t i = 0; i < originalType.getRank(); ++i) {
109 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
110 distributedDims.push_back(i);
111 }
112 return distributedDims;
113}
114
115/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
116/// op. This simply drops the layout attribute from the tensor descriptor type.
117struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
118 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
119
120 LogicalResult
121 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override {
123 xegpu::TensorDescType resultType = op.getType();
124 // If no layout, nothing to do.
125 if (!resultType.getLayout())
126 return failure();
127
128 auto newOp = xegpu::CreateNdDescOp::create(
129 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
130 op->getAttrs());
131 rewriter.replaceOp(op, newOp.getResult());
132 return success();
133 }
134};
135
136/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
137/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
138/// original rank.
139struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
140 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
141
142 LogicalResult
143 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
146 // If no layout, nothing to do.
147 if (!layout)
148 return failure();
149 // Check if the layout attached to the tensor descriptor is same as the
150 // anchor layout. Otherwise, this is a conflict.
151 if (op.getTensorDescType().getLayout() != layout)
152 return rewriter.notifyMatchFailure(
153 op, "conflicting layout attributes on tensor descriptor and anchor");
154 auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
155 if (!uArch)
156 return rewriter.notifyMatchFailure(
157 op, "xegpu::LoadNdOp require target attribute attached to "
158 "determine transpose "
159 "requirement");
160 auto supportedWiResultTyOrFailure =
161 xegpu::getDistributedVectorType(op.getTensorDescType());
162 auto expectedWiResultTyOrFailure =
163 xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
164 if (failed(supportedWiResultTyOrFailure))
165 return rewriter.notifyMatchFailure(
166 op, "unable to compute the workitem vector type for LoadNdOp");
167 if (failed(expectedWiResultTyOrFailure))
168 return rewriter.notifyMatchFailure(
169 op,
170 "unable to compute expected workitem vector type from lane layout");
171 auto newOp = xegpu::LoadNdOp::create(
172 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
173 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
174 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
175 op.getL3HintAttr(), /**layout**/ nullptr);
176 // Set the packed attribute if the layout requires it.
177 newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
178 // Set the transpose attribute if the layout requires it.
179 if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
180 newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
181 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
182 expectedWiResultTyOrFailure.value()));
183 return success();
184 }
185};
186
187/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
188/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
189/// incoming value to 1D.
190struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
191 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
192
193 LogicalResult
194 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter) const override {
196 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
197 // If no layout, nothing to do.
198 if (!layout)
199 return failure();
200 // Check if the layout attached to the tensor descriptor and value layout is
201 // same as the anchor layout. Otherwise, this is a conflict.
202 if (op.getTensorDescType().getLayout() != layout)
203 return rewriter.notifyMatchFailure(
204 op, "conflicting layout attributes on tensor descriptor and anchor");
205 auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
206 if (valueLayout != layout)
207 return rewriter.notifyMatchFailure(
208 op, "conflicting layout attributes on value and anchor");
209 auto supportedWiValueTyOrFailure =
210 xegpu::getDistributedVectorType(op.getTensorDescType());
211 if (failed(supportedWiValueTyOrFailure))
212 return rewriter.notifyMatchFailure(
213 op,
214 "unable to compute wi vector type for StoreNdOp value from tensor "
215 "descriptor");
216
217 xegpu::StoreNdOp::create(
218 rewriter, op.getLoc(),
219 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
220 supportedWiValueTyOrFailure.value()),
221 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
222 op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
223 rewriter.eraseOp(op);
224 return success();
225 }
226};
227
228/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
229/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
230/// convert the inputs and output to/from 1D.
231struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
232 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
233
234 LogicalResult
235 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const override {
237 // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
238 // Check if the op has A, B and CD layouts attached.
239 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
240 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
241 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
242 if (!layoutA || !layoutB || !layoutCd)
243 return failure();
244 // llvm::errs() << "tryning to calculate wi types for dpas op\n";
245 auto wiResultTyOrFailure =
246 xegpu::getDistributedVectorType(op.getType(), layoutCd);
247 auto wiATypeOrFailure =
248 xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
249 auto wiBTypeOrFailure =
250 xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
251 auto expectedWiResultTyOrFailure =
252 xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
253 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
254 failed(wiBTypeOrFailure))
255 return rewriter.notifyMatchFailure(
256 op, "failed to calculate supported workitem vector types for DpasOp "
257 "from layouts");
258 if (failed(expectedWiResultTyOrFailure))
259 return rewriter.notifyMatchFailure(
260 op, "unable to compute expected workitem vector type for DpasOp from "
261 "lane layout");
262 auto newOp = xegpu::DpasOp::create(
263 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
264 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
265 wiATypeOrFailure.value()),
266 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
267 wiBTypeOrFailure.value()),
268 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
269 wiResultTyOrFailure.value()),
270 /** layoutA**/ nullptr,
271 /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
272 // Explicitly set the new types to enable correct type materializations.
273 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
274 expectedWiResultTyOrFailure.value()));
275 return success();
276 }
277};
278
279/// Distributes elementwise ops to workitem-level elementwise ops. This
280/// currently handles elementwise ops with single result only.
281struct SgToWiElementWise : public ConversionPattern {
282 SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
283 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
284
285 LogicalResult
286 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
287 ConversionPatternRewriter &rewriter) const override {
288 // Only match ops with elementwise trait and single result.
290 return failure();
291
292 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
293 if (!resultType)
294 return rewriter.notifyMatchFailure(
295 op, "operation result is not a vector type");
296
297 xegpu::DistributeLayoutAttr layout =
298 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
299 if (!layout || !layout.isForSubgroup())
300 return rewriter.notifyMatchFailure(
301 op, "operation result does not have subgroup distribute layout");
302
303 auto wiShapeOrFailure =
304 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
305
306 if (failed(wiShapeOrFailure))
307 return rewriter.notifyMatchFailure(
308 op, "unable to compute workitem vector type from the layout");
309
310 VectorType newResultType = wiShapeOrFailure.value();
311 OperationState state(op->getLoc(), op->getName());
312 state.addOperands(operands);
313 state.addTypes(newResultType);
314 // Copy all attributes except for DistributeLayoutAttr.
315 for (auto attr : op->getAttrs()) {
316 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
317 state.addAttribute(attr.getName(), attr.getValue());
318 }
319 Operation *newOp = rewriter.create(state);
320
321 rewriter.replaceOp(op, newOp->getResult(0));
322 return success();
323 }
324};
325
326/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
327/// ConstantOp.
328struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
329 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
330
331 LogicalResult
332 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
333 ConversionPatternRewriter &rewriter) const override {
334 auto resultType = dyn_cast<VectorType>(op.getType());
335 if (!resultType)
336 return failure();
337
338 // Only handle dense vector constants
339 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
340 if (!dense)
341 return rewriter.notifyMatchFailure(
342 op, "only dense splat vector constants are supported");
343
344 xegpu::DistributeLayoutAttr layout =
345 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
346 if (!layout || !layout.isForSubgroup())
347 return rewriter.notifyMatchFailure(
348 op, "operation result does not have subgroup distribute layout");
349
350 auto wiShapeOrFailure =
351 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
352
353 if (failed(wiShapeOrFailure))
354 return rewriter.notifyMatchFailure(
355 op, "unable to compute workitem vector type from the layout");
356
357 VectorType newResultType = wiShapeOrFailure.value();
358 auto sclarValue = dense.getSplatValue<Attribute>();
359 auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
360
361 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
362 newDenseAttr);
363 rewriter.replaceOp(op, newOp.getResult());
364 return success();
365 }
366};
367
368/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
369struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
370 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
371
372 LogicalResult
373 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
374 ConversionPatternRewriter &rewriter) const override {
375 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
376 // If no layout, nothing to do.
377 if (!layout)
378 return failure();
379
380 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
381 op.getMixedOffsets(), op.getL1HintAttr(),
382 op.getL2HintAttr(), op.getL3HintAttr(),
383 /**layout**/ nullptr);
384 rewriter.eraseOp(op);
385 return success();
386 }
387};
388
389/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
390///
391/// Example 1 (1D, no chunk size):
392/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
393/// %mask = producer_op : vector<16xi1>
394/// %offset = producer_op : vector<16xindex>
395/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
396/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
397/// Distributed to:
398/// %mask = producer_op : vector<1xi1>
399/// %offset = producer_op : vector<1xindex>
400/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
401/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
402///
403/// Example 2 (2D with chunk size, same mask & offset):
404/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
405/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
406/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
407/// Distributed to:
408/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
409/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
410///
411/// Example 3 (3D with leading unit dims):
412/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
413/// %mask = producer_op : vector<1x1x16xi1>
414/// %offset = producer_op : vector<1x1x16xindex>
415/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
416/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
417/// Distributed to:
418/// %mask = producer_op : vector<1x1x1xi1>
419/// %offset = producer_op : vector<1x1x1xindex>
420/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
421/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
422struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
423 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
424
425 LogicalResult
426 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
427 ConversionPatternRewriter &rewriter) const override {
428 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
429 if (!layout)
430 return failure();
431
432 VectorType origResultTy = op.getValueType();
433 if (!origResultTy)
434 return failure();
435
436 // Check that leading dimensions are unit.
437 int chunkSize = op.getChunkSize().value_or(1);
438 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
439 ArrayRef<int64_t> shape = origResultTy.getShape();
440 if (llvm::any_of(
441 shape.take_front(origResultTy.getRank() - effectiveVecRank),
442 [](int64_t d) { return d != 1; }))
443 return rewriter.notifyMatchFailure(
444 op, "Only unit dimensions allowed for the leading "
445 "dimensions of the load vector!");
446
447 auto distResultTyOrFailure =
448 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
449 if (failed(distResultTyOrFailure))
450 return rewriter.notifyMatchFailure(
451 op,
452 "unable to compute expected workitem vector type from lane layout");
453
454 VectorType distResultTy = distResultTyOrFailure.value();
455 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
456 distResultTy.getElementType());
457
458 // Flatten offsets and mask to 1D to match the 1D result type.
459 Value distOffsets = adaptor.getOffsets();
460 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
461 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
462 distOffsetsTy.getElementType());
463 distOffsets = castValueTo(
464 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
465
466 Value distMask = adaptor.getMask();
467 auto distMaskTy = cast<VectorType>(distMask.getType());
468 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
469 distMaskTy.getElementType());
470 distMask =
471 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
472
473 Value distSource = adaptor.getSource();
474 auto newOp = xegpu::LoadGatherOp::create(
475 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
476 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
477 op.getL3HintAttr(), /*layout=*/nullptr);
478
479 Value result = newOp->getResult(0);
480 if (distResultTy1D != distResultTy)
481 result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
482 distResultTy);
483 rewriter.replaceOp(op, result);
484 return success();
485 }
486};
487
488/// This pattern distributes a subgroup-level vector.reduction op to
489/// workitem-level. This require shuffling the data across the workitems (using
490/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
491/// result.
492struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
493 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
494
495 LogicalResult
496 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
497 ConversionPatternRewriter &rewriter) const override {
498 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
499
500 // If no layout, nothing to do.
501 if (!layout || !layout.isForSubgroup())
502 return failure();
503
504 VectorType srcVecType = op.getSourceVectorType();
505 // Only rank 1 vectors supported.
506 if (srcVecType.getRank() != 1)
507 return rewriter.notifyMatchFailure(
508 op, "Only rank 1 reductions can be distributed.");
509 // Lane layout must have the same rank as the vector.
510 if (layout.getRank() != srcVecType.getRank())
511 return rewriter.notifyMatchFailure(
512 op, "Layout rank does not match vector rank.");
513
514 // Get the subgroup size from the layout.
515 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
516 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
517 if (!uArch)
518 return rewriter.notifyMatchFailure(
519 op, "xegpu::ReductionOp require target attribute attached to "
520 "determine subgroup size");
521
522 // Only subgroup-sized vectors supported.
523 if (sgSize != uArch->getSubgroupSize() ||
524 srcVecType.getShape()[0] % sgSize != 0)
525 return rewriter.notifyMatchFailure(op,
526 "Invalid layout or reduction vector "
527 "dimension must match subgroup size.");
528
529 if (!op.getType().isIntOrFloat())
530 return rewriter.notifyMatchFailure(
531 op, "Reduction distribution currently only supports floats and "
532 "integer types.");
533
534 // Get the distributed vector (per work-item portion).
535 Value laneValVec = adaptor.getVector();
536
537 // Distribute and reduce across work-items in the subgroup.
538 Value fullReduce = xegpu::subgroupReduction(
539 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
540
541 // If there's an accumulator, combine it with the reduced value.
542 if (adaptor.getAcc())
543 fullReduce = vector::makeArithReduction(
544 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
545
546 rewriter.replaceOp(op, fullReduce);
547 return success();
548 }
549};
550
551/// This pattern distributes a subgroup-level vector.multi_reduction op to
552/// workitem-level only if the reduction is lane-local. This means that
553/// reduction dimension is not distributed to lanes and each lane does its own
554/// local reduction.
555struct SgToWiMultiDimReduction
556 : public OpConversionPattern<vector::MultiDimReductionOp> {
557 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
558
559 LogicalResult
560 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
561 ConversionPatternRewriter &rewriter) const override {
563 ArrayRef<int64_t> reductionDims = op.getReductionDims();
564 assert(reductionDims.size() == 1 &&
565 "Expecting single reduction dimension for subgroup multi "
566 "reduction op");
567 // For rank > 2, ensure leading dimensions are unit.
568 VectorType sourceType = op.getSourceVectorType();
569 int64_t rank = sourceType.getRank();
570 if (rank > 2) {
571 ArrayRef<int64_t> shape = sourceType.getShape();
572 if (llvm::any_of(shape.take_front(rank - 2),
573 [](int64_t d) { return d != 1; }))
574 return rewriter.notifyMatchFailure(
575 op, "only unit leading dimensions are supported for "
576 "multi_reduction with rank > 2");
577 }
578 // Handle scalar result: full reduction of a distributed vector to a
579 // scalar. First do a local vector reduction, then cross-lane shuffles.
580 if (op.getType().isIntOrFloat()) {
581 auto reductionDim = reductionDims[0];
582 VectorType origSourceType = op.getSourceVectorType();
583 int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
584 // Local reduction to scalar, then cross-lane butterfly shuffles.
585 result =
586 xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
587 op.getKind(), reductionDimSize);
588 // Combine with accumulator if present.
589 if (adaptor.getAcc())
590 result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
591 result, adaptor.getAcc());
592 } else if (isReductionLaneLocal(op)) {
593 // For lane-local reduction, lower to a sequence of vector.reduction ops
594 // over 1D slices extracted from the distributed source vector. This is
595 // required so we dont have 2D source vectors at xegpu-linearize.
596 auto reductionDim = reductionDims[0];
598 cast<TypedValue<VectorType>>(adaptor.getSource()),
599 cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
600 reductionDim, op.getLoc(), rewriter);
601 } else {
602 auto reductionDim = reductionDims[0];
603 VectorType sourceType = op.getSourceVectorType();
604 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
606 cast<TypedValue<VectorType>>(adaptor.getSource()),
607 cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
608 reductionDim, reductionDimSize, op.getLoc(), rewriter);
609 }
610 rewriter.replaceOp(op, result);
611 return success();
612 }
613};
614
615/// Helper to compute distributed coordinates for matrix ops.
616/// When not using subgroup_block_io, each workitem computes its own
617/// coordinates based on the layout and lane ID.
618static SmallVector<Value> computeDistributedCoordsForMatrixOp(
619 ConversionPatternRewriter &rewriter, Location loc,
620 xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
621 ValueRange origOffsets) {
622 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
623 /*upperBound=*/mlir::IntegerAttr());
624 auto maybeCoords =
625 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
626 if (failed(maybeCoords))
627 return {};
628 assert(maybeCoords.value().size() == 1 &&
629 "Expected one set of distributed offsets");
631 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
632 getAsOpFoldResult(origOffsets));
633 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
634}
635
636/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
637struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
638 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
639
640 LogicalResult
641 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter) const override {
643 auto layout = op.getLayoutAttr();
644 // If no layout, nothing to do.
645 if (!layout)
646 return failure();
647
648 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
649 if (!sgPayloadTy)
650 return rewriter.notifyMatchFailure(
651 op, "the matrix op payload must be a vector type");
652
653 auto loc = op.getLoc();
654 auto offsets = op.getMixedOffsets();
655 if (offsets.empty())
656 return rewriter.notifyMatchFailure(op, "the load op must have offsets");
657
658 FailureOr<VectorType> distPayloadTyOrFailure =
659 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
660 if (failed(distPayloadTyOrFailure))
661 return rewriter.notifyMatchFailure(
662 op, "Failed to distribute matrix op payload based on layout.");
663
664 SmallVector<Value> offsetsAsValues =
665 vector::getAsValues(rewriter, loc, offsets);
666
667 SmallVector<Value> newCoords = offsetsAsValues;
668 if (!op.getSubgroupBlockIoAttr()) {
669 newCoords = computeDistributedCoordsForMatrixOp(
670 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
671 if (newCoords.empty())
672 return rewriter.notifyMatchFailure(
673 op, "Failed to compute distributed coordinates.");
674 }
675
676 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
677 ShapedType::kDynamic);
678 DenseI64ArrayAttr newConstOffsetsAttr =
679 rewriter.getDenseI64ArrayAttr(newConstOffsets);
680
681 auto newOp = xegpu::LoadMatrixOp::create(
682 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
683 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
684 xegpu::DistributeLayoutAttr{});
685 rewriter.replaceOp(op, newOp.getResult());
686 return success();
687 }
688};
689
690/// Distributes a subgroup-level vector.transpose op to workitem-level.
691struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
692 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
693
694 LogicalResult
695 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
696 ConversionPatternRewriter &rewriter) const override {
697 xegpu::DistributeLayoutAttr sourceLayout =
698 xegpu::getTemporaryLayout(op->getOpOperand(0));
699 xegpu::DistributeLayoutAttr resultLayout =
700 xegpu::getTemporaryLayout(op->getOpResult(0));
701 if (!sourceLayout || !resultLayout)
702 return rewriter.notifyMatchFailure(
703 op, "the source or result vector of the transpose op lacks layout "
704 "attribute");
705 ArrayRef<int64_t> perm = op.getPermutation();
706 // Result layout must be a transpose of source layout.
707 if (!resultLayout.isTransposeOf(sourceLayout, perm,
709 return rewriter.notifyMatchFailure(
710 op, "the source or result vector layouts must be transposes of "
711 "each other");
712 FailureOr<VectorType> distributedResultTypeOrFailure =
713 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
714 if (failed(distributedResultTypeOrFailure))
715 return rewriter.notifyMatchFailure(
716 op, "Failed to distribute the result vector type in "
717 "vector::Transpose op");
718 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
719 adaptor.getVector(), perm);
720 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
721 distributedResultTypeOrFailure.value()));
722 return success();
723 }
724};
725
726/// Distributes a subgroup-level vector.bitcast op to workitem-level.
727/// Bitcast only impacts the innermost dimension of the source/result vectors.
728struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
729 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
730
731 LogicalResult
732 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
733 ConversionPatternRewriter &rewriter) const override {
734 xegpu::DistributeLayoutAttr resultLayout =
735 xegpu::getTemporaryLayout(op->getOpResult(0));
736 if (!resultLayout)
737 return rewriter.notifyMatchFailure(
738 op, "result vector of the bitcast op lacks layout attribute");
739 FailureOr<VectorType> distributedResultTypeOrFailure =
740 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
741 if (failed(distributedResultTypeOrFailure))
742 return rewriter.notifyMatchFailure(
743 op, "Failed to distribute the result vector type in "
744 "vector::BitCast op");
745 auto newOp = vector::BitCastOp::create(
746 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
747 adaptor.getSource());
748 rewriter.replaceOp(op, newOp.getResult());
749 return success();
750 }
751};
752
753/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
754/// to workitem-level. Uses `computeDistributedCoords()` to obtain the
755/// coordinates each workitem owns, then compares each coordinate against the
756/// original mask bounds using `arith.cmpi slt`. The per-element boolean
757/// results are assembled into the distributed mask vector.
758///
759/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
760/// satisfy `coord[i] < bound[i]`.
761///
762/// Example (1D):
763/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
764/// %mask = vector.create_mask %m0 : vector<16xi1>
765/// For lane k, computeDistributedCoords gives coord = [k], so:
766/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1
767/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
768///
769/// Example (2D):
770/// layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
771/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
772/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
773/// [[r0, c0], [r0, c1]]
774/// For each coord: in_bounds = (r < m0) && (c < m1)
775/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
776template <typename OpType,
777 typename = std::enable_if_t<llvm::is_one_of<
778 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
779struct SgToWiCreateMask : public OpConversionPattern<OpType> {
780 using OpConversionPattern<OpType>::OpConversionPattern;
781
782 LogicalResult
783 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
784 ConversionPatternRewriter &rewriter) const override {
785 xegpu::DistributeLayoutAttr layout =
786 xegpu::getTemporaryLayout(op->getOpResult(0));
787 if (!layout || !layout.isForSubgroup())
788 return rewriter.notifyMatchFailure(
789 op, "operation result does not have subgroup distribute layout");
790
791 VectorType origType = op.getType();
792 FailureOr<VectorType> distTypeOrFailure =
793 getDistVecTypeBasedOnLaneLayout(layout, origType);
794 if (failed(distTypeOrFailure))
795 return rewriter.notifyMatchFailure(
796 op, "unable to compute workitem vector type from the layout");
797
798 VectorType distType = distTypeOrFailure.value();
799 Location loc = op.getLoc();
800
801 // Materialize the original mask bounds as Values.
802 SmallVector<Value> origBounds;
803 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
804 origBounds.append(op.getOperands().begin(), op.getOperands().end());
805 } else {
806 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
807 for (auto dimSize : dimSizes)
808 origBounds.push_back(
809 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
810 }
811
812 ArrayRef<int64_t> origShape = origType.getShape();
813
814 // Use computeDistributedCoords to get the coordinates each WI owns.
815 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
816 /*upperBound=*/mlir::IntegerAttr());
817 auto maybeCoordsVec =
818 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
819 if (failed(maybeCoordsVec))
820 return rewriter.notifyMatchFailure(
821 op, "failed to compute distributed coordinates from layout");
822
823 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
824 int64_t numElements = distType.getNumElements();
825 assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
826 "number of coordinate sets must match number of distributed "
827 "elements");
828
829 // For each element, compare all coordinates against bounds.
830 Value trueVal =
831 arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
832 SmallVector<Value> maskBits;
833 for (auto &coords : coordsVec) {
834 Value inBounds = trueVal;
835 for (size_t i = 0; i < coords.size(); ++i) {
836 Value cmp = arith::CmpIOp::create(
837 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
838 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
839 }
840 maskBits.push_back(inBounds);
841 }
842
843 // Build the distributed mask vector.
845 if (numElements == 1) {
846 result =
847 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
848 } else {
849 result =
850 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
851 }
852 rewriter.replaceOp(op, result);
853 return success();
854 }
855};
856
857/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
858struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
859 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
860
861 LogicalResult
862 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
863 ConversionPatternRewriter &rewriter) const override {
864 auto layout = op.getLayoutAttr();
865 // If no layout, nothing to do.
866 if (!layout)
867 return failure();
868
869 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
870 if (!sgPayloadTy)
871 return rewriter.notifyMatchFailure(
872 op, "the matrix op payload must be a vector type");
873
874 auto loc = op.getLoc();
875 auto offsets = op.getMixedOffsets();
876 if (offsets.empty())
877 return rewriter.notifyMatchFailure(op, "the store op must have offsets");
878
879 FailureOr<VectorType> distPayloadTyOrFailure =
880 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
881 if (failed(distPayloadTyOrFailure))
882 return rewriter.notifyMatchFailure(
883 op, "Failed to distribute matrix op payload based on layout.");
884
885 SmallVector<Value> offsetsAsValues =
886 vector::getAsValues(rewriter, loc, offsets);
887
888 SmallVector<Value> newCoords = offsetsAsValues;
889 if (!op.getSubgroupBlockIoAttr()) {
890 newCoords = computeDistributedCoordsForMatrixOp(
891 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
892 if (newCoords.empty())
893 return rewriter.notifyMatchFailure(
894 op, "Failed to compute distributed coordinates.");
895 }
896
897 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
898 ShapedType::kDynamic);
899 DenseI64ArrayAttr newConstOffsetsAttr =
900 rewriter.getDenseI64ArrayAttr(newConstOffsets);
901
902 xegpu::StoreMatrixOp::create(
903 rewriter, loc, TypeRange{},
904 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
905 distPayloadTyOrFailure.value()),
906 adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
907 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
908 rewriter.eraseOp(op);
909 return success();
910 }
911};
912
913/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
914/// workitem-level.
915///
916/// Example 1 (1D, no chunk size):
917/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
918/// %mask = producer_op : vector<16xi1>
919/// %offset = producer_op : vector<16xindex>
920/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
921/// memref<256xf16>, vector<16xindex>, vector<16xi1>
922/// Distributed to:
923/// %mask = producer_op : vector<1xi1>
924/// %offset = producer_op : vector<1xindex>
925/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
926/// memref<256xf16>, vector<1xindex>, vector<1xi1>
927///
928/// Example 2 (2D with chunk size, same mask & offset):
929/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
930/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
931/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
932/// Distributed to:
933/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
934/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
935///
936/// Example 3 (3D with leading unit dims):
937/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
938/// %mask = producer_op : vector<1x1x16xi1>
939/// %offset = producer_op : vector<1x1x16xindex>
940/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
941/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
942/// Distributed to:
943/// %mask = producer_op : vector<1x1x1xi1>
944/// %offset = producer_op : vector<1x1x1xindex>
945/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
946/// memref<256xf16>, vector<1xindex>, vector<1xi1>
947struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
948 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
949
950 LogicalResult
951 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
952 ConversionPatternRewriter &rewriter) const override {
953 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
954 if (!layout)
955 return failure();
956
957 VectorType origValueTy = op.getValueType();
958 if (!origValueTy)
959 return failure();
960
961 // Check that all leading dimensions are unit dimensions.
962 int chunkSize = op.getChunkSize().value_or(1);
963 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
964 ArrayRef<int64_t> shape = origValueTy.getShape();
965 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
966 [](int64_t d) { return d != 1; }))
967 return rewriter.notifyMatchFailure(
968 op, "Only unit dimensions allowed for the leading "
969 "dimensions of the store vector!");
970
971 auto distValueTyOrFailure =
972 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
973 if (failed(distValueTyOrFailure))
974 return rewriter.notifyMatchFailure(
975 op,
976 "unable to compute expected workitem vector type from lane layout");
977
978 VectorType distValueTy = distValueTyOrFailure.value();
979 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
980 distValueTy.getElementType());
981
982 Value distValue = adaptor.getValue();
983 if (distValue.getType() != distValueTy1D)
984 distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
985 distValueTy1D);
986
987 // Flatten offsets and mask to 1D to match the 1D value type.
988 Value distOffsets = adaptor.getOffsets();
989 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
990 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
991 distOffsetsTy.getElementType());
992 distOffsets = castValueTo(
993 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
994
995 Value distMask = adaptor.getMask();
996 auto distMaskTy = cast<VectorType>(distMask.getType());
997 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
998 distMaskTy.getElementType());
999 distMask =
1000 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
1001
1002 Value distDest = adaptor.getDest();
1003 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1004 distOffsets, distMask, op.getChunkSizeAttr(),
1005 op.getL1HintAttr(), op.getL2HintAttr(),
1006 op.getL3HintAttr(), /*layout=*/nullptr);
1007 rewriter.eraseOp(op);
1008 return success();
1009 }
1010};
1011
1012/// Distribute a vector::StepOp to workitem-level.
1013/// The layout must have exactly 1 effective lane dimension.
1014/// We completely resolve the vector::StepOp by computing the lane_data-sized
1015/// subranges.
1016struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
1017 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1018
1019 LogicalResult
1020 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022 xegpu::DistributeLayoutAttr resultLayout =
1023 xegpu::getTemporaryLayout(op->getResult(0));
1024 if (!resultLayout || !resultLayout.isForSubgroup())
1025 return rewriter.notifyMatchFailure(
1026 op, "the result vector of the step op lacks subgroup layout");
1027
1028 auto loc = op.getLoc();
1029 auto stepResultVecTy = op.getResult().getType();
1030 auto wiShapeOrFailure =
1031 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
1032 if (failed(wiShapeOrFailure))
1033 return rewriter.notifyMatchFailure(
1034 op, "unable to compute workitem vector type from the layout");
1035 VectorType newVecTy = wiShapeOrFailure.value();
1036
1037 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1038 /*upperBound=*/mlir::IntegerAttr());
1039 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1040 rewriter, loc, laneId, stepResultVecTy.getShape());
1041 if (failed(laneDataBlockCoords))
1042 return rewriter.notifyMatchFailure(
1043 op, "failed to compute lane data block coordinates");
1044
1045 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1046 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1047 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
1048 newVecTy.getNumElements() / laneDataBlockLength);
1049 SmallVector<Value> stepVals;
1050 // For each lane_data block, reconstruct its sub-range
1051 // from the range of SG-level vector.step.Example: vector.step
1052 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
1053 // vector<16xindex>
1054 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
1055 // The blocks are round-robin distributed, so logical lane id 0
1056 // holds values [0,1, 8,9].
1057 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1058 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1059 stepVals.push_back(laneDataBlockStartCoord);
1060 for (int i = 1; i < laneDataBlockLength; ++i) {
1061 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
1062 stepVals.push_back(arith::AddIOp::create(
1063 rewriter, loc, laneDataBlockStartCoord, offset));
1064 }
1065 }
1066 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
1067 "Expecting the number of step values to match the number of "
1068 "elements in the vector");
1069 auto stepOpVal =
1070 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1071 rewriter.replaceOp(op, stepOpVal);
1072 return success();
1073 }
1074};
1075
1076/// Distributes a subgroup-level vector.extract op to workitem-level. Only
1077/// handles sub-vector extraction (result is VectorType, not scalar).
1078struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
1079 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1080
1081 LogicalResult
1082 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1083 ConversionPatternRewriter &rewriter) const override {
1084 // Only handle vector results (not scalar extraction).
1085 auto resultType = dyn_cast<VectorType>(op.getType());
1086 if (!resultType)
1087 return rewriter.notifyMatchFailure(op, "scalar extract not supported");
1088
1089 xegpu::DistributeLayoutAttr layout =
1090 xegpu::getTemporaryLayout(op->getOpResult(0));
1091 if (!layout || !layout.isForSubgroup())
1092 return failure();
1093
1094 // This implementation assumes distribution only happens on the innermost
1095 // dimension. Verify that lane_layout[0...n-2] are all unit.
1096 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1097 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1098 [](int64_t v) { return v != 1; }))
1099 return rewriter.notifyMatchFailure(
1100 op, "only innermost dimension distribution is supported for "
1101 "vector.extract");
1102
1103 auto newOp = vector::ExtractOp::create(
1104 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1105 rewriter.replaceOp(op, newOp.getResult());
1106 return success();
1107 }
1108};
1109
1110/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
1111struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
1112 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1113
1114 LogicalResult
1115 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1116 ConversionPatternRewriter &rewriter) const override {
1117 xegpu::DistributeLayoutAttr resultLayout =
1118 xegpu::getTemporaryLayout(op->getOpResult(0));
1119 if (!resultLayout || !resultLayout.isForSubgroup())
1120 return rewriter.notifyMatchFailure(
1121 op, "the result vector of the shape_cast op lacks subgroup layout");
1122
1123 auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
1124 resultLayout, op.getResultVectorType());
1125 if (failed(resultDistTypeOrFailure))
1126 return rewriter.notifyMatchFailure(
1127 op, "failed to get distributed vector type for result");
1128
1129 Value source = adaptor.getSource();
1130 auto newShapeCast = vector::ShapeCastOp::create(
1131 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1132 rewriter.replaceOp(op, newShapeCast);
1133 return success();
1134 }
1135};
1136
1137/// Distributes a subgroup-level vector.extract_strided_slice op to
1138/// workitem-level. If the result is distributed, the offsets and sizes are
1139/// adjusted to match the distributed types.
1140struct SgToWiVectorExtractStridedSlice
1141 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
1142 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1143
1144 LogicalResult
1145 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const override {
1147 xegpu::DistributeLayoutAttr resultLayout =
1148 xegpu::getTemporaryLayout(op->getOpResult(0));
1149 if (!resultLayout || !resultLayout.isForSubgroup())
1150 return failure();
1151
1152 VectorType resultType = op.getType();
1153 auto distResultTyOrFailure =
1154 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
1155 if (failed(distResultTyOrFailure))
1156 return rewriter.notifyMatchFailure(
1157 op, "unable to compute distributed vector type from lane layout");
1158 VectorType distResultTy = *distResultTyOrFailure;
1159
1160 SmallVector<int64_t> distributedDims =
1161 getDistributedDims(resultType, distResultTy);
1162
1163 // Collect updated sizes, offsets, strides. Pad to full source rank.
1164 int64_t sourceRank = op.getSourceVectorType().getRank();
1165 SmallVector<Attribute> updatedSizes =
1166 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1167 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1168 op.getOffsets(), [](Attribute attr) { return attr; });
1169 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1170 op.getStrides(), [](Attribute attr) { return attr; });
1171 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1172 updatedSizes.push_back(
1173 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1174 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1175 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1176 }
1177
1178 // If the result is distributed, adjust offsets and sizes in the
1179 // distributed dimension.
1180 if (!distributedDims.empty()) {
1181 if (distributedDims.size() != 1)
1182 return rewriter.notifyMatchFailure(
1183 op, "only single dimension distribution is supported");
1184 int64_t distDim = distributedDims[0];
1185 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1186 if (!uArch)
1187 return rewriter.notifyMatchFailure(
1188 op, "target attribute required to determine subgroup size");
1189 int subgroupSize = uArch->getSubgroupSize();
1190 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1191 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1192 return rewriter.notifyMatchFailure(
1193 op, "source of extract_strided_slice lacks distribution layout");
1194 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1195 if (sourceDistrDimSize % subgroupSize != 0)
1196 return rewriter.notifyMatchFailure(
1197 op, "source size along distributed dim is not a multiple of "
1198 "subgroup size");
1199 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1200 // Only check lane_data for the distributed dimension. Non-distributed
1201 // dimensions may have non-unit lane_data (e.g., packed layouts).
1202 if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
1203 sourceLaneData[distDim] != 1)
1204 return rewriter.notifyMatchFailure(
1205 op, "expecting unit lane data along the distributed dimension");
1206 int64_t distrDimOffset =
1207 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1208 if (distrDimOffset % subgroupSize != 0)
1209 return rewriter.notifyMatchFailure(
1210 op, "offset along distributed dim is not a multiple of "
1211 "subgroup size");
1212 // Adjust sizes and offsets for the distributed dimension.
1213 updatedSizes[distDim] =
1214 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1215 updatedOffsets[distDim] =
1216 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1217 }
1218
1219 auto newOp = vector::ExtractStridedSliceOp::create(
1220 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1221 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1222 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1223 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1224 rewriter.replaceOp(op, newOp.getResult());
1225 return success();
1226 }
1227};
1228
1229/// This pattern distributes a subgroup-level `vector.broadcast` op to
1230/// workitem-level. The pattern supports three cases:
1231///
1232/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1233/// vector must have a slice layout of the result. If the distributed source
1234/// and target vector types are identical, this lowers to a no-op; otherwise,
1235/// it remains a broadcast but operates on distributed vectors.
1236///
1237/// 2) Broadcast a same-rank vector with identical layouts for source and
1238/// target: The source vector must have unit dimensions, and lane_data must
1239/// be unit size for those unit dims. This always lowers to a no-op.
1240///
1241/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
1242/// from scalar to distributed result type.
1243///
1244/// Example 1 (low-rank to high-rank broadcast):
1245/// ```
1246/// %0 = "some_op"() {layout_result_0 =
1247/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
1248/// dims = [0]>} : () -> vector<16xf16>
1249/// %1 = vector.broadcast %0 {layout_result_0 =
1250/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1251/// : vector<16xf16> to vector<16x16xf16>
1252/// ```
1253/// is distributed to:
1254/// ```
1255/// %0 = "some_op"() : () -> vector<1xf16>
1256/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
1257/// ```
1258///
1259/// Example 2 (same-rank broadcast, no-op):
1260/// ```
1261/// %0 = "some_op"() {layout_result_0 =
1262/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1263/// : () -> vector<16x1xf16>
1264/// %1 = vector.broadcast %0 {layout_result_0 =
1265/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1266/// : vector<16x1xf16> to vector<16x16xf16>
1267/// ```
1268/// is distributed to (no-op, source already matches distributed result type):
1269/// ```
1270/// %0 = "some_op"() : () -> vector<16x1xf16>
1271/// // broadcast is eliminated, %0 is used directly
1272/// ```
1273///
1274/// Example 3 (scalar to vector broadcast):
1275/// ```
1276/// %0 = "some_op"() : () -> f16
1277/// %1 = vector.broadcast %0 {layout_result_0 =
1278/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1279/// : f16 to vector<16x16xf16>
1280/// ```
1281/// is distributed to:
1282/// ```
1283/// %0 = "some_op"() : f16
1284/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
1285/// ```
1286struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
1287 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1288
1289 LogicalResult
1290 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1291 ConversionPatternRewriter &rewriter) const override {
1292 xegpu::DistributeLayoutAttr resultLayout =
1293 xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
1294 if (!resultLayout || !resultLayout.isForSubgroup())
1295 return rewriter.notifyMatchFailure(
1296 op, "result does not have subgroup distribute layout");
1297
1298 VectorType destType = op.getResultVectorType();
1299 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1300
1301 xegpu::DistributeLayoutAttr sourceLayout =
1302 xegpu::getTemporaryLayout(op->getOpOperand(0));
1303
1304 if (sourceType) {
1305 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1306 if (rankDiff > 0) {
1307 // Case 1: Low-rank to high-rank broadcast.
1308 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1309 op.emitWarning(
1310 "broadcast source layout must be a slice of result layout");
1311 } else if (rankDiff == 0) {
1312 // Case 2: Same-rank broadcast.
1313 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1314 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1315 broadcastUnitDimsSet.end());
1316 assert(sourceLayout.isEqualTo(
1317 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1318 "The sg_data for unit dimensions should be set as 1");
1319 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1320 }
1321 } else {
1322 // Case 3: Scalar to vector broadcast.
1323 if (sourceLayout)
1324 return rewriter.notifyMatchFailure(
1325 op, "broadcast from scalar must not have a layout attribute");
1326 }
1327
1328 auto destDistType =
1329 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1330 if (failed(destDistType))
1331 return rewriter.notifyMatchFailure(
1332 op, "failed to distribute the result vector type");
1333
1334 Value source = adaptor.getSource();
1335 // If the adapted source already matches the dest dist type, it's a no-op.
1336 if (source.getType() == destDistType.value()) {
1337 rewriter.replaceOp(op, source);
1338 return success();
1339 }
1340
1341 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1342 destDistType.value(), source);
1343 rewriter.replaceOp(op, newOp);
1344 return success();
1345 }
1346};
1347
1348/// Distributes a subgroup-level vector.insert_strided_slice op to
1349/// workitem-level. If the dest is distributed, the offsets are adjusted to
1350/// match the distributed types.
1351struct SgToWiVectorInsertStridedSlice
1352 : public OpConversionPattern<vector::InsertStridedSliceOp> {
1353 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1354
1355 LogicalResult
1356 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1357 ConversionPatternRewriter &rewriter) const override {
1358 xegpu::DistributeLayoutAttr resultLayout =
1359 xegpu::getTemporaryLayout(op->getOpResult(0));
1360 if (!resultLayout || !resultLayout.isForSubgroup())
1361 return failure();
1362
1363 VectorType destType = op.getDestVectorType();
1364 auto distDestTyOrFailure =
1365 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1366 if (failed(distDestTyOrFailure))
1367 return rewriter.notifyMatchFailure(
1368 op, "unable to compute distributed vector type from lane layout");
1369 VectorType distDestTy = *distDestTyOrFailure;
1370
1371 SmallVector<int64_t> destDistributedDims =
1372 getDistributedDims(destType, distDestTy);
1373
1374 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1375 op.getOffsets(), [](Attribute attr) { return attr; });
1376
1377 if (!destDistributedDims.empty()) {
1378 if (destDistributedDims.size() != 1)
1379 return rewriter.notifyMatchFailure(
1380 op, "only single dimension distribution is supported");
1381 int64_t destDistDim = destDistributedDims[0];
1382
1383 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1384 if (!uArch)
1385 return rewriter.notifyMatchFailure(
1386 op, "target attribute required to determine subgroup size");
1387 int subgroupSize = uArch->getSubgroupSize();
1388
1389 VectorType srcType = op.getSourceVectorType();
1390 // The distributed dim must be in the last k (source rank) dims of dest.
1391 int64_t sourceDistDim =
1392 destDistDim - (destType.getRank() - srcType.getRank());
1393 if (sourceDistDim < 0)
1394 return rewriter.notifyMatchFailure(
1395 op, "distributed dimension must be in the last k dims of dest");
1396
1397 auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
1398 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1399 if (!destLayout || !sourceLayout ||
1400 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1401 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1402 return rewriter.notifyMatchFailure(
1403 op, "source or dest of insert_strided_slice lacks distribution "
1404 "layout");
1405
1406 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1407 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1408 // Only check lane_data for the distributed dimension. Non-distributed
1409 // dimensions may have non-unit lane_data (e.g., packed layouts).
1410 if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
1411 destLaneData[destDistDim] != 1) ||
1412 (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
1413 sourceLaneData[sourceDistDim] != 1))
1414 return rewriter.notifyMatchFailure(
1415 op, "expecting unit lane data along the distributed dimension");
1416
1417 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1418 if (srcDistrDimSize % subgroupSize != 0)
1419 return rewriter.notifyMatchFailure(
1420 op, "source distributed dim size is not a multiple of "
1421 "subgroup size");
1422
1423 int64_t destDistrDimOffset =
1424 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1425 if (destDistrDimOffset % subgroupSize != 0)
1426 return rewriter.notifyMatchFailure(
1427 op, "offset along distributed dim is not a multiple of "
1428 "subgroup size");
1429 // Adjust offset for the distributed dimension.
1430 updatedOffsets[destDistDim] =
1431 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1432 }
1433
1434 auto newOp = vector::InsertStridedSliceOp::create(
1435 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1436 adaptor.getDest(),
1437 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1438 rewriter.replaceOp(op, newOp.getResult());
1439 return success();
1440 }
1441};
1442
1443/// Distributes a subgroup-level vector.insert op to workitem-level. Only
1444/// handles sub-vector insertion (value to store is VectorType, not scalar).
1445struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
1446 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1447
1448 LogicalResult
1449 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1450 ConversionPatternRewriter &rewriter) const override {
1451 // Only handle vector value-to-store (not scalar insertion).
1452 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1453 if (!valueType)
1454 return rewriter.notifyMatchFailure(op, "scalar insert not supported");
1455
1456 xegpu::DistributeLayoutAttr layout =
1457 xegpu::getTemporaryLayout(op->getOpResult(0));
1458 if (!layout || !layout.isForSubgroup())
1459 return failure();
1460
1461 // verify that the outer k dimensions (for offsets)
1462 // don't have non-unit lane_layout.
1463 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1464 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1465 [](int64_t v) { return v != 1; }))
1466 return rewriter.notifyMatchFailure(
1467 op, "only innermost dimension distribution is supported for "
1468 "vector.insert");
1469
1470 auto newOp = vector::InsertOp::create(
1471 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1472 op.getMixedPosition());
1473 rewriter.replaceOp(op, newOp.getResult());
1474 return success();
1475 }
1476};
1477
1478/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
1479struct SgToWiConvertLayout
1480 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
1481 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1482
1483 LogicalResult
1484 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1485 ConversionPatternRewriter &rewriter) const override {
1486 auto inputLayout = op.getInputLayoutAttr();
1487 auto targetLayout = op.getTargetLayoutAttr();
1488 Type valType = op.getResult().getType();
1489
1490 if (valType.isIntOrFloat()) {
1491 rewriter.replaceOp(op, op.getSource());
1492 return success();
1493 }
1494
1495 auto resShape = cast<VectorType>(valType).getShape();
1496 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1497 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1499 return rewriter.notifyMatchFailure(
1500 op, "lowering incompatible convert_layout not yet supported");
1501 }
1502
1503 rewriter.replaceOp(op, adaptor.getSource());
1504 return success();
1505 }
1506};
1507
1508struct XeGPUSgToWiDistributeExperimentalPass
1509 : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
1510 XeGPUSgToWiDistributeExperimentalPass> {
1511 void runOnOperation() override;
1512};
1513
1514} // namespace
1515
1516void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
1517
1518 // Recover temporary operand layouts for usage in patterns.
1519 Operation *root = getOperation();
1520 if (!xegpu::recoverTemporaryLayouts(root)) {
1521 signalPassFailure();
1522 return;
1523 }
1524
1525 // Collect existing UnrealizedConversionCastOps. These must be preserved.
1526 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1527 root->walk(
1528 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1529 // Perform a structural type conversion to convert structural ops to have WI
1530 // types. This will insert UnrealizedConversionCastOps to make the IR
1531 // valid.
1532 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1533 mlir::ValueRange inputs,
1534 mlir::Location loc) -> mlir::Value {
1535 UnrealizedConversionCastOp castOp =
1536 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1537 return castOp.getResult(0);
1538 };
1539 {
1540 ConversionTarget target(getContext());
1541 TypeConverter typeConverter;
1542 RewritePatternSet patterns(&getContext());
1543 typeConverter.addSourceMaterialization(materializeCast);
1544 typeConverter.addTargetMaterialization(materializeCast);
1547 patterns, target);
1549 typeConverter, patterns, target);
1550 target.addLegalOp<UnrealizedConversionCastOp>();
1551 (void)applyPartialConversion(root, target, std::move(patterns));
1552 }
1553 // Structural type conversion can generate some redundant
1554 // UnrealizedConversionCastOps to materialize the SG type from type converted
1555 // WI type. These are redundant at this point and can be eliminated by
1556 // inserting shape casts instead.
1557 // Example:
1558 // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
1559 // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
1560 // This can be replaced with:
1561 // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
1562 OpBuilder builder(root);
1563 root->walk([&](UnrealizedConversionCastOp op) {
1564 // If this op existed before, nothing to do.
1565 if (existingCasts.contains(op))
1566 return;
1567 // number of inputs and outputs must be 1.
1568 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
1569 return;
1570 // Both input and output types must be vector types.
1571 auto singleInput = op.getInputs()[0];
1572 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
1573 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
1574 if (!inputTy || !outputTy)
1575 return;
1576
1577 // Check if the defining op of the input is also an
1578 // UnrealizedConversionCastOp and it has a single user (which is this
1579 // op).
1580 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
1581 if (!definingOp || !definingOp->hasOneUse())
1582 return;
1583 auto inputOfDefiningOp = definingOp.getInputs()[0];
1584 // If the input of the defining op and output type are both vector types
1585 // have same number of elements, insert a shape cast.
1586 auto inputOfDefiningOpTy =
1587 dyn_cast<VectorType>(inputOfDefiningOp.getType());
1588 if (inputOfDefiningOpTy &&
1589 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
1590 builder.setInsertionPoint(op);
1591 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
1592 outputTy, inputOfDefiningOp);
1593 op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
1594 return;
1595 }
1596 });
1597 // At this point, we will have some dead UnrealizedConversionCastOps. Just
1598 // erase them.
1599 bool changed = true;
1600 while (changed) {
1601 changed = false;
1602 root->walk([&](UnrealizedConversionCastOp op) {
1603 // Skip existing casts.
1604 if (existingCasts.contains(op))
1605 return;
1606 if (op.use_empty()) {
1607 op.erase();
1608 changed = true;
1609 }
1610 });
1611 }
1612
1613 xegpu::removeTemporaryLayoutAttrs(getOperation());
1614}
1615
1617 TypeConverter &typeConverter) {
1618 // Any type other than TensorDescType and VectorType are legal as is.
1619 typeConverter.addConversion([](Type type) -> std::optional<Type> {
1620 if (!isa<TensorDescType, VectorType>(type))
1621 return type;
1622 return std::nullopt;
1623 });
1624 // For TensorDescType, drop the layout attribute if any.
1625 typeConverter.addConversion([](TensorDescType type) -> Type {
1626 if (type.getLayoutAttr()) {
1627 return type.dropLayouts();
1628 }
1629 return type;
1630 });
1631 // For VectorType, check if there is a distribute layout attribute on the
1632 // value. If so, convert to the distributed vector type based on the layout.
1633 typeConverter.addConversion([](Value v) -> std::optional<Type> {
1634 auto type = v.getType();
1635 // If value is not vector type, nothing to do.
1636 if (!isa<VectorType>(type))
1637 return std::nullopt;
1638 auto layout = xegpu::getDistributeLayoutAttr(v);
1639 if (!layout || !layout.isForSubgroup())
1640 return type;
1641 // Vector type is distributed based on lane layout.
1642 auto newTyOrFailure =
1643 getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
1644 if (failed(newTyOrFailure))
1645 return type;
1646 return *newTyOrFailure;
1647 });
1648}
1649
1651 TypeConverter &typeConverter, RewritePatternSet &patterns,
1654 // CreateNdDescOp is legal only if its result type has no layout attribute.
1655 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1656 [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
1657 // Any anchor XeGPU op is legal only if it has no anchor layout.
1658 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
1659 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1660 if (!anchorOp)
1661 return true;
1662 return !anchorOp.getAnchorLayout();
1663 });
1664 // Arith constants are legal only if they have no temporary layout attribute.
1665 target.addDynamicallyLegalOp<arith::ConstantOp>(
1666 [=](arith::ConstantOp op) -> bool {
1667 // If the result type is not a vector, it's legal.
1668 if (!isa<VectorType>(op.getResult().getType()))
1669 return true;
1670 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1671 });
1672 // In math and arith dialects, only handle elementwise ops with a single
1673 // result and with a result layout attribute.
1674 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1675 [=](Operation *op) -> std::optional<bool> {
1676 // Only handle elementwise mappable ops
1678 return true;
1679 // Only handle ops with single vector result
1680 if (op->getNumResults() != 1)
1681 return true;
1682
1683 VectorType resultType =
1684 dyn_cast<VectorType>(op->getResult(0).getType());
1685 if (!resultType)
1686 return true;
1687
1688 // Check if all operands are vectors of the same shape
1689 for (Value operand : op->getOperands()) {
1690 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1691 if (!operandType || operandType.getShape() != resultType.getShape()) {
1692 return true;
1693 }
1694 }
1695 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1696 });
1697 // vector::ReductionOp is legal only if its source has no distribute layout
1698 // attribute.
1699 target.addDynamicallyLegalOp<vector::ReductionOp>(
1700 [=](vector::ReductionOp op) -> bool {
1701 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
1702 return !layout;
1703 });
1704 // vector::MultiDimReductionOp op legality.
1705 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1706 [=](vector::MultiDimReductionOp op) -> bool {
1707 return !isValidSubgroupMultiReductionOp(op);
1708 });
1709 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1710 vector::TransposeOp, vector::BitCastOp,
1711 vector::ShapeCastOp, vector::StepOp,
1712 vector::BroadcastOp>([=](Operation *op) -> bool {
1713 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1714 });
1715 target.addDynamicallyLegalOp<vector::ExtractOp>(
1716 [=](vector::ExtractOp op) -> bool {
1717 if (!isa<VectorType>(op.getType()))
1718 return true;
1719 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1720 });
1721 target.addDynamicallyLegalOp<vector::InsertOp>(
1722 [=](vector::InsertOp op) -> bool {
1723 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1724 });
1725 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1726 [=](vector::ExtractStridedSliceOp op) -> bool {
1727 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1728 });
1729 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1730 [=](vector::InsertStridedSliceOp op) -> bool {
1731 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1732 });
1733 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
1734 patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
1735 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
1736 SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
1737 SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
1738 SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
1739 SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
1740 SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
1741 SgToWiVectorShapeCast, SgToWiBroadcast,
1742 SgToWiCreateMask<vector::CreateMaskOp>,
1743 SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
1744 patterns.getContext());
1745}
return success()
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
virtual int getSubgroupSize() const =0