MLIR 22.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
19#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Operation.h"
27#include "mlir/IR/TypeRange.h"
28#include "mlir/IR/Value.h"
29#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallVector.h"
38
39namespace mlir {
40namespace xegpu {
41#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
42#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
43} // namespace xegpu
44} // namespace mlir
45
46#define DEBUG_TYPE "xegpu-subgroup-distribute"
47#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48
49using namespace mlir;
50
51static const char *const resolveSIMTTypeMismatch =
52 "resolve_simt_type_mismatch"; // Attribute name for identifying
53 // UnrelizedConversionCastOp added to resolve
54 // SIMT type mismatches.
55
56namespace {
57
58//===----------------------------------------------------------------------===//
59// SIMT Distribution Patterns
60//===----------------------------------------------------------------------===//
61
62/// In certain cases, we may need to favor XeGPU specific distribution patterns
63/// over generic vector distribution patterns. In such cases, we can assign
64/// priorities to patterns.
65static constexpr unsigned regularPatternBenefit = 1;
66static constexpr unsigned highPatternBenefit = 2;
67
68/// Helper function to get distributed vector type for a source vector type
69/// according to the lane_layout. We simply divide each dimension of tensor
70/// descriptor shape by corresponding lane_layout dimension. If
71/// array_length > 1, that is appended to the front of the ditributed shape.
72/// NOTE: This is the vector type that will be returned by the
73/// gpu.warp_execute_on_lane0 op.
74///
75/// Examples:
76/// | original vector shape | lane_layout | distributed vector shape |
77/// |-----------------------|-------------|--------------------------|
78/// | 32x16 | [1, 16] | 32x1 |
79/// | 32x16 | [2, 8] | 16x2 |
80/// | 2x32x16 | [1, 16] | 2x32x1 |
81static FailureOr<VectorType>
82getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
83 VectorType originalType) {
84 if (!layout)
85 return failure();
86 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
87 "Expecting a valid layout.");
88 SmallVector<int64_t> effectiveLaneLayout =
89 layout.getEffectiveLaneLayoutAsInt();
90 assert(static_cast<size_t>(originalType.getRank()) >=
91 effectiveLaneLayout.size() &&
92 "Rank of the original vector type should be greater or equal to the "
93 "size of the lane layout to distribute the vector type.");
94 SmallVector<int64_t> distributedShape(originalType.getShape());
95 // Only distribute the last `laneLayout.size()` dimensions. The remaining
96 // dimensions are not distributed.
97 unsigned distributionStart =
98 originalType.getRank() - effectiveLaneLayout.size();
99 for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
100 if (i < distributionStart)
101 continue;
102 // Check if the dimension can be distributed evenly.
103 if (dim % effectiveLaneLayout[i - distributionStart] != 0)
104 return failure();
105 distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
106 }
107 return VectorType::get(distributedShape, originalType.getElementType());
108}
109
110/// Helper function to resolve types if the distributed type out of
111/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
112/// Example 1:
113/// distributed type: vector<8x1xf32>
114/// expected type: vector<8xf32>
115/// resolved using,
116/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
117/// Example 2:
118/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
119/// expected type: xegpu.tensor_desc<8x16xf32>
120/// resolved using,
121/// %0 = unrealized_conversion_cast %1 :
122/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
123/// xegpu.tensor_desc<8x16xf32>
124template <typename T>
125static Value resolveDistributedTy(Value orig, T expected,
126 PatternRewriter &rewriter) {
127 // If orig and expected types are the same, return orig.
128 if (orig.getType() == expected)
129 return orig;
130 // If orig is a vector type, create a shape cast op to reconcile the types.
131 if (isa<VectorType>(orig.getType())) {
132 auto castOp =
133 vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
134 return castOp.getResult();
135 }
136 // If orig is a tensor descriptor type, create an unrealized conversion cast
137 // op to reconcile the types.
138 if (isa<xegpu::TensorDescType>(orig.getType())) {
139 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
140 expected, orig);
141 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
142 return castOp.getResult(0);
143 }
144 llvm_unreachable("Unsupported type for reconciliation");
145 return orig;
146}
147
148/// Helper function to check if the layout is packed. Layout is packed if it is
149/// 2D and lane_data[0] != 1 (data packed from col dimension).
150/// TODO: Move to target info.
151static bool requirePacked(const xegpu::LayoutAttr layout) {
152 if (!layout)
153 return false;
154 auto laneData = layout.getEffectiveLaneDataAsInt();
155 if (laneData.size() != 2)
156 return false;
157 return laneData[0] != 1;
158}
159
160/// Helper function to check if the layout requires a transpose effect.
161static bool requireTranspose(const xegpu::LayoutAttr layout,
162 const xegpu::uArch::uArch *uArch) {
163 // Return false for unsupported targets.
164 // TODO: Add more support or move to target info.
165 if (uArch->getName().equals_insensitive("pvc") &&
166 uArch->getName().equals_insensitive("bmg"))
167 return false;
168 if (!layout)
169 return false;
170 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
171 if (laneLayout.size() != 2)
172 return false;
173 return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
174}
175
176/// Given a vector type and its distributed vector type, return the list of
177/// dimensions that are distributed.
178static SmallVector<int64_t> getDistributedDims(VectorType originalType,
179 VectorType distributedType) {
180 assert(originalType.getRank() == distributedType.getRank() &&
181 "sequential and distributed vector types must have the same rank");
182 SmallVector<int64_t> distributedDims;
183 for (int64_t i = 0; i < originalType.getRank(); ++i) {
184 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
185 distributedDims.push_back(i);
186 }
187 }
188 return distributedDims;
189}
190
191/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
192/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
193/// contained within a WarpExecuteOnLane0Op.
194/// Example:
195///
196/// ```
197/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
198/// ...
199/// ...
200/// gpu.return %result: vector<8x16xf32>
201/// }
202/// ```
203/// To
204/// ```
205/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
206/// %laneid = gpu.lane_id : index
207/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
208/// ...
209/// ...
210/// gpu.yield %result: vector<8x16xf32>
211/// }
212/// return %0
213/// }
214struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
215 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
216 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
217 PatternRewriter &rewriter) const override {
218 auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
219 if (!uArch)
220 return rewriter.notifyMatchFailure(
221 gpuFuncOp, "Subgroup distribution requires target attribute attached "
222 "to set the warp size");
223 // If the function only contains a single void return, skip.
224 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
225 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
226 }))
227 return failure();
228 // If the function already moved inside a warp_execute_on_lane0, skip.
229 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
230 return isa<gpu::WarpExecuteOnLane0Op>(op);
231 }))
232 return failure();
233 // Create a new function with the same signature and same attributes.
234 SmallVector<Type> workgroupAttributionsTypes =
235 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
236 [](BlockArgument arg) { return arg.getType(); });
237 SmallVector<Type> privateAttributionsTypes =
238 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
239 [](BlockArgument arg) { return arg.getType(); });
240 auto newGpuFunc = gpu::GPUFuncOp::create(
241 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
242 gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
243 privateAttributionsTypes);
244 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
245 // Create a WarpExecuteOnLane0Op with same arguments and results as the
246 // original gpuFuncOp.
247 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
248 auto laneId = gpu::LaneIdOp::create(
249 rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
250 /** upperBound = **/ mlir::IntegerAttr());
251 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
252 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
253 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
254 uArch->getSubgroupSize(), newGpuFunc.getArguments(),
255 newGpuFunc.getArgumentTypes());
256 Block &warpBodyBlock = warpOp.getBodyRegion().front();
257 // Replace the ReturnOp of the original gpu function with a YieldOp.
258 auto origRetunOp =
259 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
260 rewriter.setInsertionPointAfter(origRetunOp);
261 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
262 origRetunOp.getOperands());
263 rewriter.eraseOp(origRetunOp);
264 // Move the original function body to the WarpExecuteOnLane0Op body.
265 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
266 warpOp.getBodyRegion().begin());
267 rewriter.eraseBlock(&warpBodyBlock);
268 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
269 rewriter.setInsertionPointAfter(warpOp);
270 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
271 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
272 return success();
273 }
274};
275
276/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
277/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
278/// still contain the original op that will not be used by the yield op (and
279/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
280/// arguments. Tensor descriptor shape is not distributed because it is a
281/// uniform value across all work items within the subgroup. However, the
282/// layout information is dropped in the new tensor descriptor type.
283///
284/// Example:
285///
286/// ```
287/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
288/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
289/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
290/// ...
291/// %td = xegpu.create_nd_tdesc %arg0
292/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
293/// vector.yield %td
294/// }
295/// ```
296/// To
297/// ```
298/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
299/// ...
300/// %dead = xegpu.create_nd_tdesc %arg0
301/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
302/// vector.yield %arg0, %dead
303/// }
304/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
305/// -> !xegpu.tensor_desc<4x8xf32>
306///
307/// ```
308struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
309 using gpu::WarpDistributionPattern::WarpDistributionPattern;
310 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
311 PatternRewriter &rewriter) const override {
312 OpOperand *operand =
313 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
314 if (!operand)
315 return rewriter.notifyMatchFailure(
316 warpOp, "warp result is not a xegpu::CreateNdDesc op");
317 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
318 unsigned operandIdx = operand->getOperandNumber();
319
320 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
321 if (!layout)
322 return rewriter.notifyMatchFailure(
323 descOp, "the tensor descriptor lacks layout attribute");
324 // CreateNdOp must not have offsets.
325 if (descOp.getMixedOffsets().size())
326 return rewriter.notifyMatchFailure(
327 descOp, "xegpu::CreateNdDescOp must not have offsets");
328
329 SmallVector<size_t> newRetIndices;
330 rewriter.setInsertionPoint(warpOp);
331 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
332 rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
333 /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
334
335 SmallVector<Value> newDescOperands = llvm::map_to_vector(
336 newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
337 rewriter.setInsertionPointAfter(newWarpOp);
338 xegpu::TensorDescType distributedTensorDescTy =
339 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
340 // does not contain layout info.
341 Value newDescOp = xegpu::CreateNdDescOp::create(
342 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
343 descOp->getAttrs());
344
345 Value distributedVal = newWarpOp.getResult(operandIdx);
346 // Resolve the distributed type to the expected type.
347 newDescOp =
348 resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
349 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
350 return success();
351 }
352};
353
354/// Distribute a store_nd op at the end of enclosing
355/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
356/// through the warp op interface they would be propagated as returned values.
357/// Source vector is distributed based on lane layout. Appropriate cast ops are
358/// inserted if the distributed types does not match expected xegpu SIMT types.
359///
360/// Example:
361///
362/// ```
363/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
364/// gpu.warp_execute_on_lane_0(%laneid) -> () {
365/// ...
366/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
367/// !xegpu.tensor_desc<4x8xf32, #layout0>
368/// }
369/// ```
370/// To
371/// ```
372/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
373/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
374/// ...
375/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
376/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
377/// }
378/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
379/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
380/// #layout0>
381/// -> !xegpu.tensor_desc<4x8xf32>
382/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
383/// !xegpu.tensor_desc<4x8xf32>
384///
385/// ```
386struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
387 using gpu::WarpDistributionPattern::WarpDistributionPattern;
388 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
389 PatternRewriter &rewriter) const override {
390 gpu::YieldOp yield = warpOp.getTerminator();
391 Operation *lastNode = yield->getPrevNode();
392 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
393 if (!storeOp)
394 return failure();
395
396 SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
397 // Expecting offsets to be present.
398 if (offsets.empty())
399 return rewriter.notifyMatchFailure(storeOp,
400 "the store op must have offsets");
401 SmallVector<Value> offsetsAsValues =
402 vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
403 SmallVector<Type> offsetTypes = llvm::to_vector(
404 llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
405 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
406 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
407 if (!layout)
408 return rewriter.notifyMatchFailure(
409 storeOp, "the source tensor descriptor lacks layout attribute");
410
411 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
412 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
413 if (failed(distributedTypeByWarpOpOrFailure))
414 return rewriter.notifyMatchFailure(storeOp,
415 "Failed to distribute the type");
416 VectorType distributedTypeByWarpOp =
417 distributedTypeByWarpOpOrFailure.value();
418
419 SmallVector<size_t> newRetIndices;
420 SmallVector<Value> newYieldedValues = {storeOp.getValue(),
421 storeOp.getTensorDesc()};
422 SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
423 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
424 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
425 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
426 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
427 // Create a new store op outside the warp op with the distributed vector
428 // type. Tensor descriptor is not distributed.
429 rewriter.setInsertionPointAfter(newWarpOp);
430 SmallVector<Value> newStoreOperands;
431
432 // For the value operand, there can be a mismatch between the vector type
433 // distributed by the warp op and (xegpu-specific) distributed type
434 // supported by the store op. Type mismatch must be resolved using
435 // appropriate cast op.
436 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
437 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
438 if (failed(storeNdDistributedValueTyOrFailure))
439 return rewriter.notifyMatchFailure(
440 storeOp, "Failed to get distributed vector type for the store op");
441 newStoreOperands.push_back(resolveDistributedTy(
442 newWarpOp.getResult(newRetIndices[0]),
443 storeNdDistributedValueTyOrFailure.value(), rewriter));
444 // For the tensor descriptor operand, the layout attribute is dropped after
445 // distribution. Types needs to be resolved in this case also.
446 xegpu::TensorDescType distributedTensorDescTy =
447 storeOp.getTensorDescType().dropLayouts();
448 newStoreOperands.push_back(
449 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
450 distributedTensorDescTy, rewriter));
451 // Collect offsets.
452 for (size_t i = 2; i < newRetIndices.size(); ++i)
453 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
454
455 auto newStoreOp =
456 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
457 newStoreOperands, storeOp->getAttrs());
458 xegpu::removeLayoutAttrs(newStoreOp);
459 rewriter.eraseOp(storeOp);
460 return success();
461 }
462};
463
464/// Distribute a load_nd op feeding into vector.yield op for the enclosing
465/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
466/// The warp op will still contain the original op that will not be used by
467/// the yield op (and should be cleaned up later). The yield op will
468/// bypass the load's arguments. Only the loaded vector is distributed
469/// according to lane layout and, tensor descriptor types is not
470/// distributed. Appropriate cast ops are inserted if the distributed types does
471/// not match expected xegpu SIMT types.
472///
473/// Example:
474///
475/// ```
476/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
477/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
478/// (vector<4x1xf32>) {
479/// ...
480/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
481/// ->
482/// vector<4x8xf32>
483/// gpu.yield %ld
484/// }
485/// ```
486/// To
487/// ```
488/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
489/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
490/// ...
491/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
492/// vector<4x8xf32> gpu.yield %dead, %arg0
493/// }
494/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
495/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
496/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
497/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
498///
499/// ```
500struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
501 using gpu::WarpDistributionPattern::WarpDistributionPattern;
502 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
503 PatternRewriter &rewriter) const override {
504 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
505 if (!isa<xegpu::LoadNdOp>(op))
506 return false;
507 // Make sure the same load op is the last operation in the warp op body.
508 // This ensure that load op is not sinked earlier violating any barrier
509 // synchronizations.
510 gpu::YieldOp yield = warpOp.getTerminator();
511 return yield->getPrevNode() == op;
512 });
513
514 if (!operand)
515 return rewriter.notifyMatchFailure(
516 warpOp, "warp result is not a xegpu::LoadNd op");
517
518 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
519 auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
520 if (!uArch)
521 return rewriter.notifyMatchFailure(
522 loadOp, "xegpu::LoadNdOp require target attribute attached to "
523 "determine transpose "
524 "requirement");
525 // Chip information is required to decide if the layout requires transpose
526 // effect.
527 // Expecting offsets to be present.
528 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
529 if (offsets.empty())
530 return rewriter.notifyMatchFailure(loadOp,
531 "the load op must have offsets");
532 SmallVector<Value> offsetsAsValues =
533 vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
534 SmallVector<Type> offsetTypes = llvm::to_vector(
535 llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
536
537 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
538 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
539 if (!layout)
540 return rewriter.notifyMatchFailure(
541 loadOp, "the source tensor descriptor lacks layout attribute");
542
543 unsigned operandIdx = operand->getOperandNumber();
544 VectorType distributedTypeByWarpOp =
545 cast<VectorType>(warpOp.getResult(operandIdx).getType());
546
547 SmallVector<size_t> newRetIndices;
548 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
549 SmallVector<Type> newYieldedTypes = {tensorDescTy};
550 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
551 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
552 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
553 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
554
555 // Create a new load op outside the warp op with the distributed vector
556 // type.
557 rewriter.setInsertionPointAfter(newWarpOp);
558 FailureOr<VectorType> loadNdDistValueTyOrFailure =
559 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
560 if (failed(loadNdDistValueTyOrFailure))
561 return rewriter.notifyMatchFailure(
562 loadOp, "Failed to get distributed vector type for the load op");
563 xegpu::TensorDescType distributedTensorDescTy =
564 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
565 // descriptor type does not
566 // contain layout info.
567 SmallVector<Value> newLoadOperands{
568 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
569 distributedTensorDescTy, rewriter)};
570 // Collect offsets.
571 for (size_t i = 1; i < newRetIndices.size(); ++i)
572 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
573 auto newLoadOp = xegpu::LoadNdOp::create(
574 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
575 newLoadOperands, loadOp->getAttrs());
576 xegpu::removeLayoutAttrs(newLoadOp);
577 // Set the packed attribute if the layout requires it.
578 newLoadOp.setPacked(requirePacked(layout));
579 // Set the transpose attribute if the layout requires it.
580 if (requireTranspose(layout, uArch))
581 newLoadOp.setTranspose(
582 DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
583 Value distributedVal = newWarpOp.getResult(operandIdx);
584 // There can be a conflict between the vector type distributed by the
585 // warp op and (xegpu-specific) distributed type supported by the load
586 // op. Resolve these mismatches by inserting a cast.
587 Value tyResolvedVal = resolveDistributedTy(
588 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
589 rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
590 return success();
591 }
592};
593
594/// Distribute a dpas op feeding into vector.yield op for the enclosing
595/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
596/// The warp op will still contain the original op that will not be used by
597/// the yield op (and should be cleaned up later). The yield op will
598/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
599/// distributed types does not match expected xegpu SIMT types.
600/// Example:
601/// ```
602/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
603/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
604/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
605/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
606/// (vector<8x1xf32>) {
607/// ...
608/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
609/// vector<8x16xf32>
610/// gpu.yield %dpas
611/// }
612/// ```
613/// To
614/// ```
615/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
616/// vector<8x1xf16>, vector<16x1xf16>) {
617/// ...
618/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
619/// -> vector<8x16xf32>
620/// gpu.yield %dead, %arg0, %arg1
621/// }
622/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
623/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
624/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
625/// vector<8xf32>
626/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
627/// ```
628struct DpasDistribution final : public gpu::WarpDistributionPattern {
629 using gpu::WarpDistributionPattern::WarpDistributionPattern;
630 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
631 PatternRewriter &rewriter) const override {
632 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
633 if (!operand)
634 return rewriter.notifyMatchFailure(warpOp,
635 "warp result is not a xegpu::Dpas op");
636
637 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
638 unsigned operandIdx = operand->getOperandNumber();
639
640 xegpu::LayoutAttr layoutA =
641 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
642 xegpu::LayoutAttr layoutB =
643 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
644 xegpu::LayoutAttr layoutOut =
645 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
646
647 if (!layoutA || !layoutB || !layoutOut)
648 return rewriter.notifyMatchFailure(
649 dpasOp,
650 "the xegpu::Dpas op lacks layout attribute for A, B or output");
651
652 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
653 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
654 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
655 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
656 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
657 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
658
659 if (failed(distLhsTypeByWarpOpOrFailure) ||
660 failed(distRhsTypeByWarpOpOrFailure) ||
661 failed(distResultTypeByWarpOpOrFailure))
662 return rewriter.notifyMatchFailure(
663 dpasOp,
664 "Failed to distribute the A, B or output types in xegpu::Dpas op");
665
666 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
667 dpasOp.getRhs()};
668 llvm::SmallVector<Type, 3> newYieldTypes{
669 distLhsTypeByWarpOpOrFailure.value(),
670 distRhsTypeByWarpOpOrFailure.value()};
671 // Dpas acc operand is optional.
672 if (dpasOp.getAcc()) {
673 newYieldValues.push_back(dpasOp.getAcc());
674 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
675 }
676 // Create a new warp op without the dpas.
677 SmallVector<size_t> newRetIndices;
678 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
679 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
680
681 FailureOr<VectorType> expectedDistLhsTyOrFailure =
682 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
683 FailureOr<VectorType> expectedDistRhsTyOrFailure =
684 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
685 FailureOr<VectorType> expectedDistResultTyOrFailure =
686 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
687
688 if (failed(expectedDistLhsTyOrFailure) ||
689 failed(expectedDistRhsTyOrFailure) ||
690 failed(expectedDistResultTyOrFailure))
691 return rewriter.notifyMatchFailure(
692 dpasOp,
693 "Failed to get distributed vector type for the dpas operands.");
694 // Create a new dpas op outside the warp op.
695 rewriter.setInsertionPointAfter(newWarpOp);
696 SmallVector<Value> newDpasOperands;
697 SmallVector<VectorType> newDpasOperandExpectedTypes;
698
699 // Resolve the distributed types with the original types.
700 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
701 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
702 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
703 if (dpasOp.getAcc())
704 newDpasOperandExpectedTypes.push_back(distributedResultTy);
705
706 for (unsigned i = 0; i < newRetIndices.size(); i++) {
707 newDpasOperands.push_back(
708 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
709 newDpasOperandExpectedTypes[i], rewriter));
710 }
711 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
712 distributedResultTy, newDpasOperands,
713 dpasOp->getAttrs());
714 xegpu::removeLayoutAttrs(newDpasOp);
715 Value distributedVal = newWarpOp.getResult(operandIdx);
716 // Resolve the output type.
717 Value typeResolved =
718 resolveDistributedTy(newDpasOp.getResult(),
719 distResultTypeByWarpOpOrFailure.value(), rewriter);
720 rewriter.replaceAllUsesWith(distributedVal, typeResolved);
721 return success();
722 }
723};
724
725/// Distribute a prefetch_nd op at the end of enclosing
726/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
727/// through the warp op interface they would be propagated as returned values.
728/// Tensor descriptor shape is not distributed because it is a uniform value
729/// across all work items within the subgroup. Appropriate cast ops are inserted
730/// if the distributed types does not match expected xegpu SIMT types.
731///
732/// Example:
733///
734/// ```
735/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
736/// gpu.warp_execute_on_lane_0(%laneid) -> () {
737/// ...
738/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
739/// }
740/// ```
741/// To
742/// ```
743/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
744/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
745/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
746/// index
747/// }
748/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
749/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
750/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
751///
752/// ```
753struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
754 using gpu::WarpDistributionPattern::WarpDistributionPattern;
755 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
756 PatternRewriter &rewriter) const override {
757 gpu::YieldOp yield = warpOp.getTerminator();
758 Operation *lastNode = yield->getPrevNode();
759 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
760 if (!prefetchOp)
761 return failure();
762
763 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
764 // PrefetchNdOp must have offsets.
765 if (offsets.empty())
766 return rewriter.notifyMatchFailure(prefetchOp,
767 "the prefetch op must have offsets");
768 SmallVector<Value> offsetsAsValues =
769 vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
770 SmallVector<Type> offsetTypes = llvm::to_vector(
771 llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
772
773 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
774 if (!layout)
775 return rewriter.notifyMatchFailure(
776 prefetchOp, "the source tensor descriptor lacks layout attribute");
777
778 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
779 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
780 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
781 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
782 SmallVector<size_t> newRetIndices;
783 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
784 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
785 // Create a new prefetch op outside the warp op with updated tensor
786 // descriptor type. Source tensor descriptor require type resolution.
787 xegpu::TensorDescType newTensorDescTy =
788 prefetchOp.getTensorDescType().dropLayouts();
789 rewriter.setInsertionPointAfter(newWarpOp);
790 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
791 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
792 // Collect offsets.
793 for (size_t i = 1; i < newRetIndices.size(); ++i)
794 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
795 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
796 newPrefetchOperands, prefetchOp->getAttrs());
797 xegpu::removeLayoutAttrs(prefetchOp);
798 rewriter.eraseOp(prefetchOp);
799 return success();
800 }
801};
802
803/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
804/// region. This will simply move the barrier op outside of the warp op.
805struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
806 using gpu::WarpDistributionPattern::WarpDistributionPattern;
807 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
808 PatternRewriter &rewriter) const override {
809 gpu::YieldOp yield = warpOp.getTerminator();
810 Operation *lastNode = yield->getPrevNode();
811 // The last node must be a gpu::BarrierOp.
812 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
813 if (!barrierOp)
814 return failure();
815 // Move the barrier op outside of the warp op.
816 rewriter.setInsertionPointAfter(warpOp);
817 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
818 barrierOp->getResultTypes(),
819 barrierOp->getOperands(), barrierOp->getAttrs());
820 rewriter.eraseOp(barrierOp);
821 return success();
822 }
823};
824
825/// Distribute a scattered store op. The offsets argument is required.
826/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
827/// The layouts are fixed and implicit: one offset/mask per lane.
828/// The pass changes the offset/mask vector shapes to a
829/// single-element vector, **it is assumed that their producer will also be
830/// distributed**. The payload vector also has a fixed distribution:
831/// no chunk size -> vector of one element.
832/// chunk size -> vector of the innermost dimension of the SG-payload.
833/// Example 1 (no chunk size):
834/// %mask = producer_op : vector<16xi1>
835/// %offset = producer_op : vector<16xindex>
836/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
837/// memref<256xf16>, vector<16xindex>, vector<16xi1>
838/// To
839/// %mask = producer_op : vector<1xi1>
840/// %offset = producer_op : vector<1xindex>
841/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
842/// memref<256xf16>, vector<1xindex>, vector<1xi1>
843/// Example 2 (chunk size, same mask and offsets):
844/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
845/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
846/// To
847/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
848/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
849struct StoreDistribution final : public gpu::WarpDistributionPattern {
850 using gpu::WarpDistributionPattern::WarpDistributionPattern;
851 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
852 PatternRewriter &rewriter) const override {
853 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
854 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
855 if (!storeScatterOp)
856 return failure();
857 auto offsets = storeScatterOp.getOffsets();
858 if (!offsets || !isa<VectorType>(offsets.getType()))
859 return rewriter.notifyMatchFailure(
860 storeScatterOp, "Store op must have a vector of offsets argument");
861 VectorType offsetsTy = cast<VectorType>(offsets.getType());
862 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
863 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
864 return rewriter.notifyMatchFailure(storeScatterOp,
865 "Expected 1D offsets and mask vector");
866 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
867 if (storeVecTy.getRank() > 2)
868 return rewriter.notifyMatchFailure(
869 storeScatterOp, "Expected at most 2D result at SG level");
870
871 std::string layoutPayloadName =
872 xegpu::getTemporaryLayoutName(storeScatterOp->getOpOperand(0));
873 std::string layoutOffsetsName =
874 xegpu::getTemporaryLayoutName(storeScatterOp->getOpOperand(2));
875 std::string layoutMaskName =
876 xegpu::getTemporaryLayoutName(storeScatterOp->getOpOperand(3));
877
878 xegpu::LayoutAttr layoutPayload =
879 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
880 xegpu::LayoutAttr layoutOffsets =
881 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
882 xegpu::LayoutAttr layoutMask =
883 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
884
885 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
886 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
887 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
888 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
889 FailureOr<VectorType> distMaskByWarpOpOrFailure =
890 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
891 if (failed(distStoreVecByWarpOpOrFailure) ||
892 failed(distOffsetsByWarpOpOrFailure) ||
893 failed(distMaskByWarpOpOrFailure)) {
894 return rewriter.notifyMatchFailure(
895 storeScatterOp,
896 "Some vector operands have no layouts, using defaults instead.");
897 }
898 // Distributed store payload type according to the lane layout.
899 VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
900 // Expected distributed payload type is always 1D.
901 VectorType expectedPayloadTy =
902 VectorType::get({distPayloadTyByWarpOp.getNumElements()},
903 distPayloadTyByWarpOp.getElementType());
904
905 SmallVector<size_t> newRetIndices;
906 SmallVector<Value> operands = storeScatterOp->getOperands();
907 SmallVector<Type> operandTypesToYield = {
908 distPayloadTyByWarpOp, operands[1].getType(),
909 distOffsetsByWarpOpOrFailure.value(),
910 distMaskByWarpOpOrFailure.value()};
911
912 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
913 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
914 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
915 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
916 // The payload operand may need type adjustment due to mismatch between warp
917 // distributed type and expected SIMT type.
918 rewriter.setInsertionPointAfter(newWarpOp);
919 newStoreScatterOpOperands[0] = resolveDistributedTy(
920 newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
921 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
922 rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
923 storeScatterOp->getAttrs());
925 rewriter.eraseOp(storeScatterOp);
926 return success();
927 }
928};
929
930static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
931 PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
932 Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
933 SmallVector<Value> newCoods;
934 auto maybeCoords =
935 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
936 if (failed(maybeCoords))
937 return {};
938 assert(maybeCoords.value().size() == 1 &&
939 "Expected one set of distributed offsets");
941 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
942 getAsOpFoldResult(origOffsets));
943 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
944 return newCoods;
945}
946
947/// Pattern for distributing xegpu::LoadMatrixOp.
948struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
949 using gpu::WarpDistributionPattern::WarpDistributionPattern;
950 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
951 PatternRewriter &rewriter) const override {
952 gpu::YieldOp yield = warpOp.getTerminator();
953 Operation *lastNode = yield->getPrevNode();
954 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
955 if (!matrixOp)
956 return failure();
957
958 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
959 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
960 });
961 if (!producedByLastLoad)
962 return rewriter.notifyMatchFailure(
963 warpOp, "The last op is not xegpu::LoadMatrixOp");
964 const int operandIdx = producedByLastLoad->getOperandNumber();
965
966 VectorType sgPayloadTy =
967 dyn_cast<VectorType>(matrixOp.getResult().getType());
968 VectorType warpResultTy =
969 cast<VectorType>(warpOp.getResult(operandIdx).getType());
970 if (!sgPayloadTy)
971 return rewriter.notifyMatchFailure(
972 matrixOp, "the matrix op payload must be a vector type");
973
974 auto loc = matrixOp.getLoc();
975 auto offsets = matrixOp.getMixedOffsets();
976 if (offsets.empty())
977 return rewriter.notifyMatchFailure(matrixOp,
978 "the load op must have offsets");
979 SmallVector<Value> offsetsAsValues =
980 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
981
982 auto layout = matrixOp.getLayoutAttr();
983 if (!layout)
984 return rewriter.notifyMatchFailure(
985 matrixOp, "the matrix operation lacks layout attribute");
986
987 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
988 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
989 if (failed(distPayloadByWarpOpOrFailure))
990 return rewriter.notifyMatchFailure(
991 matrixOp, "Failed to distribute matrix op payload based on layout.");
992
993 SmallVector<Value> operands = {matrixOp.getMemDesc()};
994 const unsigned offsetsStartIdx = operands.size();
995 operands.append(offsetsAsValues);
996
997 SmallVector<Type> operandTypes = llvm::to_vector(
998 llvm::map_range(operands, [](Value v) { return v.getType(); }));
999
1000 SmallVector<size_t> newRetIndices;
1001 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1002 rewriter, warpOp, operands, operandTypes, newRetIndices);
1003 SmallVector<Value> newOperands = llvm::map_to_vector(
1004 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1005
1006 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1007 ShapedType::kDynamic);
1008 DenseI64ArrayAttr newConstOffsetsAttr =
1009 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1010 ValueRange currentOffsets =
1011 ValueRange(newOperands).drop_front(offsetsStartIdx);
1012
1013 SmallVector<Value> newCoords = currentOffsets;
1014 rewriter.setInsertionPointAfter(newWarpOp);
1015
1016 if (!matrixOp.getSubgroupBlockIoAttr()) {
1017 newCoords = computeDistributedCoordinatesForMatrixOp(
1018 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1019 currentOffsets);
1020 }
1021 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
1022 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
1023 newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
1024 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1025 // Resolve the output type and replace all uses.
1026 rewriter.replaceAllUsesWith(
1027 newWarpOp.getResult(operandIdx),
1028 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
1029 return success();
1030 }
1031};
1032
1033/// Pattern for distributing xegpu::StoreMatrixOp.
1034struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
1035 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1036 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1037 PatternRewriter &rewriter) const override {
1038 gpu::YieldOp yield = warpOp.getTerminator();
1039 Operation *lastNode = yield->getPrevNode();
1040 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1041 if (!matrixOp)
1042 return failure();
1043
1044 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1045 if (!sgPayloadTy)
1046 return rewriter.notifyMatchFailure(
1047 matrixOp, "the matrix op payload must be a vector type");
1048
1049 auto loc = matrixOp.getLoc();
1050 auto offsets = matrixOp.getMixedOffsets();
1051 if (offsets.empty())
1052 return rewriter.notifyMatchFailure(matrixOp,
1053 "the store op must have offsets");
1054 SmallVector<Value> offsetsAsValues =
1055 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1056
1057 auto layout = matrixOp.getLayoutAttr();
1058 if (!layout)
1059 return rewriter.notifyMatchFailure(
1060 matrixOp, "the matrix operation lacks layout attribute");
1061
1062 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1063 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1064 if (failed(distPayloadByWarpOpOrFailure))
1065 return rewriter.notifyMatchFailure(
1066 matrixOp, "Failed to distribute matrix op payload based on layout.");
1067
1068 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1069 const unsigned offsetsStartIdx = operands.size();
1070 operands.append(offsetsAsValues);
1071
1072 SmallVector<Type> operandTypes = llvm::to_vector(
1073 llvm::map_range(operands, [](Value v) { return v.getType(); }));
1074 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1075
1076 SmallVector<size_t> newRetIndices;
1077 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1078 rewriter, warpOp, operands, operandTypes, newRetIndices);
1079 SmallVector<Value> newOperands = llvm::map_to_vector(
1080 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1081
1082 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1083 ShapedType::kDynamic);
1084 DenseI64ArrayAttr newConstOffsetsAttr =
1085 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1086 ValueRange currentOffsets =
1087 ValueRange(newOperands).drop_front(offsetsStartIdx);
1088
1089 SmallVector<Value> newCoords = currentOffsets;
1090 rewriter.setInsertionPointAfter(newWarpOp);
1091
1092 if (!matrixOp.getSubgroupBlockIoAttr()) {
1093 newCoords = computeDistributedCoordinatesForMatrixOp(
1094 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1095 currentOffsets);
1096 }
1097
1098 xegpu::StoreMatrixOp::create(
1099 rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1100 ValueRange(newCoords), newConstOffsetsAttr,
1101 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1102 rewriter.eraseOp(matrixOp);
1103 return success();
1104 }
1105};
1106
1107/// Distribute a scattered load op. The logic and requirements are the same as
1108/// for the scattered store distribution. The warpOp's payload vector is
1109/// expected to be distributed by the load's result consumer.
1110/// Example 1 (no chunk size):
1111/// %mask = producer_op : vector<16xi1>
1112/// %offset = producer_op : vector<16xindex>
1113/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1114/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
1115/// To
1116/// %mask = producer_op : vector<1xi1>
1117/// %offset = producer_op : vector<1xindex>
1118/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1119/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
1120/// Example 2 (chunk size, same mask and offsets):
1121/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1122/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
1123/// To
1124/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1125/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
1126struct LoadDistribution final : public gpu::WarpDistributionPattern {
1127 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1128 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1129 PatternRewriter &rewriter) const override {
1130 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1131 // Check if the yield operand that was produced by the *last* scattered
1132 // load op to avoid sinking it before barriers (maintain memory order).
1133 return isa<xegpu::LoadGatherOp>(op) &&
1134 warpOp.getTerminator()->getPrevNode() == op;
1135 });
1136 if (!producedByLastLoad)
1137 return rewriter.notifyMatchFailure(
1138 warpOp, "The last op is not xegpu::LoadGatherOp");
1139
1140 auto loadGatherOp =
1141 producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
1142 auto offsets = loadGatherOp.getOffsets();
1143 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1144 !isa<VectorType>(loadGatherOp.getMask().getType()))
1145 return rewriter.notifyMatchFailure(
1146 loadGatherOp,
1147 "Load op must have a vector arguments for offsets and mask");
1148 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1149 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1150 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
1151 return rewriter.notifyMatchFailure(loadGatherOp,
1152 "Expected 1D offsets and mask vector");
1153 // Assume offset and mask producers will be distributed as well.
1154 std::string layoutOffsetsName =
1155 xegpu::getTemporaryLayoutName(loadGatherOp->getOpOperand(1));
1156 std::string layoutMaskName =
1157 xegpu::getTemporaryLayoutName(loadGatherOp->getOpOperand(2));
1158
1159 xegpu::LayoutAttr layoutOffsets =
1160 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
1161 xegpu::LayoutAttr layoutMask =
1162 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
1163
1164 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1165 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1166 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1167 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1168 if (failed(distOffsetsByWarpOpOrFailure) ||
1169 failed(distMaskByWarpOpOrFailure)) {
1170 return rewriter.notifyMatchFailure(
1171 loadGatherOp,
1172 "Some vector operands have no layouts, using defaults instead.");
1173 }
1174
1175 SmallVector<size_t> newRetIndices;
1176 SmallVector<Value> operands = loadGatherOp->getOperands();
1177 SmallVector<Type> operandTypesToYield = {
1178 operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
1179 distMaskByWarpOpOrFailure.value()};
1180
1181 const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1182 VectorType distResultTy =
1183 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1184 // Distributed load op will always be 1D.
1185 VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
1186 distResultTy.getElementType());
1187
1188 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1189 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1190
1191 SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
1192 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1193
1194 rewriter.setInsertionPointAfter(newWarpOp);
1195 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1196 rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1197 loadGatherOp->getAttrs());
1199 Value distributedVal = newWarpOp.getResult(operandIdx);
1200 // Resolve the output type and replace all uses.
1201 rewriter.replaceAllUsesWith(
1202 distributedVal,
1203 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1204 return success();
1205 }
1206};
1207
1208/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1209/// VectorReductionOps. We also insert layouts for the newly created ops.
1210static Value lowerToVectorReductions(TypedValue<VectorType> src,
1212 vector::CombiningKind kind,
1213 int64_t reductionDim, Location loc,
1214 PatternRewriter &rewriter) {
1215 // Expecting a 2D source vector.
1216 assert(src.getType().getRank() == 2 && "expected a 2D source vector");
1217 VectorType sourceType = src.getType();
1218 int64_t sourceH = sourceType.getShape()[0];
1219 int64_t sourceW = sourceType.getShape()[1];
1220 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1221 // Create a constant vector to hold the result of the reduction.
1222 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
1223 Value reductionResult = arith::ConstantOp::create(
1224 rewriter, loc, acc.getType(),
1225 DenseElementsAttr::get(acc.getType(), zeroAttr));
1226 // Reduction result should have the same layout as the accumulator.
1227 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult),
1228 xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc)));
1229 // For each slice of the source, extract the slice vector, do a reduction
1230 // and, insert the reduced value back to the result vector.
1231 for (int i = 0; i < nSlices; ++i) {
1232 SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1233 if (reductionDim == 1) {
1234 sliceOffsets = {i, 0};
1235 sliceSizes = {1, sourceW};
1236 } else {
1237 sliceOffsets = {0, i};
1238 sliceSizes = {sourceH, 1};
1239 }
1240 vector::ExtractStridedSliceOp extractOp =
1241 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1242 sliceSizes, {1, 1});
1243
1244 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1245
1246 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1247 rewriter, loc,
1248 VectorType::get({nSliceElements}, sourceType.getElementType()),
1249 extractOp.getResult());
1250
1251 // Shape cast is currently handled in xegpu side. So layouts must be
1252 // retained during lowering. Shape cast output has the same layout as the
1253 // accumulator. Shape cast source has the same layout as the original
1254 // reduction source.
1255 // TODO: other ops generated here may also need layout attributes.
1256 auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
1257 auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
1258
1259 xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
1260 xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
1261 // Extract and reduction results in scalars, so no result layout is needed.
1262 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1263 Value reduction = vector::ReductionOp::create(
1264 rewriter, loc, kind, slice.getResult(), accExtract);
1265 reductionResult =
1266 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1267 }
1268 return reductionResult;
1269}
1270
1271/// This patterns distribute the `vector.multi_reduction` operation across
1272/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1273/// layouts for the source and accumulator vectors,
1274/// * If the reduction dimension is distributed across lanes, the reduction is
1275/// non-lane-local and the reduction is done using warp shuffles. Here we
1276/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1277/// the warp op body.
1278/// * If the reduction dimension is not distributed across lanes, the reduction
1279/// is lane-local. In this case, we yield the source and accumulator vectors
1280/// from the warp op and perform the lane-local reduction outside the warp op
1281/// using a sequence of ReductionOps.
1282/// Example 1 (Reduction is lane-local):
1283/// ```
1284/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1285/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1286/// %acc = "some_def"() : () -> (vector<32xf32>)
1287/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1288/// vector<32xf32> gpu.yield %1 : vector<32xf32>
1289/// }
1290/// ```
1291/// is lowered to:
1292/// ```
1293/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1294/// vector<1xf32>) {
1295/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1296/// %acc = "some_def"() : () -> (vector<32xf32>)
1297/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1298/// }
1299/// %c = arith.constant dense<0.0> : vector<1xf32>
1300/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1301/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1302/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1303/// ```
1304/// Example 2 (Reduction is non-lane-local):
1305/// ```
1306/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1307/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1308/// %acc = "some_def"() : () -> (vector<2xf32>)
1309/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1310/// vector<2xf32>
1311/// gpu.yield %1 : vector<2xf32>
1312/// }
1313/// ```
1314/// is lowered to:
1315/// ```
1316/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1317/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1318/// %acc = "some_def"() : () -> (vector<2xf32>)
1319/// %1 = arith.constant dense<0.0> : vector<2xf32>
1320/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1321/// %3 = ("warp.reduction %2") : f32
1322/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1323/// ... repeat for row 1
1324/// gpu.yield %1 : vector<2xf32>
1325/// }
1326struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1327 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1328 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1329 PatternRewriter &rewriter) const override {
1330 OpOperand *yieldOperand =
1331 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1332 if (!yieldOperand)
1333 return failure();
1334 auto reductionOp =
1335 cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1336 unsigned operandIdx = yieldOperand->getOperandNumber();
1337 VectorType sourceType = reductionOp.getSourceVectorType();
1338 // Only 2D vectors are supported.
1339 if (sourceType.getRank() != 2)
1340 return rewriter.notifyMatchFailure(warpOp,
1341 "Only 2D reductions are supported.");
1342 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1343 // Only 1 reduction dimension supported. This also ensures that the result
1344 // is vector type.
1345 if (reductionDims.size() != 1)
1346 return rewriter.notifyMatchFailure(
1347 warpOp, "Only 1 reduction dimension is supported.");
1348 int64_t reductionDim = reductionDims[0];
1349 VectorType distributedResultType =
1350 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1351 VectorType resultType = cast<VectorType>(reductionOp.getType());
1352 xegpu::DistributeLayoutAttr sourceLayout =
1353 xegpu::getTemporaryLayout(reductionOp->getOpOperand(0));
1354
1355 FailureOr<VectorType> sourceDistTypeOrFailure =
1356 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1357 if (failed(sourceDistTypeOrFailure))
1358 return rewriter.notifyMatchFailure(
1359 warpOp, "Failed to distribute the source vector type.");
1360 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1361 // Only single dimension distribution is supported.
1362 bool dim0Distributed =
1363 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1364 bool dim1Distributed =
1365 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1366 if (dim0Distributed && dim1Distributed)
1367 return rewriter.notifyMatchFailure(
1368 warpOp, "Expecting source to be distributed in a single dimension.");
1369 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1370 if (sourceDistDim == -1)
1371 return rewriter.notifyMatchFailure(
1372 warpOp, "Expecting a distributed source vector.");
1373 bool resultDistributed =
1374 distributedResultType.getNumElements() < resultType.getNumElements();
1375 // If the lane owns all the data required for reduction (i.e. reduction is
1376 // fully parallel accross lanes), then each lane owns part of the result
1377 // (i.e. result is distributed). If the reduction require cross-lane
1378 // shuffling, then the result is shared among all lanes (broadcasted).
1379 // Therefore we expect following cases:
1380 //
1381 // | Source vector | Reduction dim | Result vector |
1382 // |----------------------|----------------|----------------|
1383 // | dim-0 distributed | 0 | broadcasted |
1384 // | dim-0 distributed | 1 | distributed |
1385 // | dim-1 distributed | 0 | distributed |
1386 // | dim-1 distributed | 1 | broadcasted |
1387
1388 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1389 (sourceDistDim == 1 && reductionDim == 0);
1390 if (isReductionLaneLocal && !resultDistributed)
1391 return rewriter.notifyMatchFailure(
1392 warpOp, "Expecting a distributed result for lane-local reduction.");
1393
1394 if (!isReductionLaneLocal && resultDistributed)
1395 return rewriter.notifyMatchFailure(
1396 warpOp,
1397 "Expecting a broadcasted result for non-lane-local reduction.");
1398
1399 // Handle lane-local reduction case. In this case we fully distribute the
1400 // reduction result.
1401 if (isReductionLaneLocal) {
1402 // Yield the source and acc vectors from the WarpOp.
1403 SmallVector<size_t> newRetIndices;
1404 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1405 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1406 {sourceDistType, distributedResultType}, newRetIndices);
1407 rewriter.setInsertionPointAfter(newWarpOp);
1408 Value result = lowerToVectorReductions(
1409 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1410 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1411 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1412 // Replace the warp op result with the final result.
1413 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1414 return success();
1415 }
1416 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1417 // of multiple ReductionOps. Actual distribution is done by the
1418 // WarpOpReduction pattern.
1419 rewriter.setInsertionPointAfter(reductionOp);
1420 Value result = lowerToVectorReductions(
1421 cast<TypedValue<VectorType>>(reductionOp.getSource()),
1422 cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1423 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1424 // Replace the warp op result with the final result.
1425 rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1426 return success();
1427 }
1428};
1429
1430/// This pattern distributes the `vector.broadcast` operation across lanes in a
1431/// warp. The pattern supports three use cases:
1432///
1433/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1434/// vector
1435/// must have a slice layout of the result. If the distributed source and
1436/// target vector types are identical, this lowers to a no-op; otherwise, it
1437/// remains a broadcast but operates on distributed vectors.
1438///
1439/// 2) Broadcast a same-rank vector with identical layouts for source and
1440/// target:
1441/// The source vector must have unit dimensions, and lane_data must be unit
1442/// size for those unit dims. This always lowers to a no-op.
1443///
1444/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
1445/// scalar to distributed result type.
1446///
1447/// Example 1 (lowering to a broadcast with distributed types):
1448/// ```
1449/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1450/// %0 = "some_def"() {layout_result_0 =
1451/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1452/// dims = [0]> } : () -> (vector<32xf32>)
1453/// %2 = vector.broadcast %0 {layout_result_0 =
1454/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
1455/// : vector<32xf32> to vector<8x32xf32>
1456/// gpu.yield %1 : vector<8x32xf32>
1457/// }
1458/// ```
1459/// is lowered to:
1460/// ```
1461/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1462/// %0 = "some_def"() {layout_result_0 =
1463/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1464/// dims = [0]> } : () -> (vector<32xf32>)
1465/// gpu.yield %0 : vector<32xf32>
1466/// }
1467/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
1468///
1469/// Example 2 (no-op):
1470/// ```
1471/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
1472/// %0 = "some_def"() {layout_result_0 =
1473/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1474/// dims = [1]> } : () -> (vector<8xf32>)
1475/// %1 = vector.shape_cast %0
1476/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1477/// 1]>}: vector<8xf32> to vector<8x1xf32>
1478/// %2 = vector.broadcast %1
1479/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1480/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
1481/// gpu.yield %1 : vector<8x32xf32>
1482/// }
1483/// ```
1484/// is lowered to:
1485/// ```
1486/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1487/// %0 = "some_def"() {layout_result_0 =
1488/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1489/// dims = [1]> } : () -> (vector<8xf32>)
1490/// %1 = vector.shape_cast %0
1491/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1492/// 1]>}: vector<8xf32> to vector<8x1xf32>
1493/// gpu.yield %1 : vector<8x1xf32>
1494/// }
1495/// // The broadcast is implicit through layout transformation (no-op)
1496/// "some_use"(%r#0)
1497/// ```
1498struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
1499 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1500 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1501 PatternRewriter &rewriter) const override {
1502 OpOperand *yieldOperand =
1503 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1504 if (!yieldOperand)
1505 return failure();
1506 auto broadcastOp =
1507 cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
1508 unsigned operandIdx = yieldOperand->getOperandNumber();
1509
1510 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1511 VectorType destType =
1512 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1513
1514 xegpu::DistributeLayoutAttr sourceLayout =
1515 xegpu::getTemporaryLayout(broadcastOp->getOpOperand(0));
1516 xegpu::DistributeLayoutAttr resultLayout =
1517 xegpu::getTemporaryLayout(dyn_cast<OpResult>(broadcastOp.getResult()));
1518
1519 FailureOr<VectorType> sourceDistType;
1520 Type sourceElemOrDistType;
1521 if (sourceType) {
1522
1523 // Case 1 and 2: source is a vector type.
1524 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1525 if (rankDiff > 0) {
1526 // Case 1: source is lower-rank than result.
1527 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1528 if (!isSliceOf)
1529 return rewriter.notifyMatchFailure(
1530 warpOp,
1531 "Broadcast input layout must be a slice of result layout.");
1532 }
1533 // case 2: source and result have same rank
1534 if (rankDiff == 0) {
1535 SetVector<int64_t> broadcastUnitDims =
1536 broadcastOp.computeBroadcastedUnitDims();
1537 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1538 if (!isEqualTo)
1539 return rewriter.notifyMatchFailure(
1540 warpOp, "For same-rank broadcast, source must be identical to "
1541 "adjusted result layouts with unit dims.");
1542 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1543 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1544 }
1545
1546 sourceDistType =
1547 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1548 if (failed(sourceDistType)) {
1549 return rewriter.notifyMatchFailure(
1550 warpOp, "Failed to distribute the source vector type.");
1551 }
1552 sourceElemOrDistType = sourceDistType.value();
1553
1554 } else {
1555 // Case 3: source is a scalar type.
1556 if (sourceLayout) {
1557 return rewriter.notifyMatchFailure(
1558 warpOp, "Broadcast from scalar must not have a layout attribute.");
1559 }
1560 sourceElemOrDistType = broadcastOp.getSourceType();
1561 }
1562 FailureOr<VectorType> destDistType =
1563 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1564 if (failed(destDistType)) {
1565 return rewriter.notifyMatchFailure(
1566 warpOp, "Failed to distribute the dest vector type.");
1567 }
1568
1569 SmallVector<size_t> newRetIndices;
1571 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1572 newRetIndices);
1573
1574 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1575
1576 Value newBroadcast = distributedSource;
1577
1578 if (sourceElemOrDistType != destDistType.value()) {
1579 rewriter.setInsertionPointAfter(newWarpOp);
1580 newBroadcast =
1581 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1582 destDistType.value(), distributedSource);
1583 }
1584
1585 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
1586 return success();
1587 }
1588};
1589
1590/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1591/// `gpu.warp_execute_on_lane_0` region.
1592struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1593 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1594 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1595 PatternRewriter &rewriter) const override {
1596 OpOperand *yieldOperand =
1597 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1598 if (!yieldOperand)
1599 return failure();
1600 auto shapeCastOp =
1601 cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1602 unsigned operandNumber = yieldOperand->getOperandNumber();
1603 auto resultDistTy =
1604 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1605 xegpu::DistributeLayoutAttr sourceLayout =
1606 xegpu::getTemporaryLayout(shapeCastOp->getOpOperand(0));
1607 xegpu::DistributeLayoutAttr resultLayout =
1608 xegpu::getTemporaryLayout(dyn_cast<OpResult>(shapeCastOp.getResult()));
1609 if (!sourceLayout || !resultLayout)
1610 return rewriter.notifyMatchFailure(
1611 warpOp,
1612 "the source or result of shape_cast op lacks distribution layout");
1613
1614 // For rank reducing or increasing shape_cast ops, the lower rank layout
1615 // must be a slice of higher rank layout.
1616 int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1617 int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1618 if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1619 return rewriter.notifyMatchFailure(
1620 warpOp, "shape_cast is rank reducing but source layout is not a "
1621 "slice of result layout");
1622 if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1623 return rewriter.notifyMatchFailure(
1624 warpOp, "shape_cast is rank increasing but result layout is not a "
1625 "slice of source layout");
1626
1627 FailureOr<VectorType> sourceDistTypeOrFailure =
1628 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1629 shapeCastOp.getSourceVectorType());
1630 if (failed(sourceDistTypeOrFailure))
1631 return rewriter.notifyMatchFailure(
1632 warpOp, "failed to get distributed vector type for source");
1633 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1634 // Create a new warp op that yields the source of the shape_cast op.
1635 SmallVector<size_t> newRetIndices;
1637 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1638 newRetIndices);
1639 rewriter.setInsertionPointAfter(newWarpOp);
1640 Value source = newWarpOp.getResult(newRetIndices[0]);
1641 // Create a new shape_cast op outside the warp op.
1642 Value newShapeCast = vector::ShapeCastOp::create(
1643 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1644 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1645 newShapeCast);
1646 return success();
1647 }
1648};
1649
1650// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1651// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1652// advanced cases where the distributed dimension is partially extracted and
1653// currently not supported by the generic vector distribution patterns.
1654struct VectorExtractStridedSliceDistribution
1656 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1657 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1658 PatternRewriter &rewriter) const override {
1659 OpOperand *operand =
1660 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1661 if (!operand)
1662 return failure();
1663 auto extractOp =
1664 cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1665 unsigned operandIdx = operand->getOperandNumber();
1666 auto distributedType =
1667 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1668 // Find the distributed dimensions.
1669 auto extractResultType = cast<VectorType>(operand->get().getType());
1670 auto distributedDims =
1671 getDistributedDims(extractResultType, distributedType);
1672 // Collect updated source type, sizes and offsets. They may be adjusted
1673 // later if the data is distributed to lanes (as opposed to being owned by
1674 // all lanes uniformly).
1675 VectorType updatedSourceType = extractOp.getSourceVectorType();
1676 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1677 extractOp.getSizes(), [](Attribute attr) { return attr; });
1678 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1679 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1680 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1681 extractOp.getStrides(), [](Attribute attr) { return attr; });
1682 // If the provided sizes, offsets, strides are less than the rank, pad them
1683 // with full sizes, zero offsets, and unit strides. This makes it easier to
1684 // adjust them later.
1685 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1686 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1687 updatedSizes.push_back(rewriter.getI64IntegerAttr(
1688 extractOp.getSourceVectorType().getDimSize(i)));
1689 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1690 updatedStrides.push_back(
1691 rewriter.getI64IntegerAttr(1)); // stride is always 1.
1692 }
1693 // If the result is distributed, it must be distributed in exactly one
1694 // dimension. In this case, we adjust the sourceDistType, distributedSizes
1695 // and distributedOffsets accordingly.
1696 if (distributedDims.size() > 0) {
1697 if (distributedDims.size() != 1)
1698 return rewriter.notifyMatchFailure(
1699 warpOp, "Source can not be distributed in multiple dimensions.");
1700 int64_t distributedDim = distributedDims[0];
1701 int sourceDistrDimSize =
1702 extractOp.getSourceVectorType().getShape()[distributedDim];
1703 auto sourceLayout = xegpu::getTemporaryLayout(extractOp->getOpOperand(0));
1704 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1705 return rewriter.notifyMatchFailure(
1706 warpOp, "the source of extract_strided_slice op lacks distribution "
1707 "layout");
1708 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1709 // Because only single dimension distribution is supported, lane layout
1710 // size at the distributed dim must be the subgroup size.
1711 int subgroupSize = sourceLaneLayout[distributedDim];
1712 // Check if the source size in the distributed dimension is a multiple of
1713 // subgroup size.
1714 if (sourceDistrDimSize % subgroupSize != 0)
1715 return rewriter.notifyMatchFailure(
1716 warpOp,
1717 "Source size along distributed dimension is not a multiple of "
1718 "subgroup size.");
1719 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1720 // We expect lane data to be all ones in this case.
1721 if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1722 return rewriter.notifyMatchFailure(
1723 warpOp, "Expecting unit lane data in source layout");
1724 // The offsets in the distributed dimention must be a multiple of subgroup
1725 // size.
1726 int64_t distrDimOffset =
1727 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1728 if (distrDimOffset % subgroupSize != 0)
1729 return rewriter.notifyMatchFailure(
1730 warpOp, "Offset along distributed dimension "
1731 "is not a multiple of subgroup size.");
1732 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1733 sourceLayout, extractOp.getSourceVectorType())
1734 .value();
1735 // Update the distributed sizes to match the distributed type.
1736 updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1737 distributedType.getDimSize(distributedDim));
1738 // Update the distributed offsets to match round robin distribution (i.e.
1739 // each lane owns data at `subgroupSize` stride given unit lane data).
1740 updatedOffsets[distributedDim] =
1741 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1742 }
1743 // Do the distribution by yielding the source of the extract op from
1744 // the warp op and creating a new extract op outside the warp op.
1745 SmallVector<size_t> newRetIndices;
1747 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1748 newRetIndices);
1749 rewriter.setInsertionPointAfter(newWarpOp);
1750 Value source = newWarpOp.getResult(newRetIndices[0]);
1751 // Create a new extract op outside the warp op.
1752 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1753 rewriter, extractOp.getLoc(), distributedType, source,
1754 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1755 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1756 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1757 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1758 return success();
1759 }
1760};
1761
1762/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1763/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1764/// advanced cases where the distributed dimension is partially inserted and
1765/// currently not supported by the generic vector distribution patterns.
1766struct VectorInsertStridedSliceDistribution
1768 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1769 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1770 PatternRewriter &rewriter) const override {
1771 OpOperand *operand =
1772 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1773 if (!operand)
1774 return failure();
1775 unsigned int operandNumber = operand->getOperandNumber();
1776 auto insertOp =
1777 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1778 auto distributedType =
1779 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1780 // Find the distributed dimensions of the dest vector.
1781 auto insertResultType = cast<VectorType>(operand->get().getType());
1782 auto destDistributedDims =
1783 getDistributedDims(insertResultType, distributedType);
1784 // Collect updated offsets, source type and dest type. They may be adjusted
1785 // later if the data is distributed to lanes (as opposed to being owned by
1786 // all lanes uniformly).
1787 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1788 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1789 VectorType updatedSourceType = insertOp.getSourceVectorType();
1790 VectorType updatedDestType = insertOp.getDestVectorType();
1791 if (destDistributedDims.size() > 0) {
1792 // Only single dimension distribution is supported.
1793 if (destDistributedDims.size() != 1)
1794 return rewriter.notifyMatchFailure(
1795 warpOp,
1796 "Expecting source to be distributed in a single dimension.");
1797 int64_t destDistributedDim = destDistributedDims[0];
1798
1799 VectorType srcType = insertOp.getSourceVectorType();
1800 VectorType destType = insertOp.getDestVectorType();
1801 // Currently we require that both source (kD) and dest (nD) vectors are
1802 // distributed. This requires that distributedDim (d) is contained in the
1803 // last k dims of the dest vector (d >= n - k).
1804 int64_t sourceDistributedDim =
1805 destDistributedDim - (destType.getRank() - srcType.getRank());
1806 if (sourceDistributedDim < 0)
1807 return rewriter.notifyMatchFailure(
1808 insertOp,
1809 "distributed dimension must be in the last k (i.e. source "
1810 "rank) dims of dest vector");
1811 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1812 // Obtain the source and dest layouts.
1813 auto destLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(1));
1814 auto sourceLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(0));
1815 if (!destLayout || !sourceLayout ||
1816 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1817 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1818 return rewriter.notifyMatchFailure(
1819 warpOp, "the source or dest of insert_strided_slice op lacks "
1820 "distribution layout");
1821 // Because only single dimension distribution is supported, lane layout
1822 // size at the distributed dim must be the subgroup size.
1823 int subgroupSize =
1824 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1825 // We require that source and dest lane data are all ones to ensure
1826 // uniform round robin distribution.
1827 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1828 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1829 if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1830 !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1831 return rewriter.notifyMatchFailure(
1832 warpOp, "Expecting unit lane data in source and dest layouts");
1833 // Source distributed dim size must be multiples of subgroup size.
1834 if (srcDistrDimSize % subgroupSize != 0)
1835 return rewriter.notifyMatchFailure(
1836 warpOp, "Distributed dimension size in source is not a multiple of "
1837 "subgroup size.");
1838 // Offsets in the distributed dimension must be multiples of subgroup
1839 // size.
1840 int64_t destDistrDimOffset =
1841 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1842 if (destDistrDimOffset % subgroupSize != 0)
1843 return rewriter.notifyMatchFailure(
1844 warpOp,
1845 "Offset along distributed dimension in dest is not a multiple of "
1846 "subgroup size.");
1847 // Update the source and dest types based on their layouts.
1848 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1849 sourceLayout, insertOp.getSourceVectorType())
1850 .value();
1851 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1852 destLayout, insertOp.getDestVectorType())
1853 .value();
1854 // Update the distributed offsets to match round robin distribution (i.e.
1855 // each lane owns data at `subgroupSize` stride given unit lane data).
1856 updatedOffsets[destDistributedDim] =
1857 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1858 }
1859 // Do the distribution by yielding the source and dest of the insert op
1860 // from the warp op and creating a new insert op outside the warp op.
1861 SmallVector<size_t> newRetIndices;
1863 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1864 {updatedSourceType, updatedDestType}, newRetIndices);
1865 rewriter.setInsertionPointAfter(newWarpOp);
1866
1867 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1868 Value dest = newWarpOp.getResult(newRetIndices[1]);
1869 // Create a new insert op outside the warp op.
1870 Value newInsertOp = vector::InsertStridedSliceOp::create(
1871 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1872 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1873 insertOp.getStrides());
1874 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1875 newInsertOp);
1876 return success();
1877 }
1878};
1879
1880/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1881/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1882/// outside of the warp op.
1883struct MemrefExtractAlignedPointerAsIndexDistribution final
1885 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1886 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1887 PatternRewriter &rewriter) const override {
1888 OpOperand *operand = getWarpResult(
1889 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1890 if (!operand)
1891 return rewriter.notifyMatchFailure(
1892 warpOp,
1893 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1894 auto extractOp =
1895 operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1896 unsigned operandIdx = operand->getOperandNumber();
1897 SmallVector<size_t> newRetIndices;
1898 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1899 rewriter, warpOp, extractOp.getSource(),
1900 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1901 rewriter.setInsertionPointAfter(newWarpOp);
1902 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1903 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1904 newWarpOp.getResult(newRetIndices[0]));
1905 Value distributedVal = newWarpOp.getResult(operandIdx);
1906 rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
1907 return success();
1908 }
1909};
1910
1911/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1912/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1913/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1914/// created outside of the warp op with distributed source vector type (computed
1915/// using assigned layout).
1916struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1917 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1918 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1919 PatternRewriter &rewriter) const override {
1920 OpOperand *operand =
1921 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1922 if (!operand)
1923 return rewriter.notifyMatchFailure(
1924 warpOp, "warp result is not a vector::BitCast op");
1925 auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1926 unsigned operandIdx = operand->getOperandNumber();
1927 VectorType distributedSourceType =
1928 getDistVecTypeBasedOnLaneLayout(
1929 xegpu::getTemporaryLayout(bitcastOp->getOpOperand(0)),
1930 bitcastOp.getSourceVectorType())
1931 .value_or(VectorType());
1932 if (!distributedSourceType)
1933 return rewriter.notifyMatchFailure(
1934 bitcastOp, "Failed to distribute the source vector type in "
1935 "vector::BitCast op");
1936 VectorType distributedResultType =
1937 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1938 SmallVector<size_t> newRetIndices;
1939 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1940 rewriter, warpOp, bitcastOp.getSource(),
1941 TypeRange{distributedSourceType}, newRetIndices);
1942 rewriter.setInsertionPointAfter(newWarpOp);
1943 auto newBitcastOp = vector::BitCastOp::create(
1944 rewriter, newWarpOp.getLoc(), distributedResultType,
1945 newWarpOp.getResult(newRetIndices[0]));
1946 Value distributedVal = newWarpOp.getResult(operandIdx);
1947 rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1948 return success();
1949 }
1950};
1951
1952/// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1953/// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1954/// supported. In most cases, transpose is a no op because it is entirely
1955/// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1956/// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1957/// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1958/// vector::TransposeOp outside of the warp op with distributed source vector
1959/// type (computed using assigned layout).
1960struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1961 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1962 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1963 PatternRewriter &rewriter) const override {
1964 OpOperand *operand =
1965 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1966 if (!operand)
1967 return rewriter.notifyMatchFailure(
1968 warpOp, "warp result is not a vector::Transpose op");
1969 auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1970 unsigned operandIdx = operand->getOperandNumber();
1971 xegpu::DistributeLayoutAttr sourceLayout =
1972 xegpu::getTemporaryLayout(transposeOp->getOpOperand(0));
1973 xegpu::DistributeLayoutAttr resultLayout =
1974 xegpu::getTemporaryLayout(transposeOp->getOpResult(0));
1975 if (!sourceLayout || !resultLayout)
1976 return rewriter.notifyMatchFailure(
1977 transposeOp,
1978 "the source or result vector of the transpose op lacks layout "
1979 "attribute");
1980 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1981 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1982 // Only 2D transposes are supported for now.
1983 // TODO: Support nD transposes.
1984 if (sourceRank != 2 || resultRank != 2)
1985 return rewriter.notifyMatchFailure(
1986 transposeOp, "the source or result vector of the transpose op "
1987 "does not have 2D layout");
1988 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1989 // Result layout must be a transpose of source layout.
1990 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1991 return rewriter.notifyMatchFailure(
1992 transposeOp,
1993 "the source or result vector layouts must be 2D transposes of each "
1994 "other");
1995 FailureOr<VectorType> distributedSourceTypeOrFailure =
1996 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1997 transposeOp.getSourceVectorType());
1998 if (failed(distributedSourceTypeOrFailure))
1999 return rewriter.notifyMatchFailure(
2000 transposeOp, "Failed to distribute the source vector type in "
2001 "vector::Transpose op");
2002 SmallVector<size_t> newRetIndices;
2003 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2004 rewriter, warpOp, transposeOp.getVector(),
2005 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2006 rewriter.setInsertionPointAfter(newWarpOp);
2007 auto newTransposeOp = vector::TransposeOp::create(
2008 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2009 perm);
2010 Value distributedVal = newWarpOp.getResult(operandIdx);
2011 rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
2012 return success();
2013 }
2014};
2015
2016} // namespace
2017
2018namespace {
2019struct XeGPUSubgroupDistributePass final
2021 XeGPUSubgroupDistributePass> {
2022 void runOnOperation() override;
2023};
2024} // namespace
2025
2028 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2029 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2030 GpuBarrierDistribution, VectorMultiReductionDistribution,
2031 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2032 VectorBitcastDistribution, LoadMatrixDistribution,
2033 StoreMatrixDistribution,
2034 MemrefExtractAlignedPointerAsIndexDistribution>(
2035 patterns.getContext(),
2036 /*pattern benefit=*/regularPatternBenefit);
2037 // For following patterns, we need to override the regular vector distribution
2038 // patterns. Therefore, assign higher benefit.
2039 patterns
2040 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2041 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
2042 patterns.getContext(),
2043 /*pattern benefit=*/highPatternBenefit);
2044}
2045
2048 patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
2049}
2050
2051void XeGPUSubgroupDistributePass::runOnOperation() {
2052 // Step 1: Attach layouts to op operands.
2053 // TODO: Following assumptions are made:
2054 // 1) It is assumed that there are no layout conflicts.
2055 // 2) Any existing layout attributes attached to the operands are ignored.
2056 Operation *op = getOperation();
2058 signalPassFailure();
2059 return;
2060 }
2061
2062 // Step 2: Move all operations of a GPU function inside
2063 // gpu.warp_execute_on_lane_0 operation.
2064 {
2067
2068 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2069 signalPassFailure();
2070 return;
2071 }
2072 // At this point, we have moved the entire function body inside the
2073 // warpOp. Now move any scalar uniform code outside of the warpOp (like
2074 // GPU index ops, scalar constants, etc.). This will simplify the
2075 // later lowering and avoid custom patterns for these ops.
2076 getOperation()->walk([&](Operation *op) {
2077 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2078 vector::moveScalarUniformCode(warpOp);
2079 });
2080 }
2081 // Step 3: Apply subgroup to workitem distribution patterns.
2082 RewritePatternSet patterns(&getContext());
2084 // distributionFn is used by vector distribution patterns to determine the
2085 // distributed vector type for a given vector value. In XeGPU subgroup
2086 // distribution context, we compute this based on lane layout.
2087 auto distributionFn = [](Value val) {
2088 VectorType vecType = dyn_cast<VectorType>(val.getType());
2089 int64_t vecRank = vecType ? vecType.getRank() : 0;
2090 if (vecRank == 0)
2091 return AffineMap::get(val.getContext());
2092 // Get the layout of the vector type.
2093 xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
2094 // If no layout is specified, that means no distribution.
2095 if (!layout)
2096 return AffineMap::getMultiDimMapWithTargets(vecRank, {},
2097 val.getContext());
2098 // Expecting vector and layout rank to match.
2099 assert(layout.getRank() == vecRank &&
2100 "Expecting vector and layout rank to match");
2101 // A dimension is distributed only if layout suggests there are
2102 // multiple lanes assigned for this dimension and the shape can be evenly
2103 // distributed to those lanes.
2104 SmallVector<unsigned int> distributedDims;
2105 for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2106 if (v > 1 && vecType.getShape()[i] % v == 0)
2107 distributedDims.push_back(i);
2108 }
2109 return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
2110 val.getContext());
2111 };
2112 // TODO: shuffleFn is not used.
2113 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2114 int64_t warpSz) { return Value(); };
2115
2116 auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
2117 vector::CombiningKind kind, uint32_t size) {
2118 // First reduce on a single thread to get per lane reduction value.
2119 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
2120 // Parallel reduction using butterfly shuffles.
2121 for (uint64_t i = 1; i < size; i <<= 1) {
2122 Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
2123 /*width=*/size,
2124 /*mode=*/gpu::ShuffleMode::XOR)
2125 .getShuffleResult();
2126 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
2127 }
2128 return laneVal;
2129 };
2130
2131 vector::populateDistributeReduction(
2132 patterns, warpReduction,
2133 /*pattern benefit=*/regularPatternBenefit);
2134
2135 vector::populatePropagateWarpVectorDistributionPatterns(
2136 patterns, distributionFn, shuffleFn,
2137 /*pattern benefit=*/regularPatternBenefit);
2138 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2139 signalPassFailure();
2140 return;
2141 }
2142
2143 // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
2144 // due to tensor desc type mismatches created by using upstream distribution
2145 // patterns (scf.for). This cleanup should only be done if all the ops are
2146 // distributed successfully, if some ops are still not distributed and remains
2147 // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
2148 // breaking the IR.
2149 bool foundWarpOp = false;
2150 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2151 // Look for WarpOps that are not trivially dead.
2152 if (isOpTriviallyDead(warpOp))
2153 return WalkResult::advance();
2154 foundWarpOp = true;
2155 return WalkResult::interrupt();
2156 });
2157 if (foundWarpOp)
2158 return;
2159
2160 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2161 // We are only interested in UnrealizedConversionCastOps there were added
2162 // for resolving SIMT type mismatches.
2163 if (!op->getAttr(resolveSIMTTypeMismatch))
2164 return WalkResult::skip();
2165
2166 Value input = op.getOperand(0);
2167 Value output = op.getResult(0);
2168
2169 // Both input and output must have tensor descriptor types.
2170 xegpu::TensorDescType inputDescType =
2171 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
2172 xegpu::TensorDescType outputDescType =
2173 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
2174 assert(inputDescType && outputDescType &&
2175 "Unrealized conversion cast must have tensor descriptor types");
2176
2177 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
2178 // This occurs inside scf.for body to resolve the block argument type to
2179 // SIMT type.
2180 if (inputDescType.getLayout()) {
2181 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2182 if (argument) {
2183 argument.setType(output.getType());
2184 output.replaceAllUsesWith(argument);
2185 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2186 argument.getOwner()->getParentOp())) {
2187 auto result = loopOp.getTiedLoopResult(argument);
2188 result.setType(output.getType());
2189 }
2190 }
2191 }
2192
2193 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
2194 // conversions. This occurs at the yield op of scf.for body to go back
2195 // from SIMT type to original type.
2196 if (outputDescType.getLayout())
2197 output.replaceAllUsesWith(input);
2198
2199 if (op->use_empty())
2200 op->erase();
2201 return WalkResult::advance();
2202 });
2203}
return success()
b getContext())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:153
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's regio...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0
StringRef getName() const
Definition uArchBase.h:152