MLIR 22.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
19#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Operation.h"
27#include "mlir/IR/TypeRange.h"
28#include "mlir/IR/Value.h"
29#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallVector.h"
38
39namespace mlir {
40namespace xegpu {
41#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
42#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
43} // namespace xegpu
44} // namespace mlir
45
46#define DEBUG_TYPE "xegpu-subgroup-distribute"
47#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48
49using namespace mlir;
50
51static const char *const resolveSIMTTypeMismatch =
52 "resolve_simt_type_mismatch"; // Attribute name for identifying
53 // UnrelizedConversionCastOp added to resolve
54 // SIMT type mismatches.
55
56namespace {
57
58//===----------------------------------------------------------------------===//
59// SIMT Distribution Patterns
60//===----------------------------------------------------------------------===//
61
62/// In certain cases, we may need to favor XeGPU specific distribution patterns
63/// over generic vector distribution patterns. In such cases, we can assign
64/// priorities to patterns.
65static constexpr unsigned regularPatternBenefit = 1;
66static constexpr unsigned highPatternBenefit = 2;
67
68/// Helper function to get distributed vector type for a source vector type
69/// according to the lane_layout. We simply divide each dimension of tensor
70/// descriptor shape by corresponding lane_layout dimension. If
71/// array_length > 1, that is appended to the front of the ditributed shape.
72/// NOTE: This is the vector type that will be returned by the
73/// gpu.warp_execute_on_lane0 op.
74///
75/// Examples:
76/// | original vector shape | lane_layout | distributed vector shape |
77/// |-----------------------|-------------|--------------------------|
78/// | 32x16 | [1, 16] | 32x1 |
79/// | 32x16 | [2, 8] | 16x2 |
80/// | 2x32x16 | [1, 16] | 2x32x1 |
81static FailureOr<VectorType>
82getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
83 VectorType originalType) {
84 if (!layout)
85 return failure();
86 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
87 "Expecting a valid layout.");
88 SmallVector<int64_t> effectiveLaneLayout =
89 layout.getEffectiveLaneLayoutAsInt();
90 assert(static_cast<size_t>(originalType.getRank()) >=
91 effectiveLaneLayout.size() &&
92 "Rank of the original vector type should be greater or equal to the "
93 "size of the lane layout to distribute the vector type.");
94 SmallVector<int64_t> distributedShape(originalType.getShape());
95 // Only distribute the last `laneLayout.size()` dimensions. The remaining
96 // dimensions are not distributed.
97 unsigned distributionStart =
98 originalType.getRank() - effectiveLaneLayout.size();
99 for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
100 if (i < distributionStart)
101 continue;
102 // Check if the dimension can be distributed evenly.
103 if (dim % effectiveLaneLayout[i - distributionStart] != 0)
104 return failure();
105 distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
106 }
107 return VectorType::get(distributedShape, originalType.getElementType());
108}
109
110/// Helper function to resolve types if the distributed type out of
111/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
112/// Example 1:
113/// distributed type: vector<8x1xf32>
114/// expected type: vector<8xf32>
115/// resolved using,
116/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
117/// Example 2:
118/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
119/// expected type: xegpu.tensor_desc<8x16xf32>
120/// resolved using,
121/// %0 = unrealized_conversion_cast %1 :
122/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
123/// xegpu.tensor_desc<8x16xf32>
124template <typename T>
125static Value resolveDistributedTy(Value orig, T expected,
126 PatternRewriter &rewriter) {
127 // If orig and expected types are the same, return orig.
128 if (orig.getType() == expected)
129 return orig;
130 // If orig is a vector type, create a shape cast op to reconcile the types.
131 if (isa<VectorType>(orig.getType())) {
132 auto castOp =
133 vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
134 return castOp.getResult();
135 }
136 // If orig is a tensor descriptor type, create an unrealized conversion cast
137 // op to reconcile the types.
138 if (isa<xegpu::TensorDescType>(orig.getType())) {
139 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
140 expected, orig);
141 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
142 return castOp.getResult(0);
143 }
144 llvm_unreachable("Unsupported type for reconciliation");
145 return orig;
146}
147
148/// Helper function to check if the layout is packed. Layout is packed if it is
149/// 2D and lane_data[0] != 1 (data packed from col dimension).
150/// TODO: Move to target info.
151static bool requirePacked(const xegpu::LayoutAttr layout) {
152 if (!layout)
153 return false;
154 auto laneData = layout.getEffectiveLaneDataAsInt();
155 if (laneData.size() != 2)
156 return false;
157 return laneData[0] != 1;
158}
159
160/// Helper function to check if the layout requires a transpose effect.
161static bool requireTranspose(const xegpu::LayoutAttr layout,
162 const xegpu::uArch::uArch *uArch) {
163 // Return false for unsupported targets.
164 // TODO: Add more support or move to target info.
165 if (uArch->getName().equals_insensitive("pvc") &&
166 uArch->getName().equals_insensitive("bmg"))
167 return false;
168 if (!layout)
169 return false;
170 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
171 if (laneLayout.size() != 2)
172 return false;
173 return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
174}
175
176/// Given a vector type and its distributed vector type, return the list of
177/// dimensions that are distributed.
178static SmallVector<int64_t> getDistributedDims(VectorType originalType,
179 VectorType distributedType) {
180 assert(originalType.getRank() == distributedType.getRank() &&
181 "sequential and distributed vector types must have the same rank");
182 SmallVector<int64_t> distributedDims;
183 for (int64_t i = 0; i < originalType.getRank(); ++i) {
184 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
185 distributedDims.push_back(i);
186 }
187 }
188 return distributedDims;
189}
190
191/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
192/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
193/// contained within a WarpExecuteOnLane0Op.
194/// Example:
195///
196/// ```
197/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
198/// ...
199/// ...
200/// gpu.return %result: vector<8x16xf32>
201/// }
202/// ```
203/// To
204/// ```
205/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
206/// %laneid = gpu.lane_id : index
207/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
208/// ...
209/// ...
210/// gpu.yield %result: vector<8x16xf32>
211/// }
212/// return %0
213/// }
214struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
215 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
216 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
217 PatternRewriter &rewriter) const override {
218 auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
219 if (!uArch)
220 return rewriter.notifyMatchFailure(
221 gpuFuncOp, "Subgroup distribution requires target attribute attached "
222 "to set the warp size");
223 // If the function only contains a single void return, skip.
224 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
225 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
226 }))
227 return failure();
228 // If the function already moved inside a warp_execute_on_lane0, skip.
229 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
230 return isa<gpu::WarpExecuteOnLane0Op>(op);
231 }))
232 return failure();
233 // Create a new function with the same signature and same attributes.
234 SmallVector<Type> workgroupAttributionsTypes =
235 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
236 [](BlockArgument arg) { return arg.getType(); });
237 SmallVector<Type> privateAttributionsTypes =
238 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
239 [](BlockArgument arg) { return arg.getType(); });
240 auto newGpuFunc = gpu::GPUFuncOp::create(
241 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
242 gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
243 privateAttributionsTypes);
244 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
245 // Create a WarpExecuteOnLane0Op with same arguments and results as the
246 // original gpuFuncOp.
247 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
248 auto laneId = gpu::LaneIdOp::create(
249 rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
250 /** upperBound = **/ mlir::IntegerAttr());
251 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
252 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
253 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
254 uArch->getSubgroupSize(), newGpuFunc.getArguments(),
255 newGpuFunc.getArgumentTypes());
256 Block &warpBodyBlock = warpOp.getBodyRegion().front();
257 // Replace the ReturnOp of the original gpu function with a YieldOp.
258 auto origRetunOp =
259 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
260 rewriter.setInsertionPointAfter(origRetunOp);
261 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
262 origRetunOp.getOperands());
263 rewriter.eraseOp(origRetunOp);
264 // Move the original function body to the WarpExecuteOnLane0Op body.
265 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
266 warpOp.getBodyRegion().begin());
267 rewriter.eraseBlock(&warpBodyBlock);
268 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
269 rewriter.setInsertionPointAfter(warpOp);
270 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
271 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
272 return success();
273 }
274};
275
276/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
277/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
278/// still contain the original op that will not be used by the yield op (and
279/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
280/// arguments. Tensor descriptor shape is not distributed because it is a
281/// uniform value across all work items within the subgroup. However, the
282/// layout information is dropped in the new tensor descriptor type.
283///
284/// Example:
285///
286/// ```
287/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
288/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
289/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
290/// ...
291/// %td = xegpu.create_nd_tdesc %arg0
292/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
293/// vector.yield %td
294/// }
295/// ```
296/// To
297/// ```
298/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
299/// ...
300/// %dead = xegpu.create_nd_tdesc %arg0
301/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
302/// vector.yield %arg0, %dead
303/// }
304/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
305/// -> !xegpu.tensor_desc<4x8xf32>
306///
307/// ```
308struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
309 using gpu::WarpDistributionPattern::WarpDistributionPattern;
310 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
311 PatternRewriter &rewriter) const override {
312 OpOperand *operand =
313 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
314 if (!operand)
315 return rewriter.notifyMatchFailure(
316 warpOp, "warp result is not a xegpu::CreateNdDesc op");
317 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
318 unsigned operandIdx = operand->getOperandNumber();
319
320 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
321 if (!layout)
322 return rewriter.notifyMatchFailure(
323 descOp, "the tensor descriptor lacks layout attribute");
324 // CreateNdOp must not have offsets.
325 if (descOp.getMixedOffsets().size())
326 return rewriter.notifyMatchFailure(
327 descOp, "xegpu::CreateNdDescOp must not have offsets");
328
329 SmallVector<size_t> newRetIndices;
330 rewriter.setInsertionPoint(warpOp);
331 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
332 rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
333 /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
334
335 SmallVector<Value> newDescOperands = llvm::map_to_vector(
336 newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
337 rewriter.setInsertionPointAfter(newWarpOp);
338 xegpu::TensorDescType distributedTensorDescTy =
339 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
340 // does not contain layout info.
341 Value newDescOp = xegpu::CreateNdDescOp::create(
342 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
343 descOp->getAttrs());
344
345 Value distributedVal = newWarpOp.getResult(operandIdx);
346 // Resolve the distributed type to the expected type.
347 newDescOp =
348 resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
349 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
350 return success();
351 }
352};
353
354/// Distribute a store_nd op at the end of enclosing
355/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
356/// through the warp op interface they would be propagated as returned values.
357/// Source vector is distributed based on lane layout. Appropriate cast ops are
358/// inserted if the distributed types does not match expected xegpu SIMT types.
359///
360/// Example:
361///
362/// ```
363/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
364/// gpu.warp_execute_on_lane_0(%laneid) -> () {
365/// ...
366/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
367/// !xegpu.tensor_desc<4x8xf32, #layout0>
368/// }
369/// ```
370/// To
371/// ```
372/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
373/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
374/// ...
375/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
376/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
377/// }
378/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
379/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
380/// #layout0>
381/// -> !xegpu.tensor_desc<4x8xf32>
382/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
383/// !xegpu.tensor_desc<4x8xf32>
384///
385/// ```
386struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
387 using gpu::WarpDistributionPattern::WarpDistributionPattern;
388 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
389 PatternRewriter &rewriter) const override {
390 gpu::YieldOp yield = warpOp.getTerminator();
391 Operation *lastNode = yield->getPrevNode();
392 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
393 if (!storeOp)
394 return failure();
395
396 SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
397 // Expecting offsets to be present.
398 if (offsets.empty())
399 return rewriter.notifyMatchFailure(storeOp,
400 "the store op must have offsets");
401 SmallVector<Value> offsetsAsValues =
402 vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
403 SmallVector<Type> offsetTypes = llvm::to_vector(
404 llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
405 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
406 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
407 if (!layout)
408 return rewriter.notifyMatchFailure(
409 storeOp, "the source tensor descriptor lacks layout attribute");
410
411 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
412 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
413 if (failed(distributedTypeByWarpOpOrFailure))
414 return rewriter.notifyMatchFailure(storeOp,
415 "Failed to distribute the type");
416 VectorType distributedTypeByWarpOp =
417 distributedTypeByWarpOpOrFailure.value();
418
419 SmallVector<size_t> newRetIndices;
420 SmallVector<Value> newYieldedValues = {storeOp.getValue(),
421 storeOp.getTensorDesc()};
422 SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
423 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
424 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
425 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
426 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
427 // Create a new store op outside the warp op with the distributed vector
428 // type. Tensor descriptor is not distributed.
429 rewriter.setInsertionPointAfter(newWarpOp);
430 SmallVector<Value> newStoreOperands;
431
432 // For the value operand, there can be a mismatch between the vector type
433 // distributed by the warp op and (xegpu-specific) distributed type
434 // supported by the store op. Type mismatch must be resolved using
435 // appropriate cast op.
436 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
437 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
438 if (failed(storeNdDistributedValueTyOrFailure))
439 return rewriter.notifyMatchFailure(
440 storeOp, "Failed to get distributed vector type for the store op");
441 newStoreOperands.push_back(resolveDistributedTy(
442 newWarpOp.getResult(newRetIndices[0]),
443 storeNdDistributedValueTyOrFailure.value(), rewriter));
444 // For the tensor descriptor operand, the layout attribute is dropped after
445 // distribution. Types needs to be resolved in this case also.
446 xegpu::TensorDescType distributedTensorDescTy =
447 storeOp.getTensorDescType().dropLayouts();
448 newStoreOperands.push_back(
449 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
450 distributedTensorDescTy, rewriter));
451 // Collect offsets.
452 for (size_t i = 2; i < newRetIndices.size(); ++i)
453 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
454
455 auto newStoreOp =
456 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
457 newStoreOperands, storeOp->getAttrs());
458 xegpu::removeLayoutAttrs(newStoreOp);
459 rewriter.eraseOp(storeOp);
460 return success();
461 }
462};
463
464/// Distribute a load_nd op feeding into vector.yield op for the enclosing
465/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
466/// The warp op will still contain the original op that will not be used by
467/// the yield op (and should be cleaned up later). The yield op will
468/// bypass the load's arguments. Only the loaded vector is distributed
469/// according to lane layout and, tensor descriptor types is not
470/// distributed. Appropriate cast ops are inserted if the distributed types does
471/// not match expected xegpu SIMT types.
472///
473/// Example:
474///
475/// ```
476/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
477/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
478/// (vector<4x1xf32>) {
479/// ...
480/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
481/// ->
482/// vector<4x8xf32>
483/// gpu.yield %ld
484/// }
485/// ```
486/// To
487/// ```
488/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
489/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
490/// ...
491/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
492/// vector<4x8xf32> gpu.yield %dead, %arg0
493/// }
494/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
495/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
496/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
497/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
498///
499/// ```
500struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
501 using gpu::WarpDistributionPattern::WarpDistributionPattern;
502 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
503 PatternRewriter &rewriter) const override {
504 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
505 if (!isa<xegpu::LoadNdOp>(op))
506 return false;
507 // Make sure the same load op is the last operation in the warp op body.
508 // This ensure that load op is not sinked earlier violating any barrier
509 // synchronizations.
510 gpu::YieldOp yield = warpOp.getTerminator();
511 return yield->getPrevNode() == op;
512 });
513
514 if (!operand)
515 return rewriter.notifyMatchFailure(
516 warpOp, "warp result is not a xegpu::LoadNd op");
517
518 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
519 auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
520 if (!uArch)
521 return rewriter.notifyMatchFailure(
522 loadOp, "xegpu::LoadNdOp require target attribute attached to "
523 "determine transpose "
524 "requirement");
525 // Chip information is required to decide if the layout requires transpose
526 // effect.
527 // Expecting offsets to be present.
528 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
529 if (offsets.empty())
530 return rewriter.notifyMatchFailure(loadOp,
531 "the load op must have offsets");
532 SmallVector<Value> offsetsAsValues =
533 vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
534 SmallVector<Type> offsetTypes = llvm::to_vector(
535 llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
536
537 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
538 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
539 if (!layout)
540 return rewriter.notifyMatchFailure(
541 loadOp, "the source tensor descriptor lacks layout attribute");
542
543 unsigned operandIdx = operand->getOperandNumber();
544 VectorType distributedTypeByWarpOp =
545 cast<VectorType>(warpOp.getResult(operandIdx).getType());
546
547 SmallVector<size_t> newRetIndices;
548 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
549 SmallVector<Type> newYieldedTypes = {tensorDescTy};
550 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
551 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
552 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
553 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
554
555 // Create a new load op outside the warp op with the distributed vector
556 // type.
557 rewriter.setInsertionPointAfter(newWarpOp);
558 FailureOr<VectorType> loadNdDistValueTyOrFailure =
559 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
560 if (failed(loadNdDistValueTyOrFailure))
561 return rewriter.notifyMatchFailure(
562 loadOp, "Failed to get distributed vector type for the load op");
563 xegpu::TensorDescType distributedTensorDescTy =
564 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
565 // descriptor type does not
566 // contain layout info.
567 SmallVector<Value> newLoadOperands{
568 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
569 distributedTensorDescTy, rewriter)};
570 // Collect offsets.
571 for (size_t i = 1; i < newRetIndices.size(); ++i)
572 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
573 auto newLoadOp = xegpu::LoadNdOp::create(
574 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
575 newLoadOperands, loadOp->getAttrs());
576 xegpu::removeLayoutAttrs(newLoadOp);
577 // Set the packed attribute if the layout requires it.
578 newLoadOp.setPacked(requirePacked(layout));
579 // Set the transpose attribute if the layout requires it.
580 if (requireTranspose(layout, uArch))
581 newLoadOp.setTranspose(
582 DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
583 Value distributedVal = newWarpOp.getResult(operandIdx);
584 // There can be a conflict between the vector type distributed by the
585 // warp op and (xegpu-specific) distributed type supported by the load
586 // op. Resolve these mismatches by inserting a cast.
587 Value tyResolvedVal = resolveDistributedTy(
588 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
589 rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
590 return success();
591 }
592};
593
594/// Distribute a dpas op feeding into vector.yield op for the enclosing
595/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
596/// The warp op will still contain the original op that will not be used by
597/// the yield op (and should be cleaned up later). The yield op will
598/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
599/// distributed types does not match expected xegpu SIMT types.
600/// Example:
601/// ```
602/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
603/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
604/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
605/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
606/// (vector<8x1xf32>) {
607/// ...
608/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
609/// vector<8x16xf32>
610/// gpu.yield %dpas
611/// }
612/// ```
613/// To
614/// ```
615/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
616/// vector<8x1xf16>, vector<16x1xf16>) {
617/// ...
618/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
619/// -> vector<8x16xf32>
620/// gpu.yield %dead, %arg0, %arg1
621/// }
622/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
623/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
624/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
625/// vector<8xf32>
626/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
627/// ```
628struct DpasDistribution final : public gpu::WarpDistributionPattern {
629 using gpu::WarpDistributionPattern::WarpDistributionPattern;
630 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
631 PatternRewriter &rewriter) const override {
632 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
633 if (!operand)
634 return rewriter.notifyMatchFailure(warpOp,
635 "warp result is not a xegpu::Dpas op");
636
637 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
638 unsigned operandIdx = operand->getOperandNumber();
639 std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
640 std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
641 std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
642
643 xegpu::LayoutAttr layoutA =
644 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
645 xegpu::LayoutAttr layoutB =
646 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
647 xegpu::LayoutAttr layoutOut =
648 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
649 if (!layoutA || !layoutB || !layoutOut)
650 return rewriter.notifyMatchFailure(
651 dpasOp,
652 "the xegpu::Dpas op lacks layout attribute for A, B or output");
653
654 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
655 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
656 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
657 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
658 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
659 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
660 if (failed(distLhsTypeByWarpOpOrFailure) ||
661 failed(distRhsTypeByWarpOpOrFailure) ||
662 failed(distResultTypeByWarpOpOrFailure))
663 return rewriter.notifyMatchFailure(
664 dpasOp,
665 "Failed to distribute the A, B or output types in xegpu::Dpas op");
666
667 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
668 dpasOp.getRhs()};
669 llvm::SmallVector<Type, 3> newYieldTypes{
670 distLhsTypeByWarpOpOrFailure.value(),
671 distRhsTypeByWarpOpOrFailure.value()};
672 // Dpas acc operand is optional.
673 if (dpasOp.getAcc()) {
674 newYieldValues.push_back(dpasOp.getAcc());
675 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
676 }
677 // Create a new warp op without the dpas.
678 SmallVector<size_t> newRetIndices;
679 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
680 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
681
682 FailureOr<VectorType> expectedDistLhsTyOrFailure =
683 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
684 FailureOr<VectorType> expectedDistRhsTyOrFailure =
685 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
686 FailureOr<VectorType> expectedDistResultTyOrFailure =
687 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
688 if (failed(expectedDistLhsTyOrFailure) ||
689 failed(expectedDistRhsTyOrFailure) ||
690 failed(expectedDistResultTyOrFailure))
691 return rewriter.notifyMatchFailure(
692 dpasOp,
693 "Failed to get distributed vector type for the dpas operands.");
694 // Create a new dpas op outside the warp op.
695 rewriter.setInsertionPointAfter(newWarpOp);
696 SmallVector<Value> newDpasOperands;
697 SmallVector<VectorType> newDpasOperandExpectedTypes;
698
699 // Resolve the distributed types with the original types.
700 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
701 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
702 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
703 if (dpasOp.getAcc())
704 newDpasOperandExpectedTypes.push_back(distributedResultTy);
705
706 for (unsigned i = 0; i < newRetIndices.size(); i++) {
707 newDpasOperands.push_back(
708 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
709 newDpasOperandExpectedTypes[i], rewriter));
710 }
711 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
712 distributedResultTy, newDpasOperands,
713 dpasOp->getAttrs());
714 xegpu::removeLayoutAttrs(newDpasOp);
715 Value distributedVal = newWarpOp.getResult(operandIdx);
716 // Resolve the output type.
717 Value typeResolved =
718 resolveDistributedTy(newDpasOp.getResult(),
719 distResultTypeByWarpOpOrFailure.value(), rewriter);
720 rewriter.replaceAllUsesWith(distributedVal, typeResolved);
721 return success();
722 }
723};
724
725/// Distribute a prefetch_nd op at the end of enclosing
726/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
727/// through the warp op interface they would be propagated as returned values.
728/// Tensor descriptor shape is not distributed because it is a uniform value
729/// across all work items within the subgroup. Appropriate cast ops are inserted
730/// if the distributed types does not match expected xegpu SIMT types.
731///
732/// Example:
733///
734/// ```
735/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
736/// gpu.warp_execute_on_lane_0(%laneid) -> () {
737/// ...
738/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
739/// }
740/// ```
741/// To
742/// ```
743/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
744/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
745/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
746/// index
747/// }
748/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
749/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
750/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
751///
752/// ```
753struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
754 using gpu::WarpDistributionPattern::WarpDistributionPattern;
755 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
756 PatternRewriter &rewriter) const override {
757 gpu::YieldOp yield = warpOp.getTerminator();
758 Operation *lastNode = yield->getPrevNode();
759 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
760 if (!prefetchOp)
761 return failure();
762
763 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
764 // PrefetchNdOp must have offsets.
765 if (offsets.empty())
766 return rewriter.notifyMatchFailure(prefetchOp,
767 "the prefetch op must have offsets");
768 SmallVector<Value> offsetsAsValues =
769 vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
770 SmallVector<Type> offsetTypes = llvm::to_vector(
771 llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
772
773 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
774 if (!layout)
775 return rewriter.notifyMatchFailure(
776 prefetchOp, "the source tensor descriptor lacks layout attribute");
777
778 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
779 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
780 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
781 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
782 SmallVector<size_t> newRetIndices;
783 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
784 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
785 // Create a new prefetch op outside the warp op with updated tensor
786 // descriptor type. Source tensor descriptor require type resolution.
787 xegpu::TensorDescType newTensorDescTy =
788 prefetchOp.getTensorDescType().dropLayouts();
789 rewriter.setInsertionPointAfter(newWarpOp);
790 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
791 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
792 // Collect offsets.
793 for (size_t i = 1; i < newRetIndices.size(); ++i)
794 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
795 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
796 newPrefetchOperands, prefetchOp->getAttrs());
797 xegpu::removeLayoutAttrs(prefetchOp);
798 rewriter.eraseOp(prefetchOp);
799 return success();
800 }
801};
802
803/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
804/// region. This will simply move the barrier op outside of the warp op.
805struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
806 using gpu::WarpDistributionPattern::WarpDistributionPattern;
807 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
808 PatternRewriter &rewriter) const override {
809 gpu::YieldOp yield = warpOp.getTerminator();
810 Operation *lastNode = yield->getPrevNode();
811 // The last node must be a gpu::BarrierOp.
812 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
813 if (!barrierOp)
814 return failure();
815 // Move the barrier op outside of the warp op.
816 rewriter.setInsertionPointAfter(warpOp);
817 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
818 barrierOp->getResultTypes(),
819 barrierOp->getOperands(), barrierOp->getAttrs());
820 rewriter.eraseOp(barrierOp);
821 return success();
822 }
823};
824
825/// Distribute a scattered store op. The offsets argument is required.
826/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
827/// The layouts are fixed and implicit: one offset/mask per lane.
828/// The pass changes the offset/mask vector shapes to a
829/// single-element vector, **it is assumed that their producer will also be
830/// distributed**. The payload vector also has a fixed distribution:
831/// no chunk size -> vector of one element.
832/// chunk size -> vector of the innermost dimension of the SG-payload.
833/// Example 1 (no chunk size):
834/// %mask = producer_op : vector<16xi1>
835/// %offset = producer_op : vector<16xindex>
836/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
837/// memref<256xf16>, vector<16xindex>, vector<16xi1>
838/// To
839/// %mask = producer_op : vector<1xi1>
840/// %offset = producer_op : vector<1xindex>
841/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
842/// memref<256xf16>, vector<1xindex>, vector<1xi1>
843/// Example 2 (chunk size, same mask and offsets):
844/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
845/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
846/// To
847/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
848/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
849struct StoreDistribution final : public gpu::WarpDistributionPattern {
850 using gpu::WarpDistributionPattern::WarpDistributionPattern;
851 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
852 PatternRewriter &rewriter) const override {
853 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
854 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
855 if (!storeScatterOp)
856 return failure();
857 auto offsets = storeScatterOp.getOffsets();
858 if (!offsets || !isa<VectorType>(offsets.getType()))
859 return rewriter.notifyMatchFailure(
860 storeScatterOp, "Store op must have a vector of offsets argument");
861 VectorType offsetsTy = cast<VectorType>(offsets.getType());
862 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
863 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
864 return rewriter.notifyMatchFailure(storeScatterOp,
865 "Expected 1D offsets and mask vector");
866 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
867 if (storeVecTy.getRank() > 2)
868 return rewriter.notifyMatchFailure(
869 storeScatterOp, "Expected at most 2D result at SG level");
870
871 std::string layoutPayloadName =
872 xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
873 std::string layoutOffsetsName =
874 xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
875 std::string layoutMaskName =
876 xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
877
878 xegpu::LayoutAttr layoutPayload =
879 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
880 xegpu::LayoutAttr layoutOffsets =
881 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
882 xegpu::LayoutAttr layoutMask =
883 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
884
885 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
886 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
887 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
888 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
889 FailureOr<VectorType> distMaskByWarpOpOrFailure =
890 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
891 if (failed(distStoreVecByWarpOpOrFailure) ||
892 failed(distOffsetsByWarpOpOrFailure) ||
893 failed(distMaskByWarpOpOrFailure)) {
894 return rewriter.notifyMatchFailure(
895 storeScatterOp,
896 "Some vector operands have no layouts, using defaults instead.");
897 }
898 // Distributed store payload type according to the lane layout.
899 VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
900 // Expected distributed payload type is always 1D.
901 VectorType expectedPayloadTy =
902 VectorType::get({distPayloadTyByWarpOp.getNumElements()},
903 distPayloadTyByWarpOp.getElementType());
904
905 SmallVector<size_t> newRetIndices;
906 SmallVector<Value> operands = storeScatterOp->getOperands();
907 SmallVector<Type> operandTypesToYield = {
908 distPayloadTyByWarpOp, operands[1].getType(),
909 distOffsetsByWarpOpOrFailure.value(),
910 distMaskByWarpOpOrFailure.value()};
911
912 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
913 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
914 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
915 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
916 // The payload operand may need type adjustment due to mismatch between warp
917 // distributed type and expected SIMT type.
918 rewriter.setInsertionPointAfter(newWarpOp);
919 newStoreScatterOpOperands[0] = resolveDistributedTy(
920 newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
921 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
922 rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
923 storeScatterOp->getAttrs());
925 rewriter.eraseOp(storeScatterOp);
926 return success();
927 }
928};
929
930static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
931 PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
932 Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
933 SmallVector<Value> newCoods;
934 auto maybeCoords =
935 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
936 if (failed(maybeCoords))
937 return {};
938 assert(maybeCoords.value().size() == 1 &&
939 "Expected one set of distributed offsets");
941 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
942 getAsOpFoldResult(origOffsets));
943 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
944 return newCoods;
945}
946
947/// Pattern for distributing xegpu::LoadMatrixOp.
948struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
949 using gpu::WarpDistributionPattern::WarpDistributionPattern;
950 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
951 PatternRewriter &rewriter) const override {
952 gpu::YieldOp yield = warpOp.getTerminator();
953 Operation *lastNode = yield->getPrevNode();
954 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
955 if (!matrixOp)
956 return failure();
957
958 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
959 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
960 });
961 if (!producedByLastLoad)
962 return rewriter.notifyMatchFailure(
963 warpOp, "The last op is not xegpu::LoadMatrixOp");
964 const int operandIdx = producedByLastLoad->getOperandNumber();
965
966 VectorType sgPayloadTy =
967 dyn_cast<VectorType>(matrixOp.getResult().getType());
968 VectorType warpResultTy =
969 cast<VectorType>(warpOp.getResult(operandIdx).getType());
970 if (!sgPayloadTy)
971 return rewriter.notifyMatchFailure(
972 matrixOp, "the matrix op payload must be a vector type");
973
974 auto loc = matrixOp.getLoc();
975 auto offsets = matrixOp.getMixedOffsets();
976 if (offsets.empty())
977 return rewriter.notifyMatchFailure(matrixOp,
978 "the load op must have offsets");
979 SmallVector<Value> offsetsAsValues =
980 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
981
982 auto layout = matrixOp.getLayoutAttr();
983 if (!layout)
984 return rewriter.notifyMatchFailure(
985 matrixOp, "the matrix operation lacks layout attribute");
986
987 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
988 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
989 if (failed(distPayloadByWarpOpOrFailure))
990 return rewriter.notifyMatchFailure(
991 matrixOp, "Failed to distribute matrix op payload based on layout.");
992
993 SmallVector<Value> operands = {matrixOp.getMemDesc()};
994 const unsigned offsetsStartIdx = operands.size();
995 operands.append(offsetsAsValues);
996
997 SmallVector<Type> operandTypes = llvm::to_vector(
998 llvm::map_range(operands, [](Value v) { return v.getType(); }));
999
1000 SmallVector<size_t> newRetIndices;
1001 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1002 rewriter, warpOp, operands, operandTypes, newRetIndices);
1003 SmallVector<Value> newOperands = llvm::map_to_vector(
1004 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1005
1006 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1007 ShapedType::kDynamic);
1008 DenseI64ArrayAttr newConstOffsetsAttr =
1009 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1010 ValueRange currentOffsets =
1011 ValueRange(newOperands).drop_front(offsetsStartIdx);
1012
1013 SmallVector<Value> newCoords = currentOffsets;
1014 rewriter.setInsertionPointAfter(newWarpOp);
1015
1016 if (!matrixOp.getSubgroupBlockIoAttr()) {
1017 newCoords = computeDistributedCoordinatesForMatrixOp(
1018 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1019 currentOffsets);
1020 }
1021 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
1022 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
1023 newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
1024 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1025 // Resolve the output type and replace all uses.
1026 rewriter.replaceAllUsesWith(
1027 newWarpOp.getResult(operandIdx),
1028 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
1029 return success();
1030 }
1031};
1032
1033/// Pattern for distributing xegpu::StoreMatrixOp.
1034struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
1035 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1036 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1037 PatternRewriter &rewriter) const override {
1038 gpu::YieldOp yield = warpOp.getTerminator();
1039 Operation *lastNode = yield->getPrevNode();
1040 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1041 if (!matrixOp)
1042 return failure();
1043
1044 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1045 if (!sgPayloadTy)
1046 return rewriter.notifyMatchFailure(
1047 matrixOp, "the matrix op payload must be a vector type");
1048
1049 auto loc = matrixOp.getLoc();
1050 auto offsets = matrixOp.getMixedOffsets();
1051 if (offsets.empty())
1052 return rewriter.notifyMatchFailure(matrixOp,
1053 "the store op must have offsets");
1054 SmallVector<Value> offsetsAsValues =
1055 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1056
1057 auto layout = matrixOp.getLayoutAttr();
1058 if (!layout)
1059 return rewriter.notifyMatchFailure(
1060 matrixOp, "the matrix operation lacks layout attribute");
1061
1062 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1063 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1064 if (failed(distPayloadByWarpOpOrFailure))
1065 return rewriter.notifyMatchFailure(
1066 matrixOp, "Failed to distribute matrix op payload based on layout.");
1067
1068 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1069 const unsigned offsetsStartIdx = operands.size();
1070 operands.append(offsetsAsValues);
1071
1072 SmallVector<Type> operandTypes = llvm::to_vector(
1073 llvm::map_range(operands, [](Value v) { return v.getType(); }));
1074 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1075
1076 SmallVector<size_t> newRetIndices;
1077 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1078 rewriter, warpOp, operands, operandTypes, newRetIndices);
1079 SmallVector<Value> newOperands = llvm::map_to_vector(
1080 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1081
1082 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1083 ShapedType::kDynamic);
1084 DenseI64ArrayAttr newConstOffsetsAttr =
1085 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1086 ValueRange currentOffsets =
1087 ValueRange(newOperands).drop_front(offsetsStartIdx);
1088
1089 SmallVector<Value> newCoords = currentOffsets;
1090 rewriter.setInsertionPointAfter(newWarpOp);
1091
1092 if (!matrixOp.getSubgroupBlockIoAttr()) {
1093 newCoords = computeDistributedCoordinatesForMatrixOp(
1094 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1095 currentOffsets);
1096 }
1097
1098 xegpu::StoreMatrixOp::create(
1099 rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1100 ValueRange(newCoords), newConstOffsetsAttr,
1101 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1102 rewriter.eraseOp(matrixOp);
1103 return success();
1104 }
1105};
1106
1107/// Distribute a scattered load op. The logic and requirements are the same as
1108/// for the scattered store distribution. The warpOp's payload vector is
1109/// expected to be distributed by the load's result consumer.
1110/// Example 1 (no chunk size):
1111/// %mask = producer_op : vector<16xi1>
1112/// %offset = producer_op : vector<16xindex>
1113/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1114/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
1115/// To
1116/// %mask = producer_op : vector<1xi1>
1117/// %offset = producer_op : vector<1xindex>
1118/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1119/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
1120/// Example 2 (chunk size, same mask and offsets):
1121/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1122/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
1123/// To
1124/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1125/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
1126struct LoadDistribution final : public gpu::WarpDistributionPattern {
1127 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1128 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1129 PatternRewriter &rewriter) const override {
1130 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1131 // Check if the yield operand that was produced by the *last* scattered
1132 // load op to avoid sinking it before barriers (maintain memory order).
1133 return isa<xegpu::LoadGatherOp>(op) &&
1134 warpOp.getTerminator()->getPrevNode() == op;
1135 });
1136 if (!producedByLastLoad)
1137 return rewriter.notifyMatchFailure(
1138 warpOp, "The last op is not xegpu::LoadGatherOp");
1139
1140 auto loadGatherOp =
1141 producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
1142 auto offsets = loadGatherOp.getOffsets();
1143 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1144 !isa<VectorType>(loadGatherOp.getMask().getType()))
1145 return rewriter.notifyMatchFailure(
1146 loadGatherOp,
1147 "Load op must have a vector arguments for offsets and mask");
1148 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1149 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1150 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
1151 return rewriter.notifyMatchFailure(loadGatherOp,
1152 "Expected 1D offsets and mask vector");
1153 // Assume offset and mask producers will be distributed as well.
1154 std::string layoutOffsetsName =
1155 xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
1156 std::string layoutMaskName =
1157 xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
1158
1159 xegpu::LayoutAttr layoutOffsets =
1160 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
1161 xegpu::LayoutAttr layoutMask =
1162 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
1163
1164 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1165 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1166 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1167 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1168 if (failed(distOffsetsByWarpOpOrFailure) ||
1169 failed(distMaskByWarpOpOrFailure)) {
1170 return rewriter.notifyMatchFailure(
1171 loadGatherOp,
1172 "Some vector operands have no layouts, using defaults instead.");
1173 }
1174
1175 SmallVector<size_t> newRetIndices;
1176 SmallVector<Value> operands = loadGatherOp->getOperands();
1177 SmallVector<Type> operandTypesToYield = {
1178 operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
1179 distMaskByWarpOpOrFailure.value()};
1180
1181 const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1182 VectorType distResultTy =
1183 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1184 // Distributed load op will always be 1D.
1185 VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
1186 distResultTy.getElementType());
1187
1188 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1189 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1190
1191 SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
1192 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1193
1194 rewriter.setInsertionPointAfter(newWarpOp);
1195 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1196 rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1197 loadGatherOp->getAttrs());
1199 Value distributedVal = newWarpOp.getResult(operandIdx);
1200 // Resolve the output type and replace all uses.
1201 rewriter.replaceAllUsesWith(
1202 distributedVal,
1203 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1204 return success();
1205 }
1206};
1207
1208/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1209/// VectorReductionOps. We also insert layouts for the newly created ops.
1210static Value lowerToVectorReductions(TypedValue<VectorType> src,
1212 vector::CombiningKind kind,
1213 int64_t reductionDim, Location loc,
1214 PatternRewriter &rewriter) {
1215 // Expecting a 2D source vector.
1216 assert(src.getType().getRank() == 2 && "expected a 2D source vector");
1217 VectorType sourceType = src.getType();
1218 int64_t sourceH = sourceType.getShape()[0];
1219 int64_t sourceW = sourceType.getShape()[1];
1220 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1221 // Create a constant vector to hold the result of the reduction.
1222 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
1223 Value reductionResult = arith::ConstantOp::create(
1224 rewriter, loc, acc.getType(),
1225 DenseElementsAttr::get(acc.getType(), zeroAttr));
1226 // Reduction result should have the same layout as the accumulator.
1227 xegpu::setDistributeLayoutAttr(cast<OpResult>(reductionResult),
1229 // For each slice of the source, extract the slice vector, do a reduction
1230 // and, insert the reduced value back to the result vector.
1231 for (int i = 0; i < nSlices; ++i) {
1232 SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1233 if (reductionDim == 1) {
1234 sliceOffsets = {i, 0};
1235 sliceSizes = {1, sourceW};
1236 } else {
1237 sliceOffsets = {0, i};
1238 sliceSizes = {sourceH, 1};
1239 }
1240 vector::ExtractStridedSliceOp extractOp =
1241 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1242 sliceSizes, {1, 1});
1243 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1244 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1245 rewriter, loc,
1246 VectorType::get({nSliceElements}, sourceType.getElementType()),
1247 extractOp.getResult());
1248 // Shape cast is currently handled in xegpu side. So layouts must be
1249 // retained during lowering. Shape cast output has the same layout as the
1250 // accumulator. Shape cast source has the same layout as the original
1251 // reduction source.
1252 // TODO: other ops generated here may also need layout attributes.
1253 xegpu::setDistributeLayoutAttr(slice->getOpOperand(0),
1255 xegpu::setDistributeLayoutAttr(slice->getOpResult(0),
1257 // Extract and reduction results in scalars, so no result layout is needed.
1258 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1259 Value reduction = vector::ReductionOp::create(
1260 rewriter, loc, kind, slice.getResult(), accExtract);
1261 reductionResult =
1262 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1263 }
1264 return reductionResult;
1265}
1266
1267/// This patterns distribute the `vector.multi_reduction` operation across
1268/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1269/// layouts for the source and accumulator vectors,
1270/// * If the reduction dimension is distributed across lanes, the reduction is
1271/// non-lane-local and the reduction is done using warp shuffles. Here we
1272/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1273/// the warp op body.
1274/// * If the reduction dimension is not distributed across lanes, the reduction
1275/// is lane-local. In this case, we yield the source and accumulator vectors
1276/// from the warp op and perform the lane-local reduction outside the warp op
1277/// using a sequence of ReductionOps.
1278/// Example 1 (Reduction is lane-local):
1279/// ```
1280/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1281/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1282/// %acc = "some_def"() : () -> (vector<32xf32>)
1283/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1284/// vector<32xf32> gpu.yield %1 : vector<32xf32>
1285/// }
1286/// ```
1287/// is lowered to:
1288/// ```
1289/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1290/// vector<1xf32>) {
1291/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1292/// %acc = "some_def"() : () -> (vector<32xf32>)
1293/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1294/// }
1295/// %c = arith.constant dense<0.0> : vector<1xf32>
1296/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1297/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1298/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1299/// ```
1300/// Example 2 (Reduction is non-lane-local):
1301/// ```
1302/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1303/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1304/// %acc = "some_def"() : () -> (vector<2xf32>)
1305/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1306/// vector<2xf32>
1307/// gpu.yield %1 : vector<2xf32>
1308/// }
1309/// ```
1310/// is lowered to:
1311/// ```
1312/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1313/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1314/// %acc = "some_def"() : () -> (vector<2xf32>)
1315/// %1 = arith.constant dense<0.0> : vector<2xf32>
1316/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1317/// %3 = ("warp.reduction %2") : f32
1318/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1319/// ... repeat for row 1
1320/// gpu.yield %1 : vector<2xf32>
1321/// }
1322struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1323 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1324 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1325 PatternRewriter &rewriter) const override {
1326 OpOperand *yieldOperand =
1327 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1328 if (!yieldOperand)
1329 return failure();
1330 auto reductionOp =
1331 cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1332 unsigned operandIdx = yieldOperand->getOperandNumber();
1333 VectorType sourceType = reductionOp.getSourceVectorType();
1334 // Only 2D vectors are supported.
1335 if (sourceType.getRank() != 2)
1336 return rewriter.notifyMatchFailure(warpOp,
1337 "Only 2D reductions are supported.");
1338 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1339 // Only 1 reduction dimension supported. This also ensures that the result
1340 // is vector type.
1341 if (reductionDims.size() != 1)
1342 return rewriter.notifyMatchFailure(
1343 warpOp, "Only 1 reduction dimension is supported.");
1344 int64_t reductionDim = reductionDims[0];
1345 VectorType distributedResultType =
1346 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1347 VectorType resultType = cast<VectorType>(reductionOp.getType());
1348 xegpu::DistributeLayoutAttr sourceLayout =
1349 xegpu::getDistributeLayoutAttr(reductionOp.getSource());
1350
1351 FailureOr<VectorType> sourceDistTypeOrFailure =
1352 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1353 if (failed(sourceDistTypeOrFailure))
1354 return rewriter.notifyMatchFailure(
1355 warpOp, "Failed to distribute the source vector type.");
1356 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1357 // Only single dimension distribution is supported.
1358 bool dim0Distributed =
1359 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1360 bool dim1Distributed =
1361 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1362 if (dim0Distributed && dim1Distributed)
1363 return rewriter.notifyMatchFailure(
1364 warpOp, "Expecting source to be distributed in a single dimension.");
1365 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1366 if (sourceDistDim == -1)
1367 return rewriter.notifyMatchFailure(
1368 warpOp, "Expecting a distributed source vector.");
1369 bool resultDistributed =
1370 distributedResultType.getNumElements() < resultType.getNumElements();
1371 // If the lane owns all the data required for reduction (i.e. reduction is
1372 // fully parallel accross lanes), then each lane owns part of the result
1373 // (i.e. result is distributed). If the reduction require cross-lane
1374 // shuffling, then the result is shared among all lanes (broadcasted).
1375 // Therefore we expect following cases:
1376 //
1377 // | Source vector | Reduction dim | Result vector |
1378 // |----------------------|----------------|----------------|
1379 // | dim-0 distributed | 0 | broadcasted |
1380 // | dim-0 distributed | 1 | distributed |
1381 // | dim-1 distributed | 0 | distributed |
1382 // | dim-1 distributed | 1 | broadcasted |
1383
1384 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1385 (sourceDistDim == 1 && reductionDim == 0);
1386 if (isReductionLaneLocal && !resultDistributed)
1387 return rewriter.notifyMatchFailure(
1388 warpOp, "Expecting a distributed result for lane-local reduction.");
1389
1390 if (!isReductionLaneLocal && resultDistributed)
1391 return rewriter.notifyMatchFailure(
1392 warpOp,
1393 "Expecting a broadcasted result for non-lane-local reduction.");
1394
1395 // Handle lane-local reduction case. In this case we fully distribute the
1396 // reduction result.
1397 if (isReductionLaneLocal) {
1398 // Yield the source and acc vectors from the WarpOp.
1399 SmallVector<size_t> newRetIndices;
1400 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1401 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1402 {sourceDistType, distributedResultType}, newRetIndices);
1403 rewriter.setInsertionPointAfter(newWarpOp);
1404 Value result = lowerToVectorReductions(
1405 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1406 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1407 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1408 // Replace the warp op result with the final result.
1409 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1410 return success();
1411 }
1412 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1413 // of multiple ReductionOps. Actual distribution is done by the
1414 // WarpOpReduction pattern.
1415 rewriter.setInsertionPointAfter(reductionOp);
1416 Value result = lowerToVectorReductions(
1417 cast<TypedValue<VectorType>>(reductionOp.getSource()),
1418 cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1419 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1420 // Replace the warp op result with the final result.
1421 rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1422 return success();
1423 }
1424};
1425
1426/// This pattern distributes the `vector.broadcast` operation across lanes in a
1427/// warp. The pattern supports three use cases:
1428///
1429/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1430/// vector
1431/// must have a slice layout of the result. If the distributed source and
1432/// target vector types are identical, this lowers to a no-op; otherwise, it
1433/// remains a broadcast but operates on distributed vectors.
1434///
1435/// 2) Broadcast a same-rank vector with identical layouts for source and
1436/// target:
1437/// The source vector must have unit dimensions, and lane_data must be unit
1438/// size for those unit dims. This always lowers to a no-op.
1439///
1440/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
1441/// scalar to distributed result type.
1442///
1443/// Example 1 (lowering to a broadcast with distributed types):
1444/// ```
1445/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1446/// %0 = "some_def"() {layout_result_0 =
1447/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1448/// dims = [0]> } : () -> (vector<32xf32>)
1449/// %2 = vector.broadcast %0 {layout_result_0 =
1450/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
1451/// : vector<32xf32> to vector<8x32xf32>
1452/// gpu.yield %1 : vector<8x32xf32>
1453/// }
1454/// ```
1455/// is lowered to:
1456/// ```
1457/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1458/// %0 = "some_def"() {layout_result_0 =
1459/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1460/// dims = [0]> } : () -> (vector<32xf32>)
1461/// gpu.yield %0 : vector<32xf32>
1462/// }
1463/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
1464///
1465/// Example 2 (no-op):
1466/// ```
1467/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
1468/// %0 = "some_def"() {layout_result_0 =
1469/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1470/// dims = [1]> } : () -> (vector<8xf32>)
1471/// %1 = vector.shape_cast %0
1472/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1473/// 1]>}: vector<8xf32> to vector<8x1xf32>
1474/// %2 = vector.broadcast %1
1475/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1476/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
1477/// gpu.yield %1 : vector<8x32xf32>
1478/// }
1479/// ```
1480/// is lowered to:
1481/// ```
1482/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1483/// %0 = "some_def"() {layout_result_0 =
1484/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1485/// dims = [1]> } : () -> (vector<8xf32>)
1486/// %1 = vector.shape_cast %0
1487/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1488/// 1]>}: vector<8xf32> to vector<8x1xf32>
1489/// gpu.yield %1 : vector<8x1xf32>
1490/// }
1491/// // The broadcast is implicit through layout transformation (no-op)
1492/// "some_use"(%r#0)
1493/// ```
1494struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
1495 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1496 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1497 PatternRewriter &rewriter) const override {
1498 OpOperand *yieldOperand =
1499 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1500 if (!yieldOperand)
1501 return failure();
1502 auto broadcastOp =
1503 cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
1504 unsigned operandIdx = yieldOperand->getOperandNumber();
1505
1506 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1507 VectorType destType =
1508 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1509
1510 xegpu::DistributeLayoutAttr sourceLayout =
1511 xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0));
1512 xegpu::DistributeLayoutAttr resultLayout =
1513 xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
1514
1515 FailureOr<VectorType> sourceDistType;
1516 Type sourceElemOrDistType;
1517 if (sourceType) {
1518
1519 // Case 1 and 2: source is a vector type.
1520 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1521 if (rankDiff > 0) {
1522 // Case 1: source is lower-rank than result.
1523 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1524 if (!isSliceOf)
1525 return rewriter.notifyMatchFailure(
1526 warpOp,
1527 "Broadcast input layout must be a slice of result layout.");
1528 }
1529 // case 2: source and result have same rank
1530 if (rankDiff == 0) {
1531 SetVector<int64_t> broadcastUnitDims =
1532 broadcastOp.computeBroadcastedUnitDims();
1533 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1534 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1535 if (!isEqualTo)
1536 return rewriter.notifyMatchFailure(
1537 warpOp, "For same-rank broadcast, source must be identical to "
1538 "adjusted result layouts with unit dims.");
1539 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1540 }
1541
1542 sourceDistType =
1543 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1544 if (failed(sourceDistType)) {
1545 return rewriter.notifyMatchFailure(
1546 warpOp, "Failed to distribute the source vector type.");
1547 }
1548 sourceElemOrDistType = sourceDistType.value();
1549
1550 } else {
1551 // Case 3: source is a scalar type.
1552 if (sourceLayout) {
1553 return rewriter.notifyMatchFailure(
1554 warpOp, "Broadcast from scalar must not have a layout attribute.");
1555 }
1556 sourceElemOrDistType = broadcastOp.getSourceType();
1557 }
1558 FailureOr<VectorType> destDistType =
1559 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1560 if (failed(destDistType)) {
1561 return rewriter.notifyMatchFailure(
1562 warpOp, "Failed to distribute the dest vector type.");
1563 }
1564
1565 SmallVector<size_t> newRetIndices;
1567 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1568 newRetIndices);
1569
1570 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1571
1572 Value newBroadcast = distributedSource;
1573
1574 if (sourceElemOrDistType != destDistType.value()) {
1575 rewriter.setInsertionPointAfter(newWarpOp);
1576 newBroadcast =
1577 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1578 destDistType.value(), distributedSource);
1579 }
1580
1581 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
1582 return success();
1583 }
1584};
1585
1586/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1587/// `gpu.warp_execute_on_lane_0` region.
1588struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1589 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1590 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1591 PatternRewriter &rewriter) const override {
1592 OpOperand *yieldOperand =
1593 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1594 if (!yieldOperand)
1595 return failure();
1596 auto shapeCastOp =
1597 cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1598 unsigned operandNumber = yieldOperand->getOperandNumber();
1599 auto resultDistTy =
1600 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1601 xegpu::DistributeLayoutAttr sourceLayout =
1602 xegpu::getDistributeLayoutAttr(shapeCastOp->getOpOperand(0));
1603 xegpu::DistributeLayoutAttr resultLayout =
1604 xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
1605 if (!sourceLayout || !resultLayout)
1606 return rewriter.notifyMatchFailure(
1607 warpOp,
1608 "the source or result of shape_cast op lacks distribution layout");
1609
1610 // For rank reducing or increasing shape_cast ops, the lower rank layout
1611 // must be a slice of higher rank layout.
1612 int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1613 int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1614 if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1615 return rewriter.notifyMatchFailure(
1616 warpOp, "shape_cast is rank reducing but source layout is not a "
1617 "slice of result layout");
1618 if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1619 return rewriter.notifyMatchFailure(
1620 warpOp, "shape_cast is rank increasing but result layout is not a "
1621 "slice of source layout");
1622
1623 FailureOr<VectorType> sourceDistTypeOrFailure =
1624 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1625 shapeCastOp.getSourceVectorType());
1626 if (failed(sourceDistTypeOrFailure))
1627 return rewriter.notifyMatchFailure(
1628 warpOp, "failed to get distributed vector type for source");
1629 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1630 // Create a new warp op that yields the source of the shape_cast op.
1631 SmallVector<size_t> newRetIndices;
1633 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1634 newRetIndices);
1635 rewriter.setInsertionPointAfter(newWarpOp);
1636 Value source = newWarpOp.getResult(newRetIndices[0]);
1637 // Create a new shape_cast op outside the warp op.
1638 Value newShapeCast = vector::ShapeCastOp::create(
1639 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1640 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1641 newShapeCast);
1642 return success();
1643 }
1644};
1645
1646// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1647// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1648// advanced cases where the distributed dimension is partially extracted and
1649// currently not supported by the generic vector distribution patterns.
1650struct VectorExtractStridedSliceDistribution
1652 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1653 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1654 PatternRewriter &rewriter) const override {
1655 OpOperand *operand =
1656 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1657 if (!operand)
1658 return failure();
1659 auto extractOp =
1660 cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1661 unsigned operandIdx = operand->getOperandNumber();
1662 auto distributedType =
1663 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1664 // Find the distributed dimensions.
1665 auto extractResultType = cast<VectorType>(operand->get().getType());
1666 auto distributedDims =
1667 getDistributedDims(extractResultType, distributedType);
1668 // Collect updated source type, sizes and offsets. They may be adjusted
1669 // later if the data is distributed to lanes (as opposed to being owned by
1670 // all lanes uniformly).
1671 VectorType updatedSourceType = extractOp.getSourceVectorType();
1672 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1673 extractOp.getSizes(), [](Attribute attr) { return attr; });
1674 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1675 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1676 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1677 extractOp.getStrides(), [](Attribute attr) { return attr; });
1678 // If the provided sizes, offsets, strides are less than the rank, pad them
1679 // with full sizes, zero offsets, and unit strides. This makes it easier to
1680 // adjust them later.
1681 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1682 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1683 updatedSizes.push_back(rewriter.getI64IntegerAttr(
1684 extractOp.getSourceVectorType().getDimSize(i)));
1685 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1686 updatedStrides.push_back(
1687 rewriter.getI64IntegerAttr(1)); // stride is always 1.
1688 }
1689 // If the result is distributed, it must be distributed in exactly one
1690 // dimension. In this case, we adjust the sourceDistType, distributedSizes
1691 // and distributedOffsets accordingly.
1692 if (distributedDims.size() > 0) {
1693 if (distributedDims.size() != 1)
1694 return rewriter.notifyMatchFailure(
1695 warpOp, "Source can not be distributed in multiple dimensions.");
1696 int64_t distributedDim = distributedDims[0];
1697 int sourceDistrDimSize =
1698 extractOp.getSourceVectorType().getShape()[distributedDim];
1699 auto sourceLayout =
1700 xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
1701 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1702 return rewriter.notifyMatchFailure(
1703 warpOp, "the source of extract_strided_slice op lacks distribution "
1704 "layout");
1705 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1706 // Because only single dimension distribution is supported, lane layout
1707 // size at the distributed dim must be the subgroup size.
1708 int subgroupSize = sourceLaneLayout[distributedDim];
1709 // Check if the source size in the distributed dimension is a multiple of
1710 // subgroup size.
1711 if (sourceDistrDimSize % subgroupSize != 0)
1712 return rewriter.notifyMatchFailure(
1713 warpOp,
1714 "Source size along distributed dimension is not a multiple of "
1715 "subgroup size.");
1716 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1717 // We expect lane data to be all ones in this case.
1718 if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1719 return rewriter.notifyMatchFailure(
1720 warpOp, "Expecting unit lane data in source layout");
1721 // The offsets in the distributed dimention must be a multiple of subgroup
1722 // size.
1723 int64_t distrDimOffset =
1724 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1725 if (distrDimOffset % subgroupSize != 0)
1726 return rewriter.notifyMatchFailure(
1727 warpOp, "Offset along distributed dimension "
1728 "is not a multiple of subgroup size.");
1729 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1730 sourceLayout, extractOp.getSourceVectorType())
1731 .value();
1732 // Update the distributed sizes to match the distributed type.
1733 updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1734 distributedType.getDimSize(distributedDim));
1735 // Update the distributed offsets to match round robin distribution (i.e.
1736 // each lane owns data at `subgroupSize` stride given unit lane data).
1737 updatedOffsets[distributedDim] =
1738 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1739 }
1740 // Do the distribution by yielding the source of the extract op from
1741 // the warp op and creating a new extract op outside the warp op.
1742 SmallVector<size_t> newRetIndices;
1744 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1745 newRetIndices);
1746 rewriter.setInsertionPointAfter(newWarpOp);
1747 Value source = newWarpOp.getResult(newRetIndices[0]);
1748 // Create a new extract op outside the warp op.
1749 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1750 rewriter, extractOp.getLoc(), distributedType, source,
1751 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1752 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1753 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1754 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1755 return success();
1756 }
1757};
1758
1759/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1760/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1761/// advanced cases where the distributed dimension is partially inserted and
1762/// currently not supported by the generic vector distribution patterns.
1763struct VectorInsertStridedSliceDistribution
1765 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1766 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1767 PatternRewriter &rewriter) const override {
1768 OpOperand *operand =
1769 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1770 if (!operand)
1771 return failure();
1772 unsigned int operandNumber = operand->getOperandNumber();
1773 auto insertOp =
1774 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1775 auto distributedType =
1776 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1777 // Find the distributed dimensions of the dest vector.
1778 auto insertResultType = cast<VectorType>(operand->get().getType());
1779 auto destDistributedDims =
1780 getDistributedDims(insertResultType, distributedType);
1781 // Collect updated offsets, source type and dest type. They may be adjusted
1782 // later if the data is distributed to lanes (as opposed to being owned by
1783 // all lanes uniformly).
1784 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1785 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1786 VectorType updatedSourceType = insertOp.getSourceVectorType();
1787 VectorType updatedDestType = insertOp.getDestVectorType();
1788 if (destDistributedDims.size() > 0) {
1789 // Only single dimension distribution is supported.
1790 if (destDistributedDims.size() != 1)
1791 return rewriter.notifyMatchFailure(
1792 warpOp,
1793 "Expecting source to be distributed in a single dimension.");
1794 int64_t destDistributedDim = destDistributedDims[0];
1795
1796 VectorType srcType = insertOp.getSourceVectorType();
1797 VectorType destType = insertOp.getDestVectorType();
1798 // Currently we require that both source (kD) and dest (nD) vectors are
1799 // distributed. This requires that distributedDim (d) is contained in the
1800 // last k dims of the dest vector (d >= n - k).
1801 int64_t sourceDistributedDim =
1802 destDistributedDim - (destType.getRank() - srcType.getRank());
1803 if (sourceDistributedDim < 0)
1804 return rewriter.notifyMatchFailure(
1805 insertOp,
1806 "distributed dimension must be in the last k (i.e. source "
1807 "rank) dims of dest vector");
1808 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1809 // Obtain the source and dest layouts.
1810 auto destLayout =
1811 xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
1812 auto sourceLayout =
1813 xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
1814 if (!destLayout || !sourceLayout ||
1815 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1816 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1817 return rewriter.notifyMatchFailure(
1818 warpOp, "the source or dest of insert_strided_slice op lacks "
1819 "distribution layout");
1820 // Because only single dimension distribution is supported, lane layout
1821 // size at the distributed dim must be the subgroup size.
1822 int subgroupSize =
1823 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1824 // We require that source and dest lane data are all ones to ensure
1825 // uniform round robin distribution.
1826 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1827 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1828 if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1829 !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1830 return rewriter.notifyMatchFailure(
1831 warpOp, "Expecting unit lane data in source and dest layouts");
1832 // Source distributed dim size must be multiples of subgroup size.
1833 if (srcDistrDimSize % subgroupSize != 0)
1834 return rewriter.notifyMatchFailure(
1835 warpOp, "Distributed dimension size in source is not a multiple of "
1836 "subgroup size.");
1837 // Offsets in the distributed dimension must be multiples of subgroup
1838 // size.
1839 int64_t destDistrDimOffset =
1840 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1841 if (destDistrDimOffset % subgroupSize != 0)
1842 return rewriter.notifyMatchFailure(
1843 warpOp,
1844 "Offset along distributed dimension in dest is not a multiple of "
1845 "subgroup size.");
1846 // Update the source and dest types based on their layouts.
1847 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1848 sourceLayout, insertOp.getSourceVectorType())
1849 .value();
1850 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1851 destLayout, insertOp.getDestVectorType())
1852 .value();
1853 // Update the distributed offsets to match round robin distribution (i.e.
1854 // each lane owns data at `subgroupSize` stride given unit lane data).
1855 updatedOffsets[destDistributedDim] =
1856 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1857 }
1858 // Do the distribution by yielding the source and dest of the insert op
1859 // from the warp op and creating a new insert op outside the warp op.
1860 SmallVector<size_t> newRetIndices;
1862 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1863 {updatedSourceType, updatedDestType}, newRetIndices);
1864 rewriter.setInsertionPointAfter(newWarpOp);
1865
1866 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1867 Value dest = newWarpOp.getResult(newRetIndices[1]);
1868 // Create a new insert op outside the warp op.
1869 Value newInsertOp = vector::InsertStridedSliceOp::create(
1870 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1871 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1872 insertOp.getStrides());
1873 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1874 newInsertOp);
1875 return success();
1876 }
1877};
1878
1879/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1880/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1881/// outside of the warp op.
1882struct MemrefExtractAlignedPointerAsIndexDistribution final
1884 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1885 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1886 PatternRewriter &rewriter) const override {
1887 OpOperand *operand = getWarpResult(
1888 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1889 if (!operand)
1890 return rewriter.notifyMatchFailure(
1891 warpOp,
1892 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1893 auto extractOp =
1894 operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1895 unsigned operandIdx = operand->getOperandNumber();
1896 SmallVector<size_t> newRetIndices;
1897 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1898 rewriter, warpOp, extractOp.getSource(),
1899 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1900 rewriter.setInsertionPointAfter(newWarpOp);
1901 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1902 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1903 newWarpOp.getResult(newRetIndices[0]));
1904 Value distributedVal = newWarpOp.getResult(operandIdx);
1905 rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
1906 return success();
1907 }
1908};
1909
1910/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1911/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1912/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1913/// created outside of the warp op with distributed source vector type (computed
1914/// using assigned layout).
1915struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1916 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1917 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1918 PatternRewriter &rewriter) const override {
1919 OpOperand *operand =
1920 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1921 if (!operand)
1922 return rewriter.notifyMatchFailure(
1923 warpOp, "warp result is not a vector::BitCast op");
1924 auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1925 unsigned operandIdx = operand->getOperandNumber();
1926 VectorType distributedSourceType =
1927 getDistVecTypeBasedOnLaneLayout(
1928 xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
1929 bitcastOp.getSourceVectorType())
1930 .value_or(VectorType());
1931 if (!distributedSourceType)
1932 return rewriter.notifyMatchFailure(
1933 bitcastOp, "Failed to distribute the source vector type in "
1934 "vector::BitCast op");
1935 VectorType distributedResultType =
1936 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1937 SmallVector<size_t> newRetIndices;
1938 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1939 rewriter, warpOp, bitcastOp.getSource(),
1940 TypeRange{distributedSourceType}, newRetIndices);
1941 rewriter.setInsertionPointAfter(newWarpOp);
1942 auto newBitcastOp = vector::BitCastOp::create(
1943 rewriter, newWarpOp.getLoc(), distributedResultType,
1944 newWarpOp.getResult(newRetIndices[0]));
1945 Value distributedVal = newWarpOp.getResult(operandIdx);
1946 rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1947 return success();
1948 }
1949};
1950
1951/// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1952/// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1953/// supported. In most cases, transpose is a no op because it is entirely
1954/// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1955/// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1956/// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1957/// vector::TransposeOp outside of the warp op with distributed source vector
1958/// type (computed using assigned layout).
1959struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1960 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1961 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1962 PatternRewriter &rewriter) const override {
1963 OpOperand *operand =
1964 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1965 if (!operand)
1966 return rewriter.notifyMatchFailure(
1967 warpOp, "warp result is not a vector::Transpose op");
1968 auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1969 unsigned operandIdx = operand->getOperandNumber();
1970 xegpu::DistributeLayoutAttr sourceLayout =
1971 xegpu::getDistributeLayoutAttr(transposeOp.getVector());
1972 xegpu::DistributeLayoutAttr resultLayout =
1973 xegpu::getDistributeLayoutAttr(transposeOp.getResult());
1974 if (!sourceLayout || !resultLayout)
1975 return rewriter.notifyMatchFailure(
1976 transposeOp,
1977 "the source or result vector of the transpose op lacks layout "
1978 "attribute");
1979 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1980 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1981 // Only 2D transposes are supported for now.
1982 // TODO: Support nD transposes.
1983 if (sourceRank != 2 || resultRank != 2)
1984 return rewriter.notifyMatchFailure(
1985 transposeOp, "the source or result vector of the transpose op "
1986 "does not have 2D layout");
1987 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1988 // Result layout must be a transpose of source layout.
1989 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1990 return rewriter.notifyMatchFailure(
1991 transposeOp,
1992 "the source or result vector layouts must be 2D transposes of each "
1993 "other");
1994 FailureOr<VectorType> distributedSourceTypeOrFailure =
1995 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1996 transposeOp.getSourceVectorType());
1997 if (failed(distributedSourceTypeOrFailure))
1998 return rewriter.notifyMatchFailure(
1999 transposeOp, "Failed to distribute the source vector type in "
2000 "vector::Transpose op");
2001 SmallVector<size_t> newRetIndices;
2002 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2003 rewriter, warpOp, transposeOp.getVector(),
2004 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2005 rewriter.setInsertionPointAfter(newWarpOp);
2006 auto newTransposeOp = vector::TransposeOp::create(
2007 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2008 perm);
2009 Value distributedVal = newWarpOp.getResult(operandIdx);
2010 rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
2011 return success();
2012 }
2013};
2014
2015} // namespace
2016
2017namespace {
2018struct XeGPUSubgroupDistributePass final
2020 XeGPUSubgroupDistributePass> {
2021 void runOnOperation() override;
2022};
2023} // namespace
2024
2027 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2028 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2029 GpuBarrierDistribution, VectorMultiReductionDistribution,
2030 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2031 VectorBitcastDistribution, LoadMatrixDistribution,
2032 StoreMatrixDistribution,
2033 MemrefExtractAlignedPointerAsIndexDistribution>(
2034 patterns.getContext(),
2035 /*pattern benefit=*/regularPatternBenefit);
2036 // For following patterns, we need to override the regular vector distribution
2037 // patterns. Therefore, assign higher benefit.
2038 patterns
2039 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2040 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
2041 patterns.getContext(),
2042 /*pattern benefit=*/highPatternBenefit);
2043}
2044
2047 patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
2048}
2049
2050void XeGPUSubgroupDistributePass::runOnOperation() {
2051 // Step 1: Attach layouts to op operands.
2052 // TODO: Following assumptions are made:
2053 // 1) It is assumed that there are no layout conflicts.
2054 // 2) Any existing layout attributes attached to the operands are ignored.
2055 Operation *op = getOperation();
2056 op->walk([&](Operation *op) {
2057 for (OpOperand &operand : op->getOpOperands()) {
2058 // Layouts are needed for vector type only.
2059 if (!isa<VectorType>(operand.get().getType()))
2060 continue;
2061 if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
2062 continue;
2063
2064 auto layout = xegpu::getDistributeLayoutAttr(operand.get());
2065 if (!layout) {
2066 op->emitError("Could not find layout attribute for operand ")
2067 << operand.getOperandNumber() << " of operation " << op->getName();
2068 signalPassFailure();
2069 return;
2070 }
2071 xegpu::setDistributeLayoutAttr(operand, layout);
2072 }
2073 });
2074 // Step 2: Move all operations of a GPU function inside
2075 // gpu.warp_execute_on_lane_0 operation.
2076 {
2077 RewritePatternSet patterns(&getContext());
2079
2080 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2081 signalPassFailure();
2082 return;
2083 }
2084 // At this point, we have moved the entire function body inside the
2085 // warpOp. Now move any scalar uniform code outside of the warpOp (like
2086 // GPU index ops, scalar constants, etc.). This will simplify the
2087 // later lowering and avoid custom patterns for these ops.
2088 getOperation()->walk([&](Operation *op) {
2089 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2090 vector::moveScalarUniformCode(warpOp);
2091 });
2092 }
2093 // Step 3: Apply subgroup to workitem distribution patterns.
2094 RewritePatternSet patterns(&getContext());
2096 // distributionFn is used by vector distribution patterns to determine the
2097 // distributed vector type for a given vector value. In XeGPU subgroup
2098 // distribution context, we compute this based on lane layout.
2099 auto distributionFn = [](Value val) {
2100 VectorType vecType = dyn_cast<VectorType>(val.getType());
2101 int64_t vecRank = vecType ? vecType.getRank() : 0;
2102 if (vecRank == 0)
2103 return AffineMap::get(val.getContext());
2104 // Get the layout of the vector type.
2105 xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
2106 // If no layout is specified, that means no distribution.
2107 if (!layout)
2108 return AffineMap::getMultiDimMapWithTargets(vecRank, {},
2109 val.getContext());
2110 // Expecting vector and layout rank to match.
2111 assert(layout.getRank() == vecRank &&
2112 "Expecting vector and layout rank to match");
2113 // A dimension is distributed only if layout suggests there are
2114 // multiple lanes assigned for this dimension and the shape can be evenly
2115 // distributed to those lanes.
2116 SmallVector<unsigned int> distributedDims;
2117 for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2118 if (v > 1 && vecType.getShape()[i] % v == 0)
2119 distributedDims.push_back(i);
2120 }
2121 return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
2122 val.getContext());
2123 };
2124 // TODO: shuffleFn is not used.
2125 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2126 int64_t warpSz) { return Value(); };
2127
2128 auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
2129 vector::CombiningKind kind, uint32_t size) {
2130 // First reduce on a single thread to get per lane reduction value.
2131 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
2132 // Parallel reduction using butterfly shuffles.
2133 for (uint64_t i = 1; i < size; i <<= 1) {
2134 Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
2135 /*width=*/size,
2136 /*mode=*/gpu::ShuffleMode::XOR)
2137 .getShuffleResult();
2138 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
2139 }
2140 return laneVal;
2141 };
2142
2143 vector::populateDistributeReduction(
2144 patterns, warpReduction,
2145 /*pattern benefit=*/regularPatternBenefit);
2146
2147 vector::populatePropagateWarpVectorDistributionPatterns(
2148 patterns, distributionFn, shuffleFn,
2149 /*pattern benefit=*/regularPatternBenefit);
2150 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2151 signalPassFailure();
2152 return;
2153 }
2154
2155 // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
2156 // due to tensor desc type mismatches created by using upstream distribution
2157 // patterns (scf.for). This cleanup should only be done if all the ops are
2158 // distributed successfully, if some ops are still not distributed and remains
2159 // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
2160 // breaking the IR.
2161 bool foundWarpOp = false;
2162 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2163 // Look for WarpOps that are not trivially dead.
2164 if (isOpTriviallyDead(warpOp))
2165 return WalkResult::advance();
2166 foundWarpOp = true;
2167 return WalkResult::interrupt();
2168 });
2169 if (foundWarpOp)
2170 return;
2171
2172 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2173 // We are only interested in UnrealizedConversionCastOps there were added
2174 // for resolving SIMT type mismatches.
2175 if (!op->getAttr(resolveSIMTTypeMismatch))
2176 return WalkResult::skip();
2177
2178 Value input = op.getOperand(0);
2179 Value output = op.getResult(0);
2180
2181 // Both input and output must have tensor descriptor types.
2182 xegpu::TensorDescType inputDescType =
2183 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
2184 xegpu::TensorDescType outputDescType =
2185 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
2186 assert(inputDescType && outputDescType &&
2187 "Unrealized conversion cast must have tensor descriptor types");
2188
2189 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
2190 // This occurs inside scf.for body to resolve the block argument type to
2191 // SIMT type.
2192 if (inputDescType.getLayout()) {
2193 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2194 if (argument) {
2195 argument.setType(output.getType());
2196 output.replaceAllUsesWith(argument);
2197 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2198 argument.getOwner()->getParentOp())) {
2199 auto result = loopOp.getTiedLoopResult(argument);
2200 result.setType(output.getType());
2201 }
2202 }
2203 }
2204
2205 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
2206 // conversions. This occurs at the yield op of scf.for body to go back
2207 // from SIMT type to original type.
2208 if (outputDescType.getLayout())
2209 output.replaceAllUsesWith(input);
2210
2211 if (op->use_empty())
2212 op->erase();
2213 return WalkResult::advance();
2214 });
2215}
return success()
b getContext())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:153
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0
StringRef getName() const
Definition uArchBase.h:152