MLIR 23.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/Operation.h"
28#include "mlir/IR/TypeRange.h"
29#include "mlir/IR/Value.h"
30#include "mlir/IR/Visitors.h"
32#include "mlir/Support/LLVM.h"
36#include "llvm/ADT/ArrayRef.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/SmallVector.h"
39#include "llvm/ADT/SmallVectorExtras.h"
40
41namespace mlir {
42namespace xegpu {
43#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45} // namespace xegpu
46} // namespace mlir
47
48#define DEBUG_TYPE "xegpu-subgroup-distribute"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50
51using namespace mlir;
52
53static const char *const resolveSIMTTypeMismatch =
54 "resolve_simt_type_mismatch"; // Attribute name for identifying
55 // UnrelizedConversionCastOp added to resolve
56 // SIMT type mismatches.
57
58namespace {
59
60//===----------------------------------------------------------------------===//
61// SIMT Distribution Patterns
62//===----------------------------------------------------------------------===//
63
64/// In certain cases, we may need to favor XeGPU specific distribution patterns
65/// over generic vector distribution patterns. In such cases, we can assign
66/// priorities to patterns.
67enum PatternHierarchy : unsigned { Regular = 1, AboveRegular = 2 };
68
69/// Helper function to resolve types if the distributed type out of
70/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
71/// Example 1:
72/// distributed type: vector<8x1xf32>
73/// expected type: vector<8xf32>
74/// resolved using,
75/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
76/// Example 2:
77/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
78/// expected type: xegpu.tensor_desc<8x16xf32>
79/// resolved using,
80/// %0 = unrealized_conversion_cast %1 :
81/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
82/// xegpu.tensor_desc<8x16xf32>
83template <typename T>
84static Value resolveDistributedTy(Value orig, T expected,
85 PatternRewriter &rewriter) {
86 // If orig and expected types are the same, return orig.
87 if (orig.getType() == expected)
88 return orig;
89 // If orig is a vector type, create a shape cast op to reconcile the types.
90 if (isa<VectorType>(orig.getType())) {
91 auto castOp =
92 vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
93 return castOp.getResult();
94 }
95 // If orig is a tensor descriptor type, create an unrealized conversion cast
96 // op to reconcile the types.
97 if (isa<xegpu::TensorDescType>(orig.getType())) {
98 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
99 expected, orig);
100 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
101 return castOp.getResult(0);
102 }
103 llvm_unreachable("Unsupported type for reconciliation");
104 return orig;
105}
106
107/// Given a vector type and its distributed vector type, return the list of
108/// dimensions that are distributed.
109static SmallVector<int64_t> getDistributedDims(VectorType originalType,
110 VectorType distributedType) {
111 assert(originalType.getRank() == distributedType.getRank() &&
112 "sequential and distributed vector types must have the same rank");
113 SmallVector<int64_t> distributedDims;
114 for (int64_t i = 0; i < originalType.getRank(); ++i) {
115 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
116 distributedDims.push_back(i);
117 }
118 }
119 return distributedDims;
120}
121
122/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
123/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
124/// contained within a WarpExecuteOnLane0Op.
125/// Example:
126///
127/// ```
128/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
129/// ...
130/// ...
131/// gpu.return %result: vector<8x16xf32>
132/// }
133/// ```
134/// To
135/// ```
136/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
137/// %laneid = gpu.lane_id : index
138/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
139/// ...
140/// ...
141/// gpu.yield %result: vector<8x16xf32>
142/// }
143/// return %0
144/// }
145struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
146 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
147 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
148 PatternRewriter &rewriter) const override {
149 auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
150 if (!uArch)
151 return rewriter.notifyMatchFailure(
152 gpuFuncOp, "Subgroup distribution requires target attribute attached "
153 "to set the warp size");
154 // If the function only contains a single void return, skip.
155 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
156 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
157 }))
158 return failure();
159 // If the function already moved inside a warp_execute_on_lane0, skip.
160 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
161 return isa<gpu::WarpExecuteOnLane0Op>(op);
162 }))
163 return failure();
164 // Create a new function with the same signature and same attributes.
165 SmallVector<Type> workgroupAttributionsTypes =
166 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
167 [](BlockArgument arg) { return arg.getType(); });
168 SmallVector<Type> privateAttributionsTypes =
169 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
170 [](BlockArgument arg) { return arg.getType(); });
171 auto newGpuFunc = gpu::GPUFuncOp::create(
172 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
173 gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
174 privateAttributionsTypes);
175 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
176 // Create a WarpExecuteOnLane0Op with same arguments and results as the
177 // original gpuFuncOp.
178 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
179 auto laneId = gpu::LaneIdOp::create(
180 rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
181 /** upperBound = **/ mlir::IntegerAttr());
182 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
183 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
184 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
185 uArch->getSubgroupSize(), newGpuFunc.getArguments(),
186 newGpuFunc.getArgumentTypes());
187 Block &warpBodyBlock = warpOp.getBodyRegion().front();
188 // Replace the ReturnOp of the original gpu function with a YieldOp.
189 auto origRetunOp =
190 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
191 rewriter.setInsertionPointAfter(origRetunOp);
192 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
193 origRetunOp.getOperands());
194 rewriter.eraseOp(origRetunOp);
195 // Move the original function body to the WarpExecuteOnLane0Op body.
196 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
197 warpOp.getBodyRegion().begin());
198 rewriter.eraseBlock(&warpBodyBlock);
199 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
200 rewriter.setInsertionPointAfter(warpOp);
201 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
202 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
203 return success();
204 }
205};
206
207/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
208/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
209/// still contain the original op that will not be used by the yield op (and
210/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
211/// arguments. Tensor descriptor shape is not distributed because it is a
212/// uniform value across all work items within the subgroup. However, the
213/// layout information is dropped in the new tensor descriptor type.
214///
215/// Example:
216///
217/// ```
218/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
219/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
220/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
221/// ...
222/// %td = xegpu.create_nd_tdesc %arg0
223/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
224/// vector.yield %td
225/// }
226/// ```
227/// To
228/// ```
229/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
230/// ...
231/// %dead = xegpu.create_nd_tdesc %arg0
232/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
233/// vector.yield %arg0, %dead
234/// }
235/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
236/// -> !xegpu.tensor_desc<4x8xf32>
237///
238/// ```
239struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
240 using gpu::WarpDistributionPattern::WarpDistributionPattern;
241 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
242 PatternRewriter &rewriter) const override {
243 OpOperand *operand =
244 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
245 if (!operand)
246 return rewriter.notifyMatchFailure(
247 warpOp, "warp result is not a xegpu::CreateNdDesc op");
248 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
249 unsigned operandIdx = operand->getOperandNumber();
250
251 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
252 if (!layout)
253 return rewriter.notifyMatchFailure(
254 descOp, "the tensor descriptor lacks layout attribute");
255 // CreateNdOp must not have offsets.
256 if (descOp.getMixedOffsets().size())
257 return rewriter.notifyMatchFailure(
258 descOp, "xegpu::CreateNdDescOp must not have offsets");
259
260 SmallVector<size_t> newRetIndices;
261 rewriter.setInsertionPoint(warpOp);
262 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
263 rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
264 /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
265
266 SmallVector<Value> newDescOperands = llvm::map_to_vector(
267 newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
268 rewriter.setInsertionPointAfter(newWarpOp);
269 xegpu::TensorDescType distributedTensorDescTy =
270 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
271 // does not contain layout info.
272 Value newDescOp = xegpu::CreateNdDescOp::create(
273 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
274 descOp->getAttrs());
275
276 Value distributedVal = newWarpOp.getResult(operandIdx);
277 // Resolve the distributed type to the expected type.
278 newDescOp =
279 resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
280 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
281 return success();
282 }
283};
284
285/// Distribute a store_nd op at the end of enclosing
286/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
287/// through the warp op interface they would be propagated as returned values.
288/// Source vector is distributed based on lane layout. Appropriate cast ops are
289/// inserted if the distributed types does not match expected xegpu SIMT types.
290///
291/// Example:
292///
293/// ```
294/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
295/// gpu.warp_execute_on_lane_0(%laneid) -> () {
296/// ...
297/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
298/// !xegpu.tensor_desc<4x8xf32, #layout0>
299/// }
300/// ```
301/// To
302/// ```
303/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
304/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
305/// ...
306/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
307/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
308/// }
309/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
310/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
311/// #layout0>
312/// -> !xegpu.tensor_desc<4x8xf32>
313/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
314/// !xegpu.tensor_desc<4x8xf32>
315///
316/// ```
317struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
318 using gpu::WarpDistributionPattern::WarpDistributionPattern;
319 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
320 PatternRewriter &rewriter) const override {
321 gpu::YieldOp yield = warpOp.getTerminator();
322 Operation *lastNode = yield->getPrevNode();
323 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
324 if (!storeOp)
325 return failure();
326
327 SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
328 // Expecting offsets to be present.
329 if (offsets.empty())
330 return rewriter.notifyMatchFailure(storeOp,
331 "the store op must have offsets");
332 SmallVector<Value> offsetsAsValues =
333 vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
334 SmallVector<Type> offsetTypes = llvm::map_to_vector(
335 offsetsAsValues, [](Value v) { return v.getType(); });
336 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
337 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
338 if (!layout)
339 return rewriter.notifyMatchFailure(
340 storeOp, "the source tensor descriptor lacks layout attribute");
341
342 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
343 xegpu::getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
344 if (failed(distributedTypeByWarpOpOrFailure))
345 return rewriter.notifyMatchFailure(storeOp,
346 "Failed to distribute the type");
347 VectorType distributedTypeByWarpOp =
348 distributedTypeByWarpOpOrFailure.value();
349
350 SmallVector<size_t> newRetIndices;
351 SmallVector<Value> newYieldedValues = {storeOp.getValue(),
352 storeOp.getTensorDesc()};
353 SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
354 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
355 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
356 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
357 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
358 // Create a new store op outside the warp op with the distributed vector
359 // type. Tensor descriptor is not distributed.
360 rewriter.setInsertionPointAfter(newWarpOp);
361 SmallVector<Value> newStoreOperands;
362
363 // For the value operand, there can be a mismatch between the vector type
364 // distributed by the warp op and (xegpu-specific) distributed type
365 // supported by the store op. Type mismatch must be resolved using
366 // appropriate cast op.
367 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
368 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
369 if (failed(storeNdDistributedValueTyOrFailure))
370 return rewriter.notifyMatchFailure(
371 storeOp, "Failed to get distributed vector type for the store op");
372 newStoreOperands.push_back(resolveDistributedTy(
373 newWarpOp.getResult(newRetIndices[0]),
374 storeNdDistributedValueTyOrFailure.value(), rewriter));
375 // For the tensor descriptor operand, the layout attribute is dropped after
376 // distribution. Types needs to be resolved in this case also.
377 xegpu::TensorDescType distributedTensorDescTy =
378 storeOp.getTensorDescType().dropLayouts();
379 newStoreOperands.push_back(
380 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
381 distributedTensorDescTy, rewriter));
382 // Collect offsets.
383 for (size_t i = 2; i < newRetIndices.size(); ++i)
384 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
385
386 auto newStoreOp =
387 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
388 newStoreOperands, storeOp->getAttrs());
389 xegpu::removeLayoutAttrs(newStoreOp);
390 rewriter.eraseOp(storeOp);
391 return success();
392 }
393};
394
395/// Distribute a load_nd op feeding into vector.yield op for the enclosing
396/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
397/// The warp op will still contain the original op that will not be used by
398/// the yield op (and should be cleaned up later). The yield op will
399/// bypass the load's arguments. Only the loaded vector is distributed
400/// according to lane layout and, tensor descriptor types is not
401/// distributed. Appropriate cast ops are inserted if the distributed types does
402/// not match expected xegpu SIMT types.
403///
404/// Example:
405///
406/// ```
407/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
408/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
409/// (vector<4x1xf32>) {
410/// ...
411/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
412/// ->
413/// vector<4x8xf32>
414/// gpu.yield %ld
415/// }
416/// ```
417/// To
418/// ```
419/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
420/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
421/// ...
422/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
423/// vector<4x8xf32> gpu.yield %dead, %arg0
424/// }
425/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
426/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
427/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
428/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
429///
430/// ```
431struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
432 using gpu::WarpDistributionPattern::WarpDistributionPattern;
433 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
434 PatternRewriter &rewriter) const override {
435 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
436 if (!isa<xegpu::LoadNdOp>(op))
437 return false;
438 // Make sure the same load op is the last operation in the warp op body.
439 // This ensure that load op is not sinked earlier violating any barrier
440 // synchronizations.
441 gpu::YieldOp yield = warpOp.getTerminator();
442 return yield->getPrevNode() == op;
443 });
444
445 if (!operand)
446 return rewriter.notifyMatchFailure(
447 warpOp, "warp result is not a xegpu::LoadNd op");
449 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
450 auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
451 if (!uArch)
452 return rewriter.notifyMatchFailure(
453 loadOp, "xegpu::LoadNdOp require target attribute attached to "
454 "determine transpose "
455 "requirement");
456 // Chip information is required to decide if the layout requires transpose
457 // effect.
458 // Expecting offsets to be present.
459 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
460 if (offsets.empty())
461 return rewriter.notifyMatchFailure(loadOp,
462 "the load op must have offsets");
463 SmallVector<Value> offsetsAsValues =
464 vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
465 SmallVector<Type> offsetTypes = llvm::map_to_vector(
466 offsetsAsValues, [](Value v) { return v.getType(); });
468 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
469 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
470 if (!layout)
471 return rewriter.notifyMatchFailure(
472 loadOp, "the source tensor descriptor lacks layout attribute");
473
474 unsigned operandIdx = operand->getOperandNumber();
475 VectorType distributedTypeByWarpOp =
476 cast<VectorType>(warpOp.getResult(operandIdx).getType());
478 SmallVector<size_t> newRetIndices;
479 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
480 SmallVector<Type> newYieldedTypes = {tensorDescTy};
481 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
482 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
483 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
484 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
485
486 // Create a new load op outside the warp op with the distributed vector
487 // type.
488 rewriter.setInsertionPointAfter(newWarpOp);
489 FailureOr<VectorType> loadNdDistValueTyOrFailure =
490 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
491 if (failed(loadNdDistValueTyOrFailure))
492 return rewriter.notifyMatchFailure(
493 loadOp, "Failed to get distributed vector type for the load op");
494 xegpu::TensorDescType distributedTensorDescTy =
495 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
496 // descriptor type does not
497 // contain layout info.
498 SmallVector<Value> newLoadOperands{
499 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
500 distributedTensorDescTy, rewriter)};
501 // Collect offsets.
502 for (size_t i = 1; i < newRetIndices.size(); ++i)
503 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
504 auto newLoadOp = xegpu::LoadNdOp::create(
505 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
506 newLoadOperands, loadOp->getAttrs());
507 xegpu::removeLayoutAttrs(newLoadOp);
508 // Set the packed attribute if the layout requires it.
509 newLoadOp.setPacked(xegpu::requirePacked(layout));
510 // Set the transpose attribute if the layout requires it.
511 if (xegpu::requireTranspose(layout, uArch))
512 newLoadOp.setTranspose(
513 DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
514 Value distributedVal = newWarpOp.getResult(operandIdx);
515 // There can be a conflict between the vector type distributed by the
516 // warp op and (xegpu-specific) distributed type supported by the load
517 // op. Resolve these mismatches by inserting a cast.
518 Value tyResolvedVal = resolveDistributedTy(
519 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
520 rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
521 return success();
522 }
523};
524
525/// Distribute a dpas op feeding into vector.yield op for the enclosing
526/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
527/// The warp op will still contain the original op that will not be used by
528/// the yield op (and should be cleaned up later). The yield op will
529/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
530/// distributed types does not match expected xegpu SIMT types.
531/// Example:
532/// ```
533/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
534/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
535/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
536/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
537/// (vector<8x1xf32>) {
538/// ...
539/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
540/// vector<8x16xf32>
541/// gpu.yield %dpas
542/// }
543/// ```
544/// To
545/// ```
546/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
547/// vector<8x1xf16>, vector<16x1xf16>) {
548/// ...
549/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
550/// -> vector<8x16xf32>
551/// gpu.yield %dead, %arg0, %arg1
552/// }
553/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
554/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
555/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
556/// vector<8xf32>
557/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
558/// ```
559struct DpasDistribution final : public gpu::WarpDistributionPattern {
560 using gpu::WarpDistributionPattern::WarpDistributionPattern;
561 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
562 PatternRewriter &rewriter) const override {
563 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
564 if (!operand)
565 return rewriter.notifyMatchFailure(warpOp,
566 "warp result is not a xegpu::Dpas op");
567
568 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
569 unsigned operandIdx = operand->getOperandNumber();
570
571 xegpu::LayoutAttr layoutA =
572 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
573 xegpu::LayoutAttr layoutB =
574 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
575 xegpu::LayoutAttr layoutOut =
576 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
577
578 if (!layoutA || !layoutB || !layoutOut)
579 return rewriter.notifyMatchFailure(
580 dpasOp,
581 "the xegpu::Dpas op lacks layout attribute for A, B or output");
582
583 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
584 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
585 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
586 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
587 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
588 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
589
590 if (failed(distLhsTypeByWarpOpOrFailure) ||
591 failed(distRhsTypeByWarpOpOrFailure) ||
592 failed(distResultTypeByWarpOpOrFailure))
593 return rewriter.notifyMatchFailure(
594 dpasOp,
595 "Failed to distribute the A, B or output types in xegpu::Dpas op");
596
597 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
598 dpasOp.getRhs()};
599 llvm::SmallVector<Type, 3> newYieldTypes{
600 distLhsTypeByWarpOpOrFailure.value(),
601 distRhsTypeByWarpOpOrFailure.value()};
602 // Dpas acc operand is optional.
603 if (dpasOp.getAcc()) {
604 newYieldValues.push_back(dpasOp.getAcc());
605 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
606 }
607 // Create a new warp op without the dpas.
608 SmallVector<size_t> newRetIndices;
609 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
610 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
611
612 FailureOr<VectorType> expectedDistLhsTyOrFailure =
613 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
614 FailureOr<VectorType> expectedDistRhsTyOrFailure =
615 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
616 FailureOr<VectorType> expectedDistResultTyOrFailure =
617 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
618
619 if (failed(expectedDistLhsTyOrFailure) ||
620 failed(expectedDistRhsTyOrFailure) ||
621 failed(expectedDistResultTyOrFailure))
622 return rewriter.notifyMatchFailure(
623 dpasOp,
624 "Failed to get distributed vector type for the dpas operands.");
625 // Create a new dpas op outside the warp op.
626 rewriter.setInsertionPointAfter(newWarpOp);
627 SmallVector<Value> newDpasOperands;
628 SmallVector<VectorType> newDpasOperandExpectedTypes;
629
630 // Resolve the distributed types with the original types.
631 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
632 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
633 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
634 if (dpasOp.getAcc())
635 newDpasOperandExpectedTypes.push_back(distributedResultTy);
636
637 for (unsigned i = 0; i < newRetIndices.size(); i++) {
638 newDpasOperands.push_back(
639 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
640 newDpasOperandExpectedTypes[i], rewriter));
641 }
642 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
643 distributedResultTy, newDpasOperands,
644 dpasOp->getAttrs());
645 xegpu::removeLayoutAttrs(newDpasOp);
646 Value distributedVal = newWarpOp.getResult(operandIdx);
647 // Resolve the output type.
648 Value typeResolved =
649 resolveDistributedTy(newDpasOp.getResult(),
650 distResultTypeByWarpOpOrFailure.value(), rewriter);
651 rewriter.replaceAllUsesWith(distributedVal, typeResolved);
652 return success();
653 }
654};
655
656/// Distribute a prefetch_nd op at the end of enclosing
657/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
658/// through the warp op interface they would be propagated as returned values.
659/// Tensor descriptor shape is not distributed because it is a uniform value
660/// across all work items within the subgroup. Appropriate cast ops are inserted
661/// if the distributed types does not match expected xegpu SIMT types.
662///
663/// Example:
664///
665/// ```
666/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
667/// gpu.warp_execute_on_lane_0(%laneid) -> () {
668/// ...
669/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
670/// }
671/// ```
672/// To
673/// ```
674/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
675/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
676/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
677/// index
678/// }
679/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
680/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
681/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
682///
683/// ```
684struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
685 using gpu::WarpDistributionPattern::WarpDistributionPattern;
686 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
687 PatternRewriter &rewriter) const override {
688 gpu::YieldOp yield = warpOp.getTerminator();
689 Operation *lastNode = yield->getPrevNode();
690 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
691 if (!prefetchOp)
692 return failure();
693
694 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
695 // PrefetchNdOp must have offsets.
696 if (offsets.empty())
697 return rewriter.notifyMatchFailure(prefetchOp,
698 "the prefetch op must have offsets");
699 SmallVector<Value> offsetsAsValues =
700 vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
701 SmallVector<Type> offsetTypes = llvm::map_to_vector(
702 offsetsAsValues, [](Value v) { return v.getType(); });
703
704 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
705 if (!layout)
706 return rewriter.notifyMatchFailure(
707 prefetchOp, "the source tensor descriptor lacks layout attribute");
708
709 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
710 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
711 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
712 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
713 SmallVector<size_t> newRetIndices;
714 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
715 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
716 // Create a new prefetch op outside the warp op with updated tensor
717 // descriptor type. Source tensor descriptor require type resolution.
718 xegpu::TensorDescType newTensorDescTy =
719 prefetchOp.getTensorDescType().dropLayouts();
720 rewriter.setInsertionPointAfter(newWarpOp);
721 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
722 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
723 // Collect offsets.
724 for (size_t i = 1; i < newRetIndices.size(); ++i)
725 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
726 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
727 rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
728 prefetchOp->getAttrs());
729 xegpu::removeLayoutAttrs(newPrefetchOp);
730 rewriter.eraseOp(prefetchOp);
731 return success();
732 }
733};
734
735/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
736/// region. This will simply move the barrier op outside of the warp op.
737struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
738 using gpu::WarpDistributionPattern::WarpDistributionPattern;
739 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
740 PatternRewriter &rewriter) const override {
741 gpu::YieldOp yield = warpOp.getTerminator();
742 Operation *lastNode = yield->getPrevNode();
743 // The last node must be a gpu::BarrierOp.
744 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
745 if (!barrierOp)
746 return failure();
747 // Move the barrier op outside of the warp op.
748 rewriter.setInsertionPointAfter(warpOp);
749 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
750 barrierOp->getResultTypes(),
751 barrierOp->getOperands(), barrierOp->getAttrs());
752 rewriter.eraseOp(barrierOp);
753 return success();
754 }
755};
756
757/// Distribute a scattered store op. The offsets argument is required.
758/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
759/// The layouts are fixed and implicit: one offset/mask per lane.
760/// The pass changes the offset/mask vector shapes to a
761/// single-element vector, **it is assumed that their producer will also be
762/// distributed**. The payload vector also has a fixed distribution:
763/// no chunk size -> vector of one element.
764/// chunk size -> vector of the innermost dimension of the SG-payload.
765/// Example 1 (no chunk size):
766/// %mask = producer_op : vector<16xi1>
767/// %offset = producer_op : vector<16xindex>
768/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
769/// memref<256xf16>, vector<16xindex>, vector<16xi1>
770/// To
771/// %mask = producer_op : vector<1xi1>
772/// %offset = producer_op : vector<1xindex>
773/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
774/// memref<256xf16>, vector<1xindex>, vector<1xi1>
775/// Example 2 (chunk size, same mask and offsets):
776/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
777/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
778/// To
779/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
780/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
781///
782/// Note that the store distribution pattern also handles leading unit
783/// dimensions in the payload, mask and offsets vectors. In this case the store
784/// distribution will only change the dimensions corresponding to the SG
785/// distribution and keep the leading unit dimensions unchanged.
786/// For example, a store with payload vector<1x16xf16> with lane layout [1, 16 ]
787/// will be distributed as vector<1x1xf16>. Shapecast ops are inserted for the
788/// offset/mask/payload when necessary so that the distributed store is workign
789/// on 1D shape vector to match the HW capability.
790struct StoreDistribution final : public gpu::WarpDistributionPattern {
791 using gpu::WarpDistributionPattern::WarpDistributionPattern;
792 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
793 PatternRewriter &rewriter) const override {
794 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
795 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
796 if (!storeScatterOp)
797 return failure();
798 auto offsets = storeScatterOp.getOffsets();
799 if (!offsets || !isa<VectorType>(offsets.getType()))
800 return rewriter.notifyMatchFailure(
801 storeScatterOp, "Store op must have a vector of offsets argument");
802 VectorType offsetsTy = cast<VectorType>(offsets.getType());
803 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
804 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
805
806 // Add handling for leading unit dimensions support
807 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
808 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
809
810 // Check that all leading dimensions are unit dimensions
811 for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
812 if (storeVecTy.getShape()[i] != 1) {
813 return rewriter.notifyMatchFailure(
814 storeScatterOp, "Only unit dimensions allowed for the leading "
815 "dimensions of the store vector!");
816 }
817 }
818
819 auto layoutPayload =
820 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
821 auto layoutOffsets =
822 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
823 auto layoutMask =
824 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
825
826 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
827 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
828 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
829 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
830 FailureOr<VectorType> distMaskByWarpOpOrFailure =
831 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
832 if (failed(distStoreVecByWarpOpOrFailure) ||
833 failed(distOffsetsByWarpOpOrFailure) ||
834 failed(distMaskByWarpOpOrFailure)) {
835 return rewriter.notifyMatchFailure(
836 storeScatterOp,
837 "Some vector operands have no layouts, using defaults instead.");
838 }
839
840 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
841 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
842 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
843
844 SmallVector<size_t> newRetIndices;
845 SmallVector<Value> operands = storeScatterOp->getOperands();
846 SmallVector<Type> operandTypesToYield = {
847 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
848
849 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
850 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
851
852 rewriter.setInsertionPointAfter(newWarpOp);
853
854 // Distributed store payload type is always 1D without leading unit dims
855 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
856 distPayloadTy.getElementType());
857
858 VectorType distOffsetsTy1D = VectorType::get(
859 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
860 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
861 distMaskTy.getElementType());
862
863 // Resolve distributed types to 1D for SIMT execution
864 Value distPayloadVal = resolveDistributedTy(
865 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
866 Value distOffsetVal = resolveDistributedTy(
867 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
868 Value distMaskVal = resolveDistributedTy(
869 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
870
871 SmallVector<Value> newStoreScatterOpOperands = {
872 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
873 distMaskVal};
874
875 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
876 rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
877 storeScatterOp->getAttrs());
879 rewriter.eraseOp(storeScatterOp);
880 return success();
881 }
882};
883
884static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
885 PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
886 Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
887 SmallVector<Value> newCoods;
888 auto maybeCoords =
889 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
890 if (failed(maybeCoords))
891 return {};
892 assert(maybeCoords.value().size() == 1 &&
893 "Expected one set of distributed offsets");
895 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
896 getAsOpFoldResult(origOffsets));
897 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
898 return newCoods;
899}
900
901/// Pattern for distributing xegpu::LoadMatrixOp.
902struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
903 using gpu::WarpDistributionPattern::WarpDistributionPattern;
904 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
905 PatternRewriter &rewriter) const override {
906 gpu::YieldOp yield = warpOp.getTerminator();
907 Operation *lastNode = yield->getPrevNode();
908 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
909 if (!matrixOp)
910 return failure();
911
912 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
913 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
914 });
915 if (!producedByLastLoad)
916 return rewriter.notifyMatchFailure(
917 warpOp, "The last op is not xegpu::LoadMatrixOp");
918 const int operandIdx = producedByLastLoad->getOperandNumber();
919
920 VectorType sgPayloadTy =
921 dyn_cast<VectorType>(matrixOp.getResult().getType());
922 VectorType warpResultTy =
923 cast<VectorType>(warpOp.getResult(operandIdx).getType());
924 if (!sgPayloadTy)
925 return rewriter.notifyMatchFailure(
926 matrixOp, "the matrix op payload must be a vector type");
927
928 auto loc = matrixOp.getLoc();
929 auto offsets = matrixOp.getMixedOffsets();
930 if (offsets.empty())
931 return rewriter.notifyMatchFailure(matrixOp,
932 "the load op must have offsets");
933 SmallVector<Value> offsetsAsValues =
934 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
935
936 auto layout = matrixOp.getLayoutAttr();
937 if (!layout)
938 return rewriter.notifyMatchFailure(
939 matrixOp, "the matrix operation lacks layout attribute");
940
941 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
942 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
943 if (failed(distPayloadByWarpOpOrFailure))
944 return rewriter.notifyMatchFailure(
945 matrixOp, "Failed to distribute matrix op payload based on layout.");
946
947 SmallVector<Value> operands = {matrixOp.getMemDesc()};
948 const unsigned offsetsStartIdx = operands.size();
949 operands.append(offsetsAsValues);
950
951 SmallVector<Type> operandTypes =
952 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
953
954 SmallVector<size_t> newRetIndices;
955 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
956 rewriter, warpOp, operands, operandTypes, newRetIndices);
957 SmallVector<Value> newOperands = llvm::map_to_vector(
958 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
959
960 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
961 ShapedType::kDynamic);
962 DenseI64ArrayAttr newConstOffsetsAttr =
963 rewriter.getDenseI64ArrayAttr(newConstOffsets);
964 ValueRange currentOffsets =
965 ValueRange(newOperands).drop_front(offsetsStartIdx);
966
967 SmallVector<Value> newCoords = currentOffsets;
968 rewriter.setInsertionPointAfter(newWarpOp);
969
970 if (!matrixOp.getSubgroupBlockIoAttr()) {
971 newCoords = computeDistributedCoordinatesForMatrixOp(
972 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
973 currentOffsets);
974 }
975 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
976 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
977 newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
978 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
979 // Resolve the output type and replace all uses.
980 rewriter.replaceAllUsesWith(
981 newWarpOp.getResult(operandIdx),
982 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
983 return success();
984 }
985};
986
987/// Pattern for distributing xegpu::StoreMatrixOp.
988struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
989 using gpu::WarpDistributionPattern::WarpDistributionPattern;
990 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
991 PatternRewriter &rewriter) const override {
992 gpu::YieldOp yield = warpOp.getTerminator();
993 Operation *lastNode = yield->getPrevNode();
994 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
995 if (!matrixOp)
996 return failure();
997
998 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
999 if (!sgPayloadTy)
1000 return rewriter.notifyMatchFailure(
1001 matrixOp, "the matrix op payload must be a vector type");
1002
1003 auto loc = matrixOp.getLoc();
1004 auto offsets = matrixOp.getMixedOffsets();
1005 if (offsets.empty())
1006 return rewriter.notifyMatchFailure(matrixOp,
1007 "the store op must have offsets");
1008 SmallVector<Value> offsetsAsValues =
1009 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1010
1011 auto layout = matrixOp.getLayoutAttr();
1012 if (!layout)
1013 return rewriter.notifyMatchFailure(
1014 matrixOp, "the matrix operation lacks layout attribute");
1015
1016 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1017 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1018 if (failed(distPayloadByWarpOpOrFailure))
1019 return rewriter.notifyMatchFailure(
1020 matrixOp, "Failed to distribute matrix op payload based on layout.");
1021
1022 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1023 const unsigned offsetsStartIdx = operands.size();
1024 operands.append(offsetsAsValues);
1025
1026 SmallVector<Type> operandTypes =
1027 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
1028 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1029
1030 SmallVector<size_t> newRetIndices;
1031 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1032 rewriter, warpOp, operands, operandTypes, newRetIndices);
1033 SmallVector<Value> newOperands = llvm::map_to_vector(
1034 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1035
1036 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1037 ShapedType::kDynamic);
1038 DenseI64ArrayAttr newConstOffsetsAttr =
1039 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1040 ValueRange currentOffsets =
1041 ValueRange(newOperands).drop_front(offsetsStartIdx);
1042
1043 SmallVector<Value> newCoords = currentOffsets;
1044 rewriter.setInsertionPointAfter(newWarpOp);
1045
1046 if (!matrixOp.getSubgroupBlockIoAttr()) {
1047 newCoords = computeDistributedCoordinatesForMatrixOp(
1048 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1049 currentOffsets);
1050 }
1051
1052 xegpu::StoreMatrixOp::create(
1053 rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1054 ValueRange(newCoords), newConstOffsetsAttr,
1055 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1056 rewriter.eraseOp(matrixOp);
1057 return success();
1058 }
1059};
1060
1061/// Distribute a scattered load op. The logic and requirements are the same as
1062/// for the scattered store distribution. The warpOp's payload vector is
1063/// expected to be distributed by the load's result consumer.
1064/// Example 1 (no chunk size):
1065/// %mask = producer_op : vector<16xi1>
1066/// %offset = producer_op : vector<16xindex>
1067/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1068/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
1069/// To
1070/// %mask = producer_op : vector<1xi1>
1071/// %offset = producer_op : vector<1xindex>
1072/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1073/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
1074/// Example 2 (chunk size, same mask and offsets):
1075/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1076/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
1077/// To
1078/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1079/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
1080///
1081/// Note that the load distribution pattern also handles leading unit dimensions
1082/// in the payload, mask, and offsets vector.The load distribution will only
1083/// change the dimensions corresponding to the SG distribution and keep the
1084/// leading unit dimensions unchanged. For example, a load with result type
1085/// vector<1x16xf16> with lane layout [1, 16 ] will be distributed
1086/// as result type vector<1x1xf16>. Shapecast ops are inserted for the
1087/// offset/mask/payload when necessary so that the distributed load is workign
1088/// on 1D shape vector to match the HW capability.
1089struct LoadDistribution final : public gpu::WarpDistributionPattern {
1090 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1091 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1092 PatternRewriter &rewriter) const override {
1093 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1094 // Check if the yield operand that was produced by the *last* scattered
1095 // load op to avoid sinking it before barriers (maintain memory order).
1096 return isa<xegpu::LoadGatherOp>(op) &&
1097 warpOp.getTerminator()->getPrevNode() == op;
1098 });
1099 if (!producedByLastLoad)
1100 return rewriter.notifyMatchFailure(
1101 warpOp, "The last op is not xegpu::LoadGatherOp");
1102
1103 auto loadGatherOp =
1104 producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
1105 auto offsets = loadGatherOp.getOffsets();
1106 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1107 !isa<VectorType>(loadGatherOp.getMask().getType()))
1108 return rewriter.notifyMatchFailure(
1109 loadGatherOp,
1110 "Load op must have a vector arguments for offsets and mask");
1111 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1112 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1113 VectorType resultVecTy =
1114 cast<VectorType>(loadGatherOp.getResult().getType());
1115 // add handling leading unit dimensions support
1116 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1117 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1118 for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1119 if (resultVecTy.getShape()[i] != 1) {
1120 return rewriter.notifyMatchFailure(
1121 loadGatherOp, "Only unit dimensions allowed for the leading "
1122 "dimensions of the load vector!");
1123 }
1124 }
1125
1126 auto layoutOffsets =
1127 xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
1128 auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
1129
1130 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1131 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1132 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1133 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1134 if (failed(distOffsetsByWarpOpOrFailure) ||
1135 failed(distMaskByWarpOpOrFailure)) {
1136 return rewriter.notifyMatchFailure(
1137 loadGatherOp,
1138 "Some vector operands have no layouts, using defaults instead.");
1139 }
1140
1141 SmallVector<size_t> newRetIndices;
1142 SmallVector<Value> operands = loadGatherOp->getOperands();
1143
1144 const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1145 VectorType distResultTy =
1146 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1147 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1148 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1149
1150 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1151 distOffsetsTy, distMaskTy};
1152
1153 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1154 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1155
1156 rewriter.setInsertionPointAfter(newWarpOp);
1157
1158 // Distributed load op will always be 1D.
1159 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1160 distResultTy.getElementType());
1161
1162 VectorType distOffsetsTy1D =
1163 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1164 distOffsetsByWarpOpOrFailure.value().getElementType());
1165 VectorType distMaskTy1D =
1166 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1167 distMaskByWarpOpOrFailure.value().getElementType());
1168
1169 Value distOffsetVal = resolveDistributedTy(
1170 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1171 Value distmaskVal = resolveDistributedTy(
1172 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1173
1174 SmallVector<Value> newLoadGatherOperands = {
1175 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1176
1177 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1178 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1179 loadGatherOp->getAttrs());
1181 Value distributedVal = newWarpOp.getResult(operandIdx);
1182 // Resolve the output type and replace all uses.
1183 rewriter.replaceAllUsesWith(
1184 distributedVal,
1185 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1186 return success();
1187 }
1188};
1189
1190// Sink SG-uniform ops. An op is uniform if none
1191// of its operands/results has a distribution layout attribute.
1192// Non-uniform vectors are handled by dedicated patterns.
1193// This pattern must have a higher priority than vector dialect distribution
1194// patterns, because a distributable shape may be logically intended as
1195// uniform (i.e., no layout), so we want to omit its distribution.
1196struct SinkUniformOps final : public gpu::WarpDistributionPattern {
1197 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1198 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1199 PatternRewriter &rewriter) const override {
1200 // Take the last op
1201 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1202 // Any ops with nested regions must be handled carefully in dedicated
1203 // patterns.
1204 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
1205 return failure();
1206 int operandIdx = -1;
1207 if (warpRegionPreYieldOp->getNumResults()) {
1208 OpOperand *operand = getWarpResult(
1209 warpOp, [&](Operation *op) { return warpRegionPreYieldOp == op; });
1210 if (!operand)
1211 return failure();
1212 operandIdx = operand->getOperandNumber();
1213 if (warpRegionPreYieldOp->getResult(0).getType() !=
1214 warpOp.getResult(operandIdx).getType())
1215 return rewriter.notifyMatchFailure(warpOp,
1216 "The op result is not uniform.");
1217 }
1218
1219 // The op must have no layout-based operands or results.
1220 bool uniformValuesOnly =
1221 llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
1222 return !xegpu::getDistributeLayoutAttr(v);
1223 });
1224 uniformValuesOnly &=
1225 llvm::all_of(warpRegionPreYieldOp->getOpOperands(), [](OpOperand &opr) {
1226 return !xegpu::getDistributeLayoutAttr(opr);
1227 });
1228 if (!uniformValuesOnly)
1229 return rewriter.notifyMatchFailure(warpOp,
1230 "Some values are not uniform.");
1231 SmallVector<size_t> newRetIndices;
1232 SmallVector<Value> operands =
1233 llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
1234 SmallVector<Type> operandTypes =
1235 llvm::to_vector_of<Type>(warpRegionPreYieldOp->getOperandTypes());
1236 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1237 rewriter, warpOp, operands, operandTypes, newRetIndices);
1238
1239 rewriter.setInsertionPointAfter(newWarpOp);
1240 IRMapping operandMapper;
1241 for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1242 operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),
1243 newWarpOp->getResult(newOperandIdx));
1244 Operation *clonedOp = rewriter.clone(*warpRegionPreYieldOp, operandMapper);
1245 if (!clonedOp->getNumResults())
1246 rewriter.eraseOp(warpRegionPreYieldOp);
1247 else {
1248 assert(operandIdx != -1 && "Expected a warp result for the operation");
1249 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx),
1250 clonedOp->getResult(0));
1251 }
1252 return success();
1253 }
1254};
1255
1256/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1257/// VectorReductionOps. We also insert layouts for the newly created ops.
1258static Value lowerToVectorReductions(TypedValue<VectorType> src,
1260 vector::CombiningKind kind,
1261 int64_t reductionDim, Location loc,
1262 PatternRewriter &rewriter) {
1263 // Expecting a 2D source vector.
1264 assert(src.getType().getRank() == 2 && "expected a 2D source vector");
1265 VectorType sourceType = src.getType();
1266 int64_t sourceH = sourceType.getShape()[0];
1267 int64_t sourceW = sourceType.getShape()[1];
1268 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1269 // Create a constant vector to hold the result of the reduction.
1270 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
1271 Value reductionResult = arith::ConstantOp::create(
1272 rewriter, loc, acc.getType(),
1273 DenseElementsAttr::get(acc.getType(), zeroAttr));
1274 // Reduction result should have the same layout as the accumulator.
1275 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult),
1276 xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc)));
1277 // For each slice of the source, extract the slice vector, do a reduction
1278 // and, insert the reduced value back to the result vector.
1279 for (int i = 0; i < nSlices; ++i) {
1280 SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1281 if (reductionDim == 1) {
1282 sliceOffsets = {i, 0};
1283 sliceSizes = {1, sourceW};
1284 } else {
1285 sliceOffsets = {0, i};
1286 sliceSizes = {sourceH, 1};
1287 }
1288 vector::ExtractStridedSliceOp extractOp =
1289 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1290 sliceSizes, {1, 1});
1291
1292 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1293
1294 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1295 rewriter, loc,
1296 VectorType::get({nSliceElements}, sourceType.getElementType()),
1297 extractOp.getResult());
1298
1299 // Shape cast is currently handled in xegpu side. So layouts must be
1300 // retained during lowering. Shape cast output has the same layout as the
1301 // accumulator. Shape cast source has the same layout as the original
1302 // reduction source.
1303 // TODO: other ops generated here may also need layout attributes.
1304 auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
1305 auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
1306
1307 xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
1308 xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
1309 // Extract and reduction results in scalars, so no result layout is needed.
1310 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1311 Value reduction = vector::ReductionOp::create(
1312 rewriter, loc, kind, slice.getResult(), accExtract);
1313 reductionResult =
1314 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1315 }
1316 return reductionResult;
1317}
1318
1319/// This patterns distribute the `vector.multi_reduction` operation across
1320/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1321/// layouts for the source and accumulator vectors,
1322/// * If the reduction dimension is distributed across lanes, the reduction is
1323/// non-lane-local and the reduction is done using warp shuffles. Here we
1324/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1325/// the warp op body.
1326/// * If the reduction dimension is not distributed across lanes, the reduction
1327/// is lane-local. In this case, we yield the source and accumulator vectors
1328/// from the warp op and perform the lane-local reduction outside the warp op
1329/// using a sequence of ReductionOps.
1330/// Example 1 (Reduction is lane-local):
1331/// ```
1332/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1333/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1334/// %acc = "some_def"() : () -> (vector<32xf32>)
1335/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1336/// vector<32xf32> gpu.yield %1 : vector<32xf32>
1337/// }
1338/// ```
1339/// is lowered to:
1340/// ```
1341/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1342/// vector<1xf32>) {
1343/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1344/// %acc = "some_def"() : () -> (vector<32xf32>)
1345/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1346/// }
1347/// %c = arith.constant dense<0.0> : vector<1xf32>
1348/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1349/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1350/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1351/// ```
1352/// Example 2 (Reduction is non-lane-local):
1353/// ```
1354/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1355/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1356/// %acc = "some_def"() : () -> (vector<2xf32>)
1357/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1358/// vector<2xf32>
1359/// gpu.yield %1 : vector<2xf32>
1360/// }
1361/// ```
1362/// is lowered to:
1363/// ```
1364/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1365/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1366/// %acc = "some_def"() : () -> (vector<2xf32>)
1367/// %1 = arith.constant dense<0.0> : vector<2xf32>
1368/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1369/// %3 = ("warp.reduction %2") : f32
1370/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1371/// ... repeat for row 1
1372/// gpu.yield %1 : vector<2xf32>
1373/// }
1374struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1375 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1376 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1377 PatternRewriter &rewriter) const override {
1378 OpOperand *yieldOperand =
1379 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1380 if (!yieldOperand)
1381 return failure();
1382 auto reductionOp =
1383 cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1384 unsigned operandIdx = yieldOperand->getOperandNumber();
1385 VectorType sourceType = reductionOp.getSourceVectorType();
1386 // Only 2D vectors are supported.
1387 if (sourceType.getRank() != 2)
1388 return rewriter.notifyMatchFailure(warpOp,
1389 "Only 2D reductions are supported.");
1390 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1391 // Only 1 reduction dimension supported. This also ensures that the result
1392 // is vector type.
1393 if (reductionDims.size() != 1)
1394 return rewriter.notifyMatchFailure(
1395 warpOp, "Only 1 reduction dimension is supported.");
1396 int64_t reductionDim = reductionDims[0];
1397 VectorType distributedResultType =
1398 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1399 VectorType resultType = cast<VectorType>(reductionOp.getType());
1400 xegpu::DistributeLayoutAttr sourceLayout =
1401 xegpu::getTemporaryLayout(reductionOp->getOpOperand(0));
1402
1403 FailureOr<VectorType> sourceDistTypeOrFailure =
1404 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1405 if (failed(sourceDistTypeOrFailure))
1406 return rewriter.notifyMatchFailure(
1407 warpOp, "Failed to distribute the source vector type.");
1408 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1409 // Only single dimension distribution is supported.
1410 bool dim0Distributed =
1411 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1412 bool dim1Distributed =
1413 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1414 if (dim0Distributed && dim1Distributed)
1415 return rewriter.notifyMatchFailure(
1416 warpOp, "Expecting source to be distributed in a single dimension.");
1417 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1418 if (sourceDistDim == -1)
1419 return rewriter.notifyMatchFailure(
1420 warpOp, "Expecting a distributed source vector.");
1421 bool resultDistributed =
1422 distributedResultType.getNumElements() < resultType.getNumElements();
1423 // If the lane owns all the data required for reduction (i.e. reduction is
1424 // fully parallel accross lanes), then each lane owns part of the result
1425 // (i.e. result is distributed). If the reduction require cross-lane
1426 // shuffling, then the result is shared among all lanes (broadcasted).
1427 // Therefore we expect following cases:
1428 //
1429 // | Source vector | Reduction dim | Result vector |
1430 // |----------------------|----------------|----------------|
1431 // | dim-0 distributed | 0 | broadcasted |
1432 // | dim-0 distributed | 1 | distributed |
1433 // | dim-1 distributed | 0 | distributed |
1434 // | dim-1 distributed | 1 | broadcasted |
1435
1436 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1437 (sourceDistDim == 1 && reductionDim == 0);
1438 if (isReductionLaneLocal && !resultDistributed)
1439 return rewriter.notifyMatchFailure(
1440 warpOp, "Expecting a distributed result for lane-local reduction.");
1441
1442 if (!isReductionLaneLocal && resultDistributed)
1443 return rewriter.notifyMatchFailure(
1444 warpOp,
1445 "Expecting a broadcasted result for non-lane-local reduction.");
1446
1447 // Handle lane-local reduction case. In this case we fully distribute the
1448 // reduction result.
1449 if (isReductionLaneLocal) {
1450 // Yield the source and acc vectors from the WarpOp.
1451 SmallVector<size_t> newRetIndices;
1452 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1453 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1454 {sourceDistType, distributedResultType}, newRetIndices);
1455 rewriter.setInsertionPointAfter(newWarpOp);
1456 Value result = lowerToVectorReductions(
1457 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1458 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1459 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1460 // Replace the warp op result with the final result.
1461 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1462 return success();
1463 }
1464 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1465 // of multiple ReductionOps. Actual distribution is done by the
1466 // WarpOpReduction pattern.
1467 rewriter.setInsertionPointAfter(reductionOp);
1468 Value result = lowerToVectorReductions(
1469 cast<TypedValue<VectorType>>(reductionOp.getSource()),
1470 cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1471 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1472 // Replace the warp op result with the final result.
1473 rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1474 return success();
1475 }
1476};
1477
1478/// This pattern distributes the `vector.broadcast` operation across lanes in a
1479/// warp. The pattern supports three use cases:
1480///
1481/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1482/// vector
1483/// must have a slice layout of the result. If the distributed source and
1484/// target vector types are identical, this lowers to a no-op; otherwise, it
1485/// remains a broadcast but operates on distributed vectors.
1486///
1487/// 2) Broadcast a same-rank vector with identical layouts for source and
1488/// target:
1489/// The source vector must have unit dimensions, and lane_data must be unit
1490/// size for those unit dims. This always lowers to a no-op.
1491///
1492/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
1493/// scalar to distributed result type.
1494///
1495/// Example 1 (lowering to a broadcast with distributed types):
1496/// ```
1497/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1498/// %0 = "some_def"() {layout_result_0 =
1499/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1500/// dims = [0]> } : () -> (vector<32xf32>)
1501/// %2 = vector.broadcast %0 {layout_result_0 =
1502/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
1503/// : vector<32xf32> to vector<8x32xf32>
1504/// gpu.yield %1 : vector<8x32xf32>
1505/// }
1506/// ```
1507/// is lowered to:
1508/// ```
1509/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1510/// %0 = "some_def"() {layout_result_0 =
1511/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1512/// dims = [0]> } : () -> (vector<32xf32>)
1513/// gpu.yield %0 : vector<32xf32>
1514/// }
1515/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
1516///
1517/// Example 2 (no-op):
1518/// ```
1519/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
1520/// %0 = "some_def"() {layout_result_0 =
1521/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1522/// dims = [1]> } : () -> (vector<8xf32>)
1523/// %1 = vector.shape_cast %0
1524/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1525/// 1]>}: vector<8xf32> to vector<8x1xf32>
1526/// %2 = vector.broadcast %1
1527/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1528/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
1529/// gpu.yield %1 : vector<8x32xf32>
1530/// }
1531/// ```
1532/// is lowered to:
1533/// ```
1534/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1535/// %0 = "some_def"() {layout_result_0 =
1536/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1537/// dims = [1]> } : () -> (vector<8xf32>)
1538/// %1 = vector.shape_cast %0
1539/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1540/// 1]>}: vector<8xf32> to vector<8x1xf32>
1541/// gpu.yield %1 : vector<8x1xf32>
1542/// }
1543/// // The broadcast is implicit through layout transformation (no-op)
1544/// "some_use"(%r#0)
1545/// ```
1546struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
1547 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1548 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1549 PatternRewriter &rewriter) const override {
1550 OpOperand *yieldOperand =
1551 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1552 if (!yieldOperand)
1553 return failure();
1554 auto broadcastOp =
1555 cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
1556 unsigned operandIdx = yieldOperand->getOperandNumber();
1557
1558 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1559 VectorType destType =
1560 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1561
1562 xegpu::DistributeLayoutAttr sourceLayout =
1563 xegpu::getTemporaryLayout(broadcastOp->getOpOperand(0));
1564 xegpu::DistributeLayoutAttr resultLayout =
1565 xegpu::getTemporaryLayout(dyn_cast<OpResult>(broadcastOp.getResult()));
1566
1567 FailureOr<VectorType> sourceDistType;
1568 Type sourceElemOrDistType;
1569 if (sourceType) {
1570
1571 // Case 1 and 2: source is a vector type.
1572 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1573 if (rankDiff > 0) {
1574 // Case 1: source is lower-rank than result.
1575 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1576 if (!isSliceOf)
1577 return rewriter.notifyMatchFailure(
1578 warpOp,
1579 "Broadcast input layout must be a slice of result layout.");
1580 }
1581 // case 2: source and result have same rank
1582 if (rankDiff == 0) {
1583 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1584 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1585 broadcastUnitDimsSet.end());
1586 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1587 if (!isEqualTo)
1588 return rewriter.notifyMatchFailure(
1589 warpOp, "For same-rank broadcast, source must be identical to "
1590 "adjusted result layouts with unit dims.");
1591 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1592 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1593 }
1594
1595 sourceDistType =
1596 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1597 if (failed(sourceDistType)) {
1598 return rewriter.notifyMatchFailure(
1599 warpOp, "Failed to distribute the source vector type.");
1600 }
1601 sourceElemOrDistType = sourceDistType.value();
1602
1603 } else {
1604 // Case 3: source is a scalar type.
1605 if (sourceLayout) {
1606 return rewriter.notifyMatchFailure(
1607 warpOp, "Broadcast from scalar must not have a layout attribute.");
1608 }
1609 sourceElemOrDistType = broadcastOp.getSourceType();
1610 }
1611 FailureOr<VectorType> destDistType =
1612 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1613 if (failed(destDistType)) {
1614 return rewriter.notifyMatchFailure(
1615 warpOp, "Failed to distribute the dest vector type.");
1616 }
1617
1618 SmallVector<size_t> newRetIndices;
1620 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1621 newRetIndices);
1622
1623 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1624
1625 Value newBroadcast = distributedSource;
1626
1627 if (sourceElemOrDistType != destDistType.value()) {
1628 rewriter.setInsertionPointAfter(newWarpOp);
1629 newBroadcast =
1630 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1631 destDistType.value(), distributedSource);
1632 }
1633
1634 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
1635 return success();
1636 }
1637};
1638
1639/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1640/// `gpu.warp_execute_on_lane_0` region.
1641struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1642 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1643 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1644 PatternRewriter &rewriter) const override {
1645 OpOperand *yieldOperand =
1646 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1647 if (!yieldOperand)
1648 return failure();
1649 auto shapeCastOp =
1650 cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1651 unsigned operandNumber = yieldOperand->getOperandNumber();
1652 auto resultDistTy =
1653 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1654 xegpu::DistributeLayoutAttr sourceLayout =
1655 xegpu::getTemporaryLayout(shapeCastOp->getOpOperand(0));
1656 xegpu::DistributeLayoutAttr resultLayout =
1657 xegpu::getTemporaryLayout(dyn_cast<OpResult>(shapeCastOp.getResult()));
1658 if (!sourceLayout || !resultLayout)
1659 return rewriter.notifyMatchFailure(
1660 warpOp,
1661 "the source or result of shape_cast op lacks distribution layout");
1662
1663 FailureOr<VectorType> sourceDistTypeOrFailure =
1665 shapeCastOp.getSourceVectorType());
1666 if (failed(sourceDistTypeOrFailure))
1667 return rewriter.notifyMatchFailure(
1668 warpOp, "failed to get distributed vector type for source");
1669 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1670 // Create a new warp op that yields the source of the shape_cast op.
1671 SmallVector<size_t> newRetIndices;
1673 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1674 newRetIndices);
1675 rewriter.setInsertionPointAfter(newWarpOp);
1676 Value source = newWarpOp.getResult(newRetIndices[0]);
1677 // Create a new shape_cast op outside the warp op.
1678 Value newShapeCast = vector::ShapeCastOp::create(
1679 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1680 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1681 newShapeCast);
1682 return success();
1683 }
1684};
1685
1686// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1687// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1688// advanced cases where the distributed dimension is partially extracted and
1689// currently not supported by the generic vector distribution patterns.
1690struct VectorExtractStridedSliceDistribution
1692 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1693 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1694 PatternRewriter &rewriter) const override {
1695 OpOperand *operand =
1696 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1697 if (!operand)
1698 return failure();
1699 auto extractOp =
1700 cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1701 unsigned operandIdx = operand->getOperandNumber();
1702 auto distributedType =
1703 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1704 // Find the distributed dimensions.
1705 auto extractResultType = cast<VectorType>(operand->get().getType());
1706 auto distributedDims =
1707 getDistributedDims(extractResultType, distributedType);
1708 // Collect updated source type, sizes and offsets. They may be adjusted
1709 // later if the data is distributed to lanes (as opposed to being owned by
1710 // all lanes uniformly).
1711 VectorType updatedSourceType = extractOp.getSourceVectorType();
1712 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1713 extractOp.getSizes(), [](Attribute attr) { return attr; });
1714 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1715 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1716 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1717 extractOp.getStrides(), [](Attribute attr) { return attr; });
1718 // If the provided sizes, offsets, strides are less than the rank, pad them
1719 // with full sizes, zero offsets, and unit strides. This makes it easier to
1720 // adjust them later.
1721 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1722 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1723 updatedSizes.push_back(rewriter.getI64IntegerAttr(
1724 extractOp.getSourceVectorType().getDimSize(i)));
1725 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1726 updatedStrides.push_back(
1727 rewriter.getI64IntegerAttr(1)); // stride is always 1.
1728 }
1729 // If the result is distributed, it must be distributed in exactly one
1730 // dimension. In this case, we adjust the sourceDistType, distributedSizes
1731 // and distributedOffsets accordingly.
1732 if (distributedDims.size() > 0) {
1733 if (distributedDims.size() != 1)
1734 return rewriter.notifyMatchFailure(
1735 warpOp, "Source can not be distributed in multiple dimensions.");
1736 int64_t distributedDim = distributedDims[0];
1737 int sourceDistrDimSize =
1738 extractOp.getSourceVectorType().getShape()[distributedDim];
1739 auto sourceLayout = xegpu::getTemporaryLayout(extractOp->getOpOperand(0));
1740 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1741 return rewriter.notifyMatchFailure(
1742 warpOp, "the source of extract_strided_slice op lacks distribution "
1743 "layout");
1744 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1745 // Because only single dimension distribution is supported, lane layout
1746 // size at the distributed dim must be the subgroup size.
1747 int subgroupSize = sourceLaneLayout[distributedDim];
1748 // Check if the source size in the distributed dimension is a multiple of
1749 // subgroup size.
1750 if (sourceDistrDimSize % subgroupSize != 0)
1751 return rewriter.notifyMatchFailure(
1752 warpOp,
1753 "Source size along distributed dimension is not a multiple of "
1754 "subgroup size.");
1755 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1756 // We expect lane data to be all ones in this case.
1757 if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1758 return rewriter.notifyMatchFailure(
1759 warpOp, "Expecting unit lane data in source layout");
1760 // The offsets in the distributed dimention must be a multiple of subgroup
1761 // size.
1762 int64_t distrDimOffset =
1763 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1764 if (distrDimOffset % subgroupSize != 0)
1765 return rewriter.notifyMatchFailure(
1766 warpOp, "Offset along distributed dimension "
1767 "is not a multiple of subgroup size.");
1768 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1769 sourceLayout, extractOp.getSourceVectorType())
1770 .value();
1771 // Update the distributed sizes to match the distributed type.
1772 updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1773 distributedType.getDimSize(distributedDim));
1774 // Update the distributed offsets to match round robin distribution (i.e.
1775 // each lane owns data at `subgroupSize` stride given unit lane data).
1776 updatedOffsets[distributedDim] =
1777 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1778 }
1779 // Do the distribution by yielding the source of the extract op from
1780 // the warp op and creating a new extract op outside the warp op.
1781 SmallVector<size_t> newRetIndices;
1783 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1784 newRetIndices);
1785 rewriter.setInsertionPointAfter(newWarpOp);
1786 Value source = newWarpOp.getResult(newRetIndices[0]);
1787 // Create a new extract op outside the warp op.
1788 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1789 rewriter, extractOp.getLoc(), distributedType, source,
1790 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1791 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1792 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1793 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1794 return success();
1795 }
1796};
1797
1798/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1799/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1800/// advanced cases where the distributed dimension is partially inserted and
1801/// currently not supported by the generic vector distribution patterns.
1802struct VectorInsertStridedSliceDistribution
1804 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1805 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1806 PatternRewriter &rewriter) const override {
1807 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
1808 // Check if the InsertStridedSliceOp is the last op before yield op
1809 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1810 warpOp.getTerminator()->getPrevNode() == op;
1811 });
1812 if (!operand)
1813 return failure();
1814 unsigned int operandNumber = operand->getOperandNumber();
1815 auto insertOp =
1816 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1817 auto distributedType =
1818 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1819 // Find the distributed dimensions of the dest vector.
1820 auto insertResultType = cast<VectorType>(operand->get().getType());
1821 auto destDistributedDims =
1822 getDistributedDims(insertResultType, distributedType);
1823 // Collect updated offsets, source type and dest type. They may be adjusted
1824 // later if the data is distributed to lanes (as opposed to being owned by
1825 // all lanes uniformly).
1826 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1827 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1828 VectorType updatedSourceType = insertOp.getSourceVectorType();
1829 VectorType updatedDestType = insertOp.getDestVectorType();
1830 if (destDistributedDims.size() > 0) {
1831 // Only single dimension distribution is supported.
1832 if (destDistributedDims.size() != 1)
1833 return rewriter.notifyMatchFailure(
1834 warpOp,
1835 "Expecting source to be distributed in a single dimension.");
1836 int64_t destDistributedDim = destDistributedDims[0];
1837
1838 VectorType srcType = insertOp.getSourceVectorType();
1839 VectorType destType = insertOp.getDestVectorType();
1840 // Currently we require that both source (kD) and dest (nD) vectors are
1841 // distributed. This requires that distributedDim (d) is contained in the
1842 // last k dims of the dest vector (d >= n - k).
1843 int64_t sourceDistributedDim =
1844 destDistributedDim - (destType.getRank() - srcType.getRank());
1845 if (sourceDistributedDim < 0)
1846 return rewriter.notifyMatchFailure(
1847 insertOp,
1848 "distributed dimension must be in the last k (i.e. source "
1849 "rank) dims of dest vector");
1850 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1851 // Obtain the source and dest layouts.
1852 auto destLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(1));
1853 auto sourceLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(0));
1854 if (!destLayout || !sourceLayout ||
1855 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1856 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1857 return rewriter.notifyMatchFailure(
1858 warpOp, "the source or dest of insert_strided_slice op lacks "
1859 "distribution layout");
1860 // Because only single dimension distribution is supported, lane layout
1861 // size at the distributed dim must be the subgroup size.
1862 int subgroupSize =
1863 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1864 // We require that source and dest lane data are all ones to ensure
1865 // uniform round robin distribution.
1866 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1867 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1868 if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1869 !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1870 return rewriter.notifyMatchFailure(
1871 warpOp, "Expecting unit lane data in source and dest layouts");
1872 // Source distributed dim size must be multiples of subgroup size.
1873 if (srcDistrDimSize % subgroupSize != 0)
1874 return rewriter.notifyMatchFailure(
1875 warpOp, "Distributed dimension size in source is not a multiple of "
1876 "subgroup size.");
1877 // Offsets in the distributed dimension must be multiples of subgroup
1878 // size.
1879 int64_t destDistrDimOffset =
1880 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1881 if (destDistrDimOffset % subgroupSize != 0)
1882 return rewriter.notifyMatchFailure(
1883 warpOp,
1884 "Offset along distributed dimension in dest is not a multiple of "
1885 "subgroup size.");
1886 // Update the source and dest types based on their layouts.
1887 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1888 sourceLayout, insertOp.getSourceVectorType())
1889 .value();
1890 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1891 destLayout, insertOp.getDestVectorType())
1892 .value();
1893 // Update the distributed offsets to match round robin distribution (i.e.
1894 // each lane owns data at `subgroupSize` stride given unit lane data).
1895 updatedOffsets[destDistributedDim] =
1896 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1897 }
1898 // Do the distribution by yielding the source and dest of the insert op
1899 // from the warp op and creating a new insert op outside the warp op.
1900 SmallVector<size_t> newRetIndices;
1902 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1903 {updatedSourceType, updatedDestType}, newRetIndices);
1904 rewriter.setInsertionPointAfter(newWarpOp);
1905
1906 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1907 Value dest = newWarpOp.getResult(newRetIndices[1]);
1908 // Create a new insert op outside the warp op.
1909 Value newInsertOp = vector::InsertStridedSliceOp::create(
1910 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1911 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1912 insertOp.getStrides());
1913 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1914 newInsertOp);
1915 return success();
1916 }
1917};
1918
1919/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1920/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1921/// outside of the warp op.
1922struct MemrefExtractAlignedPointerAsIndexDistribution final
1924 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1925 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1926 PatternRewriter &rewriter) const override {
1927 OpOperand *operand = getWarpResult(
1928 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1929 if (!operand)
1930 return rewriter.notifyMatchFailure(
1931 warpOp,
1932 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1933 auto extractOp =
1934 operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1935 unsigned operandIdx = operand->getOperandNumber();
1936 SmallVector<size_t> newRetIndices;
1937 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1938 rewriter, warpOp, extractOp.getSource(),
1939 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1940 rewriter.setInsertionPointAfter(newWarpOp);
1941 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1942 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1943 newWarpOp.getResult(newRetIndices[0]));
1944 Value resultVal = newWarpOp.getResult(operandIdx);
1945 rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
1946 return success();
1947 }
1948};
1949
1950/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1951/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1952/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1953/// created outside of the warp op with distributed source vector type (computed
1954/// using assigned layout).
1955struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1956 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1957 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1958 PatternRewriter &rewriter) const override {
1959 OpOperand *operand =
1960 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1961 if (!operand)
1962 return rewriter.notifyMatchFailure(
1963 warpOp, "warp result is not a vector::BitCast op");
1964 auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1965 unsigned operandIdx = operand->getOperandNumber();
1966 VectorType distributedSourceType =
1968 xegpu::getTemporaryLayout(bitcastOp->getOpOperand(0)),
1969 bitcastOp.getSourceVectorType())
1970 .value_or(VectorType());
1971 if (!distributedSourceType)
1972 return rewriter.notifyMatchFailure(
1973 bitcastOp, "Failed to distribute the source vector type in "
1974 "vector::BitCast op");
1975 VectorType distributedResultType =
1976 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1977 SmallVector<size_t> newRetIndices;
1978 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1979 rewriter, warpOp, bitcastOp.getSource(),
1980 TypeRange{distributedSourceType}, newRetIndices);
1981 rewriter.setInsertionPointAfter(newWarpOp);
1982 auto newBitcastOp = vector::BitCastOp::create(
1983 rewriter, newWarpOp.getLoc(), distributedResultType,
1984 newWarpOp.getResult(newRetIndices[0]));
1985 Value distributedVal = newWarpOp.getResult(operandIdx);
1986 rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1987 return success();
1988 }
1989};
1990
1991/// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1992/// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1993/// supported. In most cases, transpose is a no op because it is entirely
1994/// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1995/// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1996/// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1997/// vector::TransposeOp outside of the warp op with distributed source vector
1998/// type (computed using assigned layout).
1999struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
2000 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2001 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2002 PatternRewriter &rewriter) const override {
2003 OpOperand *operand =
2004 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
2005 if (!operand)
2006 return rewriter.notifyMatchFailure(
2007 warpOp, "warp result is not a vector::Transpose op");
2008 auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
2009 unsigned operandIdx = operand->getOperandNumber();
2010 xegpu::DistributeLayoutAttr sourceLayout =
2011 xegpu::getTemporaryLayout(transposeOp->getOpOperand(0));
2012 xegpu::DistributeLayoutAttr resultLayout =
2013 xegpu::getTemporaryLayout(transposeOp->getOpResult(0));
2014 if (!sourceLayout || !resultLayout)
2015 return rewriter.notifyMatchFailure(
2016 transposeOp,
2017 "the source or result vector of the transpose op lacks layout "
2018 "attribute");
2019 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
2020 int64_t resultRank = transposeOp.getResultVectorType().getRank();
2021 // Only 2D transposes are supported for now.
2022 // TODO: Support nD transposes.
2023 if (sourceRank != 2 || resultRank != 2)
2024 return rewriter.notifyMatchFailure(
2025 transposeOp, "the source or result vector of the transpose op "
2026 "does not have 2D layout");
2027 ArrayRef<int64_t> perm = transposeOp.getPermutation();
2028 // Result layout must be a transpose of source layout.
2029 if (!resultLayout.isTransposeOf(sourceLayout, perm))
2030 return rewriter.notifyMatchFailure(
2031 transposeOp,
2032 "the source or result vector layouts must be 2D transposes of each "
2033 "other");
2034 FailureOr<VectorType> distributedSourceTypeOrFailure =
2036 transposeOp.getSourceVectorType());
2037 if (failed(distributedSourceTypeOrFailure))
2038 return rewriter.notifyMatchFailure(
2039 transposeOp, "Failed to distribute the source vector type in "
2040 "vector::Transpose op");
2041 SmallVector<size_t> newRetIndices;
2042 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2043 rewriter, warpOp, transposeOp.getVector(),
2044 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2045 rewriter.setInsertionPointAfter(newWarpOp);
2046 auto newTransposeOp = vector::TransposeOp::create(
2047 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2048 perm);
2049 Value distributedVal = newWarpOp.getResult(operandIdx);
2050 rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
2051 return success();
2052 }
2053};
2054
2055} // namespace
2056
2057namespace {
2058struct XeGPUSubgroupDistributePass final
2060 XeGPUSubgroupDistributePass> {
2061 void runOnOperation() override;
2062};
2063} // namespace
2064
2067 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2068 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2069 GpuBarrierDistribution, VectorMultiReductionDistribution,
2070 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2071 VectorBitcastDistribution, LoadMatrixDistribution,
2072 StoreMatrixDistribution,
2073 MemrefExtractAlignedPointerAsIndexDistribution>(
2074 patterns.getContext(),
2075 /*pattern benefit=*/PatternHierarchy::Regular);
2076 // For following patterns, we need to override the regular vector distribution
2077 // patterns. Therefore, assign higher benefit.
2078 patterns
2079 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2080 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2081 SinkUniformOps>(patterns.getContext(),
2082 /*pattern benefit=*/PatternHierarchy::AboveRegular);
2083}
2084
2087 patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
2088}
2089
2090void XeGPUSubgroupDistributePass::runOnOperation() {
2091 // Step 1: Attach layouts to op operands.
2092 // TODO: Following assumptions are made:
2093 // 1) It is assumed that there are no layout conflicts.
2094 // 2) Any existing layout attributes attached to the operands are ignored.
2095 Operation *op = getOperation();
2097 signalPassFailure();
2098 return;
2099 }
2100
2101 // Step 2: Move all operations of a GPU function inside
2102 // gpu.warp_execute_on_lane_0 operation.
2103 {
2106
2107 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2108 signalPassFailure();
2109 return;
2110 }
2111 // At this point, we have moved the entire function body inside the
2112 // warpOp. Now move any scalar uniform code outside of the warpOp (like
2113 // GPU index ops, scalar constants, etc.). This will simplify the
2114 // later lowering and avoid custom patterns for these ops.
2115 getOperation()->walk([&](Operation *op) {
2116 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2117 vector::moveScalarUniformCode(warpOp);
2118 });
2119 }
2120 // Step 3: Apply subgroup to workitem distribution patterns.
2121 RewritePatternSet patterns(&getContext());
2123 // distributionFn is used by vector distribution patterns to determine the
2124 // distributed vector type for a given vector value. In XeGPU subgroup
2125 // distribution context, we compute this based on lane layout.
2126 auto distributionFn = [](Value val) {
2127 VectorType vecType = dyn_cast<VectorType>(val.getType());
2128 int64_t vecRank = vecType ? vecType.getRank() : 0;
2129 if (vecRank == 0)
2130 return AffineMap::get(val.getContext());
2131 // Get the layout of the vector type.
2132 xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
2133 // If no layout is specified, assume uniform case (no distribution).
2134 if (!layout)
2135 return AffineMap::get(val.getContext());
2136 // Expecting vector and layout rank to match.
2137 assert(layout.getRank() == vecRank &&
2138 "Expecting vector and layout rank to match");
2139 // A dimension is distributed only if layout suggests there are
2140 // multiple lanes assigned for this dimension and the shape can be evenly
2141 // distributed to those lanes.
2142 SmallVector<unsigned int> distributedDims;
2143 for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2144 if (v > 1 && vecType.getShape()[i] % v == 0)
2145 distributedDims.push_back(i);
2146 }
2147 return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
2148 val.getContext());
2149 };
2150 // TODO: shuffleFn is not used.
2151 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2152 int64_t warpSz) { return Value(); };
2153
2154 auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
2155 vector::CombiningKind kind, uint32_t size) {
2156 // First reduce on a single thread to get per lane reduction value.
2157 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
2158 // Parallel reduction using butterfly shuffles.
2159 for (uint64_t i = 1; i < size; i <<= 1) {
2160 Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
2161 /*width=*/size,
2162 /*mode=*/gpu::ShuffleMode::XOR)
2163 .getShuffleResult();
2164 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
2165 }
2166 return laneVal;
2167 };
2168
2169 vector::populateDistributeReduction(
2170 patterns, warpReduction,
2171 /*pattern benefit=*/PatternHierarchy::Regular);
2172
2173 vector::populatePropagateWarpVectorDistributionPatterns(
2174 patterns, distributionFn, shuffleFn,
2175 /*pattern benefit=*/PatternHierarchy::Regular);
2176 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2177 signalPassFailure();
2178 return;
2179 }
2180
2181 // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
2182 // due to tensor desc type mismatches created by using upstream distribution
2183 // patterns (scf.for). This cleanup should only be done if all the ops are
2184 // distributed successfully, if some ops are still not distributed and remains
2185 // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
2186 // breaking the IR.
2187 bool foundWarpOp = false;
2188 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2189 // Look for WarpOps that are not trivially dead.
2190 if (isOpTriviallyDead(warpOp))
2191 return WalkResult::advance();
2192 foundWarpOp = true;
2193 return WalkResult::interrupt();
2194 });
2195 if (foundWarpOp)
2196 return;
2197
2198 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2199 // We are only interested in UnrealizedConversionCastOps there were added
2200 // for resolving SIMT type mismatches.
2201 if (!op->getAttr(resolveSIMTTypeMismatch))
2202 return WalkResult::skip();
2203
2204 Value input = op.getOperand(0);
2205 Value output = op.getResult(0);
2206
2207 // Both input and output must have tensor descriptor types.
2208 xegpu::TensorDescType inputDescType =
2209 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
2210 xegpu::TensorDescType outputDescType =
2211 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
2212 assert(inputDescType && outputDescType &&
2213 "Unrealized conversion cast must have tensor descriptor types");
2214
2215 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
2216 // This occurs inside scf.for body to resolve the block argument type to
2217 // SIMT type.
2218 if (inputDescType.getLayout()) {
2219 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2220 if (argument) {
2221 argument.setType(output.getType());
2222 output.replaceAllUsesWith(argument);
2223 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2224 argument.getOwner()->getParentOp())) {
2225 auto result = loopOp.getTiedLoopResult(argument);
2226 result.setType(output.getType());
2227 }
2228 }
2229 }
2230
2231 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
2232 // conversions. This occurs at the yield op of scf.for body to go back
2233 // from SIMT type to original type.
2234 if (outputDescType.getLayout())
2235 output.replaceAllUsesWith(input);
2236
2237 if (op->use_empty())
2238 op->erase();
2239 return WalkResult::advance();
2240 });
2241}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
operand_type_range getOperandTypes()
Definition Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0