MLIR 23.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
27#include "mlir/IR/Operation.h"
29#include "mlir/IR/TypeRange.h"
30#include "mlir/IR/Value.h"
31#include "mlir/IR/Visitors.h"
33#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53
54static const char *const resolveSIMTTypeMismatch =
55 "resolve_simt_type_mismatch"; // Attribute name for identifying
56 // UnrelizedConversionCastOp added to resolve
57 // SIMT type mismatches.
58
59namespace {
60
61//===----------------------------------------------------------------------===//
62// SIMT Distribution Patterns
63//===----------------------------------------------------------------------===//
64
65/// In certain cases, we may need to favor XeGPU specific distribution patterns
66/// over generic vector distribution patterns. In such cases, we can assign
67/// priorities to patterns.
68enum PatternHierarchy : unsigned { Regular = 1, AboveRegular = 2 };
69
70/// Helper function to resolve types if the distributed type out of
71/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
72/// Example 1:
73/// distributed type: vector<8x1xf32>
74/// expected type: vector<8xf32>
75/// resolved using,
76/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
77/// Example 2:
78/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
79/// expected type: xegpu.tensor_desc<8x16xf32>
80/// resolved using,
81/// %0 = unrealized_conversion_cast %1 :
82/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
83/// xegpu.tensor_desc<8x16xf32>
84template <typename T>
85static Value resolveDistributedTy(Value orig, T expected,
86 PatternRewriter &rewriter) {
87 // If orig and expected types are the same, return orig.
88 if (orig.getType() == expected)
89 return orig;
90 // If orig is a vector type, create a shape cast op to reconcile the types.
91 if (isa<VectorType>(orig.getType())) {
92 auto castOp =
93 vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
94 return castOp.getResult();
95 }
96 // If orig is a tensor descriptor type, create an unrealized conversion cast
97 // op to reconcile the types.
98 if (isa<xegpu::TensorDescType>(orig.getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
100 expected, orig);
101 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
102 return castOp.getResult(0);
103 }
104 llvm_unreachable("Unsupported type for reconciliation");
105 return orig;
106}
107
108/// Given a vector type and its distributed vector type, return the list of
109/// dimensions that are distributed.
110static SmallVector<int64_t> getDistributedDims(VectorType originalType,
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
114 SmallVector<int64_t> distributedDims;
115 for (int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
118 }
119 }
120 return distributedDims;
121}
122
123/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
124/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
125/// contained within a WarpExecuteOnLane0Op.
126/// Example:
127///
128/// ```
129/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
130/// ...
131/// ...
132/// gpu.return %result: vector<8x16xf32>
133/// }
134/// ```
135/// To
136/// ```
137/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
138/// %laneid = gpu.lane_id : index
139/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
140/// ...
141/// ...
142/// gpu.yield %result: vector<8x16xf32>
143/// }
144/// return %0
145/// }
146struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
147 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
148 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
149 PatternRewriter &rewriter) const override {
150 auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
151 if (!uArch)
152 return rewriter.notifyMatchFailure(
153 gpuFuncOp, "Subgroup distribution requires target attribute attached "
154 "to set the warp size");
155 // If the function only contains a single void return, skip.
156 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
157 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
158 }))
159 return failure();
160 // If the function already moved inside a warp_execute_on_lane0, skip.
161 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
162 return isa<gpu::WarpExecuteOnLane0Op>(op);
163 }))
164 return failure();
165 gpu::ReturnOp origReturnOp = dyn_cast_if_present<gpu::ReturnOp>(
166 gpuFuncOp.getBlocks().back().getTerminator());
167 if (!origReturnOp)
168 return rewriter.notifyMatchFailure(
169 gpuFuncOp, "expected gpu.func terminator to be gpu.return");
170 // Create a new function with the same signature and same attributes.
171 SmallVector<Type> workgroupAttributionsTypes =
172 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
173 [](BlockArgument arg) { return arg.getType(); });
174 SmallVector<Type> privateAttributionsTypes =
175 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
176 [](BlockArgument arg) { return arg.getType(); });
177 auto newGpuFunc = gpu::GPUFuncOp::create(
178 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
179 gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
180 privateAttributionsTypes);
181 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
182 // Create a WarpExecuteOnLane0Op with same arguments and results as the
183 // original gpuFuncOp.
184 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
185 auto laneId = gpu::LaneIdOp::create(
186 rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
187 /** upperBound = **/ mlir::IntegerAttr());
188 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
189 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
190 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
191 uArch->getSubgroupSize(), newGpuFunc.getArguments(),
192 newGpuFunc.getArgumentTypes());
193 Block &warpBodyBlock = warpOp.getBodyRegion().front();
194 // Replace the ReturnOp of the original gpu function with a YieldOp.
195 rewriter.setInsertionPointAfter(origReturnOp);
196 gpu::YieldOp::create(rewriter, origReturnOp.getLoc(),
197 origReturnOp.getOperands());
198 rewriter.eraseOp(origReturnOp);
199 // Move the original function body to the WarpExecuteOnLane0Op body.
200 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
201 warpOp.getBodyRegion().begin());
202 rewriter.eraseBlock(&warpBodyBlock);
203 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
204 rewriter.setInsertionPointAfter(warpOp);
205 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
206 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
207 return success();
208 }
209};
210
211/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
212/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
213/// still contain the original op that will not be used by the yield op (and
214/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
215/// arguments. Tensor descriptor shape is not distributed because it is a
216/// uniform value across all work items within the subgroup. However, the
217/// layout information is dropped in the new tensor descriptor type.
218///
219/// Example:
220///
221/// ```
222/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
223/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
224/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
225/// ...
226/// %td = xegpu.create_nd_tdesc %arg0
227/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
228/// vector.yield %td
229/// }
230/// ```
231/// To
232/// ```
233/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
234/// ...
235/// %dead = xegpu.create_nd_tdesc %arg0
236/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
237/// vector.yield %arg0, %dead
238/// }
239/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
240/// -> !xegpu.tensor_desc<4x8xf32>
241///
242/// ```
243struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
244 using gpu::WarpDistributionPattern::WarpDistributionPattern;
245 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
246 PatternRewriter &rewriter) const override {
247 OpOperand *operand =
248 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
249 if (!operand)
250 return rewriter.notifyMatchFailure(
251 warpOp, "warp result is not a xegpu::CreateNdDesc op");
252 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
253 unsigned operandIdx = operand->getOperandNumber();
254
255 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
256 if (!layout)
257 return rewriter.notifyMatchFailure(
258 descOp, "the tensor descriptor lacks layout attribute");
259 // CreateNdOp must not have offsets.
260 if (descOp.getMixedOffsets().size())
261 return rewriter.notifyMatchFailure(
262 descOp, "xegpu::CreateNdDescOp must not have offsets");
263
264 SmallVector<size_t> newRetIndices;
265 rewriter.setInsertionPoint(warpOp);
266 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
267 rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
268 /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
269
270 SmallVector<Value> newDescOperands = llvm::map_to_vector(
271 newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
272 rewriter.setInsertionPointAfter(newWarpOp);
273 xegpu::TensorDescType distributedTensorDescTy =
274 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
275 // does not contain layout info.
276 Value newDescOp = xegpu::CreateNdDescOp::create(
277 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
278 descOp->getAttrs());
279
280 Value distributedVal = newWarpOp.getResult(operandIdx);
281 // Resolve the distributed type to the expected type.
282 newDescOp =
283 resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
284 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
285 return success();
286 }
287};
288
289/// Distribute a store_nd op at the end of enclosing
290/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
291/// through the warp op interface they would be propagated as returned values.
292/// Source vector is distributed based on lane layout. Appropriate cast ops are
293/// inserted if the distributed types does not match expected xegpu SIMT types.
294///
295/// Example:
296///
297/// ```
298/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
299/// gpu.warp_execute_on_lane_0(%laneid) -> () {
300/// ...
301/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
302/// !xegpu.tensor_desc<4x8xf32, #layout0>
303/// }
304/// ```
305/// To
306/// ```
307/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
308/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
309/// ...
310/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
311/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
312/// }
313/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
314/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
315/// #layout0>
316/// -> !xegpu.tensor_desc<4x8xf32>
317/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
318/// !xegpu.tensor_desc<4x8xf32>
319///
320/// ```
321struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
322 using gpu::WarpDistributionPattern::WarpDistributionPattern;
323 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
324 PatternRewriter &rewriter) const override {
325 gpu::YieldOp yield = warpOp.getTerminator();
326 Operation *lastNode = yield->getPrevNode();
327 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
328 if (!storeOp)
329 return failure();
330
331 SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
332 // Expecting offsets to be present.
333 if (offsets.empty())
334 return rewriter.notifyMatchFailure(storeOp,
335 "the store op must have offsets");
336 SmallVector<Value> offsetsAsValues =
337 vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
338 SmallVector<Type> offsetTypes = llvm::map_to_vector(
339 offsetsAsValues, [](Value v) { return v.getType(); });
340 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
341 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
342 if (!layout)
343 return rewriter.notifyMatchFailure(
344 storeOp, "the source tensor descriptor lacks layout attribute");
345
346 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
347 xegpu::getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
348 if (failed(distributedTypeByWarpOpOrFailure))
349 return rewriter.notifyMatchFailure(storeOp,
350 "Failed to distribute the type");
351 VectorType distributedTypeByWarpOp =
352 distributedTypeByWarpOpOrFailure.value();
353
354 SmallVector<size_t> newRetIndices;
355 SmallVector<Value> newYieldedValues = {storeOp.getValue(),
356 storeOp.getTensorDesc()};
357 SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
358 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
359 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
360 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
361 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
362 // Create a new store op outside the warp op with the distributed vector
363 // type. Tensor descriptor is not distributed.
364 rewriter.setInsertionPointAfter(newWarpOp);
365 SmallVector<Value> newStoreOperands;
366
367 // For the value operand, there can be a mismatch between the vector type
368 // distributed by the warp op and (xegpu-specific) distributed type
369 // supported by the store op. Type mismatch must be resolved using
370 // appropriate cast op.
371 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
372 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
373 if (failed(storeNdDistributedValueTyOrFailure))
374 return rewriter.notifyMatchFailure(
375 storeOp, "Failed to get distributed vector type for the store op");
376 newStoreOperands.push_back(resolveDistributedTy(
377 newWarpOp.getResult(newRetIndices[0]),
378 storeNdDistributedValueTyOrFailure.value(), rewriter));
379 // For the tensor descriptor operand, the layout attribute is dropped after
380 // distribution. Types needs to be resolved in this case also.
381 xegpu::TensorDescType distributedTensorDescTy =
382 storeOp.getTensorDescType().dropLayouts();
383 newStoreOperands.push_back(
384 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
385 distributedTensorDescTy, rewriter));
386 // Collect offsets.
387 for (size_t i = 2; i < newRetIndices.size(); ++i)
388 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
389
390 auto newStoreOp =
391 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
392 newStoreOperands, storeOp->getAttrs());
394 rewriter.eraseOp(storeOp);
395 return success();
396 }
397};
399/// Distribute a load_nd op feeding into vector.yield op for the enclosing
400/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
401/// The warp op will still contain the original op that will not be used by
402/// the yield op (and should be cleaned up later). The yield op will
403/// bypass the load's arguments. Only the loaded vector is distributed
404/// according to lane layout and, tensor descriptor types is not
405/// distributed. Appropriate cast ops are inserted if the distributed types does
406/// not match expected xegpu SIMT types.
407///
408/// Example:
409///
410/// ```
411/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
412/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
413/// (vector<4x1xf32>) {
414/// ...
415/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
416/// ->
417/// vector<4x8xf32>
418/// gpu.yield %ld
419/// }
420/// ```
421/// To
422/// ```
423/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
424/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
425/// ...
426/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
427/// vector<4x8xf32> gpu.yield %dead, %arg0
428/// }
429/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
430/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
431/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
432/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
433///
434/// ```
435struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
436 using gpu::WarpDistributionPattern::WarpDistributionPattern;
437 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
438 PatternRewriter &rewriter) const override {
439 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
440 if (!isa<xegpu::LoadNdOp>(op))
441 return false;
442 // Make sure the same load op is the last operation in the warp op body.
443 // This ensure that load op is not sinked earlier violating any barrier
444 // synchronizations.
445 gpu::YieldOp yield = warpOp.getTerminator();
446 return yield->getPrevNode() == op;
447 });
448
449 if (!operand)
450 return rewriter.notifyMatchFailure(
451 warpOp, "warp result is not a xegpu::LoadNd op");
452
453 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
454 auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
455 if (!uArch)
456 return rewriter.notifyMatchFailure(
457 loadOp, "xegpu::LoadNdOp require target attribute attached to "
458 "determine transpose "
459 "requirement");
460 // Chip information is required to decide if the layout requires transpose
461 // effect.
462 // Expecting offsets to be present.
463 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
464 if (offsets.empty())
465 return rewriter.notifyMatchFailure(loadOp,
466 "the load op must have offsets");
467 SmallVector<Value> offsetsAsValues =
468 vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
469 SmallVector<Type> offsetTypes = llvm::map_to_vector(
470 offsetsAsValues, [](Value v) { return v.getType(); });
471
472 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
473 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
474 if (!layout)
475 return rewriter.notifyMatchFailure(
476 loadOp, "the source tensor descriptor lacks layout attribute");
477
478 unsigned operandIdx = operand->getOperandNumber();
479 VectorType distributedTypeByWarpOp =
480 cast<VectorType>(warpOp.getResult(operandIdx).getType());
481
482 SmallVector<size_t> newRetIndices;
483 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
484 SmallVector<Type> newYieldedTypes = {tensorDescTy};
485 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
486 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
487 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
488 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
489
490 // Create a new load op outside the warp op with the distributed vector
491 // type.
492 rewriter.setInsertionPointAfter(newWarpOp);
493 FailureOr<VectorType> loadNdDistValueTyOrFailure =
494 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
495 if (failed(loadNdDistValueTyOrFailure))
496 return rewriter.notifyMatchFailure(
497 loadOp, "Failed to get distributed vector type for the load op");
498 xegpu::TensorDescType distributedTensorDescTy =
499 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
500 // descriptor type does not
501 // contain layout info.
502 SmallVector<Value> newLoadOperands{
503 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
504 distributedTensorDescTy, rewriter)};
505 // Collect offsets.
506 for (size_t i = 1; i < newRetIndices.size(); ++i)
507 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
508 auto newLoadOp = xegpu::LoadNdOp::create(
509 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
510 newLoadOperands, loadOp->getAttrs());
511 xegpu::removeLayoutAttrs(newLoadOp);
512 // Set the packed attribute if the layout requires it.
513 newLoadOp.setPacked(xegpu::requirePacked(layout));
514 // Set the transpose attribute if the layout requires it.
515 if (xegpu::requireTranspose(layout, uArch))
516 newLoadOp.setTranspose(
517 DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
518 Value distributedVal = newWarpOp.getResult(operandIdx);
519 // There can be a conflict between the vector type distributed by the
520 // warp op and (xegpu-specific) distributed type supported by the load
521 // op. Resolve these mismatches by inserting a cast.
522 Value tyResolvedVal = resolveDistributedTy(
523 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
524 rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
525 return success();
526 }
527};
528
529/// Distribute a dpas op feeding into vector.yield op for the enclosing
530/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
531/// The warp op will still contain the original op that will not be used by
532/// the yield op (and should be cleaned up later). The yield op will
533/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
534/// distributed types does not match expected xegpu SIMT types.
535/// Example:
536/// ```
537/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
538/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
539/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
540/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
541/// (vector<8x1xf32>) {
542/// ...
543/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
544/// vector<8x16xf32>
545/// gpu.yield %dpas
546/// }
547/// ```
548/// To
549/// ```
550/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
551/// vector<8x1xf16>, vector<16x1xf16>) {
552/// ...
553/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
554/// -> vector<8x16xf32>
555/// gpu.yield %dead, %arg0, %arg1
556/// }
557/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
558/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
559/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
560/// vector<8xf32>
561/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
562/// ```
563struct DpasDistribution final : public gpu::WarpDistributionPattern {
564 using gpu::WarpDistributionPattern::WarpDistributionPattern;
565 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
566 PatternRewriter &rewriter) const override {
567 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
568 if (!operand)
569 return rewriter.notifyMatchFailure(warpOp,
570 "warp result is not a xegpu::Dpas op");
571
572 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
573 unsigned operandIdx = operand->getOperandNumber();
574
575 xegpu::LayoutAttr layoutA =
576 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
577 xegpu::LayoutAttr layoutB =
578 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
579 xegpu::LayoutAttr layoutOut =
580 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
581
582 if (!layoutA || !layoutB || !layoutOut)
583 return rewriter.notifyMatchFailure(
584 dpasOp,
585 "the xegpu::Dpas op lacks layout attribute for A, B or output");
586
587 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
588 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
589 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
590 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
591 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
592 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
593
594 if (failed(distLhsTypeByWarpOpOrFailure) ||
595 failed(distRhsTypeByWarpOpOrFailure) ||
596 failed(distResultTypeByWarpOpOrFailure))
597 return rewriter.notifyMatchFailure(
598 dpasOp,
599 "Failed to distribute the A, B or output types in xegpu::Dpas op");
600
601 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
602 dpasOp.getRhs()};
603 llvm::SmallVector<Type, 3> newYieldTypes{
604 distLhsTypeByWarpOpOrFailure.value(),
605 distRhsTypeByWarpOpOrFailure.value()};
606 // Dpas acc operand is optional.
607 if (dpasOp.getAcc()) {
608 newYieldValues.push_back(dpasOp.getAcc());
609 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
610 }
611 // Create a new warp op without the dpas.
612 SmallVector<size_t> newRetIndices;
613 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
614 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
615
616 FailureOr<VectorType> expectedDistLhsTyOrFailure =
617 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
618 FailureOr<VectorType> expectedDistRhsTyOrFailure =
619 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
620 FailureOr<VectorType> expectedDistResultTyOrFailure =
621 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
622
623 if (failed(expectedDistLhsTyOrFailure) ||
624 failed(expectedDistRhsTyOrFailure) ||
625 failed(expectedDistResultTyOrFailure))
626 return rewriter.notifyMatchFailure(
627 dpasOp,
628 "Failed to get distributed vector type for the dpas operands.");
629 // Create a new dpas op outside the warp op.
630 rewriter.setInsertionPointAfter(newWarpOp);
631 SmallVector<Value> newDpasOperands;
632 SmallVector<VectorType> newDpasOperandExpectedTypes;
633
634 // Resolve the distributed types with the original types.
635 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
636 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
637 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
638 if (dpasOp.getAcc())
639 newDpasOperandExpectedTypes.push_back(distributedResultTy);
640
641 for (unsigned i = 0; i < newRetIndices.size(); i++) {
642 newDpasOperands.push_back(
643 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
644 newDpasOperandExpectedTypes[i], rewriter));
645 }
646 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
647 distributedResultTy, newDpasOperands,
648 dpasOp->getAttrs());
649 xegpu::removeLayoutAttrs(newDpasOp);
650 Value distributedVal = newWarpOp.getResult(operandIdx);
651 // Resolve the output type.
652 Value typeResolved =
653 resolveDistributedTy(newDpasOp.getResult(),
654 distResultTypeByWarpOpOrFailure.value(), rewriter);
655 rewriter.replaceAllUsesWith(distributedVal, typeResolved);
656 return success();
657 }
658};
659
660/// Distribute a prefetch_nd op at the end of enclosing
661/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
662/// through the warp op interface they would be propagated as returned values.
663/// Tensor descriptor shape is not distributed because it is a uniform value
664/// across all work items within the subgroup. Appropriate cast ops are inserted
665/// if the distributed types does not match expected xegpu SIMT types.
666///
667/// Example:
668///
669/// ```
670/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
671/// gpu.warp_execute_on_lane_0(%laneid) -> () {
672/// ...
673/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
674/// }
675/// ```
676/// To
677/// ```
678/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
679/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
680/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
681/// index
682/// }
683/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
684/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
685/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
686///
687/// ```
688struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
689 using gpu::WarpDistributionPattern::WarpDistributionPattern;
690 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
691 PatternRewriter &rewriter) const override {
692 gpu::YieldOp yield = warpOp.getTerminator();
693 Operation *lastNode = yield->getPrevNode();
694 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
695 if (!prefetchOp)
696 return failure();
697
698 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
699 // PrefetchNdOp must have offsets.
700 if (offsets.empty())
701 return rewriter.notifyMatchFailure(prefetchOp,
702 "the prefetch op must have offsets");
703 SmallVector<Value> offsetsAsValues =
704 vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
705 SmallVector<Type> offsetTypes = llvm::map_to_vector(
706 offsetsAsValues, [](Value v) { return v.getType(); });
707
708 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
709 if (!layout)
710 return rewriter.notifyMatchFailure(
711 prefetchOp, "the source tensor descriptor lacks layout attribute");
712
713 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
714 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
715 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
716 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
717 SmallVector<size_t> newRetIndices;
718 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
719 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
720 // Create a new prefetch op outside the warp op with updated tensor
721 // descriptor type. Source tensor descriptor require type resolution.
722 xegpu::TensorDescType newTensorDescTy =
723 prefetchOp.getTensorDescType().dropLayouts();
724 rewriter.setInsertionPointAfter(newWarpOp);
725 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
726 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
727 // Collect offsets.
728 for (size_t i = 1; i < newRetIndices.size(); ++i)
729 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
730 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
731 rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
732 prefetchOp->getAttrs());
733 xegpu::removeLayoutAttrs(newPrefetchOp);
734 rewriter.eraseOp(prefetchOp);
735 return success();
736 }
737};
738
739/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
740/// region. This will simply move the barrier op outside of the warp op.
741struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
742 using gpu::WarpDistributionPattern::WarpDistributionPattern;
743 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
744 PatternRewriter &rewriter) const override {
745 gpu::YieldOp yield = warpOp.getTerminator();
746 Operation *lastNode = yield->getPrevNode();
747 // The last node must be a gpu::BarrierOp.
748 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
749 if (!barrierOp)
750 return failure();
751 // Move the barrier op outside of the warp op.
752 rewriter.setInsertionPointAfter(warpOp);
753 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
754 barrierOp->getResultTypes(),
755 barrierOp->getOperands(), barrierOp->getAttrs());
756 rewriter.eraseOp(barrierOp);
757 return success();
758 }
759};
760
761/// Distribute a scattered store op. The offsets argument is required.
762/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
763/// The layouts are fixed and implicit: one offset/mask per lane.
764/// The pass changes the offset/mask vector shapes to a
765/// single-element vector, **it is assumed that their producer will also be
766/// distributed**. The payload vector also has a fixed distribution:
767/// no chunk size -> vector of one element.
768/// chunk size -> vector of the innermost dimension of the SG-payload.
769/// Example 1 (no chunk size):
770/// %mask = producer_op : vector<16xi1>
771/// %offset = producer_op : vector<16xindex>
772/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
773/// memref<256xf16>, vector<16xindex>, vector<16xi1>
774/// To
775/// %mask = producer_op : vector<1xi1>
776/// %offset = producer_op : vector<1xindex>
777/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
778/// memref<256xf16>, vector<1xindex>, vector<1xi1>
779/// Example 2 (chunk size, same mask and offsets):
780/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
781/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
782/// To
783/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
784/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
785///
786/// Note that the store distribution pattern also handles leading unit
787/// dimensions in the payload, mask and offsets vectors. In this case the store
788/// distribution will only change the dimensions corresponding to the SG
789/// distribution and keep the leading unit dimensions unchanged.
790/// For example, a store with payload vector<1x16xf16> with lane layout [1, 16 ]
791/// will be distributed as vector<1x1xf16>. Shapecast ops are inserted for the
792/// offset/mask/payload when necessary so that the distributed store is workign
793/// on 1D shape vector to match the HW capability.
794struct StoreDistribution final : public gpu::WarpDistributionPattern {
795 using gpu::WarpDistributionPattern::WarpDistributionPattern;
796 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
797 PatternRewriter &rewriter) const override {
798 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
799 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
800 if (!storeScatterOp)
801 return failure();
802 auto offsets = storeScatterOp.getOffsets();
803 if (!offsets || !isa<VectorType>(offsets.getType()))
804 return rewriter.notifyMatchFailure(
805 storeScatterOp, "Store op must have a vector of offsets argument");
806 VectorType offsetsTy = cast<VectorType>(offsets.getType());
807 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
808 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
809
810 // Add handling for leading unit dimensions support
811 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
812 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
813
814 // Check that all leading dimensions are unit dimensions
815 for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
816 if (storeVecTy.getShape()[i] != 1) {
817 return rewriter.notifyMatchFailure(
818 storeScatterOp, "Only unit dimensions allowed for the leading "
819 "dimensions of the store vector!");
820 }
821 }
822
823 auto layoutPayload =
824 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
825 auto layoutOffsets =
826 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
827 auto layoutMask =
828 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
829
830 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
831 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
832 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
833 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
834 FailureOr<VectorType> distMaskByWarpOpOrFailure =
835 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
836 if (failed(distStoreVecByWarpOpOrFailure) ||
837 failed(distOffsetsByWarpOpOrFailure) ||
838 failed(distMaskByWarpOpOrFailure)) {
839 return rewriter.notifyMatchFailure(
840 storeScatterOp,
841 "Some vector operands have no layouts, using defaults instead.");
842 }
843
844 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
845 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
846 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
847
848 SmallVector<size_t> newRetIndices;
849 SmallVector<Value> operands = storeScatterOp->getOperands();
850 SmallVector<Type> operandTypesToYield = {
851 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
852
853 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
854 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
855
856 rewriter.setInsertionPointAfter(newWarpOp);
857
858 // Distributed store payload type is always 1D without leading unit dims
859 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
860 distPayloadTy.getElementType());
861
862 VectorType distOffsetsTy1D = VectorType::get(
863 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
864 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
865 distMaskTy.getElementType());
866
867 // Resolve distributed types to 1D for SIMT execution
868 Value distPayloadVal = resolveDistributedTy(
869 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
870 Value distOffsetVal = resolveDistributedTy(
871 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
872 Value distMaskVal = resolveDistributedTy(
873 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
874
875 SmallVector<Value> newStoreScatterOpOperands = {
876 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
877 distMaskVal};
878
879 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
880 rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
881 storeScatterOp->getAttrs());
883 rewriter.eraseOp(storeScatterOp);
884 return success();
885 }
886};
887
888static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
889 PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
890 Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
891 SmallVector<Value> newCoods;
892 auto maybeCoords =
893 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
894 if (failed(maybeCoords))
895 return {};
896 assert(maybeCoords.value().size() == 1 &&
897 "Expected one set of distributed offsets");
899 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
900 getAsOpFoldResult(origOffsets));
901 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
902 return newCoods;
903}
904
905/// Pattern for distributing xegpu::LoadMatrixOp.
906struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
907 using gpu::WarpDistributionPattern::WarpDistributionPattern;
908 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
909 PatternRewriter &rewriter) const override {
910 gpu::YieldOp yield = warpOp.getTerminator();
911 Operation *lastNode = yield->getPrevNode();
912 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
913 if (!matrixOp)
914 return failure();
915
916 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
917 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
918 });
919 if (!producedByLastLoad)
920 return rewriter.notifyMatchFailure(
921 warpOp, "The last op is not xegpu::LoadMatrixOp");
922 const int operandIdx = producedByLastLoad->getOperandNumber();
923
924 VectorType sgPayloadTy =
925 dyn_cast<VectorType>(matrixOp.getResult().getType());
926 VectorType warpResultTy =
927 cast<VectorType>(warpOp.getResult(operandIdx).getType());
928 if (!sgPayloadTy)
929 return rewriter.notifyMatchFailure(
930 matrixOp, "the matrix op payload must be a vector type");
931
932 auto loc = matrixOp.getLoc();
933 auto offsets = matrixOp.getMixedOffsets();
934 if (offsets.empty())
935 return rewriter.notifyMatchFailure(matrixOp,
936 "the load op must have offsets");
937 SmallVector<Value> offsetsAsValues =
938 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
939
940 auto layout = matrixOp.getLayoutAttr();
941 if (!layout)
942 return rewriter.notifyMatchFailure(
943 matrixOp, "the matrix operation lacks layout attribute");
944
945 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
946 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
947 if (failed(distPayloadByWarpOpOrFailure))
948 return rewriter.notifyMatchFailure(
949 matrixOp, "Failed to distribute matrix op payload based on layout.");
950
951 SmallVector<Value> operands = {matrixOp.getMemDesc()};
952 const unsigned offsetsStartIdx = operands.size();
953 operands.append(offsetsAsValues);
954
955 SmallVector<Type> operandTypes =
956 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
957
958 SmallVector<size_t> newRetIndices;
959 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
960 rewriter, warpOp, operands, operandTypes, newRetIndices);
961 SmallVector<Value> newOperands = llvm::map_to_vector(
962 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
963
964 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
965 ShapedType::kDynamic);
966 DenseI64ArrayAttr newConstOffsetsAttr =
967 rewriter.getDenseI64ArrayAttr(newConstOffsets);
968 ValueRange currentOffsets =
969 ValueRange(newOperands).drop_front(offsetsStartIdx);
970
971 SmallVector<Value> newCoords = currentOffsets;
972 rewriter.setInsertionPointAfter(newWarpOp);
973
974 if (!matrixOp.getSubgroupBlockIoAttr()) {
975 newCoords = computeDistributedCoordinatesForMatrixOp(
976 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
977 currentOffsets);
978 }
979 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
980 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
981 newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
982 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
983 // Resolve the output type and replace all uses.
984 rewriter.replaceAllUsesWith(
985 newWarpOp.getResult(operandIdx),
986 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
987 return success();
988 }
989};
990
991/// Pattern for distributing xegpu::StoreMatrixOp.
992struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
993 using gpu::WarpDistributionPattern::WarpDistributionPattern;
994 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
995 PatternRewriter &rewriter) const override {
996 gpu::YieldOp yield = warpOp.getTerminator();
997 Operation *lastNode = yield->getPrevNode();
998 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
999 if (!matrixOp)
1000 return failure();
1001
1002 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1003 if (!sgPayloadTy)
1004 return rewriter.notifyMatchFailure(
1005 matrixOp, "the matrix op payload must be a vector type");
1006
1007 auto loc = matrixOp.getLoc();
1008 auto offsets = matrixOp.getMixedOffsets();
1009 if (offsets.empty())
1010 return rewriter.notifyMatchFailure(matrixOp,
1011 "the store op must have offsets");
1012 SmallVector<Value> offsetsAsValues =
1013 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1014
1015 auto layout = matrixOp.getLayoutAttr();
1016 if (!layout)
1017 return rewriter.notifyMatchFailure(
1018 matrixOp, "the matrix operation lacks layout attribute");
1019
1020 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1021 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1022 if (failed(distPayloadByWarpOpOrFailure))
1023 return rewriter.notifyMatchFailure(
1024 matrixOp, "Failed to distribute matrix op payload based on layout.");
1025
1026 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1027 const unsigned offsetsStartIdx = operands.size();
1028 operands.append(offsetsAsValues);
1029
1030 SmallVector<Type> operandTypes =
1031 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
1032 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1033
1034 SmallVector<size_t> newRetIndices;
1035 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1036 rewriter, warpOp, operands, operandTypes, newRetIndices);
1037 SmallVector<Value> newOperands = llvm::map_to_vector(
1038 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1039
1040 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1041 ShapedType::kDynamic);
1042 DenseI64ArrayAttr newConstOffsetsAttr =
1043 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1044 ValueRange currentOffsets =
1045 ValueRange(newOperands).drop_front(offsetsStartIdx);
1046
1047 SmallVector<Value> newCoords = currentOffsets;
1048 rewriter.setInsertionPointAfter(newWarpOp);
1049
1050 if (!matrixOp.getSubgroupBlockIoAttr()) {
1051 newCoords = computeDistributedCoordinatesForMatrixOp(
1052 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1053 currentOffsets);
1054 }
1055
1056 xegpu::StoreMatrixOp::create(
1057 rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1058 ValueRange(newCoords), newConstOffsetsAttr,
1059 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1060 rewriter.eraseOp(matrixOp);
1061 return success();
1062 }
1063};
1064
1065/// Distribute a scattered load op. The logic and requirements are the same as
1066/// for the scattered store distribution. The warpOp's payload vector is
1067/// expected to be distributed by the load's result consumer.
1068/// Example 1 (no chunk size):
1069/// %mask = producer_op : vector<16xi1>
1070/// %offset = producer_op : vector<16xindex>
1071/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1072/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
1073/// To
1074/// %mask = producer_op : vector<1xi1>
1075/// %offset = producer_op : vector<1xindex>
1076/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1077/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
1078/// Example 2 (chunk size, same mask and offsets):
1079/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1080/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
1081/// To
1082/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1083/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
1084///
1085/// Note that the load distribution pattern also handles leading unit dimensions
1086/// in the payload, mask, and offsets vector.The load distribution will only
1087/// change the dimensions corresponding to the SG distribution and keep the
1088/// leading unit dimensions unchanged. For example, a load with result type
1089/// vector<1x16xf16> with lane layout [1, 16 ] will be distributed
1090/// as result type vector<1x1xf16>. Shapecast ops are inserted for the
1091/// offset/mask/payload when necessary so that the distributed load is workign
1092/// on 1D shape vector to match the HW capability.
1093struct LoadDistribution final : public gpu::WarpDistributionPattern {
1094 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1095 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1096 PatternRewriter &rewriter) const override {
1097 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1098 // Check if the yield operand that was produced by the *last* scattered
1099 // load op to avoid sinking it before barriers (maintain memory order).
1100 return isa<xegpu::LoadGatherOp>(op) &&
1101 warpOp.getTerminator()->getPrevNode() == op;
1102 });
1103 if (!producedByLastLoad)
1104 return rewriter.notifyMatchFailure(
1105 warpOp, "The last op is not xegpu::LoadGatherOp");
1106
1107 auto loadGatherOp =
1108 producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
1109 auto offsets = loadGatherOp.getOffsets();
1110 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1111 !isa<VectorType>(loadGatherOp.getMask().getType()))
1112 return rewriter.notifyMatchFailure(
1113 loadGatherOp,
1114 "Load op must have a vector arguments for offsets and mask");
1115 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1116 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1117 VectorType resultVecTy =
1118 cast<VectorType>(loadGatherOp.getResult().getType());
1119 // add handling leading unit dimensions support
1120 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1121 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1122 for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1123 if (resultVecTy.getShape()[i] != 1) {
1124 return rewriter.notifyMatchFailure(
1125 loadGatherOp, "Only unit dimensions allowed for the leading "
1126 "dimensions of the load vector!");
1127 }
1128 }
1129
1130 auto layoutOffsets =
1131 xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
1132 auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
1133
1134 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1135 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1136 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1137 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1138 if (failed(distOffsetsByWarpOpOrFailure) ||
1139 failed(distMaskByWarpOpOrFailure)) {
1140 return rewriter.notifyMatchFailure(
1141 loadGatherOp,
1142 "Some vector operands have no layouts, using defaults instead.");
1143 }
1144
1145 SmallVector<size_t> newRetIndices;
1146 SmallVector<Value> operands = loadGatherOp->getOperands();
1147
1148 const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1149 VectorType distResultTy =
1150 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1151 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1152 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1153
1154 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1155 distOffsetsTy, distMaskTy};
1156
1157 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1158 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1159
1160 rewriter.setInsertionPointAfter(newWarpOp);
1161
1162 // Distributed load op will always be 1D.
1163 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1164 distResultTy.getElementType());
1165
1166 VectorType distOffsetsTy1D =
1167 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1168 distOffsetsByWarpOpOrFailure.value().getElementType());
1169 VectorType distMaskTy1D =
1170 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1171 distMaskByWarpOpOrFailure.value().getElementType());
1172
1173 Value distOffsetVal = resolveDistributedTy(
1174 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1175 Value distmaskVal = resolveDistributedTy(
1176 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1177
1178 SmallVector<Value> newLoadGatherOperands = {
1179 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1180
1181 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1182 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1183 loadGatherOp->getAttrs());
1185 Value distributedVal = newWarpOp.getResult(operandIdx);
1186 // Resolve the output type and replace all uses.
1187 rewriter.replaceAllUsesWith(
1188 distributedVal,
1189 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1190 return success();
1191 }
1192};
1193
1194// Sink SG-uniform ops. An op is uniform if none
1195// of its operands/results has a distribution layout attribute.
1196// Non-uniform vectors are handled by dedicated patterns.
1197// This pattern must have a higher priority than vector dialect distribution
1198// patterns, because a distributable shape may be logically intended as
1199// uniform (i.e., no layout), so we want to omit its distribution.
1200struct SinkUniformOps final : public gpu::WarpDistributionPattern {
1201 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1202 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1203 PatternRewriter &rewriter) const override {
1204 // Take the last op
1205 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1206 // Any ops with nested regions must be handled carefully in dedicated
1207 // patterns.
1208 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
1209 return failure();
1210 int operandIdx = -1;
1211 if (warpRegionPreYieldOp->getNumResults()) {
1212 OpOperand *operand = getWarpResult(
1213 warpOp, [&](Operation *op) { return warpRegionPreYieldOp == op; });
1214 if (!operand)
1215 return failure();
1216 operandIdx = operand->getOperandNumber();
1217 if (warpRegionPreYieldOp->getResult(0).getType() !=
1218 warpOp.getResult(operandIdx).getType())
1219 return rewriter.notifyMatchFailure(warpOp,
1220 "The op result is not uniform.");
1221 }
1222
1223 // The op must have no layout-based operands or results.
1224 bool uniformValuesOnly =
1225 llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
1226 return !xegpu::getDistributeLayoutAttr(v);
1227 });
1228 uniformValuesOnly &=
1229 llvm::all_of(warpRegionPreYieldOp->getOpOperands(), [](OpOperand &opr) {
1230 return !xegpu::getDistributeLayoutAttr(opr);
1231 });
1232 if (!uniformValuesOnly)
1233 return rewriter.notifyMatchFailure(warpOp,
1234 "Some values are not uniform.");
1235 SmallVector<size_t> newRetIndices;
1236 SmallVector<Value> operands =
1237 llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
1238 SmallVector<Type> operandTypes =
1239 llvm::to_vector_of<Type>(warpRegionPreYieldOp->getOperandTypes());
1240 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1241 rewriter, warpOp, operands, operandTypes, newRetIndices);
1242
1243 rewriter.setInsertionPointAfter(newWarpOp);
1244 IRMapping operandMapper;
1245 for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1246 operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),
1247 newWarpOp->getResult(newOperandIdx));
1248 Operation *clonedOp = rewriter.clone(*warpRegionPreYieldOp, operandMapper);
1249 if (!clonedOp->getNumResults())
1250 rewriter.eraseOp(warpRegionPreYieldOp);
1251 else {
1252 assert(operandIdx != -1 && "Expected a warp result for the operation");
1253 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx),
1254 clonedOp->getResult(0));
1255 }
1256 return success();
1257 }
1258};
1259
1260/// This patterns distribute the `vector.multi_reduction` operation across
1261/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1262/// layouts for the source and accumulator vectors,
1263/// * If the reduction dimension is distributed across lanes, the reduction is
1264/// non-lane-local and the reduction is done using warp shuffles. Here we
1265/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1266/// the warp op body.
1267/// * If the reduction dimension is not distributed across lanes, the reduction
1268/// is lane-local. In this case, we yield the source and accumulator vectors
1269/// from the warp op and perform the lane-local reduction outside the warp op
1270/// using a sequence of ReductionOps.
1271/// Example 1 (Reduction is lane-local):
1272/// ```
1273/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1274/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1275/// %acc = "some_def"() : () -> (vector<32xf32>)
1276/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1277/// vector<32xf32> gpu.yield %1 : vector<32xf32>
1278/// }
1279/// ```
1280/// is lowered to:
1281/// ```
1282/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1283/// vector<1xf32>) {
1284/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1285/// %acc = "some_def"() : () -> (vector<32xf32>)
1286/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1287/// }
1288/// %c = arith.constant dense<0.0> : vector<1xf32>
1289/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1290/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1291/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1292/// ```
1293/// Example 2 (Reduction is non-lane-local):
1294/// ```
1295/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1296/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1297/// %acc = "some_def"() : () -> (vector<2xf32>)
1298/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1299/// vector<2xf32>
1300/// gpu.yield %1 : vector<2xf32>
1301/// }
1302/// ```
1303/// is lowered to:
1304/// ```
1305/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1306/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1307/// %acc = "some_def"() : () -> (vector<2xf32>)
1308/// %1 = arith.constant dense<0.0> : vector<2xf32>
1309/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1310/// %3 = ("warp.reduction %2") : f32
1311/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1312/// ... repeat for row 1
1313/// gpu.yield %1 : vector<2xf32>
1314/// }
1315struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1316 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1317 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1318 PatternRewriter &rewriter) const override {
1319 OpOperand *yieldOperand =
1320 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1321 if (!yieldOperand)
1322 return failure();
1323 auto reductionOp =
1324 cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1325 unsigned operandIdx = yieldOperand->getOperandNumber();
1326 VectorType sourceType = reductionOp.getSourceVectorType();
1327 int64_t sourceRank = sourceType.getRank();
1328 // Need at least a 2D source vector.
1329 if (sourceRank < 2)
1330 return rewriter.notifyMatchFailure(warpOp,
1331 "Only 2D+ reductions are supported.");
1332 // Leading dimensions (first rank-2) must be unit (size 1).
1333 for (int64_t i = 0; i < sourceRank - 2; ++i) {
1334 if (sourceType.getShape()[i] != 1)
1335 return rewriter.notifyMatchFailure(
1336 warpOp, "Only unit dimensions allowed for the leading dimensions.");
1337 }
1338 // Effective dimension indices (last 2 dims of the source).
1339 int64_t rowIdx = sourceRank - 2;
1340 int64_t columnIdx = sourceRank - 1;
1341 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1342 if (reductionDims.size() != 1)
1343 return rewriter.notifyMatchFailure(warpOp,
1344 "Only 1 reduction dim is supported.");
1345 int64_t reductionDim = reductionDims[0];
1346 // The reduction dim must be among the last 2 dims.
1347 if (reductionDim != rowIdx && reductionDim != columnIdx)
1348 return rewriter.notifyMatchFailure(
1349 warpOp, "Reduction dim must be among the last 2 dimensions.");
1350 VectorType distributedResultType =
1351 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1352 VectorType resultType = cast<VectorType>(reductionOp.getType());
1353 xegpu::DistributeLayoutAttr sourceLayout =
1354 xegpu::getTemporaryLayout(reductionOp->getOpOperand(0));
1355
1356 FailureOr<VectorType> sourceDistTypeOrFailure =
1357 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1358 if (failed(sourceDistTypeOrFailure))
1359 return rewriter.notifyMatchFailure(
1360 warpOp, "Failed to distribute the source vector type.");
1361 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1362 // Only single dimension distribution among the last 2 dims is supported.
1363 bool rowDistributed =
1364 sourceDistType.getShape()[rowIdx] != sourceType.getShape()[rowIdx];
1365 bool columnDistributed = sourceDistType.getShape()[columnIdx] !=
1366 sourceType.getShape()[columnIdx];
1367 if (rowDistributed && columnDistributed)
1368 return rewriter.notifyMatchFailure(
1369 warpOp, "Expecting source to be distributed in a single dimension.");
1370 int64_t sourceDistDim =
1371 rowDistributed ? rowIdx : (columnDistributed ? columnIdx : -1);
1372 if (sourceDistDim == -1)
1373 return rewriter.notifyMatchFailure(
1374 warpOp, "Expecting a distributed source vector.");
1375 bool resultDistributed =
1376 distributedResultType.getNumElements() < resultType.getNumElements();
1377 // If the lane owns all the data required for reduction (i.e. reduction is
1378 // fully parallel accross lanes), then each lane owns part of the result
1379 // (i.e. result is distributed). If the reduction require cross-lane
1380 // shuffling, then the result is shared among all lanes (broadcasted).
1381 // Therefore we expect following cases:
1382 //
1383 // | Source vector | Reduction dim | Result vector |
1384 // |----------------------|----------------|----------------|
1385 // | dim-0 distributed | 0 | broadcasted |
1386 // | dim-0 distributed | 1 | distributed |
1387 // | dim-1 distributed | 0 | distributed |
1388 // | dim-1 distributed | 1 | broadcasted |
1389
1390 bool isReductionLaneLocal =
1391 (sourceDistDim == rowIdx && reductionDim == columnIdx) ||
1392 (sourceDistDim == columnIdx && reductionDim == rowIdx);
1393 if (isReductionLaneLocal && !resultDistributed)
1394 return rewriter.notifyMatchFailure(
1395 warpOp, "Expecting a distributed result for lane-local reduction.");
1396
1397 if (!isReductionLaneLocal && resultDistributed)
1398 return rewriter.notifyMatchFailure(
1399 warpOp,
1400 "Expecting a broadcasted result for non-lane-local reduction.");
1401
1402 // Handle lane-local reduction case. In this case we fully distribute the
1403 // reduction result.
1404 if (isReductionLaneLocal) {
1405 // Yield the source and acc vectors from the WarpOp.
1406 SmallVector<size_t> newRetIndices;
1407 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1408 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1409 {sourceDistType, distributedResultType}, newRetIndices);
1410 rewriter.setInsertionPointAfter(newWarpOp);
1412 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1413 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1414 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1415 // Replace the warp op result with the final result.
1416 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1417 return success();
1418 }
1419 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1420 // of multiple ReductionOps. Actual distribution is done by the
1421 // WarpOpReduction pattern.
1422 rewriter.setInsertionPointAfter(reductionOp);
1424 cast<TypedValue<VectorType>>(reductionOp.getSource()),
1425 cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1426 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1427 // Replace the warp op result with the final result.
1428 rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1429 return success();
1430 }
1431};
1432
1433/// This pattern distributes the `vector.broadcast` operation across lanes in a
1434/// warp. The pattern supports three use cases:
1435///
1436/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1437/// vector
1438/// must have a slice layout of the result. If the distributed source and
1439/// target vector types are identical, this lowers to a no-op; otherwise, it
1440/// remains a broadcast but operates on distributed vectors.
1441///
1442/// 2) Broadcast a same-rank vector with identical layouts for source and
1443/// target:
1444/// The source vector must have unit dimensions, and lane_data must be unit
1445/// size for those unit dims. This always lowers to a no-op.
1446///
1447/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
1448/// scalar to distributed result type.
1449///
1450/// Example 1 (lowering to a broadcast with distributed types):
1451/// ```
1452/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1453/// %0 = "some_def"() {layout_result_0 =
1454/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1455/// dims = [0]> } : () -> (vector<32xf32>)
1456/// %2 = vector.broadcast %0 {layout_result_0 =
1457/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
1458/// : vector<32xf32> to vector<8x32xf32>
1459/// gpu.yield %1 : vector<8x32xf32>
1460/// }
1461/// ```
1462/// is lowered to:
1463/// ```
1464/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1465/// %0 = "some_def"() {layout_result_0 =
1466/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1467/// dims = [0]> } : () -> (vector<32xf32>)
1468/// gpu.yield %0 : vector<32xf32>
1469/// }
1470/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
1471///
1472/// Example 2 (no-op):
1473/// ```
1474/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
1475/// %0 = "some_def"() {layout_result_0 =
1476/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1477/// dims = [1]> } : () -> (vector<8xf32>)
1478/// %1 = vector.shape_cast %0
1479/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1480/// 1]>}: vector<8xf32> to vector<8x1xf32>
1481/// %2 = vector.broadcast %1
1482/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1483/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
1484/// gpu.yield %1 : vector<8x32xf32>
1485/// }
1486/// ```
1487/// is lowered to:
1488/// ```
1489/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1490/// %0 = "some_def"() {layout_result_0 =
1491/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1492/// dims = [1]> } : () -> (vector<8xf32>)
1493/// %1 = vector.shape_cast %0
1494/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1495/// 1]>}: vector<8xf32> to vector<8x1xf32>
1496/// gpu.yield %1 : vector<8x1xf32>
1497/// }
1498/// // The broadcast is implicit through layout transformation (no-op)
1499/// "some_use"(%r#0)
1500/// ```
1501struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
1502 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1503 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1504 PatternRewriter &rewriter) const override {
1505 OpOperand *yieldOperand =
1506 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1507 if (!yieldOperand)
1508 return failure();
1509 auto broadcastOp =
1510 cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
1511 unsigned operandIdx = yieldOperand->getOperandNumber();
1512
1513 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1514 VectorType destType =
1515 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1516
1517 xegpu::DistributeLayoutAttr sourceLayout =
1518 xegpu::getTemporaryLayout(broadcastOp->getOpOperand(0));
1519 xegpu::DistributeLayoutAttr resultLayout =
1520 xegpu::getTemporaryLayout(dyn_cast<OpResult>(broadcastOp.getResult()));
1521
1522 FailureOr<VectorType> sourceDistType;
1523 Type sourceElemOrDistType;
1524 if (sourceType) {
1525
1526 // Case 1 and 2: source is a vector type.
1527 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1528 if (rankDiff > 0) {
1529 // Case 1: source is lower-rank than result.
1530 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1531 if (!isSliceOf)
1532 broadcastOp.emitWarning()
1533 << "Broadcast input layout must be a slice of result layout.";
1534 }
1535 // case 2: source and result have same rank
1536 if (rankDiff == 0) {
1537 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1538 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1539 broadcastUnitDimsSet.end());
1540 assert(sourceLayout.isEqualTo(
1541 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1542 "The sg_data for unit dimensions should be set as 1");
1543 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1544 }
1545
1546 sourceDistType =
1547 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1548 if (failed(sourceDistType)) {
1549 return rewriter.notifyMatchFailure(
1550 warpOp, "Failed to distribute the source vector type.");
1551 }
1552 sourceElemOrDistType = sourceDistType.value();
1553
1554 } else {
1555 // Case 3: source is a scalar type.
1556 if (sourceLayout) {
1557 return rewriter.notifyMatchFailure(
1558 warpOp, "Broadcast from scalar must not have a layout attribute.");
1559 }
1560 sourceElemOrDistType = broadcastOp.getSourceType();
1561 }
1562 FailureOr<VectorType> destDistType =
1563 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1564 if (failed(destDistType)) {
1565 return rewriter.notifyMatchFailure(
1566 warpOp, "Failed to distribute the dest vector type.");
1567 }
1568
1569 SmallVector<size_t> newRetIndices;
1571 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1572 newRetIndices);
1573
1574 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1575
1576 Value newBroadcast = distributedSource;
1577
1578 if (sourceElemOrDistType != destDistType.value()) {
1579 rewriter.setInsertionPointAfter(newWarpOp);
1580 newBroadcast =
1581 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1582 destDistType.value(), distributedSource);
1583 }
1584
1585 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
1586 return success();
1587 }
1588};
1589
1590/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1591/// `gpu.warp_execute_on_lane_0` region.
1592struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1593 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1594 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1595 PatternRewriter &rewriter) const override {
1596 OpOperand *yieldOperand =
1597 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1598 if (!yieldOperand)
1599 return failure();
1600 auto shapeCastOp =
1601 cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1602 unsigned operandNumber = yieldOperand->getOperandNumber();
1603 auto resultDistTy =
1604 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1605 xegpu::DistributeLayoutAttr sourceLayout =
1606 xegpu::getTemporaryLayout(shapeCastOp->getOpOperand(0));
1607 xegpu::DistributeLayoutAttr resultLayout =
1608 xegpu::getTemporaryLayout(dyn_cast<OpResult>(shapeCastOp.getResult()));
1609 if (!sourceLayout || !resultLayout)
1610 return rewriter.notifyMatchFailure(
1611 warpOp,
1612 "the source or result of shape_cast op lacks distribution layout");
1613
1614 FailureOr<VectorType> sourceDistTypeOrFailure =
1616 shapeCastOp.getSourceVectorType());
1617 if (failed(sourceDistTypeOrFailure))
1618 return rewriter.notifyMatchFailure(
1619 warpOp, "failed to get distributed vector type for source");
1620 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1621 // Create a new warp op that yields the source of the shape_cast op.
1622 SmallVector<size_t> newRetIndices;
1624 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1625 newRetIndices);
1626 rewriter.setInsertionPointAfter(newWarpOp);
1627 Value source = newWarpOp.getResult(newRetIndices[0]);
1628 // Create a new shape_cast op outside the warp op.
1629 Value newShapeCast = vector::ShapeCastOp::create(
1630 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1631 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1632 newShapeCast);
1633 return success();
1634 }
1635};
1636
1637// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1638// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1639// advanced cases where the distributed dimension is partially extracted and
1640// currently not supported by the generic vector distribution patterns.
1641struct VectorExtractStridedSliceDistribution
1643 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1644 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1645 PatternRewriter &rewriter) const override {
1646 OpOperand *operand =
1647 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1648 if (!operand)
1649 return failure();
1650 auto extractOp =
1651 cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1652 unsigned operandIdx = operand->getOperandNumber();
1653 auto distributedType =
1654 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1655 // Find the distributed dimensions.
1656 auto extractResultType = cast<VectorType>(operand->get().getType());
1657 auto distributedDims =
1658 getDistributedDims(extractResultType, distributedType);
1659 // Collect updated source type, sizes and offsets. They may be adjusted
1660 // later if the data is distributed to lanes (as opposed to being owned by
1661 // all lanes uniformly).
1662 VectorType updatedSourceType = extractOp.getSourceVectorType();
1663 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1664 extractOp.getSizes(), [](Attribute attr) { return attr; });
1665 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1666 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1667 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1668 extractOp.getStrides(), [](Attribute attr) { return attr; });
1669 // If the provided sizes, offsets, strides are less than the rank, pad them
1670 // with full sizes, zero offsets, and unit strides. This makes it easier to
1671 // adjust them later.
1672 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1673 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1674 updatedSizes.push_back(rewriter.getI64IntegerAttr(
1675 extractOp.getSourceVectorType().getDimSize(i)));
1676 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1677 updatedStrides.push_back(
1678 rewriter.getI64IntegerAttr(1)); // stride is always 1.
1679 }
1680 // If the result is distributed, it must be distributed in exactly one
1681 // dimension. In this case, we adjust the sourceDistType, distributedSizes
1682 // and distributedOffsets accordingly.
1683 if (distributedDims.size() > 0) {
1684 if (distributedDims.size() != 1)
1685 return rewriter.notifyMatchFailure(
1686 warpOp, "Source can not be distributed in multiple dimensions.");
1687 int64_t distributedDim = distributedDims[0];
1688 int sourceDistrDimSize =
1689 extractOp.getSourceVectorType().getShape()[distributedDim];
1690 auto sourceLayout = xegpu::getTemporaryLayout(extractOp->getOpOperand(0));
1691 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1692 return rewriter.notifyMatchFailure(
1693 warpOp, "the source of extract_strided_slice op lacks distribution "
1694 "layout");
1695 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1696 // Because only single dimension distribution is supported, lane layout
1697 // size at the distributed dim must be the subgroup size.
1698 int subgroupSize = sourceLaneLayout[distributedDim];
1699 // Check if the source size in the distributed dimension is a multiple of
1700 // subgroup size.
1701 if (sourceDistrDimSize % subgroupSize != 0)
1702 return rewriter.notifyMatchFailure(
1703 warpOp,
1704 "Source size along distributed dimension is not a multiple of "
1705 "subgroup size.");
1706 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1707 // We expect lane data to be all ones in this case.
1708 if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1709 return rewriter.notifyMatchFailure(
1710 warpOp, "Expecting unit lane data in source layout");
1711 // The offsets in the distributed dimention must be a multiple of subgroup
1712 // size.
1713 int64_t distrDimOffset =
1714 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1715 if (distrDimOffset % subgroupSize != 0)
1716 return rewriter.notifyMatchFailure(
1717 warpOp, "Offset along distributed dimension "
1718 "is not a multiple of subgroup size.");
1719 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1720 sourceLayout, extractOp.getSourceVectorType())
1721 .value();
1722 // Update the distributed sizes to match the distributed type.
1723 updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1724 distributedType.getDimSize(distributedDim));
1725 // Update the distributed offsets to match round robin distribution (i.e.
1726 // each lane owns data at `subgroupSize` stride given unit lane data).
1727 updatedOffsets[distributedDim] =
1728 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1729 }
1730 // Do the distribution by yielding the source of the extract op from
1731 // the warp op and creating a new extract op outside the warp op.
1732 SmallVector<size_t> newRetIndices;
1734 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1735 newRetIndices);
1736 rewriter.setInsertionPointAfter(newWarpOp);
1737 Value source = newWarpOp.getResult(newRetIndices[0]);
1738 // Create a new extract op outside the warp op.
1739 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1740 rewriter, extractOp.getLoc(), distributedType, source,
1741 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1742 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1743 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1744 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1745 return success();
1746 }
1747};
1748
1749/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1750/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1751/// advanced cases where the distributed dimension is partially inserted and
1752/// currently not supported by the generic vector distribution patterns.
1753struct VectorInsertStridedSliceDistribution
1755 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1756 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1757 PatternRewriter &rewriter) const override {
1758 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
1759 // Check if the InsertStridedSliceOp is the last op before yield op
1760 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1761 warpOp.getTerminator()->getPrevNode() == op;
1762 });
1763 if (!operand)
1764 return failure();
1765 unsigned int operandNumber = operand->getOperandNumber();
1766 auto insertOp =
1767 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1768 auto distributedType =
1769 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1770 // Find the distributed dimensions of the dest vector.
1771 auto insertResultType = cast<VectorType>(operand->get().getType());
1772 auto destDistributedDims =
1773 getDistributedDims(insertResultType, distributedType);
1774 // Collect updated offsets, source type and dest type. They may be adjusted
1775 // later if the data is distributed to lanes (as opposed to being owned by
1776 // all lanes uniformly).
1777 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1778 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1779 VectorType updatedSourceType = insertOp.getSourceVectorType();
1780 VectorType updatedDestType = insertOp.getDestVectorType();
1781 if (destDistributedDims.size() > 0) {
1782 // Only single dimension distribution is supported.
1783 if (destDistributedDims.size() != 1)
1784 return rewriter.notifyMatchFailure(
1785 warpOp,
1786 "Expecting source to be distributed in a single dimension.");
1787 int64_t destDistributedDim = destDistributedDims[0];
1788
1789 VectorType srcType = insertOp.getSourceVectorType();
1790 VectorType destType = insertOp.getDestVectorType();
1791 // Currently we require that both source (kD) and dest (nD) vectors are
1792 // distributed. This requires that distributedDim (d) is contained in the
1793 // last k dims of the dest vector (d >= n - k).
1794 int64_t sourceDistributedDim =
1795 destDistributedDim - (destType.getRank() - srcType.getRank());
1796 if (sourceDistributedDim < 0)
1797 return rewriter.notifyMatchFailure(
1798 insertOp,
1799 "distributed dimension must be in the last k (i.e. source "
1800 "rank) dims of dest vector");
1801 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1802 // Obtain the source and dest layouts.
1803 auto destLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(1));
1804 auto sourceLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(0));
1805 if (!destLayout || !sourceLayout ||
1806 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1807 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1808 return rewriter.notifyMatchFailure(
1809 warpOp, "the source or dest of insert_strided_slice op lacks "
1810 "distribution layout");
1811 // Because only single dimension distribution is supported, lane layout
1812 // size at the distributed dim must be the subgroup size.
1813 int subgroupSize =
1814 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1815 // We require that source and dest lane data are all ones to ensure
1816 // uniform round robin distribution.
1817 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1818 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1819 if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1820 !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1821 return rewriter.notifyMatchFailure(
1822 warpOp, "Expecting unit lane data in source and dest layouts");
1823 // Source distributed dim size must be multiples of subgroup size.
1824 if (srcDistrDimSize % subgroupSize != 0)
1825 return rewriter.notifyMatchFailure(
1826 warpOp, "Distributed dimension size in source is not a multiple of "
1827 "subgroup size.");
1828 // Offsets in the distributed dimension must be multiples of subgroup
1829 // size.
1830 int64_t destDistrDimOffset =
1831 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1832 if (destDistrDimOffset % subgroupSize != 0)
1833 return rewriter.notifyMatchFailure(
1834 warpOp,
1835 "Offset along distributed dimension in dest is not a multiple of "
1836 "subgroup size.");
1837 // Update the source and dest types based on their layouts.
1838 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1839 sourceLayout, insertOp.getSourceVectorType())
1840 .value();
1841 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1842 destLayout, insertOp.getDestVectorType())
1843 .value();
1844 // Update the distributed offsets to match round robin distribution (i.e.
1845 // each lane owns data at `subgroupSize` stride given unit lane data).
1846 updatedOffsets[destDistributedDim] =
1847 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1848 }
1849 // Do the distribution by yielding the source and dest of the insert op
1850 // from the warp op and creating a new insert op outside the warp op.
1851 SmallVector<size_t> newRetIndices;
1853 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1854 {updatedSourceType, updatedDestType}, newRetIndices);
1855 rewriter.setInsertionPointAfter(newWarpOp);
1856
1857 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1858 Value dest = newWarpOp.getResult(newRetIndices[1]);
1859 // Create a new insert op outside the warp op.
1860 Value newInsertOp = vector::InsertStridedSliceOp::create(
1861 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1862 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1863 insertOp.getStrides());
1864 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1865 newInsertOp);
1866 return success();
1867 }
1868};
1869
1870/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1871/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1872/// outside of the warp op.
1873struct MemrefExtractAlignedPointerAsIndexDistribution final
1875 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1876 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1877 PatternRewriter &rewriter) const override {
1878 OpOperand *operand = getWarpResult(
1879 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1880 if (!operand)
1881 return rewriter.notifyMatchFailure(
1882 warpOp,
1883 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1884 auto extractOp =
1885 operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1886 unsigned operandIdx = operand->getOperandNumber();
1887 SmallVector<size_t> newRetIndices;
1888 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1889 rewriter, warpOp, extractOp.getSource(),
1890 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1891 rewriter.setInsertionPointAfter(newWarpOp);
1892 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1893 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1894 newWarpOp.getResult(newRetIndices[0]));
1895 Value resultVal = newWarpOp.getResult(operandIdx);
1896 rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
1897 return success();
1898 }
1899};
1900
1901/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1902/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1903/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1904/// created outside of the warp op with distributed source vector type (computed
1905/// using assigned layout).
1906struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1907 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1908 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1909 PatternRewriter &rewriter) const override {
1910 OpOperand *operand =
1911 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1912 if (!operand)
1913 return rewriter.notifyMatchFailure(
1914 warpOp, "warp result is not a vector::BitCast op");
1915 auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1916 unsigned operandIdx = operand->getOperandNumber();
1917 VectorType distributedSourceType =
1919 xegpu::getTemporaryLayout(bitcastOp->getOpOperand(0)),
1920 bitcastOp.getSourceVectorType())
1921 .value_or(VectorType());
1922 if (!distributedSourceType)
1923 return rewriter.notifyMatchFailure(
1924 bitcastOp, "Failed to distribute the source vector type in "
1925 "vector::BitCast op");
1926 VectorType distributedResultType =
1927 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1928 SmallVector<size_t> newRetIndices;
1929 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1930 rewriter, warpOp, bitcastOp.getSource(),
1931 TypeRange{distributedSourceType}, newRetIndices);
1932 rewriter.setInsertionPointAfter(newWarpOp);
1933 auto newBitcastOp = vector::BitCastOp::create(
1934 rewriter, newWarpOp.getLoc(), distributedResultType,
1935 newWarpOp.getResult(newRetIndices[0]));
1936 Value distributedVal = newWarpOp.getResult(operandIdx);
1937 rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1938 return success();
1939 }
1940};
1941
1942/// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1943/// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1944/// supported. In most cases, transpose is a no op because it is entirely
1945/// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1946/// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1947/// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1948/// vector::TransposeOp outside of the warp op with distributed source vector
1949/// type (computed using assigned layout).
1950struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1951 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1952 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1953 PatternRewriter &rewriter) const override {
1954 OpOperand *operand =
1955 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1956 if (!operand)
1957 return rewriter.notifyMatchFailure(
1958 warpOp, "warp result is not a vector::Transpose op");
1959 auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1960 unsigned operandIdx = operand->getOperandNumber();
1961 xegpu::DistributeLayoutAttr sourceLayout =
1962 xegpu::getTemporaryLayout(transposeOp->getOpOperand(0));
1963 xegpu::DistributeLayoutAttr resultLayout =
1964 xegpu::getTemporaryLayout(transposeOp->getOpResult(0));
1965 if (!sourceLayout || !resultLayout)
1966 return rewriter.notifyMatchFailure(
1967 transposeOp,
1968 "the source or result vector of the transpose op lacks layout "
1969 "attribute");
1970 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1971 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1972 // Only 2D transposes are supported for now.
1973 // TODO: Support nD transposes.
1974 if (sourceRank != 2 || resultRank != 2)
1975 return rewriter.notifyMatchFailure(
1976 transposeOp, "the source or result vector of the transpose op "
1977 "does not have 2D layout");
1978 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1979 // Result layout must be a transpose of source layout.
1980 if (!resultLayout.isTransposeOf(sourceLayout, perm,
1981 xegpu::LayoutKind::Lane))
1982 return rewriter.notifyMatchFailure(
1983 transposeOp,
1984 "the source or result vector layouts must be 2D transposes of each "
1985 "other");
1986 FailureOr<VectorType> distributedSourceTypeOrFailure =
1988 transposeOp.getSourceVectorType());
1989 if (failed(distributedSourceTypeOrFailure))
1990 return rewriter.notifyMatchFailure(
1991 transposeOp, "Failed to distribute the source vector type in "
1992 "vector::Transpose op");
1993 SmallVector<size_t> newRetIndices;
1994 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1995 rewriter, warpOp, transposeOp.getVector(),
1996 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1997 rewriter.setInsertionPointAfter(newWarpOp);
1998 auto newTransposeOp = vector::TransposeOp::create(
1999 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2000 perm);
2001 Value distributedVal = newWarpOp.getResult(operandIdx);
2002 rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
2003 return success();
2004 }
2005};
2006
2007/// Distribute a vector::StepOp with the sliced result layout.
2008/// The sliced layout must have exactly 1 effective lane dimension.
2009/// We completely resolve the vector::StepOp by computing the lane_data-sized
2010/// subranges.
2011struct VectorStepSliceDistribution final : public gpu::WarpDistributionPattern {
2012 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2013 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2014 PatternRewriter &rewriter) const override {
2015 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2016 if (!operand)
2017 return rewriter.notifyMatchFailure(
2018 warpOp, "warp result is not a vector::StepOp op");
2019 auto stepOp = operand->get().getDefiningOp<vector::StepOp>();
2020 unsigned operandIdx = operand->getOperandNumber();
2021 xegpu::DistributeLayoutAttr resultLayout =
2022 xegpu::getTemporaryLayout(stepOp->getResult(0));
2023 if (!resultLayout)
2024 return rewriter.notifyMatchFailure(
2025 stepOp, "the result vector of the step op lacks layout "
2026 "attribute");
2027 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2028 if (!sliceLayout)
2029 return rewriter.notifyMatchFailure(
2030 stepOp, "the result layout must be a slice layout");
2031 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2032 return rewriter.notifyMatchFailure(
2033 stepOp, "expecting 1 dim in the effective result layout");
2034
2035 rewriter.setInsertionPointAfter(warpOp);
2036 auto loc = stepOp.getLoc();
2037 auto stepResultVecTy = stepOp.getResult().getType();
2038 Value distributedVal = warpOp.getResult(operandIdx);
2039 VectorType newVecTy = cast<VectorType>(distributedVal.getType());
2040
2041 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2042 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2043 if (failed(laneDataBlockCoords))
2044 return rewriter.notifyMatchFailure(
2045 stepOp, "failed to compute lane data block coordinates");
2046
2047 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2048 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2049 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
2050 newVecTy.getNumElements() / laneDataBlockLength);
2051 SmallVector<Value> stepVals;
2052 // For each lane_data block, reconstruct its sub-range
2053 // from the range of SG-level vector.step. Example: vector.step
2054 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
2055 // vector<16xindex>
2056 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
2057 // The blocks are round-robin distributed, so logical lane id 0
2058 // holds values [0,1, 8,9].
2059 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2060 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2061 stepVals.push_back(laneDataBlockStartCoord);
2062 for (int i = 1; i < laneDataBlockLength; ++i) {
2063 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
2064 stepVals.push_back(arith::AddIOp::create(
2065 rewriter, loc, laneDataBlockStartCoord, offset));
2066 }
2067 }
2068 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
2069 "Expecting the number of step values to match the number of "
2070 "elements in the vector");
2071 auto stepOpVal =
2072 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2073 rewriter.replaceAllUsesWith(distributedVal, stepOpVal);
2074 return success();
2075 }
2076};
2077
2078struct ConvertLayoutDistribution
2079 : public OpRewritePattern<xegpu::ConvertLayoutOp> {
2081
2082 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
2083 PatternRewriter &rewriter) const override {
2084 auto inputLayout = op.getInputLayoutAttr();
2085 auto targetLayout = op.getTargetLayoutAttr();
2086 auto resShape = cast<VectorType>(op.getResult().getType()).getShape();
2087
2088 if (!inputLayout || !targetLayout)
2089 return rewriter.notifyMatchFailure(op, "missing layout attributes");
2090
2091 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
2092 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
2093 xegpu::LayoutKind::Lane)) {
2094 return rewriter.notifyMatchFailure(
2095 op, "lowering incompatible convert_layout not yet supported");
2096 }
2097 rewriter.replaceOp(op, op.getSource());
2098 return success();
2099 }
2100};
2101
2102} // namespace
2103
2104namespace {
2105struct XeGPUSubgroupDistributePass final
2107 XeGPUSubgroupDistributePass> {
2108 void runOnOperation() override;
2109};
2110} // namespace
2111
2113 RewritePatternSet &patterns) {
2114 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2115 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2116 GpuBarrierDistribution, VectorMultiReductionDistribution,
2117 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2118 VectorBitcastDistribution, LoadMatrixDistribution,
2119 StoreMatrixDistribution, ConvertLayoutDistribution,
2120 MemrefExtractAlignedPointerAsIndexDistribution>(
2121 patterns.getContext(),
2122 /*pattern benefit=*/PatternHierarchy::Regular);
2123 // For following patterns, we need to override the regular vector distribution
2124 // patterns. Therefore, assign higher benefit.
2125 patterns
2126 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2127 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2128 VectorStepSliceDistribution, SinkUniformOps>(
2129 patterns.getContext(),
2130 /*pattern benefit=*/PatternHierarchy::AboveRegular);
2131}
2132
2134 RewritePatternSet &patterns) {
2135 patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
2136}
2137
2138void XeGPUSubgroupDistributePass::runOnOperation() {
2139 // Step 1: Attach layouts to op operands.
2140 // TODO: Following assumptions are made:
2141 // 1) It is assumed that there are no layout conflicts.
2142 // 2) Any existing layout attributes attached to the operands are ignored.
2143 Operation *op = getOperation();
2145 signalPassFailure();
2146 return;
2147 }
2148
2149 // Step 2: Move all operations of a GPU function inside
2150 // gpu.warp_execute_on_lane_0 operation.
2151 {
2152 RewritePatternSet patterns(&getContext());
2154
2155 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2156 signalPassFailure();
2157 return;
2158 }
2159 // At this point, we have moved the entire function body inside the
2160 // warpOp. Now move any scalar uniform code outside of the warpOp (like
2161 // GPU index ops, scalar constants, etc.). This will simplify the
2162 // later lowering and avoid custom patterns for these ops.
2163 getOperation()->walk([&](Operation *op) {
2164 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2165 vector::moveScalarUniformCode(warpOp);
2166 });
2167 }
2168 // Step 3: Apply subgroup to workitem distribution patterns.
2169 RewritePatternSet patterns(&getContext());
2171 // distributionFn is used by vector distribution patterns to determine the
2172 // distributed vector type for a given vector value. In XeGPU subgroup
2173 // distribution context, we compute this based on lane layout.
2174 auto distributionFn = [](Value val) {
2175 VectorType vecType = dyn_cast<VectorType>(val.getType());
2176 int64_t vecRank = vecType ? vecType.getRank() : 0;
2177 if (vecRank == 0)
2178 return AffineMap::get(val.getContext());
2179 // Get the layout of the vector type.
2180 xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
2181 // If no layout is specified, assume uniform case (no distribution).
2182 if (!layout)
2183 return AffineMap::get(val.getContext());
2184 // Expecting vector and layout rank to match.
2185 assert(layout.getRank() == vecRank &&
2186 "Expecting vector and layout rank to match");
2187 // A dimension is distributed only if layout suggests there are
2188 // multiple lanes assigned for this dimension and the shape can be evenly
2189 // distributed to those lanes.
2190 SmallVector<unsigned int> distributedDims;
2191 for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2192 if (v > 1 && vecType.getShape()[i] % v == 0)
2193 distributedDims.push_back(i);
2194 }
2195 return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
2196 val.getContext());
2197 };
2198 // TODO: shuffleFn is not used.
2199 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2200 int64_t warpSz) { return Value(); };
2201
2202 vector::populateDistributeReduction(
2203 patterns, xegpu::subgroupReduction,
2204 /*pattern benefit=*/PatternHierarchy::Regular);
2205
2206 vector::populatePropagateWarpVectorDistributionPatterns(
2207 patterns, distributionFn, shuffleFn,
2208 /*pattern benefit=*/PatternHierarchy::Regular);
2209 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2210 signalPassFailure();
2211 return;
2212 }
2213
2214 // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
2215 // due to tensor desc type mismatches created by using upstream distribution
2216 // patterns (scf.for). This cleanup should only be done if all the ops are
2217 // distributed successfully, if some ops are still not distributed and remains
2218 // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
2219 // breaking the IR.
2220 bool foundWarpOp = false;
2221 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2222 // Look for WarpOps that are not trivially dead.
2223 if (isOpTriviallyDead(warpOp))
2224 return WalkResult::advance();
2225 foundWarpOp = true;
2226 return WalkResult::interrupt();
2227 });
2228 if (foundWarpOp)
2229 return;
2230
2231 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2232 // We are only interested in UnrealizedConversionCastOps there were added
2233 // for resolving SIMT type mismatches.
2234 if (!op->getAttr(resolveSIMTTypeMismatch))
2235 return WalkResult::skip();
2236
2237 Value input = op.getOperand(0);
2238 Value output = op.getResult(0);
2239
2240 // Both input and output must have tensor descriptor types.
2241 xegpu::TensorDescType inputDescType =
2242 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
2243 xegpu::TensorDescType outputDescType =
2244 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
2245 assert(inputDescType && outputDescType &&
2246 "Unrealized conversion cast must have tensor descriptor types");
2247
2248 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
2249 // This occurs inside scf.for body to resolve the block argument type to
2250 // SIMT type.
2251 if (inputDescType.getLayout()) {
2252 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2253 if (argument) {
2254 argument.setType(output.getType());
2255 output.replaceAllUsesWith(argument);
2256 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2257 argument.getOwner()->getParentOp())) {
2258 auto result = loopOp.getTiedLoopResult(argument);
2259 result.setType(output.getType());
2260 }
2261 }
2262 }
2263
2264 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
2265 // conversions. This occurs at the yield op of scf.for body to go back
2266 // from SIMT type to original type.
2267 if (outputDescType.getLayout())
2268 output.replaceAllUsesWith(input);
2269
2270 if (op->use_empty())
2271 op->erase();
2272 return WalkResult::advance();
2273 });
2274}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:412
operand_type_range getOperandTypes()
Definition Operation.h:426
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0