MLIR 23.0.0git
XeGPUSgToLaneDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSgToLaneDistribute.cpp - XeGPU SG to Lane Pass ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
20#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/ValueRange.h"
29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34namespace mlir {
35namespace xegpu {
36#define GEN_PASS_DEF_XEGPUSGTOLANEDISTRIBUTE
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38} // namespace xegpu
39} // namespace mlir
40
41using namespace mlir;
42
43#define DEBUG_TYPE "xegpu-sg-to-lane-distribute"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45
46namespace {
47
48/// Casts the given vector value `v` to the expected vector type `expectedTy`.
49static Value castValueTo(ConversionPatternRewriter &rewriter,
50 TypedValue<VectorType> v, VectorType expectedTy) {
51 // If the type matches, simply return the value itself.
52 if (v.getType() == expectedTy)
53 return v;
54 // If only shape differs, use shape cast.
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
58
59 // Else create an unrealized cast.
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
61 expectedTy, ValueRange{v});
62 return newOp.getResult(0);
63}
64
65/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
66/// exactly 1 reduction dimension, it had valid result layout attribute, and
67/// result type can be distributed to lanes using the layout.
68static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
69 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
70 // If no layout, not valid.
71 if (!resLayout || !resLayout.isForSubgroup())
72 return false;
73 // Scalar result (e.g., vector<32xf32> to f32) is valid.
74 if (op.getType().isIntOrFloat())
75 return op.getReductionDims().size() == 1;
76 VectorType resTy = dyn_cast<VectorType>(op.getType());
77 if (!resTy)
78 return false;
79 // Compute the distributed result vector type based on the layout.
80 FailureOr<VectorType> resDistTypeOrFailure =
81 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
82 if (failed(resDistTypeOrFailure))
83 return false;
84 return op.getReductionDims().size() == 1;
85}
86
87/// A vector::MultiDimReductionOp is doing lane-local reduction if each lane
88/// is doing its own local reduction. In this case the result layout ensures
89/// that result vector is distributed to lanes, i.e. the result vector type is
90/// different from the distributed result vector type.
91static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
92 // Must be valid MultiDimReductionOp.
93 assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
94 "MultiDimReductionOp");
95 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
96 VectorType resTy = dyn_cast<VectorType>(op.getType());
97 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
98 return resTy != resDistTypeOrFailure.value();
99}
100
101/// Given a vector type and its distributed vector type, return the list of
102/// dimensions that are distributed.
103static SmallVector<int64_t> getDistributedDims(VectorType originalType,
104 VectorType distributedType) {
105 assert(originalType.getRank() == distributedType.getRank() &&
106 "original and distributed vector types must have the same rank");
107 SmallVector<int64_t> distributedDims;
108 for (int64_t i = 0; i < originalType.getRank(); ++i) {
109 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
110 distributedDims.push_back(i);
111 }
112 return distributedDims;
113}
114
115/// Distributes a subgroup-level CreateNdDesc op to lane-level CreateNdDesc
116/// op. This simply drops the layout attribute from the tensor descriptor type.
117struct SgToLaneCreateNdDesc
118 : public OpConversionPattern<xegpu::CreateNdDescOp> {
119 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
120
121 LogicalResult
122 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter) const override {
124 xegpu::TensorDescType resultType = op.getType();
125 // If no layout, nothing to do.
126 if (!resultType.getLayout())
127 return failure();
128
129 auto newOp = xegpu::CreateNdDescOp::create(
130 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
131 op->getAttrs());
132 rewriter.replaceOp(op, newOp.getResult());
133 return success();
134 }
135};
136
137/// Distributes a subgroup-level LoadNd op to lane-level LoadNd op. Output
138/// of lane-level LoadNd op is 1D. ShapeCast is added to restore the
139/// original rank.
140struct SgToLaneLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
141 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
142
143 LogicalResult
144 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
147 // If no layout, nothing to do.
148 if (!layout)
149 return failure();
150 // Check if the layout attached to the tensor descriptor is same as the
151 // anchor layout. Otherwise, this is a conflict.
152 if (op.getTensorDescType().getLayout() != layout)
153 return rewriter.notifyMatchFailure(
154 op, "conflicting layout attributes on tensor descriptor and anchor");
155 auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
156 if (!uArch)
157 return rewriter.notifyMatchFailure(
158 op, "xegpu::LoadNdOp require target attribute attached to "
159 "determine transpose "
160 "requirement");
161 auto supportedLaneResultTyOrFailure =
162 xegpu::getDistributedVectorType(op.getTensorDescType());
163 auto expectedLaneResultTyOrFailure =
164 xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
165 if (failed(supportedLaneResultTyOrFailure))
166 return rewriter.notifyMatchFailure(
167 op, "unable to compute the lane vector type for LoadNdOp");
168 if (failed(expectedLaneResultTyOrFailure))
169 return rewriter.notifyMatchFailure(
170 op, "unable to compute expected lane vector type from lane layout");
171 auto newOp = xegpu::LoadNdOp::create(
172 rewriter, op.getLoc(), supportedLaneResultTyOrFailure.value(),
173 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
174 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
175 op.getL3HintAttr(), /**layout**/ nullptr);
176 // Set the packed attribute if the layout requires it.
177 newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
178 // Set the transpose attribute if the layout requires it.
179 if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
180 newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
181 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
182 expectedLaneResultTyOrFailure.value()));
183 return success();
184 }
185};
186
187/// Distributes a subgroup-level StoreNd op to lane-level StoreNd op. Stored
188/// value in lane-level StoreNd op is 1D. ShapeCast is added to cast the
189/// incoming value to 1D.
190struct SgToLaneStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
191 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
192
193 LogicalResult
194 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter) const override {
196 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
197 // If no layout, nothing to do.
198 if (!layout)
199 return failure();
200 // Check if the layout attached to the tensor descriptor and value layout is
201 // same as the anchor layout. Otherwise, this is a conflict.
202 if (op.getTensorDescType().getLayout() != layout)
203 return rewriter.notifyMatchFailure(
204 op, "conflicting layout attributes on tensor descriptor and anchor");
205 auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
206 if (valueLayout != layout)
207 return rewriter.notifyMatchFailure(
208 op, "conflicting layout attributes on value and anchor");
209 auto supportedLaneValueTyOrFailure =
210 xegpu::getDistributedVectorType(op.getTensorDescType());
211 if (failed(supportedLaneValueTyOrFailure))
212 return rewriter.notifyMatchFailure(
213 op,
214 "unable to compute lane vector type for StoreNdOp value from tensor "
215 "descriptor");
216
217 xegpu::StoreNdOp::create(
218 rewriter, op.getLoc(),
219 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
220 supportedLaneValueTyOrFailure.value()),
221 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
222 op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
223 rewriter.eraseOp(op);
224 return success();
225 }
226};
227
228/// Distributes a subgroup-level Dpas op to lane-level Dpas op. All inpputs
229/// and output of lane-level Dpas op are 1D. Necessary casts are added to
230/// convert the inputs and output to/from 1D.
231struct SgToLaneDpas : public OpConversionPattern<xegpu::DpasOp> {
232 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
233
234 LogicalResult
235 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const override {
237 // Check if the op has A, B and CD layouts attached.
238 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
239 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
240 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
241 if (!layoutA || !layoutB || !layoutCd)
242 return failure();
243 auto laneResultTyOrFailure =
244 xegpu::getDistributedVectorType(op.getType(), layoutCd);
245 auto laneATypeOrFailure =
246 xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
247 auto laneBTypeOrFailure =
248 xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
249 auto expectedLaneResultTyOrFailure =
250 xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
251 if (failed(laneResultTyOrFailure) || failed(laneATypeOrFailure) ||
252 failed(laneBTypeOrFailure))
253 return rewriter.notifyMatchFailure(
254 op, "failed to calculate supported lane vector types for DpasOp "
255 "from layouts");
256 if (failed(expectedLaneResultTyOrFailure))
257 return rewriter.notifyMatchFailure(
258 op, "unable to compute expected lane vector type for DpasOp from "
259 "lane layout");
260
261 // Validate bit widths match uArch packed format requirements
262 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
263 if (uArch) {
264 const auto *uArchInstruction =
265 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
268 if (uArchInstruction) {
269 auto laneAType = laneATypeOrFailure.value();
270 auto laneBType = laneBTypeOrFailure.value();
271 // Calculate total packed bit width = element bit width * vector size
272 unsigned aPackedBitWidth =
273 laneAType.getElementTypeBitWidth() * laneAType.getNumElements();
274 unsigned bPackedBitWidth =
275 laneBType.getElementTypeBitWidth() * laneBType.getNumElements();
276 unsigned expectedABitSize = uArchInstruction->getPackedFormatBitSizeA();
277 unsigned expectedBBitSize = uArchInstruction->getPackedFormatBitSizeB();
278
279 if (aPackedBitWidth % expectedABitSize != 0)
280 return rewriter.notifyMatchFailure(
281 op,
282 "A operand packed bit width must be a multiple of uArch packed "
283 "format requirement");
284 if (bPackedBitWidth % expectedBBitSize != 0)
285 return rewriter.notifyMatchFailure(
286 op,
287 "B operand packed bit width must be a multiple of uArch packed "
288 "format requirement");
289 }
291
292 auto newOp = xegpu::DpasOp::create(
293 rewriter, op->getLoc(), laneResultTyOrFailure.value(),
294 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
295 laneATypeOrFailure.value()),
296 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
297 laneBTypeOrFailure.value()),
298 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
299 laneResultTyOrFailure.value()),
300 /** layoutA**/ nullptr,
301 /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
302 // Explicitly set the new types to enable correct type materializations.
303 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
304 expectedLaneResultTyOrFailure.value()));
305 return success();
306 }
307};
309/// Distributes elementwise ops to lane-level elementwise ops. This
310/// currently handles elementwise ops with single result only.
311struct SgToLaneElementWise : public ConversionPattern {
312 SgToLaneElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
313 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
315 LogicalResult
316 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
317 ConversionPatternRewriter &rewriter) const override {
318 // Only match ops with elementwise trait and single result.
320 return failure();
321
322 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
323 if (!resultType)
324 return rewriter.notifyMatchFailure(
325 op, "operation result is not a vector type");
326
327 xegpu::DistributeLayoutAttr layout =
328 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
329 if (!layout || !layout.isForSubgroup())
330 return rewriter.notifyMatchFailure(
331 op, "operation result does not have subgroup distribute layout");
332
333 auto laneShapeOrFailure =
334 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
335
336 if (failed(laneShapeOrFailure))
337 return rewriter.notifyMatchFailure(
338 op, "unable to compute lane vector type from the layout");
340 VectorType newResultType = laneShapeOrFailure.value();
341 OperationState state(op->getLoc(), op->getName());
342 state.addOperands(operands);
343 state.addTypes(newResultType);
344 // Copy all attributes except for DistributeLayoutAttr.
345 for (auto attr : op->getAttrs()) {
346 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
347 state.addAttribute(attr.getName(), attr.getValue());
348 }
349 Operation *newOp = rewriter.create(state);
350
351 rewriter.replaceOp(op, newOp->getResult(0));
352 return success();
353 }
354};
355
356/// Distributes a subgroup-level arith ConstantOp to lane-level arith
357/// ConstantOp.
358struct SgToLaneArithConstant : public OpConversionPattern<arith::ConstantOp> {
359 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
360
361 LogicalResult
362 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override {
364 auto resultType = dyn_cast<VectorType>(op.getType());
365 if (!resultType)
366 return failure();
367
368 // Only handle dense vector constants
369 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
370 if (!dense)
371 return rewriter.notifyMatchFailure(
372 op, "only dense splat vector constants are supported");
373
374 xegpu::DistributeLayoutAttr layout =
375 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
376 if (!layout || !layout.isForSubgroup())
377 return rewriter.notifyMatchFailure(
378 op, "operation result does not have subgroup distribute layout");
379
380 auto laneShapeOrFailure =
381 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
382
383 if (failed(laneShapeOrFailure))
384 return rewriter.notifyMatchFailure(
385 op, "unable to compute lane vector type from the layout");
386
387 VectorType newResultType = laneShapeOrFailure.value();
388 auto sclarValue = dense.getSplatValue<Attribute>();
389 auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
390
391 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
392 newDenseAttr);
393 rewriter.replaceOp(op, newOp.getResult());
394 return success();
395 }
396};
397
398/// Distributes a subgroup-level PrefetchNd op to lane-level PrefetchNd op.
399struct SgToLanePrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
400 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
401
402 LogicalResult
403 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter) const override {
405 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
406 // If no layout, nothing to do.
407 if (!layout)
408 return failure();
409
410 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
411 op.getMixedOffsets(), op.getL1HintAttr(),
412 op.getL2HintAttr(), op.getL3HintAttr(),
413 /**layout**/ nullptr);
414 rewriter.eraseOp(op);
415 return success();
416 }
417};
418
419/// Distributes a subgroup-level LoadGather (xegpu.load) op to lane-level.
420///
421/// Example 1 (1D, no chunk size):
422/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
423/// %mask = producer_op : vector<16xi1>
424/// %offset = producer_op : vector<16xindex>
425/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
426/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
427/// Distributed to:
428/// %mask = producer_op : vector<1xi1>
429/// %offset = producer_op : vector<1xindex>
430/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
431/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
432///
433/// Example 2 (2D with chunk size, same mask & offset):
434/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
435/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
436/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
437/// Distributed to:
438/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
439/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
440///
441/// Example 3 (3D with leading unit dims):
442/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
443/// %mask = producer_op : vector<1x1x16xi1>
444/// %offset = producer_op : vector<1x1x16xindex>
445/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
446/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
447/// Distributed to:
448/// %mask = producer_op : vector<1x1x1xi1>
449/// %offset = producer_op : vector<1x1x1xindex>
450/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
451/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
452struct SgToLaneLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
453 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
454
455 LogicalResult
456 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter) const override {
458 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
459 if (!layout)
460 return failure();
461
462 VectorType origResultTy = op.getValueType();
463 if (!origResultTy)
464 return failure();
465
466 // Check that leading dimensions are unit.
467 int chunkSize = op.getChunkSize().value_or(1);
468 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
469 ArrayRef<int64_t> shape = origResultTy.getShape();
470 if (llvm::any_of(
471 shape.take_front(origResultTy.getRank() - effectiveVecRank),
472 [](int64_t d) { return d != 1; }))
473 return rewriter.notifyMatchFailure(
474 op, "Only unit dimensions allowed for the leading "
475 "dimensions of the load vector!");
476
477 auto distResultTyOrFailure =
478 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
479 if (failed(distResultTyOrFailure))
480 return rewriter.notifyMatchFailure(
481 op, "unable to compute expected lane vector type from lane layout");
482
483 VectorType distResultTy = distResultTyOrFailure.value();
484 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
485 distResultTy.getElementType());
486
487 // Flatten offsets and mask to 1D to match the 1D result type.
488 Value distOffsets = adaptor.getOffsets();
489 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
490 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
491 distOffsetsTy.getElementType());
492 distOffsets = castValueTo(
493 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
494
495 Value distMask = adaptor.getMask();
496 auto distMaskTy = cast<VectorType>(distMask.getType());
497 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
498 distMaskTy.getElementType());
499 distMask =
500 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
501
502 Value distSource = adaptor.getSource();
503 auto newOp = xegpu::LoadGatherOp::create(
504 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
505 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
506 op.getL3HintAttr(), /*layout=*/nullptr);
507
508 Value result = newOp->getResult(0);
509 if (distResultTy1D != distResultTy)
510 result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
511 distResultTy);
512 rewriter.replaceOp(op, result);
513 return success();
514 }
515};
516
517/// This pattern distributes a subgroup-level vector.reduction op to
518/// lane-level. This require shuffling the data across the lanes (using
519/// gpu::ShuffleOp) and reducing in stages until all lanes have the final
520/// result.
521struct SgToLaneVectorReduction
522 : public OpConversionPattern<vector::ReductionOp> {
523 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
524
525 LogicalResult
526 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter) const override {
528 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
529
530 // If no layout, nothing to do.
531 if (!layout || !layout.isForSubgroup())
532 return failure();
533
534 VectorType srcVecType = op.getSourceVectorType();
535 // Only rank 1 vectors supported.
536 if (srcVecType.getRank() != 1)
537 return rewriter.notifyMatchFailure(
538 op, "Only rank 1 reductions can be distributed.");
539 // Lane layout must have the same rank as the vector.
540 if (layout.getRank() != srcVecType.getRank())
541 return rewriter.notifyMatchFailure(
542 op, "Layout rank does not match vector rank.");
543
544 // Get the subgroup size from the layout.
545 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
546 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
547 if (!uArch)
548 return rewriter.notifyMatchFailure(
549 op, "xegpu::ReductionOp require target attribute attached to "
550 "determine subgroup size");
551
552 // Only subgroup-sized vectors supported.
553 if (sgSize != uArch->getSubgroupSize() ||
554 srcVecType.getShape()[0] % sgSize != 0)
555 return rewriter.notifyMatchFailure(op,
556 "Invalid layout or reduction vector "
557 "dimension must match subgroup size.");
558
559 if (!op.getType().isIntOrFloat())
560 return rewriter.notifyMatchFailure(
561 op, "Reduction distribution currently only supports floats and "
562 "integer types.");
563
564 // Get the distributed vector (per lane portion).
565 Value laneValVec = adaptor.getVector();
566
567 // Distribute and reduce across lanes in the subgroup.
568 Value fullReduce = xegpu::subgroupReduction(
569 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
570
571 // If there's an accumulator, combine it with the reduced value.
572 if (adaptor.getAcc())
573 fullReduce = vector::makeArithReduction(
574 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
575
576 rewriter.replaceOp(op, fullReduce);
577 return success();
578 }
579};
580
581/// This pattern distributes a subgroup-level vector.multi_reduction op to
582/// lane-level only if the reduction is lane-local. This means that
583/// reduction dimension is not distributed to lanes and each lane does its own
584/// local reduction.
585struct SgToLaneMultiDimReduction
586 : public OpConversionPattern<vector::MultiDimReductionOp> {
587 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
588
589 LogicalResult
590 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter) const override {
592 Value result;
593 ArrayRef<int64_t> reductionDims = op.getReductionDims();
594 assert(reductionDims.size() == 1 &&
595 "Expecting single reduction dimension for subgroup multi "
596 "reduction op");
597 // For rank > 2, ensure leading dimensions are unit.
598 VectorType sourceType = op.getSourceVectorType();
599 int64_t rank = sourceType.getRank();
600 if (rank > 2) {
601 ArrayRef<int64_t> shape = sourceType.getShape();
602 if (llvm::any_of(shape.take_front(rank - 2),
603 [](int64_t d) { return d != 1; }))
604 return rewriter.notifyMatchFailure(
605 op, "only unit leading dimensions are supported for "
606 "multi_reduction with rank > 2");
607 }
608 // Handle scalar result: full reduction of a distributed vector to a
609 // scalar. First do a local vector reduction, then cross-lane shuffles.
610 if (op.getType().isIntOrFloat()) {
611 auto reductionDim = reductionDims[0];
612 VectorType origSourceType = op.getSourceVectorType();
613 int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
614 // Local reduction to scalar, then cross-lane butterfly shuffles.
615 result =
616 xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
617 op.getKind(), reductionDimSize);
618 // Combine with accumulator if present.
619 if (adaptor.getAcc())
620 result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
621 result, adaptor.getAcc());
622 } else if (isReductionLaneLocal(op)) {
623 // For lane-local reduction, lower to a sequence of vector.reduction ops
624 // over 1D slices extracted from the distributed source vector. This is
625 // required so we dont have 2D source vectors at xegpu-linearize.
626 auto reductionDim = reductionDims[0];
628 cast<TypedValue<VectorType>>(adaptor.getSource()),
629 cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
630 reductionDim, op.getLoc(), rewriter);
631 } else {
632 auto reductionDim = reductionDims[0];
633 VectorType sourceType = op.getSourceVectorType();
634 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
636 cast<TypedValue<VectorType>>(adaptor.getSource()),
637 cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
638 reductionDim, reductionDimSize, op.getLoc(), rewriter);
639 }
640 rewriter.replaceOp(op, result);
641 return success();
642 }
643};
644
645/// Helper to compute distributed coordinates for matrix ops.
646/// When not using subgroup_block_io, each lane computes its own
647/// coordinates based on the layout and lane ID.
648static SmallVector<Value> computeDistributedCoordsForMatrixOp(
649 ConversionPatternRewriter &rewriter, Location loc,
650 xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
651 ValueRange origOffsets) {
652 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
653 /*upperBound=*/mlir::IntegerAttr());
654 auto maybeCoords =
655 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
656 if (failed(maybeCoords))
657 return {};
658 assert(maybeCoords.value().size() == 1 &&
659 "Expected one set of distributed offsets");
661 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
662 getAsOpFoldResult(origOffsets));
663 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
664}
665
666/// This pattern distributes a subgroup-level LoadMatrix op to lane-level.
667struct SgToLaneLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
668 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
669
670 LogicalResult
671 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
672 ConversionPatternRewriter &rewriter) const override {
673 auto layout = op.getLayoutAttr();
674 // If no layout, nothing to do.
675 if (!layout)
676 return failure();
677
678 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
679 if (!sgPayloadTy)
680 return rewriter.notifyMatchFailure(
681 op, "the matrix op payload must be a vector type");
682
683 auto loc = op.getLoc();
684 auto offsets = op.getMixedOffsets();
685 if (offsets.empty())
686 return rewriter.notifyMatchFailure(op, "the load op must have offsets");
687
688 FailureOr<VectorType> distPayloadTyOrFailure =
689 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
690 if (failed(distPayloadTyOrFailure))
691 return rewriter.notifyMatchFailure(
692 op, "Failed to distribute matrix op payload based on layout.");
693
694 SmallVector<Value> offsetsAsValues =
695 vector::getAsValues(rewriter, loc, offsets);
696
697 SmallVector<Value> newCoords = offsetsAsValues;
698 if (!op.getSubgroupBlockIoAttr()) {
699 newCoords = computeDistributedCoordsForMatrixOp(
700 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
701 if (newCoords.empty())
702 return rewriter.notifyMatchFailure(
703 op, "Failed to compute distributed coordinates.");
704 }
705
706 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
707 ShapedType::kDynamic);
708 DenseI64ArrayAttr newConstOffsetsAttr =
709 rewriter.getDenseI64ArrayAttr(newConstOffsets);
710
711 auto newOp = xegpu::LoadMatrixOp::create(
712 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
713 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
714 xegpu::DistributeLayoutAttr{});
715 rewriter.replaceOp(op, newOp.getResult());
716 return success();
717 }
718};
719
720/// Distributes a subgroup-level vector.transpose op to lane-level.
721struct SgToLaneVectorTranspose
722 : public OpConversionPattern<vector::TransposeOp> {
723 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
724
725 LogicalResult
726 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter) const override {
728 xegpu::DistributeLayoutAttr sourceLayout =
729 xegpu::getTemporaryLayout(op->getOpOperand(0));
730 xegpu::DistributeLayoutAttr resultLayout =
731 xegpu::getTemporaryLayout(op->getOpResult(0));
732 if (!sourceLayout || !resultLayout)
733 return rewriter.notifyMatchFailure(
734 op, "the source or result vector of the transpose op lacks layout "
735 "attribute");
736 ArrayRef<int64_t> perm = op.getPermutation();
737 // Result layout must be a transpose of source layout.
738 if (!resultLayout.isTransposeOf(sourceLayout, perm,
739 xegpu::LayoutKind::Lane))
740 return rewriter.notifyMatchFailure(
741 op, "the source or result vector layouts must be transposes of "
742 "each other");
743 FailureOr<VectorType> distributedResultTypeOrFailure =
744 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
745 if (failed(distributedResultTypeOrFailure))
746 return rewriter.notifyMatchFailure(
747 op, "Failed to distribute the result vector type in "
748 "vector::Transpose op");
749 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
750 adaptor.getVector(), perm);
751 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
752 distributedResultTypeOrFailure.value()));
753 return success();
754 }
755};
756
757/// Distributes a subgroup-level vector.bitcast op to lane-level.
758/// Bitcast only impacts the innermost dimension of the source/result vectors.
759struct SgToLaneVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
760 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
761
762 LogicalResult
763 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
764 ConversionPatternRewriter &rewriter) const override {
765 xegpu::DistributeLayoutAttr resultLayout =
766 xegpu::getTemporaryLayout(op->getOpResult(0));
767 if (!resultLayout)
768 return rewriter.notifyMatchFailure(
769 op, "result vector of the bitcast op lacks layout attribute");
770 FailureOr<VectorType> distributedResultTypeOrFailure =
771 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
772 if (failed(distributedResultTypeOrFailure))
773 return rewriter.notifyMatchFailure(
774 op, "Failed to distribute the result vector type in "
775 "vector::BitCast op");
776 auto newOp = vector::BitCastOp::create(
777 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
778 adaptor.getSource());
779 rewriter.replaceOp(op, newOp.getResult());
780 return success();
781 }
782};
783
784/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
785/// to lane-level. Uses `computeDistributedCoords()` to obtain the
786/// coordinates each lane owns, then compares each coordinate against the
787/// original mask bounds using `arith.cmpi slt`. The per-element boolean
788/// results are assembled into the distributed mask vector.
789///
790/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
791/// satisfy `coord[i] < bound[i]`.
792///
793/// Example (1D):
794/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
795/// %mask = vector.create_mask %m0 : vector<16xi1>
796/// For lane k, computeDistributedCoords gives coord = [k], so:
797/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1
798/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
799///
800/// Example (2D):
801/// layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
802/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
803/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
804/// [[r0, c0], [r0, c1]]
805/// For each coord: in_bounds = (r < m0) && (c < m1)
806/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
807template <typename OpType,
808 typename = std::enable_if_t<llvm::is_one_of<
809 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
810struct SgToLaneCreateMask : public OpConversionPattern<OpType> {
811 using OpConversionPattern<OpType>::OpConversionPattern;
812
813 LogicalResult
814 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
815 ConversionPatternRewriter &rewriter) const override {
816 xegpu::DistributeLayoutAttr layout =
817 xegpu::getTemporaryLayout(op->getOpResult(0));
818 if (!layout || !layout.isForSubgroup())
819 return rewriter.notifyMatchFailure(
820 op, "operation result does not have subgroup distribute layout");
821
822 VectorType origType = op.getType();
823 FailureOr<VectorType> distTypeOrFailure =
824 getDistVecTypeBasedOnLaneLayout(layout, origType);
825 if (failed(distTypeOrFailure))
826 return rewriter.notifyMatchFailure(
827 op, "unable to compute lane vector type from the layout");
828
829 VectorType distType = distTypeOrFailure.value();
830 Location loc = op.getLoc();
831
832 // Materialize the original mask bounds as Values.
833 SmallVector<Value> origBounds;
834 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
835 origBounds.append(op.getOperands().begin(), op.getOperands().end());
836 } else {
837 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
838 for (auto dimSize : dimSizes)
839 origBounds.push_back(
840 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
841 }
842
843 ArrayRef<int64_t> origShape = origType.getShape();
844
845 // Use computeDistributedCoords to get the coordinates each WI owns.
846 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
847 /*upperBound=*/mlir::IntegerAttr());
848 auto maybeCoordsVec =
849 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
850 if (failed(maybeCoordsVec))
851 return rewriter.notifyMatchFailure(
852 op, "failed to compute distributed coordinates from layout");
853
854 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
855 int64_t numElements = distType.getNumElements();
856 assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
857 "number of coordinate sets must match number of distributed "
858 "elements");
859
860 // For each element, compare all coordinates against bounds.
861 Value trueVal =
862 arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
863 SmallVector<Value> maskBits;
864 for (auto &coords : coordsVec) {
865 Value inBounds = trueVal;
866 for (size_t i = 0; i < coords.size(); ++i) {
867 Value cmp = arith::CmpIOp::create(
868 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
869 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
870 }
871 maskBits.push_back(inBounds);
872 }
873
874 // Build the distributed mask vector.
875 Value result;
876 if (numElements == 1) {
877 result =
878 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
879 } else {
880 result =
881 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
882 }
883 rewriter.replaceOp(op, result);
884 return success();
885 }
886};
887
888/// This pattern distributes a subgroup-level StoreMatrix op to lane-level.
889struct SgToLaneStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
890 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
891
892 LogicalResult
893 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
894 ConversionPatternRewriter &rewriter) const override {
895 auto layout = op.getLayoutAttr();
896 // If no layout, nothing to do.
897 if (!layout)
898 return failure();
899
900 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
901 if (!sgPayloadTy)
902 return rewriter.notifyMatchFailure(
903 op, "the matrix op payload must be a vector type");
904
905 auto loc = op.getLoc();
906 auto offsets = op.getMixedOffsets();
907 if (offsets.empty())
908 return rewriter.notifyMatchFailure(op, "the store op must have offsets");
909
910 FailureOr<VectorType> distPayloadTyOrFailure =
911 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
912 if (failed(distPayloadTyOrFailure))
913 return rewriter.notifyMatchFailure(
914 op, "Failed to distribute matrix op payload based on layout.");
915
916 SmallVector<Value> offsetsAsValues =
917 vector::getAsValues(rewriter, loc, offsets);
918
919 SmallVector<Value> newCoords = offsetsAsValues;
920 if (!op.getSubgroupBlockIoAttr()) {
921 newCoords = computeDistributedCoordsForMatrixOp(
922 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
923 if (newCoords.empty())
924 return rewriter.notifyMatchFailure(
925 op, "Failed to compute distributed coordinates.");
926 }
927
928 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
929 ShapedType::kDynamic);
930 DenseI64ArrayAttr newConstOffsetsAttr =
931 rewriter.getDenseI64ArrayAttr(newConstOffsets);
932
933 xegpu::StoreMatrixOp::create(
934 rewriter, loc, TypeRange{},
935 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
936 distPayloadTyOrFailure.value()),
937 adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
938 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
939 rewriter.eraseOp(op);
940 return success();
941 }
942};
943
944/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
945/// lane-level.
946///
947/// Example 1 (1D, no chunk size):
948/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
949/// %mask = producer_op : vector<16xi1>
950/// %offset = producer_op : vector<16xindex>
951/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
952/// memref<256xf16>, vector<16xindex>, vector<16xi1>
953/// Distributed to:
954/// %mask = producer_op : vector<1xi1>
955/// %offset = producer_op : vector<1xindex>
956/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
957/// memref<256xf16>, vector<1xindex>, vector<1xi1>
958///
959/// Example 2 (2D with chunk size, same mask & offset):
960/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
961/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
962/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
963/// Distributed to:
964/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
965/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
966///
967/// Example 3 (3D with leading unit dims):
968/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
969/// %mask = producer_op : vector<1x1x16xi1>
970/// %offset = producer_op : vector<1x1x16xindex>
971/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
972/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
973/// Distributed to:
974/// %mask = producer_op : vector<1x1x1xi1>
975/// %offset = producer_op : vector<1x1x1xindex>
976/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
977/// memref<256xf16>, vector<1xindex>, vector<1xi1>
978struct SgToLaneStoreScatter
979 : public OpConversionPattern<xegpu::StoreScatterOp> {
980 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
981
982 LogicalResult
983 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter) const override {
985 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
986 if (!layout)
987 return failure();
988
989 VectorType origValueTy = op.getValueType();
990 if (!origValueTy)
991 return failure();
992
993 // Check that all leading dimensions are unit dimensions.
994 int chunkSize = op.getChunkSize().value_or(1);
995 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
996 ArrayRef<int64_t> shape = origValueTy.getShape();
997 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
998 [](int64_t d) { return d != 1; }))
999 return rewriter.notifyMatchFailure(
1000 op, "Only unit dimensions allowed for the leading "
1001 "dimensions of the store vector!");
1002
1003 auto distValueTyOrFailure =
1004 xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
1005 if (failed(distValueTyOrFailure))
1006 return rewriter.notifyMatchFailure(
1007 op, "unable to compute expected lane vector type from lane layout");
1008
1009 VectorType distValueTy = distValueTyOrFailure.value();
1010 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
1011 distValueTy.getElementType());
1012
1013 Value distValue = adaptor.getValue();
1014 if (distValue.getType() != distValueTy1D)
1015 distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
1016 distValueTy1D);
1017
1018 // Flatten offsets and mask to 1D to match the 1D value type.
1019 Value distOffsets = adaptor.getOffsets();
1020 auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
1021 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
1022 distOffsetsTy.getElementType());
1023 distOffsets = castValueTo(
1024 rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
1025
1026 Value distMask = adaptor.getMask();
1027 auto distMaskTy = cast<VectorType>(distMask.getType());
1028 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
1029 distMaskTy.getElementType());
1030 distMask =
1031 castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
1032
1033 Value distDest = adaptor.getDest();
1034 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1035 distOffsets, distMask, op.getChunkSizeAttr(),
1036 op.getL1HintAttr(), op.getL2HintAttr(),
1037 op.getL3HintAttr(), /*layout=*/nullptr);
1038 rewriter.eraseOp(op);
1039 return success();
1040 }
1041};
1042
1043/// Distribute a vector::StepOp to lane-level.
1044/// The layout must have exactly 1 effective lane dimension.
1045/// We completely resolve the vector::StepOp by computing the lane_data-sized
1046/// subranges.
1047struct SgToLaneVectorStep : public OpConversionPattern<vector::StepOp> {
1048 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1049
1050 LogicalResult
1051 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1052 ConversionPatternRewriter &rewriter) const override {
1053 xegpu::DistributeLayoutAttr resultLayout =
1054 xegpu::getTemporaryLayout(op->getResult(0));
1055 if (!resultLayout || !resultLayout.isForSubgroup())
1056 return rewriter.notifyMatchFailure(
1057 op, "the result vector of the step op lacks subgroup layout");
1058
1059 auto loc = op.getLoc();
1060 auto stepResultVecTy = op.getResult().getType();
1061 auto laneShapeOrFailure =
1062 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
1063 if (failed(laneShapeOrFailure))
1064 return rewriter.notifyMatchFailure(
1065 op, "unable to compute lane vector type from the layout");
1066 VectorType newVecTy = laneShapeOrFailure.value();
1067
1068 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1069 /*upperBound=*/mlir::IntegerAttr());
1070 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1071 rewriter, loc, laneId, stepResultVecTy.getShape());
1072 if (failed(laneDataBlockCoords))
1073 return rewriter.notifyMatchFailure(
1074 op, "failed to compute lane data block coordinates");
1075
1076 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1077 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1078 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
1079 newVecTy.getNumElements() / laneDataBlockLength);
1080 SmallVector<Value> stepVals;
1081 // For each lane_data block, reconstruct its sub-range
1082 // from the range of SG-level vector.step.Example: vector.step
1083 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
1084 // vector<16xindex>
1085 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
1086 // The blocks are round-robin distributed, so logical lane id 0
1087 // holds values [0,1, 8,9].
1088 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1089 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1090 stepVals.push_back(laneDataBlockStartCoord);
1091 for (int i = 1; i < laneDataBlockLength; ++i) {
1092 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
1093 stepVals.push_back(arith::AddIOp::create(
1094 rewriter, loc, laneDataBlockStartCoord, offset));
1095 }
1096 }
1097 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
1098 "Expecting the number of step values to match the number of "
1099 "elements in the vector");
1100 auto stepOpVal =
1101 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1102 rewriter.replaceOp(op, stepOpVal);
1103 return success();
1104 }
1105};
1106
1107/// Distributes a subgroup-level vector.extract op to lane-level. Only
1108/// handles sub-vector extraction (result is VectorType, not scalar).
1109struct SgToLaneVectorExtract : public OpConversionPattern<vector::ExtractOp> {
1110 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1111
1112 LogicalResult
1113 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1114 ConversionPatternRewriter &rewriter) const override {
1115 // Only handle vector results (not scalar extraction).
1116 auto resultType = dyn_cast<VectorType>(op.getType());
1117 if (!resultType)
1118 return rewriter.notifyMatchFailure(op, "scalar extract not supported");
1119
1120 xegpu::DistributeLayoutAttr layout =
1121 xegpu::getTemporaryLayout(op->getOpResult(0));
1122 if (!layout || !layout.isForSubgroup())
1123 return failure();
1124
1125 // This implementation assumes distribution only happens on the innermost
1126 // dimension. Verify that lane_layout[0...n-2] are all unit.
1127 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1128 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1129 [](int64_t v) { return v != 1; }))
1130 return rewriter.notifyMatchFailure(
1131 op, "only innermost dimension distribution is supported for "
1132 "vector.extract");
1133
1134 auto newOp = vector::ExtractOp::create(
1135 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1136 rewriter.replaceOp(op, newOp.getResult());
1137 return success();
1138 }
1139};
1140
1141/// This pattern distributes a subgroup-level ShapeCast op to lane-level.
1142struct SgToLaneVectorShapeCast
1143 : public OpConversionPattern<vector::ShapeCastOp> {
1144 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1145
1146 LogicalResult
1147 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const override {
1149 xegpu::DistributeLayoutAttr resultLayout =
1150 xegpu::getTemporaryLayout(op->getOpResult(0));
1151 if (!resultLayout || !resultLayout.isForSubgroup())
1152 return rewriter.notifyMatchFailure(
1153 op, "the result vector of the shape_cast op lacks subgroup layout");
1154
1155 auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
1156 resultLayout, op.getResultVectorType());
1157 if (failed(resultDistTypeOrFailure))
1158 return rewriter.notifyMatchFailure(
1159 op, "failed to get distributed vector type for result");
1160
1161 Value source = adaptor.getSource();
1162 auto newShapeCast = vector::ShapeCastOp::create(
1163 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1164 rewriter.replaceOp(op, newShapeCast);
1165 return success();
1166 }
1167};
1168
1169/// Distributes a subgroup-level vector.extract_strided_slice op to
1170/// lane-level. If the result is distributed, the offsets and sizes are
1171/// adjusted to match the distributed types.
1172struct SgToLaneVectorExtractStridedSlice
1173 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
1174 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1175
1176 LogicalResult
1177 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1178 ConversionPatternRewriter &rewriter) const override {
1179 xegpu::DistributeLayoutAttr resultLayout =
1180 xegpu::getTemporaryLayout(op->getOpResult(0));
1181 if (!resultLayout || !resultLayout.isForSubgroup())
1182 return failure();
1183
1184 VectorType resultType = op.getType();
1185 auto distResultTyOrFailure =
1186 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
1187 if (failed(distResultTyOrFailure))
1188 return rewriter.notifyMatchFailure(
1189 op, "unable to compute distributed vector type from lane layout");
1190 VectorType distResultTy = *distResultTyOrFailure;
1191
1192 SmallVector<int64_t> distributedDims =
1193 getDistributedDims(resultType, distResultTy);
1194
1195 // Collect updated sizes, offsets, strides. Pad to full source rank.
1196 int64_t sourceRank = op.getSourceVectorType().getRank();
1197 SmallVector<Attribute> updatedSizes =
1198 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1199 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1200 op.getOffsets(), [](Attribute attr) { return attr; });
1201 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1202 op.getStrides(), [](Attribute attr) { return attr; });
1203 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1204 updatedSizes.push_back(
1205 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1206 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1207 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1208 }
1209
1210 // If the result is distributed, adjust offsets and sizes in the
1211 // distributed dimension.
1212 if (!distributedDims.empty()) {
1213 if (distributedDims.size() != 1)
1214 return rewriter.notifyMatchFailure(
1215 op, "only single dimension distribution is supported");
1216 int64_t distDim = distributedDims[0];
1217 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1218 if (!uArch)
1219 return rewriter.notifyMatchFailure(
1220 op, "target attribute required to determine subgroup size");
1221 int subgroupSize = uArch->getSubgroupSize();
1222 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1223 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1224 return rewriter.notifyMatchFailure(
1225 op, "source of extract_strided_slice lacks distribution layout");
1226 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1227 auto laneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1228 // Effective subgroup size needs to be adjusted if laneLayout along
1229 // the distributed dimension is smaller than subgroup size.
1230 if (laneLayout[distDim] < subgroupSize &&
1231 subgroupSize % laneLayout[distDim] == 0)
1232 subgroupSize = laneLayout[distDim];
1233 if (sourceDistrDimSize % subgroupSize != 0)
1234 return rewriter.notifyMatchFailure(
1235 op, "source size along distributed dim is not a multiple of "
1236 "subgroup size");
1237 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1238 // Only check lane_data for the distributed dimension. Non-distributed
1239 // dimensions may have non-unit lane_data (e.g., packed layouts).
1240 if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
1241 sourceLaneData[distDim] != 1)
1242 return rewriter.notifyMatchFailure(
1243 op, "expecting unit lane data along the distributed dimension");
1244 int64_t distrDimOffset =
1245 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1246 if (distrDimOffset % subgroupSize != 0)
1247 return rewriter.notifyMatchFailure(
1248 op, "offset along distributed dim is not a multiple of "
1249 "subgroup size");
1250 // Adjust sizes and offsets for the distributed dimension.
1251 updatedSizes[distDim] =
1252 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1253 updatedOffsets[distDim] =
1254 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1255 }
1256
1257 auto newOp = vector::ExtractStridedSliceOp::create(
1258 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1259 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1260 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1261 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1262 rewriter.replaceOp(op, newOp.getResult());
1263 return success();
1264 }
1265};
1266
1267/// This pattern distributes a subgroup-level `vector.broadcast` op to
1268/// lane-level. The pattern supports three cases:
1269///
1270/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1271/// vector must have a slice layout of the result. If the distributed source
1272/// and target vector types are identical, this lowers to a no-op; otherwise,
1273/// it remains a broadcast but operates on distributed vectors.
1274///
1275/// 2) Broadcast a same-rank vector with identical layouts for source and
1276/// target: The source vector must have unit dimensions, and lane_data must
1277/// be unit size for those unit dims. This always lowers to a no-op.
1278///
1279/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
1280/// from scalar to distributed result type.
1281///
1282/// Example 1 (low-rank to high-rank broadcast):
1283/// ```
1284/// %0 = "some_op"() {layout_result_0 =
1285/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
1286/// dims = [0]>} : () -> vector<16xf16>
1287/// %1 = vector.broadcast %0 {layout_result_0 =
1288/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1289/// : vector<16xf16> to vector<16x16xf16>
1290/// ```
1291/// is distributed to:
1292/// ```
1293/// %0 = "some_op"() : () -> vector<1xf16>
1294/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
1295/// ```
1296///
1297/// Example 2 (same-rank broadcast, no-op):
1298/// ```
1299/// %0 = "some_op"() {layout_result_0 =
1300/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1301/// : () -> vector<16x1xf16>
1302/// %1 = vector.broadcast %0 {layout_result_0 =
1303/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1304/// : vector<16x1xf16> to vector<16x16xf16>
1305/// ```
1306/// is distributed to (no-op, source already matches distributed result type):
1307/// ```
1308/// %0 = "some_op"() : () -> vector<16x1xf16>
1309/// // broadcast is eliminated, %0 is used directly
1310/// ```
1311///
1312/// Example 3 (scalar to vector broadcast):
1313/// ```
1314/// %0 = "some_op"() : () -> f16
1315/// %1 = vector.broadcast %0 {layout_result_0 =
1316/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
1317/// : f16 to vector<16x16xf16>
1318/// ```
1319/// is distributed to:
1320/// ```
1321/// %0 = "some_op"() : f16
1322/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
1323/// ```
1324struct SgToLaneBroadcast : public OpConversionPattern<vector::BroadcastOp> {
1325 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1326
1327 LogicalResult
1328 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1329 ConversionPatternRewriter &rewriter) const override {
1330 xegpu::DistributeLayoutAttr resultLayout =
1331 xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
1332 if (!resultLayout || !resultLayout.isForSubgroup())
1333 return rewriter.notifyMatchFailure(
1334 op, "result does not have subgroup distribute layout");
1335
1336 VectorType destType = op.getResultVectorType();
1337 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1338
1339 xegpu::DistributeLayoutAttr sourceLayout =
1340 xegpu::getTemporaryLayout(op->getOpOperand(0));
1341
1342 if (sourceType) {
1343 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1344 if (rankDiff > 0) {
1345 // Case 1: Low-rank to high-rank broadcast.
1346 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1347 op.emitWarning(
1348 "broadcast source layout must be a slice of result layout");
1349 } else if (rankDiff == 0) {
1350 // Case 2: Same-rank broadcast.
1351 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1352 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1353 broadcastUnitDimsSet.end());
1354 assert(sourceLayout.isEqualTo(
1355 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1356 "The sg_data for unit dimensions should be set as 1");
1357 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1358 }
1359 } else {
1360 // Case 3: Scalar to vector broadcast.
1361 if (sourceLayout)
1362 return rewriter.notifyMatchFailure(
1363 op, "broadcast from scalar must not have a layout attribute");
1364 }
1365
1366 auto destDistType =
1367 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1368 if (failed(destDistType))
1369 return rewriter.notifyMatchFailure(
1370 op, "failed to distribute the result vector type");
1371
1372 Value source = adaptor.getSource();
1373 // If the adapted source already matches the dest dist type, it's a no-op.
1374 if (source.getType() == destDistType.value()) {
1375 rewriter.replaceOp(op, source);
1376 return success();
1377 }
1378
1379 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1380 destDistType.value(), source);
1381 rewriter.replaceOp(op, newOp);
1382 return success();
1383 }
1384};
1385
1386/// Distributes a subgroup-level vector.insert_strided_slice op to
1387/// lane-level. If the dest is distributed, the offsets are adjusted to
1388/// match the distributed types.
1389struct SgToLaneVectorInsertStridedSlice
1390 : public OpConversionPattern<vector::InsertStridedSliceOp> {
1391 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1392
1393 LogicalResult
1394 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1395 ConversionPatternRewriter &rewriter) const override {
1396 xegpu::DistributeLayoutAttr resultLayout =
1397 xegpu::getTemporaryLayout(op->getOpResult(0));
1398 if (!resultLayout || !resultLayout.isForSubgroup())
1399 return failure();
1400
1401 VectorType destType = op.getDestVectorType();
1402 auto distDestTyOrFailure =
1403 xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1404 if (failed(distDestTyOrFailure))
1405 return rewriter.notifyMatchFailure(
1406 op, "unable to compute distributed vector type from lane layout");
1407 VectorType distDestTy = *distDestTyOrFailure;
1408
1409 SmallVector<int64_t> destDistributedDims =
1410 getDistributedDims(destType, distDestTy);
1411
1412 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1413 op.getOffsets(), [](Attribute attr) { return attr; });
1414
1415 if (!destDistributedDims.empty()) {
1416 if (destDistributedDims.size() != 1)
1417 return rewriter.notifyMatchFailure(
1418 op, "only single dimension distribution is supported");
1419 int64_t destDistDim = destDistributedDims[0];
1420
1421 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1422 if (!uArch)
1423 return rewriter.notifyMatchFailure(
1424 op, "target attribute required to determine subgroup size");
1425 int subgroupSize = uArch->getSubgroupSize();
1426
1427 VectorType srcType = op.getSourceVectorType();
1428 // The distributed dim must be in the last k (source rank) dims of dest.
1429 int64_t sourceDistDim =
1430 destDistDim - (destType.getRank() - srcType.getRank());
1431 if (sourceDistDim < 0)
1432 return rewriter.notifyMatchFailure(
1433 op, "distributed dimension must be in the last k dims of dest");
1434
1435 auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
1436 auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
1437 if (!destLayout || !sourceLayout ||
1438 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1439 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1440 return rewriter.notifyMatchFailure(
1441 op, "source or dest of insert_strided_slice lacks distribution "
1442 "layout");
1443
1444 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1445 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1446 // Only check lane_data for the distributed dimension. Non-distributed
1447 // dimensions may have non-unit lane_data (e.g., packed layouts).
1448 if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
1449 destLaneData[destDistDim] != 1) ||
1450 (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
1451 sourceLaneData[sourceDistDim] != 1))
1452 return rewriter.notifyMatchFailure(
1453 op, "expecting unit lane data along the distributed dimension");
1454
1455 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1456 if (srcDistrDimSize % subgroupSize != 0)
1457 return rewriter.notifyMatchFailure(
1458 op, "source distributed dim size is not a multiple of "
1459 "subgroup size");
1460
1461 int64_t destDistrDimOffset =
1462 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1463 if (destDistrDimOffset % subgroupSize != 0)
1464 return rewriter.notifyMatchFailure(
1465 op, "offset along distributed dim is not a multiple of "
1466 "subgroup size");
1467 // Adjust offset for the distributed dimension.
1468 updatedOffsets[destDistDim] =
1469 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1470 }
1471
1472 auto newOp = vector::InsertStridedSliceOp::create(
1473 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1474 adaptor.getDest(),
1475 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1476 rewriter.replaceOp(op, newOp.getResult());
1477 return success();
1478 }
1479};
1480
1481/// Distributes a subgroup-level vector.insert op to lane-level. Only
1482/// handles sub-vector insertion (value to store is VectorType, not scalar).
1483struct SgToLaneVectorInsert : public OpConversionPattern<vector::InsertOp> {
1484 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1485
1486 LogicalResult
1487 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1488 ConversionPatternRewriter &rewriter) const override {
1489 // Only handle vector value-to-store (not scalar insertion).
1490 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1491 if (!valueType)
1492 return rewriter.notifyMatchFailure(op, "scalar insert not supported");
1493
1494 xegpu::DistributeLayoutAttr layout =
1495 xegpu::getTemporaryLayout(op->getOpResult(0));
1496 if (!layout || !layout.isForSubgroup())
1497 return failure();
1498
1499 // verify that the outer k dimensions (for offsets)
1500 // don't have non-unit lane_layout.
1501 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1502 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1503 [](int64_t v) { return v != 1; }))
1504 return rewriter.notifyMatchFailure(
1505 op, "only innermost dimension distribution is supported for "
1506 "vector.insert");
1507
1508 auto newOp = vector::InsertOp::create(
1509 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1510 op.getMixedPosition());
1511 rewriter.replaceOp(op, newOp.getResult());
1512 return success();
1513 }
1514};
1515
1516/// Redistributes `src` for a `convert_layout` that changes only the
1517/// `lane_layout` along the outer (distributed) dimension, shrinking it from
1518/// `currentLaneNum` to `targetLaneNum` lanes (a partial-subgroup
1519/// distribution). Because the data is no longer replicated across all lanes,
1520/// each surviving lane must gather the values that previously lived in the
1521/// lanes that are dropped. The values are gathered with `gpu.shuffle` and
1522/// concatenated with the lane-local data using `vector.shuffle`, which doubles
1523/// the distributed outer dimension when the lane count is halved.
1524///
1525/// Only halving the lane count (a factor of two) is currently supported.
1526/// Returns the redistributed value on success, or failure if `src` cannot be
1527/// shuffled (e.g. it is not a rank-2 vector or its bit width is not a multiple
1528/// of 32).
1529static FailureOr<Value>
1530shuffleDataAsLaneLayoutChange(ConversionPatternRewriter &rewriter, Location loc,
1531 Value src, int64_t currentLaneNum,
1532 int64_t targetLaneNum) {
1533 VectorType srcTy = dyn_cast<VectorType>(src.getType());
1534 if (!srcTy || srcTy.getRank() != 2)
1535 return failure();
1536 // Only halving the lane count (factor of two) is supported for now.
1537 if (targetLaneNum <= 0 || currentLaneNum != targetLaneNum * 2)
1538 return failure();
1539 // gpu.shuffle operates on i32, so the data must be a multiple of 32 bits.
1540 int64_t vectorBitWidth =
1541 srcTy.getNumElements() * srcTy.getElementTypeBitWidth();
1542 if (vectorBitWidth % 32 != 0)
1543 return failure();
1544
1545 // A vector cannot be shuffled across lanes directly:
1546 // -- cast the source to a 1D vector of i32
1547 // -- create a temp 1D vector of i32 initialized to zero
1548 // -- for each i32 element:
1549 // ---- extract it from the source bundle
1550 // ---- gpu.shuffle to gather the value from the partner lane
1551 // ---- insert it into the temp bundle
1552 // -- cast the temp back to the source vector type
1553 // -- vector.shuffle the source and temp to concatenate along the outer dim
1554 Type shuffleElemTy = rewriter.getI32Type();
1555 int64_t numShuffles = vectorBitWidth / 32;
1556 VectorType shuffleBundleTy = VectorType::get({numShuffles}, shuffleElemTy);
1557 // Initialize temp to zero.
1558 Value temp = arith::ConstantOp::create(
1559 rewriter, loc,
1560 DenseElementsAttr::get(shuffleBundleTy,
1561 IntegerAttr::get(shuffleElemTy, 0)));
1562 VectorType flatSrcTy =
1563 VectorType::get({srcTy.getNumElements()}, srcTy.getElementType());
1564 Value flatSrc = vector::ShapeCastOp::create(rewriter, loc, flatSrcTy, src);
1565 Value shuffleBundle =
1566 vector::BitCastOp::create(rewriter, loc, shuffleBundleTy, flatSrc);
1567 for (int64_t i = 0; i < numShuffles; i++) {
1568 Value shuffleElem =
1569 vector::ExtractOp::create(rewriter, loc, shuffleBundle, i);
1570 shuffleElem = gpu::ShuffleOp::create(rewriter, loc, shuffleElem, 0,
1571 targetLaneNum, gpu::ShuffleMode::UP)
1572 .getResult(0);
1573 temp = vector::InsertOp::create(rewriter, loc, shuffleElem, temp, i);
1574 }
1575 temp = vector::BitCastOp::create(rewriter, loc, flatSrcTy, temp);
1576 temp = vector::ShapeCastOp::create(rewriter, loc, srcTy, temp);
1577
1578 // Concatenate the lane-local and gathered data along the outer dimension.
1579 SmallVector<int64_t> indices(srcTy.getShape()[0] * 2);
1580 std::iota(indices.begin(), indices.end(), 0);
1581 Value res = vector::ShuffleOp::create(rewriter, loc, src, temp, indices);
1582 return res;
1583}
1584
1585/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
1586struct SgToLaneConvertLayout
1587 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
1588 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1589
1590 LogicalResult
1591 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1592 ConversionPatternRewriter &rewriter) const override {
1593 auto inputLayout = op.getInputLayoutAttr();
1594 auto targetLayout = op.getTargetLayoutAttr();
1595 Type valType = op.getResult().getType();
1596
1597 if (valType.isIntOrFloat()) {
1598 rewriter.replaceOp(op, op.getSource());
1599 return success();
1600 }
1601
1602 auto resShape = cast<VectorType>(valType).getShape();
1603 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1604
1605 // Equivalent layouts: the convert_layout is a no-op and folds to its
1606 // source.
1607 if (inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1608 xegpu::LayoutKind::Lane)) {
1609 rewriter.replaceOp(op, adaptor.getSource());
1610 return success();
1611 }
1612
1613 // Handle the special case where the conversion redistributes a value
1614 // across a fraction of the subgroup: the lane_layout shrinks along the
1615 // outer (distributed) dimension while lane_data stays the same. Only a
1616 // pure outer-dimension lane_layout change is supported, so the inner
1617 // lane_layout must be unit (making the outer dim the only distributed one)
1618 // and the outer lane_layout must be genuinely distributed (> 1), which
1619 // also rules out the degenerate [1, 1] layout.
1620 if (inputLayout.getEffectiveOrderAsInt() ==
1621 targetLayout.getEffectiveOrderAsInt() &&
1622 inputLayout.getRank() == 2 && targetLayout.getRank() == 2) {
1623 auto laneLayout = inputLayout.getEffectiveLaneLayoutAsInt();
1624 auto targetLaneLayout = targetLayout.getEffectiveLaneLayoutAsInt();
1625 auto laneData = inputLayout.getEffectiveLaneDataAsInt();
1626 auto targetLaneData = targetLayout.getEffectiveLaneDataAsInt();
1627 if (laneLayout.size() == 2 && targetLaneLayout.size() == 2 &&
1628 laneData == targetLaneData && laneLayout[1] == 1 &&
1629 targetLaneLayout[1] == 1 && laneLayout[0] > 1 &&
1630 laneLayout[0] != targetLaneLayout[0]) {
1631 FailureOr<Value> res = shuffleDataAsLaneLayoutChange(
1632 rewriter, op.getLoc(), adaptor.getSource(), laneLayout[0],
1633 targetLaneLayout[0]);
1634 if (succeeded(res)) {
1635 rewriter.replaceOp(op, *res);
1636 return success();
1637 }
1638 }
1639 }
1640
1641 return rewriter.notifyMatchFailure(
1642 op, "lowering incompatible convert_layout not yet supported");
1643 }
1644};
1645
1646// Trivially distribute `vector.interleave`
1647struct SgToLaneVectorInterleave
1648 : public OpConversionPattern<vector::InterleaveOp> {
1649 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1650
1651 LogicalResult
1652 matchAndRewrite(vector::InterleaveOp op, OpAdaptor adaptor,
1653 ConversionPatternRewriter &rewriter) const override {
1654
1655 auto newOp = vector::InterleaveOp::create(
1656 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
1657 rewriter.replaceOp(op, newOp.getResult());
1658 return success();
1659 }
1660};
1661
1662// Trivially distribute `vector.deinterleave`
1663struct SgToLaneVectorDeinterleave
1664 : public OpConversionPattern<vector::DeinterleaveOp> {
1665 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1666
1667 LogicalResult
1668 matchAndRewrite(vector::DeinterleaveOp op, OpAdaptor adaptor,
1669 ConversionPatternRewriter &rewriter) const override {
1670
1671 auto newOp = vector::DeinterleaveOp::create(rewriter, op.getLoc(),
1672 adaptor.getSource());
1673 rewriter.replaceOp(op, newOp.getResults());
1674 return success();
1675 }
1676};
1677
1678struct SgToLaneDpasMx : public OpConversionPattern<xegpu::DpasMxOp> {
1679 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
1680
1681 LogicalResult
1682 matchAndRewrite(xegpu::DpasMxOp op, OpAdaptor adaptor,
1683 ConversionPatternRewriter &rewriter) const override {
1684 const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
1685 if (!uArch)
1686 return failure();
1687 if (!uArch->isSupportedInstruction(
1688 xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc))
1689 return rewriter.notifyMatchFailure(
1690 op, "target uArch does not support scaled subgroup mma");
1691 // Check if the op has A, B and CD layouts attached.
1692 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
1693 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
1694 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
1695 if (!layoutA || !layoutB || !layoutCd)
1696 return rewriter.notifyMatchFailure(
1697 op, "missing required layout attributes for DpasMxOp distribution");
1698
1699 // Retrieve expected types, according to anchor layouts.
1700 auto expected1DTypeResult =
1701 xegpu::getDistributedVectorType(op.getType(), layoutCd);
1702 auto expected1DTypeA =
1703 xegpu::getDistributedVectorType(op.getA().getType(), layoutA);
1704 auto expected1DTypeB =
1705 xegpu::getDistributedVectorType(op.getB().getType(), layoutB);
1706
1707 VectorType expected1DTypeScaleA, expected1DTypeScaleB;
1708 if (op.getScaleA()) {
1709 auto layoutScaleA = cast<xegpu::LayoutAttr>(op.getLayoutAScaleAttr());
1710 auto expected1DTypeScaleAOrFailure = xegpu::getDistributedVectorType(
1711 cast<VectorType>(op.getScaleA().getType()), layoutScaleA);
1712 if (failed(expected1DTypeScaleAOrFailure))
1713 return rewriter.notifyMatchFailure(
1714 op, "failed to calculate expected 1D vector type for scale A");
1715 expected1DTypeScaleA = expected1DTypeScaleAOrFailure.value();
1716 }
1717 if (op.getScaleB()) {
1718 auto layoutScaleB = cast<xegpu::LayoutAttr>(op.getLayoutBScaleAttr());
1719 auto expected1DTypeScaleBOrFailure = xegpu::getDistributedVectorType(
1720 cast<VectorType>(op.getScaleB().getType()), layoutScaleB);
1721 if (failed(expected1DTypeScaleBOrFailure))
1722 return rewriter.notifyMatchFailure(
1723 op, "failed to calculate expected 1D vector type for scale B");
1724 expected1DTypeScaleB = expected1DTypeScaleBOrFailure.value();
1725 }
1726
1727 auto expectedNDTypeResult =
1728 xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
1729 if (failed(expected1DTypeResult) || failed(expected1DTypeA) ||
1730 failed(expected1DTypeB))
1731 return rewriter.notifyMatchFailure(
1732 op,
1733 "failed to calculate supported workitem 1D vector types for DpasOp "
1734 "from layouts");
1735 if (failed(expectedNDTypeResult))
1736 return rewriter.notifyMatchFailure(
1737 op, "unable to compute expected workitem vector type for DpasOp from "
1738 "lane layout");
1739
1740 // Validate bit widths match uArch packed format requirements
1741 const auto *uArchInstruction = dyn_cast<
1743 xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
1744 assert(uArchInstruction);
1745 auto wiAType = expected1DTypeA.value();
1746 auto wiBType = expected1DTypeB.value();
1747 // Calculate total packed bit width = element bit width * vector size
1748 unsigned aPackedBitWidth =
1749 wiAType.getElementTypeBitWidth() * wiAType.getNumElements();
1750 unsigned bPackedBitWidth =
1751 wiBType.getElementTypeBitWidth() * wiBType.getNumElements();
1752 if (aPackedBitWidth % uArchInstruction->getPackedFormatBitSizeA())
1753 return rewriter.notifyMatchFailure(
1754 op, "A operand packed bit width must be a multiple of uArch packed "
1755 "format requirement");
1756 if (bPackedBitWidth % uArchInstruction->getPackedFormatBitSizeB())
1757 return rewriter.notifyMatchFailure(
1758 op, "B operand packed bit width must be a multiple of uArch packed "
1759 "format requirement");
1760
1761 auto newOp = xegpu::DpasMxOp::create(
1762 rewriter, op->getLoc(), expected1DTypeResult.value(),
1763 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getA()),
1764 expected1DTypeA.value()),
1765 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getB()),
1766 expected1DTypeB.value()),
1767 op.getAcc()
1768 ? castValueTo(rewriter,
1769 cast<TypedValue<VectorType>>(adaptor.getAcc()),
1770 expected1DTypeResult.value())
1771 : nullptr,
1772
1773 op.getScaleA()
1774 ? castValueTo(rewriter,
1775 cast<TypedValue<VectorType>>(adaptor.getScaleA()),
1776 expected1DTypeScaleA)
1777 : nullptr,
1778 op.getScaleB()
1779 ? castValueTo(rewriter,
1780 cast<TypedValue<VectorType>>(adaptor.getScaleB()),
1781 expected1DTypeScaleB)
1782 : nullptr,
1783 /** layoutA**/ nullptr,
1784 /** layoutB**/ nullptr, /** layoutCd**/ nullptr,
1785 /** layoutAScale**/ nullptr, /** layoutBScale**/ nullptr);
1786 // Explicitly set the new types to enable correct type materializations.
1787 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
1788 expectedNDTypeResult.value()));
1789 return success();
1790 }
1791};
1792
1793struct XeGPUSgToLaneDistributePass
1795 XeGPUSgToLaneDistributePass> {
1796 void runOnOperation() override;
1797};
1798
1799} // namespace
1800
1801void XeGPUSgToLaneDistributePass::runOnOperation() {
1802
1803 // Recover temporary operand layouts for usage in patterns.
1804 Operation *root = getOperation();
1805 if (!xegpu::recoverTemporaryLayouts(root)) {
1806 signalPassFailure();
1807 return;
1808 }
1809
1810 // Collect existing UnrealizedConversionCastOps. These must be preserved.
1811 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1812 root->walk(
1813 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1814 // Perform a structural type conversion to convert structural ops to have WI
1815 // types. This will insert UnrealizedConversionCastOps to make the IR
1816 // valid.
1817 {
1818 ConversionTarget target(getContext());
1819 TypeConverter typeConverter;
1820 RewritePatternSet patterns(&getContext());
1821 // Source (N:1) and target (1:1) materializations using
1822 // UnrealizedConversionCastOp.
1823 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
1824 Location loc) -> Value {
1825 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
1826 .getResult(0);
1827 };
1828 typeConverter.addSourceMaterialization(materializeCast);
1829 typeConverter.addTargetMaterialization(materializeCast);
1832 patterns, target);
1834 typeConverter, patterns, target, root);
1835 target.addLegalOp<UnrealizedConversionCastOp>();
1836 (void)applyPartialConversion(root, target, std::move(patterns));
1837 }
1838 // Fold cancelling cast chains and erase dead casts.
1839 xegpu::cleanupUnrealizedConversionCasts(root, existingCasts);
1840 xegpu::removeTemporaryLayoutAttrs(getOperation());
1841}
1842
1844 TypeConverter &typeConverter, Operation *topLevelOp) {
1845 // Pass through any type by default; more specific conversions registered
1846 // below override this for TensorDescType and (distributing) VectorType.
1847 typeConverter.addConversion([](Type type) -> Type { return type; });
1848 // For TensorDescType, drop the layout attribute if any.
1849 typeConverter.addConversion([](TensorDescType type) -> Type {
1850 if (type.getLayoutAttr()) {
1851 return type.dropLayouts();
1852 }
1853 return type;
1854 });
1855 // For VectorType, distribute based on the lane layout (1:1 shape-changing
1856 // conversion). Uses xegpu::addVectorTypeConversion with a pre-computed
1857 // map for SCF loop block args (see precomputeLoopBlockArgTypes for the
1858 // rationale).
1859 auto getSubShapeAndCount = [](VectorType vecTy,
1860 xegpu::DistributeLayoutAttr layout)
1861 -> std::pair<SmallVector<int64_t>, int> {
1862 auto distTyOrFailure = getDistVecTypeBasedOnLaneLayout(layout, vecTy);
1863 if (failed(distTyOrFailure))
1864 return {{}, 0};
1865 return {SmallVector<int64_t>(distTyOrFailure->getShape()), 1};
1866 };
1867 auto loopArgTypes =
1868 xegpu::precomputeLoopBlockArgTypes(topLevelOp, getSubShapeAndCount);
1869 xegpu::addVectorTypeConversion(typeConverter, getSubShapeAndCount,
1870 std::move(loopArgTypes));
1871}
1872
1874 TypeConverter &typeConverter, RewritePatternSet &patterns,
1875 ConversionTarget &target, Operation *topLevelOp) {
1876 populateXeGPUSgToLaneDistributeTypeConversions(typeConverter, topLevelOp);
1877 // CreateNdDescOp is legal only if its result type has no layout attribute.
1878 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1879 [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
1880 // Any anchor XeGPU op is legal only if it has no anchor layout.
1881 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
1882 if (isa<xegpu::ConvertLayoutOp>(op))
1883 return false;
1884 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1885 if (!anchorOp)
1886 return true;
1887 return !anchorOp.getAnchorLayout();
1888 });
1889 // Arith constants are legal only if they have no temporary layout attribute.
1890 target.addDynamicallyLegalOp<arith::ConstantOp>(
1891 [=](arith::ConstantOp op) -> bool {
1892 // If the result type is not a vector, it's legal.
1893 if (!isa<VectorType>(op.getResult().getType()))
1894 return true;
1895 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1896 });
1897 // In math and arith dialects, only handle elementwise ops with a single
1898 // result and with a result layout attribute.
1899 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1900 [=](Operation *op) -> std::optional<bool> {
1901 // Only handle elementwise mappable ops
1903 return true;
1904 // Only handle ops with single vector result
1905 if (op->getNumResults() != 1)
1906 return true;
1907
1908 VectorType resultType =
1909 dyn_cast<VectorType>(op->getResult(0).getType());
1910 if (!resultType)
1911 return true;
1912
1913 // Check if all operands are vectors of the same shape
1914 for (Value operand : op->getOperands()) {
1915 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1916 if (!operandType || operandType.getShape() != resultType.getShape()) {
1917 return true;
1918 }
1919 }
1920 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1921 });
1922 // vector::ReductionOp is legal only if its source has no distribute layout
1923 // attribute.
1924 target.addDynamicallyLegalOp<vector::ReductionOp>(
1925 [=](vector::ReductionOp op) -> bool {
1926 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
1927 return !layout;
1928 });
1929 // vector::MultiDimReductionOp op legality.
1930 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1931 [=](vector::MultiDimReductionOp op) -> bool {
1932 return !isValidSubgroupMultiReductionOp(op);
1933 });
1934 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1935 vector::TransposeOp, vector::BitCastOp,
1936 vector::ShapeCastOp, vector::StepOp,
1937 vector::BroadcastOp>([=](Operation *op) -> bool {
1938 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1939 });
1940 target.addDynamicallyLegalOp<vector::ExtractOp>(
1941 [=](vector::ExtractOp op) -> bool {
1942 if (!isa<VectorType>(op.getType()))
1943 return true;
1944 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1945 });
1946 target.addDynamicallyLegalOp<vector::InsertOp>(
1947 [=](vector::InsertOp op) -> bool {
1948 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1949 });
1950 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1951 [=](vector::ExtractStridedSliceOp op) -> bool {
1952 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1953 });
1954 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1955 [=](vector::InsertStridedSliceOp op) -> bool {
1956 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1957 });
1958 target.addDynamicallyLegalOp<vector::InterleaveOp, vector::DeinterleaveOp>(
1959 [=](Operation *op) -> bool {
1960 return !xegpu::getTemporaryLayout(op->getOpResult(0));
1961 });
1962 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
1963 patterns.add<
1964 SgToLaneCreateNdDesc, SgToLaneLoadNd, SgToLaneStoreNd, SgToLaneDpas,
1965 SgToLaneElementWise, SgToLaneArithConstant, SgToLanePrefetchNd,
1966 SgToLaneLoadGather, SgToLaneStoreScatter, SgToLaneVectorReduction,
1967 SgToLaneMultiDimReduction, SgToLaneVectorExtract, SgToLaneVectorInsert,
1968 SgToLaneVectorExtractStridedSlice, SgToLaneVectorInsertStridedSlice,
1969 SgToLaneLoadMatrix, SgToLaneStoreMatrix, SgToLaneConvertLayout,
1970 SgToLaneVectorTranspose, SgToLaneVectorBitcast, SgToLaneVectorStep,
1971 SgToLaneVectorShapeCast, SgToLaneBroadcast,
1972 SgToLaneCreateMask<vector::CreateMaskOp>,
1973 SgToLaneCreateMask<vector::ConstantMaskOp>, SgToLaneVectorDeinterleave,
1974 SgToLaneVectorInterleave, SgToLaneDpasMx>(typeConverter,
1975 patterns.getContext());
1976}
return success()
b getContext())
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:283
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUSgToLaneDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, Operation *topLevelOp)
Defines type conversions and legality for XeGPU subgroup to lane distribution and appends the require...
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
DenseMap< Value, SmallVector< Type > > precomputeLoopBlockArgTypes(Operation *topLevelOp, SubShapeAndCountFn getSubShapeAndCount)
Pre-computes distributed VectorType mappings for every value carried through an SCF loop under topLev...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
void addVectorTypeConversion(TypeConverter &converter, SubShapeAndCountFn getSubShapeAndCount, DenseMap< Value, SmallVector< Type > > loopArgTypes)
Adds a context-aware VectorType conversion to converter (1:1 shape-changing or 1:N,...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUSgToLaneDistributeTypeConversions(TypeConverter &typeConverter, Operation *topLevelOp)
Define only the type conversions needed for XeGPU subgroup to lane distribution.
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
void cleanupUnrealizedConversionCasts(Operation *root, const llvm::SmallSetVector< UnrealizedConversionCastOp, 8 > &existingCasts)
Cleans up UnrealizedConversionCastOps inserted during SCF structural type conversion and/or XeGPU unr...
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
bool isSupportedInstruction(InstructionKind instr) const
Definition uArchBase.h:175
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:168