MLIR 23.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
27#include "mlir/IR/Operation.h"
29#include "mlir/IR/TypeRange.h"
30#include "mlir/IR/Value.h"
31#include "mlir/IR/Visitors.h"
33#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53
54static const char *const resolveSIMTTypeMismatch =
55 "resolve_simt_type_mismatch"; // Attribute name for identifying
56 // UnrelizedConversionCastOp added to resolve
57 // SIMT type mismatches.
58
59namespace {
60
61//===----------------------------------------------------------------------===//
62// SIMT Distribution Patterns
63//===----------------------------------------------------------------------===//
64
65/// In certain cases, we may need to favor XeGPU specific distribution patterns
66/// over generic vector distribution patterns. In such cases, we can assign
67/// priorities to patterns.
68enum PatternHierarchy : unsigned { Regular = 1, AboveRegular = 2 };
69
70/// Helper function to resolve types if the distributed type out of
71/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
72/// Example 1:
73/// distributed type: vector<8x1xf32>
74/// expected type: vector<8xf32>
75/// resolved using,
76/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
77/// Example 2:
78/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
79/// expected type: xegpu.tensor_desc<8x16xf32>
80/// resolved using,
81/// %0 = unrealized_conversion_cast %1 :
82/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
83/// xegpu.tensor_desc<8x16xf32>
84template <typename T>
85static Value resolveDistributedTy(Value orig, T expected,
86 PatternRewriter &rewriter) {
87 // If orig and expected types are the same, return orig.
88 if (orig.getType() == expected)
89 return orig;
90 // If orig is a vector type, create a shape cast op to reconcile the types.
91 if (isa<VectorType>(orig.getType())) {
92 auto castOp =
93 vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
94 return castOp.getResult();
95 }
96 // If orig is a tensor descriptor type, create an unrealized conversion cast
97 // op to reconcile the types.
98 if (isa<xegpu::TensorDescType>(orig.getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
100 expected, orig);
101 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
102 return castOp.getResult(0);
103 }
104 llvm_unreachable("Unsupported type for reconciliation");
105 return orig;
106}
107
108/// Given a vector type and its distributed vector type, return the list of
109/// dimensions that are distributed.
110static SmallVector<int64_t> getDistributedDims(VectorType originalType,
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
114 SmallVector<int64_t> distributedDims;
115 for (int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
118 }
119 }
120 return distributedDims;
121}
122
123/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
124/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
125/// contained within a WarpExecuteOnLane0Op.
126/// Example:
127///
128/// ```
129/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
130/// ...
131/// ...
132/// gpu.return %result: vector<8x16xf32>
133/// }
134/// ```
135/// To
136/// ```
137/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
138/// %laneid = gpu.lane_id : index
139/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
140/// ...
141/// ...
142/// gpu.yield %result: vector<8x16xf32>
143/// }
144/// return %0
145/// }
146struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
147 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
148 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
149 PatternRewriter &rewriter) const override {
150 auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
151 if (!uArch)
152 return rewriter.notifyMatchFailure(
153 gpuFuncOp, "Subgroup distribution requires target attribute attached "
154 "to set the warp size");
155 if (!gpuFuncOp.getBody().hasOneBlock())
156 return rewriter.notifyMatchFailure(
157 gpuFuncOp, "expected gpu.func to have a single block");
158
159 // If the function only contains a single void return, skip.
160 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
161 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
162 }))
163 return failure();
164 // If the function already moved inside a warp_execute_on_lane0, skip.
165 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
166 return isa<gpu::WarpExecuteOnLane0Op>(op);
167 }))
168 return failure();
169 gpu::ReturnOp origReturnOp = dyn_cast_if_present<gpu::ReturnOp>(
170 gpuFuncOp.getBlocks().back().getTerminator());
171 if (!origReturnOp)
172 return rewriter.notifyMatchFailure(
173 gpuFuncOp, "expected gpu.func terminator to be gpu.return");
174 // Create a new function with the same signature and same attributes.
175 SmallVector<Type> workgroupAttributionsTypes =
176 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributionBBArgs(),
177 [](BlockArgument arg) { return arg.getType(); });
178 SmallVector<Type> privateAttributionsTypes =
179 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
180 [](BlockArgument arg) { return arg.getType(); });
181 auto newGpuFunc = gpu::GPUFuncOp::create(
182 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
183 gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
184 privateAttributionsTypes);
185 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
186 // Create a WarpExecuteOnLane0Op with same arguments and results as the
187 // original gpuFuncOp.
188 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
189 auto laneId = gpu::LaneIdOp::create(
190 rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
191 /** upperBound = **/ mlir::IntegerAttr());
192 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
193 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
194 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
195 uArch->getSubgroupSize(), newGpuFunc.getArguments(),
196 newGpuFunc.getArgumentTypes());
197 Block &warpBodyBlock = warpOp.getBodyRegion().front();
198 // Replace the ReturnOp of the original gpu function with a YieldOp.
199 rewriter.setInsertionPointAfter(origReturnOp);
200 gpu::YieldOp::create(rewriter, origReturnOp.getLoc(),
201 origReturnOp.getOperands());
202 rewriter.eraseOp(origReturnOp);
203 // Move the original function body to the WarpExecuteOnLane0Op body.
204 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
205 warpOp.getBodyRegion().begin());
206 rewriter.eraseBlock(&warpBodyBlock);
207 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
208 rewriter.setInsertionPointAfter(warpOp);
209 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
210 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
211 return success();
212 }
213};
214
215/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
216/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
217/// still contain the original op that will not be used by the yield op (and
218/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
219/// arguments. Tensor descriptor shape is not distributed because it is a
220/// uniform value across all work items within the subgroup. However, the
221/// layout information is dropped in the new tensor descriptor type.
222///
223/// Example:
224///
225/// ```
226/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
227/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
228/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
229/// ...
230/// %td = xegpu.create_nd_tdesc %arg0
231/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
232/// vector.yield %td
233/// }
234/// ```
235/// To
236/// ```
237/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
238/// ...
239/// %dead = xegpu.create_nd_tdesc %arg0
240/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
241/// vector.yield %arg0, %dead
242/// }
243/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
244/// -> !xegpu.tensor_desc<4x8xf32>
245///
246/// ```
247struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
248 using gpu::WarpDistributionPattern::WarpDistributionPattern;
249 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
250 PatternRewriter &rewriter) const override {
251 OpOperand *operand =
252 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
253 if (!operand)
254 return rewriter.notifyMatchFailure(
255 warpOp, "warp result is not a xegpu::CreateNdDesc op");
256 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
257 unsigned operandIdx = operand->getOperandNumber();
258
259 xegpu::DistributeLayoutAttr layout = descOp.getType().getLayoutAttr();
260 if (!layout)
261 return rewriter.notifyMatchFailure(
262 descOp, "the tensor descriptor lacks layout attribute");
263 SmallVector<size_t> newRetIndices;
264 rewriter.setInsertionPoint(warpOp);
265 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
266 rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
267 /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
268
269 SmallVector<Value> newDescOperands = llvm::map_to_vector(
270 newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
271 rewriter.setInsertionPointAfter(newWarpOp);
272 xegpu::TensorDescType distributedTensorDescTy =
273 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
274 // does not contain layout info.
275 Value newDescOp = xegpu::CreateNdDescOp::create(
276 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
277 descOp->getAttrs());
278
279 Value distributedVal = newWarpOp.getResult(operandIdx);
280 // Resolve the distributed type to the expected type.
281 newDescOp =
282 resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
283 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
284 return success();
285 }
286};
287
288/// Distribute a store_nd op at the end of enclosing
289/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
290/// through the warp op interface they would be propagated as returned values.
291/// Source vector is distributed based on lane layout. Appropriate cast ops are
292/// inserted if the distributed types does not match expected xegpu SIMT types.
293///
294/// Example:
295///
296/// ```
297/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
298/// gpu.warp_execute_on_lane_0(%laneid) -> () {
299/// ...
300/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
301/// !xegpu.tensor_desc<4x8xf32, #layout0>
302/// }
303/// ```
304/// To
305/// ```
306/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
307/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
308/// ...
309/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
310/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
311/// }
312/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
313/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
314/// #layout0>
315/// -> !xegpu.tensor_desc<4x8xf32>
316/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
317/// !xegpu.tensor_desc<4x8xf32>
318///
319/// ```
320struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
321 using gpu::WarpDistributionPattern::WarpDistributionPattern;
322 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
323 PatternRewriter &rewriter) const override {
324 gpu::YieldOp yield = warpOp.getTerminator();
325 Operation *lastNode = yield->getPrevNode();
326 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
327 if (!storeOp)
328 return failure();
329
330 SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
331 // Expecting offsets to be present.
332 if (offsets.empty())
333 return rewriter.notifyMatchFailure(storeOp,
334 "the store op must have offsets");
335 SmallVector<Value> offsetsAsValues =
336 vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
337 SmallVector<Type> offsetTypes = llvm::map_to_vector(
338 offsetsAsValues, [](Value v) { return v.getType(); });
339 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
340 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
341 if (!layout)
342 return rewriter.notifyMatchFailure(
343 storeOp, "the source tensor descriptor lacks layout attribute");
344
345 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
346 xegpu::getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
347 if (failed(distributedTypeByWarpOpOrFailure))
348 return rewriter.notifyMatchFailure(storeOp,
349 "Failed to distribute the type");
350 VectorType distributedTypeByWarpOp =
351 distributedTypeByWarpOpOrFailure.value();
352
353 SmallVector<size_t> newRetIndices;
354 SmallVector<Value> newYieldedValues = {storeOp.getValue(),
355 storeOp.getTensorDesc()};
356 SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
357 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
358 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
359 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
360 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
361 // Create a new store op outside the warp op with the distributed vector
362 // type. Tensor descriptor is not distributed.
363 rewriter.setInsertionPointAfter(newWarpOp);
364 SmallVector<Value> newStoreOperands;
365
366 // For the value operand, there can be a mismatch between the vector type
367 // distributed by the warp op and (xegpu-specific) distributed type
368 // supported by the store op. Type mismatch must be resolved using
369 // appropriate cast op.
370 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
371 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
372 if (failed(storeNdDistributedValueTyOrFailure))
373 return rewriter.notifyMatchFailure(
374 storeOp, "Failed to get distributed vector type for the store op");
375 newStoreOperands.push_back(resolveDistributedTy(
376 newWarpOp.getResult(newRetIndices[0]),
377 storeNdDistributedValueTyOrFailure.value(), rewriter));
378 // For the tensor descriptor operand, the layout attribute is dropped after
379 // distribution. Types needs to be resolved in this case also.
380 xegpu::TensorDescType distributedTensorDescTy =
381 storeOp.getTensorDescType().dropLayouts();
382 newStoreOperands.push_back(
383 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
384 distributedTensorDescTy, rewriter));
385 // Collect offsets.
386 for (size_t i = 2; i < newRetIndices.size(); ++i)
387 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
388
389 auto newStoreOp =
390 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
391 newStoreOperands, storeOp->getAttrs());
392 xegpu::removeLayoutAttrs(newStoreOp);
393 rewriter.eraseOp(storeOp);
394 return success();
395 }
396};
397
398/// Distribute a load_nd op feeding into vector.yield op for the enclosing
399/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
400/// The warp op will still contain the original op that will not be used by
401/// the yield op (and should be cleaned up later). The yield op will
402/// bypass the load's arguments. Only the loaded vector is distributed
403/// according to lane layout and, tensor descriptor types is not
404/// distributed. Appropriate cast ops are inserted if the distributed types does
405/// not match expected xegpu SIMT types.
406///
407/// Example:
408///
409/// ```
410/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
411/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
412/// (vector<4x1xf32>) {
413/// ...
414/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
415/// ->
416/// vector<4x8xf32>
417/// gpu.yield %ld
418/// }
419/// ```
420/// To
421/// ```
422/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
423/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
424/// ...
425/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
426/// vector<4x8xf32> gpu.yield %dead, %arg0
427/// }
428/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
429/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
430/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
431/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
432///
433/// ```
434struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
435 using gpu::WarpDistributionPattern::WarpDistributionPattern;
436 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
437 PatternRewriter &rewriter) const override {
438 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
439 if (!isa<xegpu::LoadNdOp>(op))
440 return false;
441 // Make sure the same load op is the last operation in the warp op body.
442 // This ensure that load op is not sinked earlier violating any barrier
443 // synchronizations.
444 gpu::YieldOp yield = warpOp.getTerminator();
445 return yield->getPrevNode() == op;
446 });
447
448 if (!operand)
449 return rewriter.notifyMatchFailure(
450 warpOp, "warp result is not a xegpu::LoadNd op");
451
452 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
453 auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
454 if (!uArch)
455 return rewriter.notifyMatchFailure(
456 loadOp, "xegpu::LoadNdOp require target attribute attached to "
457 "determine transpose "
458 "requirement");
459 // Chip information is required to decide if the layout requires transpose
460 // effect.
461 // Expecting offsets to be present.
462 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
463 if (offsets.empty())
464 return rewriter.notifyMatchFailure(loadOp,
465 "the load op must have offsets");
466 SmallVector<Value> offsetsAsValues =
467 vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
468 SmallVector<Type> offsetTypes = llvm::map_to_vector(
469 offsetsAsValues, [](Value v) { return v.getType(); });
470
471 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
472 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
473 if (!layout)
474 return rewriter.notifyMatchFailure(
475 loadOp, "the source tensor descriptor lacks layout attribute");
476
477 unsigned operandIdx = operand->getOperandNumber();
478 VectorType distributedTypeByWarpOp =
479 cast<VectorType>(warpOp.getResult(operandIdx).getType());
480
481 SmallVector<size_t> newRetIndices;
482 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
483 SmallVector<Type> newYieldedTypes = {tensorDescTy};
484 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
485 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
486 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
487 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
488
489 // Create a new load op outside the warp op with the distributed vector
490 // type.
491 rewriter.setInsertionPointAfter(newWarpOp);
492 FailureOr<VectorType> loadNdDistValueTyOrFailure =
493 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
494 if (failed(loadNdDistValueTyOrFailure))
495 return rewriter.notifyMatchFailure(
496 loadOp, "Failed to get distributed vector type for the load op");
497 xegpu::TensorDescType distributedTensorDescTy =
498 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
499 // descriptor type does not
500 // contain layout info.
501 SmallVector<Value> newLoadOperands{
502 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
503 distributedTensorDescTy, rewriter)};
504 // Collect offsets.
505 for (size_t i = 1; i < newRetIndices.size(); ++i)
506 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
507 auto newLoadOp = xegpu::LoadNdOp::create(
508 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
509 newLoadOperands, loadOp->getAttrs());
510 xegpu::removeLayoutAttrs(newLoadOp);
511 // Set the packed attribute if the layout requires it.
512 newLoadOp.setPacked(xegpu::requirePacked(layout));
513 // Set the transpose attribute if the layout requires it.
514 if (xegpu::requireTranspose(layout, uArch))
515 newLoadOp.setTranspose(
516 DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
517 Value distributedVal = newWarpOp.getResult(operandIdx);
518 // There can be a conflict between the vector type distributed by the
519 // warp op and (xegpu-specific) distributed type supported by the load
520 // op. Resolve these mismatches by inserting a cast.
521 Value tyResolvedVal = resolveDistributedTy(
522 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
523 rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
524 return success();
525 }
526};
527
528/// Distribute a dpas op feeding into vector.yield op for the enclosing
529/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
530/// The warp op will still contain the original op that will not be used by
531/// the yield op (and should be cleaned up later). The yield op will
532/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
533/// distributed types does not match expected xegpu SIMT types.
534/// Example:
535/// ```
536/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
537/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
538/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
539/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
540/// (vector<8x1xf32>) {
541/// ...
542/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
543/// vector<8x16xf32>
544/// gpu.yield %dpas
545/// }
546/// ```
547/// To
548/// ```
549/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
550/// vector<8x1xf16>, vector<16x1xf16>) {
551/// ...
552/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
553/// -> vector<8x16xf32>
554/// gpu.yield %dead, %arg0, %arg1
555/// }
556/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
557/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
558/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
559/// vector<8xf32>
560/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
561/// ```
562struct DpasDistribution final : public gpu::WarpDistributionPattern {
563 using gpu::WarpDistributionPattern::WarpDistributionPattern;
564 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
565 PatternRewriter &rewriter) const override {
566 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
567 if (!operand)
568 return rewriter.notifyMatchFailure(warpOp,
569 "warp result is not a xegpu::Dpas op");
570
571 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
572 unsigned operandIdx = operand->getOperandNumber();
573
574 xegpu::LayoutAttr layoutA =
575 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
576 xegpu::LayoutAttr layoutB =
577 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
578 xegpu::LayoutAttr layoutOut =
579 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
580
581 if (!layoutA || !layoutB || !layoutOut)
582 return rewriter.notifyMatchFailure(
583 dpasOp,
584 "the xegpu::Dpas op lacks layout attribute for A, B or output");
585
586 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
587 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
588 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
589 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
590 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
591 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
592
593 if (failed(distLhsTypeByWarpOpOrFailure) ||
594 failed(distRhsTypeByWarpOpOrFailure) ||
595 failed(distResultTypeByWarpOpOrFailure))
596 return rewriter.notifyMatchFailure(
597 dpasOp,
598 "Failed to distribute the A, B or output types in xegpu::Dpas op");
599
600 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
601 dpasOp.getRhs()};
602 llvm::SmallVector<Type, 3> newYieldTypes{
603 distLhsTypeByWarpOpOrFailure.value(),
604 distRhsTypeByWarpOpOrFailure.value()};
605 // Dpas acc operand is optional.
606 if (dpasOp.getAcc()) {
607 newYieldValues.push_back(dpasOp.getAcc());
608 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
609 }
610 // Create a new warp op without the dpas.
611 SmallVector<size_t> newRetIndices;
612 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
613 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
614
615 FailureOr<VectorType> expectedDistLhsTyOrFailure =
616 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
617 FailureOr<VectorType> expectedDistRhsTyOrFailure =
618 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
619 FailureOr<VectorType> expectedDistResultTyOrFailure =
620 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
621
622 if (failed(expectedDistLhsTyOrFailure) ||
623 failed(expectedDistRhsTyOrFailure) ||
624 failed(expectedDistResultTyOrFailure))
625 return rewriter.notifyMatchFailure(
626 dpasOp,
627 "Failed to get distributed vector type for the dpas operands.");
628 // Create a new dpas op outside the warp op.
629 rewriter.setInsertionPointAfter(newWarpOp);
630 SmallVector<Value> newDpasOperands;
631 SmallVector<VectorType> newDpasOperandExpectedTypes;
632
633 // Resolve the distributed types with the original types.
634 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
635 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
636 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
637 if (dpasOp.getAcc())
638 newDpasOperandExpectedTypes.push_back(distributedResultTy);
639
640 for (unsigned i = 0; i < newRetIndices.size(); i++) {
641 newDpasOperands.push_back(
642 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
643 newDpasOperandExpectedTypes[i], rewriter));
644 }
645 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
646 distributedResultTy, newDpasOperands,
647 dpasOp->getAttrs());
648 xegpu::removeLayoutAttrs(newDpasOp);
649 Value distributedVal = newWarpOp.getResult(operandIdx);
650 // Resolve the output type.
651 Value typeResolved =
652 resolveDistributedTy(newDpasOp.getResult(),
653 distResultTypeByWarpOpOrFailure.value(), rewriter);
654 rewriter.replaceAllUsesWith(distributedVal, typeResolved);
655 return success();
656 }
657};
658
659/// Distribute a prefetch_nd op at the end of enclosing
660/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
661/// through the warp op interface they would be propagated as returned values.
662/// Tensor descriptor shape is not distributed because it is a uniform value
663/// across all work items within the subgroup. Appropriate cast ops are inserted
664/// if the distributed types does not match expected xegpu SIMT types.
665///
666/// Example:
667///
668/// ```
669/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
670/// gpu.warp_execute_on_lane_0(%laneid) -> () {
671/// ...
672/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
673/// }
674/// ```
675/// To
676/// ```
677/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
678/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
679/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
680/// index
681/// }
682/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
683/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
684/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
685///
686/// ```
687struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
688 using gpu::WarpDistributionPattern::WarpDistributionPattern;
689 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
690 PatternRewriter &rewriter) const override {
691 gpu::YieldOp yield = warpOp.getTerminator();
692 Operation *lastNode = yield->getPrevNode();
693 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
694 if (!prefetchOp)
695 return failure();
696
697 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
698 // PrefetchNdOp must have offsets.
699 if (offsets.empty())
700 return rewriter.notifyMatchFailure(prefetchOp,
701 "the prefetch op must have offsets");
702 SmallVector<Value> offsetsAsValues =
703 vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
704 SmallVector<Type> offsetTypes = llvm::map_to_vector(
705 offsetsAsValues, [](Value v) { return v.getType(); });
706
707 xegpu::DistributeLayoutAttr layout =
708 prefetchOp.getTensorDescType().getLayoutAttr();
709 if (!layout)
710 return rewriter.notifyMatchFailure(
711 prefetchOp, "the source tensor descriptor lacks layout attribute");
712
713 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
714 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
715 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
716 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
717 SmallVector<size_t> newRetIndices;
718 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
719 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
720 // Create a new prefetch op outside the warp op with updated tensor
721 // descriptor type. Source tensor descriptor require type resolution.
722 xegpu::TensorDescType newTensorDescTy =
723 prefetchOp.getTensorDescType().dropLayouts();
724 rewriter.setInsertionPointAfter(newWarpOp);
725 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
726 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
727 // Collect offsets.
728 for (size_t i = 1; i < newRetIndices.size(); ++i)
729 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
730 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
731 rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
732 prefetchOp->getAttrs());
733 xegpu::removeLayoutAttrs(newPrefetchOp);
734 rewriter.eraseOp(prefetchOp);
735 return success();
736 }
737};
738
739/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
740/// region. This will simply move the barrier op outside of the warp op.
741struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
742 using gpu::WarpDistributionPattern::WarpDistributionPattern;
743 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
744 PatternRewriter &rewriter) const override {
745 gpu::YieldOp yield = warpOp.getTerminator();
746 Operation *lastNode = yield->getPrevNode();
747 // The last node must be a gpu::BarrierOp.
748 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
749 if (!barrierOp)
750 return failure();
751 // Move the barrier op outside of the warp op.
752 rewriter.setInsertionPointAfter(warpOp);
753 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
754 barrierOp->getResultTypes(),
755 barrierOp->getOperands(), barrierOp->getAttrs());
756 rewriter.eraseOp(barrierOp);
757 return success();
758 }
759};
760
761/// Distribute a scattered store op. The offsets argument is required.
762/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
763/// The layouts are fixed and implicit: one offset/mask per lane.
764/// The pass changes the offset/mask vector shapes to a
765/// single-element vector, **it is assumed that their producer will also be
766/// distributed**. The payload vector also has a fixed distribution:
767/// no chunk size -> vector of one element.
768/// chunk size -> vector of the innermost dimension of the SG-payload.
769/// Example 1 (no chunk size):
770/// %mask = producer_op : vector<16xi1>
771/// %offset = producer_op : vector<16xindex>
772/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
773/// memref<256xf16>, vector<16xindex>, vector<16xi1>
774/// To
775/// %mask = producer_op : vector<1xi1>
776/// %offset = producer_op : vector<1xindex>
777/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
778/// memref<256xf16>, vector<1xindex>, vector<1xi1>
779/// Example 2 (chunk size, same mask and offsets):
780/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
781/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
782/// To
783/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
784/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
785///
786/// Note that the store distribution pattern also handles leading unit
787/// dimensions in the payload, mask and offsets vectors. In this case the store
788/// distribution will only change the dimensions corresponding to the SG
789/// distribution and keep the leading unit dimensions unchanged.
790/// For example, a store with payload vector<1x16xf16> with lane layout [1, 16 ]
791/// will be distributed as vector<1x1xf16>. Shapecast ops are inserted for the
792/// offset/mask/payload when necessary so that the distributed store is workign
793/// on 1D shape vector to match the HW capability.
794struct StoreDistribution final : public gpu::WarpDistributionPattern {
795 using gpu::WarpDistributionPattern::WarpDistributionPattern;
796 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
797 PatternRewriter &rewriter) const override {
798 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
799 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
800 if (!storeScatterOp)
801 return failure();
802 Value offsets = storeScatterOp.getOffsets();
803 if (!isa<VectorType>(offsets.getType()))
804 return rewriter.notifyMatchFailure(
805 storeScatterOp, "Store op must have a vector of offsets argument");
806 VectorType offsetsTy = cast<VectorType>(offsets.getType());
807 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
808 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
809
810 // Add handling for leading unit dimensions support
811 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
812 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
813
814 // Check that all leading dimensions are unit dimensions
815 for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
816 if (storeVecTy.getShape()[i] != 1) {
817 return rewriter.notifyMatchFailure(
818 storeScatterOp, "Only unit dimensions allowed for the leading "
819 "dimensions of the store vector!");
820 }
821 }
822
823 auto layoutPayload = storeScatterOp.getLayoutAttr();
824 auto layoutOffsets =
825 xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
826 auto layoutMask = layoutOffsets;
827
828 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
829 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
830 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
831 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
832 FailureOr<VectorType> distMaskByWarpOpOrFailure =
833 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
834 if (failed(distStoreVecByWarpOpOrFailure) ||
835 failed(distOffsetsByWarpOpOrFailure) ||
836 failed(distMaskByWarpOpOrFailure)) {
837 return rewriter.notifyMatchFailure(
838 storeScatterOp,
839 "Some vector operands have no layouts, using defaults instead.");
840 }
841
842 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
843 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
844 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
845
846 SmallVector<size_t> newRetIndices;
847 SmallVector<Value> operands = storeScatterOp->getOperands();
848 SmallVector<Type> operandTypesToYield = {
849 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
850
851 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
852 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
853
854 rewriter.setInsertionPointAfter(newWarpOp);
855
856 // Distributed store payload type is always 1D without leading unit dims
857 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
858 distPayloadTy.getElementType());
859
860 VectorType distOffsetsTy1D = VectorType::get(
861 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
862 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
863 distMaskTy.getElementType());
864
865 // Resolve distributed types to 1D for SIMT execution
866 Value distPayloadVal = resolveDistributedTy(
867 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
868 Value distOffsetVal = resolveDistributedTy(
869 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
870 Value distMaskVal = resolveDistributedTy(
871 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
872
873 SmallVector<Value> newStoreScatterOpOperands = {
874 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
875 distMaskVal};
876
877 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
878 rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
879 storeScatterOp->getAttrs());
881 rewriter.eraseOp(storeScatterOp);
882 return success();
883 }
884};
885
886static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
887 PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
888 Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
889 SmallVector<Value> newCoods;
890 auto maybeCoords =
891 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
892 if (failed(maybeCoords))
893 return {};
894 assert(maybeCoords.value().size() == 1 &&
895 "Expected one set of distributed offsets");
897 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
898 getAsOpFoldResult(origOffsets));
899 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
900 return newCoods;
901}
902
903/// Pattern for distributing xegpu::LoadMatrixOp.
904struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
905 using gpu::WarpDistributionPattern::WarpDistributionPattern;
906 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
907 PatternRewriter &rewriter) const override {
908 gpu::YieldOp yield = warpOp.getTerminator();
909 Operation *lastNode = yield->getPrevNode();
910 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
911 if (!matrixOp)
912 return failure();
913
914 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
915 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
916 });
917 if (!producedByLastLoad)
918 return rewriter.notifyMatchFailure(
919 warpOp, "The last op is not xegpu::LoadMatrixOp");
920 const int operandIdx = producedByLastLoad->getOperandNumber();
921
922 VectorType sgPayloadTy =
923 dyn_cast<VectorType>(matrixOp.getResult().getType());
924 VectorType warpResultTy =
925 cast<VectorType>(warpOp.getResult(operandIdx).getType());
926 if (!sgPayloadTy)
927 return rewriter.notifyMatchFailure(
928 matrixOp, "the matrix op payload must be a vector type");
929
930 auto loc = matrixOp.getLoc();
931 auto offsets = matrixOp.getMixedOffsets();
932 if (offsets.empty())
933 return rewriter.notifyMatchFailure(matrixOp,
934 "the load op must have offsets");
935 SmallVector<Value> offsetsAsValues =
936 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
937
938 auto layout = matrixOp.getLayoutAttr();
939 if (!layout)
940 return rewriter.notifyMatchFailure(
941 matrixOp, "the matrix operation lacks layout attribute");
942
943 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
944 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
945 if (failed(distPayloadByWarpOpOrFailure))
946 return rewriter.notifyMatchFailure(
947 matrixOp, "Failed to distribute matrix op payload based on layout.");
948
949 SmallVector<Value> operands = {matrixOp.getMemDesc()};
950 const unsigned offsetsStartIdx = operands.size();
951 operands.append(offsetsAsValues);
952
953 SmallVector<Type> operandTypes =
954 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
955
956 SmallVector<size_t> newRetIndices;
957 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
958 rewriter, warpOp, operands, operandTypes, newRetIndices);
959 SmallVector<Value> newOperands = llvm::map_to_vector(
960 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
961
962 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
963 ShapedType::kDynamic);
964 DenseI64ArrayAttr newConstOffsetsAttr =
965 rewriter.getDenseI64ArrayAttr(newConstOffsets);
966 ValueRange currentOffsets =
967 ValueRange(newOperands).drop_front(offsetsStartIdx);
968
969 SmallVector<Value> newCoords = currentOffsets;
970 rewriter.setInsertionPointAfter(newWarpOp);
971
972 if (!matrixOp.getSubgroupBlockIoAttr()) {
973 newCoords = computeDistributedCoordinatesForMatrixOp(
974 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
975 currentOffsets);
976 }
977 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
978 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
979 newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
980 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
981 // Resolve the output type and replace all uses.
982 rewriter.replaceAllUsesWith(
983 newWarpOp.getResult(operandIdx),
984 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
985 return success();
986 }
987};
988
989/// Pattern for distributing xegpu::StoreMatrixOp.
990struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
991 using gpu::WarpDistributionPattern::WarpDistributionPattern;
992 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
993 PatternRewriter &rewriter) const override {
994 gpu::YieldOp yield = warpOp.getTerminator();
995 Operation *lastNode = yield->getPrevNode();
996 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
997 if (!matrixOp)
998 return failure();
999
1000 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1001 if (!sgPayloadTy)
1002 return rewriter.notifyMatchFailure(
1003 matrixOp, "the matrix op payload must be a vector type");
1004
1005 auto loc = matrixOp.getLoc();
1006 auto offsets = matrixOp.getMixedOffsets();
1007 if (offsets.empty())
1008 return rewriter.notifyMatchFailure(matrixOp,
1009 "the store op must have offsets");
1010 SmallVector<Value> offsetsAsValues =
1011 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1012
1013 auto layout = matrixOp.getLayoutAttr();
1014 if (!layout)
1015 return rewriter.notifyMatchFailure(
1016 matrixOp, "the matrix operation lacks layout attribute");
1017
1018 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1019 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1020 if (failed(distPayloadByWarpOpOrFailure))
1021 return rewriter.notifyMatchFailure(
1022 matrixOp, "Failed to distribute matrix op payload based on layout.");
1023
1024 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1025 const unsigned offsetsStartIdx = operands.size();
1026 operands.append(offsetsAsValues);
1027
1028 SmallVector<Type> operandTypes =
1029 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
1030 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1031
1032 SmallVector<size_t> newRetIndices;
1033 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1034 rewriter, warpOp, operands, operandTypes, newRetIndices);
1035 SmallVector<Value> newOperands = llvm::map_to_vector(
1036 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1037
1038 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1039 ShapedType::kDynamic);
1040 DenseI64ArrayAttr newConstOffsetsAttr =
1041 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1042 ValueRange currentOffsets =
1043 ValueRange(newOperands).drop_front(offsetsStartIdx);
1044
1045 SmallVector<Value> newCoords = currentOffsets;
1046 rewriter.setInsertionPointAfter(newWarpOp);
1047
1048 if (!matrixOp.getSubgroupBlockIoAttr()) {
1049 newCoords = computeDistributedCoordinatesForMatrixOp(
1050 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1051 currentOffsets);
1052 }
1053
1054 xegpu::StoreMatrixOp::create(
1055 rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1056 ValueRange(newCoords), newConstOffsetsAttr,
1057 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1058 rewriter.eraseOp(matrixOp);
1059 return success();
1060 }
1061};
1062
1063/// Distribute a scattered load op. The logic and requirements are the same as
1064/// for the scattered store distribution. The warpOp's payload vector is
1065/// expected to be distributed by the load's result consumer.
1066/// Example 1 (no chunk size):
1067/// %mask = producer_op : vector<16xi1>
1068/// %offset = producer_op : vector<16xindex>
1069/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1070/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
1071/// To
1072/// %mask = producer_op : vector<1xi1>
1073/// %offset = producer_op : vector<1xindex>
1074/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1075/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
1076/// Example 2 (chunk size, same mask and offsets):
1077/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1078/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
1079/// To
1080/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1081/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
1082///
1083/// Note that the load distribution pattern also handles leading unit dimensions
1084/// in the payload, mask, and offsets vector.The load distribution will only
1085/// change the dimensions corresponding to the SG distribution and keep the
1086/// leading unit dimensions unchanged. For example, a load with result type
1087/// vector<1x16xf16> with lane layout [1, 16 ] will be distributed
1088/// as result type vector<1x1xf16>. Shapecast ops are inserted for the
1089/// offset/mask/payload when necessary so that the distributed load is workign
1090/// on 1D shape vector to match the HW capability.
1091struct LoadDistribution final : public gpu::WarpDistributionPattern {
1092 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1093 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1094 PatternRewriter &rewriter) const override {
1095 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1096 // Check if the yield operand that was produced by the *last* scattered
1097 // load op to avoid sinking it before barriers (maintain memory order).
1098 return isa<xegpu::LoadGatherOp>(op) &&
1099 warpOp.getTerminator()->getPrevNode() == op;
1100 });
1101 if (!producedByLastLoad)
1102 return rewriter.notifyMatchFailure(
1103 warpOp, "The last op is not xegpu::LoadGatherOp");
1104
1105 auto loadGatherOp =
1106 producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
1107 Value offsets = loadGatherOp.getOffsets();
1108 if (!isa<VectorType>(offsets.getType()) ||
1109 !isa<VectorType>(loadGatherOp.getMask().getType()))
1110 return rewriter.notifyMatchFailure(
1111 loadGatherOp,
1112 "Load op must have vector arguments for offsets and mask");
1113 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1114 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1115 VectorType resultVecTy =
1116 cast<VectorType>(loadGatherOp.getResult().getType());
1117 // add handling leading unit dimensions support
1118 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1119 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1120 for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1121 if (resultVecTy.getShape()[i] != 1) {
1122 return rewriter.notifyMatchFailure(
1123 loadGatherOp, "Only unit dimensions allowed for the leading "
1124 "dimensions of the load vector!");
1125 }
1126 }
1127
1128 auto layoutPayload = loadGatherOp.getLayoutAttr();
1129 auto layoutOffsets =
1130 xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
1131 auto layoutMask = layoutOffsets;
1132
1133 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1134 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1135 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1136 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1137 if (failed(distOffsetsByWarpOpOrFailure) ||
1138 failed(distMaskByWarpOpOrFailure)) {
1139 return rewriter.notifyMatchFailure(
1140 loadGatherOp,
1141 "Some vector operands have no layouts, using defaults instead.");
1142 }
1143
1144 SmallVector<size_t> newRetIndices;
1145 SmallVector<Value> operands = loadGatherOp->getOperands();
1146
1147 const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1148 VectorType distResultTy =
1149 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1150 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1151 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1152
1153 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1154 distOffsetsTy, distMaskTy};
1155
1156 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1157 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1158
1159 rewriter.setInsertionPointAfter(newWarpOp);
1160
1161 // Distributed load op will always be 1D.
1162 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1163 distResultTy.getElementType());
1164
1165 VectorType distOffsetsTy1D =
1166 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1167 distOffsetsByWarpOpOrFailure.value().getElementType());
1168 VectorType distMaskTy1D =
1169 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1170 distMaskByWarpOpOrFailure.value().getElementType());
1171
1172 Value distOffsetVal = resolveDistributedTy(
1173 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1174 Value distmaskVal = resolveDistributedTy(
1175 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1176
1177 SmallVector<Value> newLoadGatherOperands = {
1178 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1179
1180 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1181 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1182 loadGatherOp->getAttrs());
1184 Value distributedVal = newWarpOp.getResult(operandIdx);
1185 // Resolve the output type and replace all uses.
1186 rewriter.replaceAllUsesWith(
1187 distributedVal,
1188 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1189 return success();
1190 }
1191};
1192
1193// Sink SG-uniform ops. An op is uniform if none
1194// of its operands/results has a distribution layout attribute.
1195// Non-uniform vectors are handled by dedicated patterns.
1196// This pattern must have a higher priority than vector dialect distribution
1197// patterns, because a distributable shape may be logically intended as
1198// uniform (i.e., no layout), so we want to omit its distribution.
1199struct SinkUniformOps final : public gpu::WarpDistributionPattern {
1200 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1201 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1202 PatternRewriter &rewriter) const override {
1203 // Take the last op
1204 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1205 // Any ops with nested regions must be handled carefully in dedicated
1206 // patterns.
1207 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
1208 return failure();
1209 int operandIdx = -1;
1210 if (warpRegionPreYieldOp->getNumResults()) {
1211 OpOperand *operand = getWarpResult(
1212 warpOp, [&](Operation *op) { return warpRegionPreYieldOp == op; });
1213 if (!operand)
1214 return failure();
1215 operandIdx = operand->getOperandNumber();
1216 if (warpRegionPreYieldOp->getResult(0).getType() !=
1217 warpOp.getResult(operandIdx).getType())
1218 return rewriter.notifyMatchFailure(warpOp,
1219 "The op result is not uniform.");
1220 }
1221
1222 // The op must have no layout-based operands or results.
1223 bool uniformValuesOnly =
1224 llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
1225 return !xegpu::getDistributeLayoutAttr(v);
1226 });
1227 uniformValuesOnly &=
1228 llvm::all_of(warpRegionPreYieldOp->getOpOperands(), [](OpOperand &opr) {
1229 return !xegpu::getDistributeLayoutAttr(opr);
1230 });
1231 if (!uniformValuesOnly)
1232 return rewriter.notifyMatchFailure(warpOp,
1233 "Some values are not uniform.");
1234 SmallVector<size_t> newRetIndices;
1235 SmallVector<Value> operands =
1236 llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
1237 SmallVector<Type> operandTypes =
1238 llvm::to_vector_of<Type>(warpRegionPreYieldOp->getOperandTypes());
1239 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1240 rewriter, warpOp, operands, operandTypes, newRetIndices);
1241
1242 rewriter.setInsertionPointAfter(newWarpOp);
1243 IRMapping operandMapper;
1244 for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1245 operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),
1246 newWarpOp->getResult(newOperandIdx));
1247 Operation *clonedOp = rewriter.clone(*warpRegionPreYieldOp, operandMapper);
1248 if (!clonedOp->getNumResults())
1249 rewriter.eraseOp(warpRegionPreYieldOp);
1250 else {
1251 assert(operandIdx != -1 && "Expected a warp result for the operation");
1252 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx),
1253 clonedOp->getResult(0));
1254 }
1255 return success();
1256 }
1257};
1258
1259/// This patterns distribute the `vector.multi_reduction` operation across
1260/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1261/// layouts for the source and accumulator vectors,
1262/// * If the reduction dimension is distributed across lanes, the reduction is
1263/// non-lane-local and the reduction is done using warp shuffles. Here we
1264/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1265/// the warp op body.
1266/// * If the reduction dimension is not distributed across lanes, the reduction
1267/// is lane-local. In this case, we yield the source and accumulator vectors
1268/// from the warp op and perform the lane-local reduction outside the warp op
1269/// using a sequence of ReductionOps.
1270/// Example 1 (Reduction is lane-local):
1271/// ```
1272/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1273/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1274/// %acc = "some_def"() : () -> (vector<32xf32>)
1275/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1276/// vector<32xf32> gpu.yield %1 : vector<32xf32>
1277/// }
1278/// ```
1279/// is lowered to:
1280/// ```
1281/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1282/// vector<1xf32>) {
1283/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1284/// %acc = "some_def"() : () -> (vector<32xf32>)
1285/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1286/// }
1287/// %c = arith.constant dense<0.0> : vector<1xf32>
1288/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1289/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1290/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1291/// ```
1292/// Example 2 (Reduction is non-lane-local):
1293/// ```
1294/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1295/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1296/// %acc = "some_def"() : () -> (vector<2xf32>)
1297/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1298/// vector<2xf32>
1299/// gpu.yield %1 : vector<2xf32>
1300/// }
1301/// ```
1302/// is lowered to:
1303/// ```
1304/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1305/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1306/// %acc = "some_def"() : () -> (vector<2xf32>)
1307/// %1 = arith.constant dense<0.0> : vector<2xf32>
1308/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1309/// %3 = ("warp.reduction %2") : f32
1310/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1311/// ... repeat for row 1
1312/// gpu.yield %1 : vector<2xf32>
1313/// }
1314struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1315 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1316 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1317 PatternRewriter &rewriter) const override {
1318 OpOperand *yieldOperand =
1319 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1320 if (!yieldOperand)
1321 return failure();
1322 auto reductionOp =
1323 cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1324 unsigned operandIdx = yieldOperand->getOperandNumber();
1325 VectorType sourceType = reductionOp.getSourceVectorType();
1326 int64_t sourceRank = sourceType.getRank();
1327 // Need at least a 2D source vector.
1328 if (sourceRank < 2)
1329 return rewriter.notifyMatchFailure(warpOp,
1330 "Only 2D+ reductions are supported.");
1331 // Leading dimensions (first rank-2) must be unit (size 1).
1332 for (int64_t i = 0; i < sourceRank - 2; ++i) {
1333 if (sourceType.getShape()[i] != 1)
1334 return rewriter.notifyMatchFailure(
1335 warpOp, "Only unit dimensions allowed for the leading dimensions.");
1336 }
1337 // Effective dimension indices (last 2 dims of the source).
1338 int64_t rowIdx = sourceRank - 2;
1339 int64_t columnIdx = sourceRank - 1;
1340 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1341 if (reductionDims.size() != 1)
1342 return rewriter.notifyMatchFailure(warpOp,
1343 "Only 1 reduction dim is supported.");
1344 int64_t reductionDim = reductionDims[0];
1345 // The reduction dim must be among the last 2 dims.
1346 if (reductionDim != rowIdx && reductionDim != columnIdx)
1347 return rewriter.notifyMatchFailure(
1348 warpOp, "Reduction dim must be among the last 2 dimensions.");
1349 VectorType distributedResultType =
1350 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1351 VectorType resultType = cast<VectorType>(reductionOp.getType());
1352 xegpu::DistributeLayoutAttr sourceLayout =
1353 xegpu::getTemporaryLayout(reductionOp->getOpOperand(0));
1354
1355 FailureOr<VectorType> sourceDistTypeOrFailure =
1356 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1357 if (failed(sourceDistTypeOrFailure))
1358 return rewriter.notifyMatchFailure(
1359 warpOp, "Failed to distribute the source vector type.");
1360 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1361 // Only single dimension distribution among the last 2 dims is supported.
1362 bool rowDistributed =
1363 sourceDistType.getShape()[rowIdx] != sourceType.getShape()[rowIdx];
1364 bool columnDistributed = sourceDistType.getShape()[columnIdx] !=
1365 sourceType.getShape()[columnIdx];
1366 if (rowDistributed && columnDistributed)
1367 return rewriter.notifyMatchFailure(
1368 warpOp, "Expecting source to be distributed in a single dimension.");
1369 int64_t sourceDistDim =
1370 rowDistributed ? rowIdx : (columnDistributed ? columnIdx : -1);
1371 if (sourceDistDim == -1)
1372 return rewriter.notifyMatchFailure(
1373 warpOp, "Expecting a distributed source vector.");
1374 bool resultDistributed =
1375 distributedResultType.getNumElements() < resultType.getNumElements();
1376 // If the lane owns all the data required for reduction (i.e. reduction is
1377 // fully parallel accross lanes), then each lane owns part of the result
1378 // (i.e. result is distributed). If the reduction require cross-lane
1379 // shuffling, then the result is shared among all lanes (broadcasted).
1380 // Therefore we expect following cases:
1381 //
1382 // | Source vector | Reduction dim | Result vector |
1383 // |----------------------|----------------|----------------|
1384 // | dim-0 distributed | 0 | broadcasted |
1385 // | dim-0 distributed | 1 | distributed |
1386 // | dim-1 distributed | 0 | distributed |
1387 // | dim-1 distributed | 1 | broadcasted |
1388
1389 bool isReductionLaneLocal =
1390 (sourceDistDim == rowIdx && reductionDim == columnIdx) ||
1391 (sourceDistDim == columnIdx && reductionDim == rowIdx);
1392 if (isReductionLaneLocal && !resultDistributed)
1393 return rewriter.notifyMatchFailure(
1394 warpOp, "Expecting a distributed result for lane-local reduction.");
1395
1396 if (!isReductionLaneLocal && resultDistributed)
1397 return rewriter.notifyMatchFailure(
1398 warpOp,
1399 "Expecting a broadcasted result for non-lane-local reduction.");
1400
1401 // Handle lane-local reduction case. In this case we fully distribute the
1402 // reduction result.
1403 if (isReductionLaneLocal) {
1404 // Yield the source and acc vectors from the WarpOp.
1405 SmallVector<size_t> newRetIndices;
1406 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1407 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1408 {sourceDistType, distributedResultType}, newRetIndices);
1409 rewriter.setInsertionPointAfter(newWarpOp);
1411 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1412 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1413 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1414 // Replace the warp op result with the final result.
1415 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1416 return success();
1417 }
1418 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1419 // of multiple ReductionOps. Actual distribution is done by the
1420 // WarpOpReduction pattern.
1421 rewriter.setInsertionPointAfter(reductionOp);
1423 cast<TypedValue<VectorType>>(reductionOp.getSource()),
1424 cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1425 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1426 // Replace the warp op result with the final result.
1427 rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1428 return success();
1429 }
1430};
1431
1432/// This pattern distributes the `vector.broadcast` operation across lanes in a
1433/// warp. The pattern supports three use cases:
1434///
1435/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1436/// vector
1437/// must have a slice layout of the result. If the distributed source and
1438/// target vector types are identical, this lowers to a no-op; otherwise, it
1439/// remains a broadcast but operates on distributed vectors.
1440///
1441/// 2) Broadcast a same-rank vector with identical layouts for source and
1442/// target:
1443/// The source vector must have unit dimensions, and lane_data must be unit
1444/// size for those unit dims. This always lowers to a no-op.
1445///
1446/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
1447/// scalar to distributed result type.
1448///
1449/// Example 1 (lowering to a broadcast with distributed types):
1450/// ```
1451/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1452/// %0 = "some_def"() {layout_result_0 =
1453/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1454/// dims = [0]> } : () -> (vector<32xf32>)
1455/// %2 = vector.broadcast %0 {layout_result_0 =
1456/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
1457/// : vector<32xf32> to vector<8x32xf32>
1458/// gpu.yield %1 : vector<8x32xf32>
1459/// }
1460/// ```
1461/// is lowered to:
1462/// ```
1463/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1464/// %0 = "some_def"() {layout_result_0 =
1465/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1466/// dims = [0]> } : () -> (vector<32xf32>)
1467/// gpu.yield %0 : vector<32xf32>
1468/// }
1469/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
1470///
1471/// Example 2 (no-op):
1472/// ```
1473/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
1474/// %0 = "some_def"() {layout_result_0 =
1475/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1476/// dims = [1]> } : () -> (vector<8xf32>)
1477/// %1 = vector.shape_cast %0
1478/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1479/// 1]>}: vector<8xf32> to vector<8x1xf32>
1480/// %2 = vector.broadcast %1
1481/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1482/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
1483/// gpu.yield %1 : vector<8x32xf32>
1484/// }
1485/// ```
1486/// is lowered to:
1487/// ```
1488/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1489/// %0 = "some_def"() {layout_result_0 =
1490/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1491/// dims = [1]> } : () -> (vector<8xf32>)
1492/// %1 = vector.shape_cast %0
1493/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1494/// 1]>}: vector<8xf32> to vector<8x1xf32>
1495/// gpu.yield %1 : vector<8x1xf32>
1496/// }
1497/// // The broadcast is implicit through layout transformation (no-op)
1498/// "some_use"(%r#0)
1499/// ```
1500struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
1501 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1502 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1503 PatternRewriter &rewriter) const override {
1504 OpOperand *yieldOperand =
1505 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1506 if (!yieldOperand)
1507 return failure();
1508 auto broadcastOp =
1509 cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
1510 unsigned operandIdx = yieldOperand->getOperandNumber();
1511
1512 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1513 VectorType destType =
1514 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1515
1516 xegpu::DistributeLayoutAttr sourceLayout =
1517 xegpu::getTemporaryLayout(broadcastOp->getOpOperand(0));
1518 xegpu::DistributeLayoutAttr resultLayout =
1519 xegpu::getTemporaryLayout(dyn_cast<OpResult>(broadcastOp.getResult()));
1520
1521 FailureOr<VectorType> sourceDistType;
1522 Type sourceElemOrDistType;
1523 if (sourceType) {
1524
1525 // Case 1 and 2: source is a vector type.
1526 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1527 if (rankDiff > 0) {
1528 // Case 1: source is lower-rank than result.
1529 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1530 if (!isSliceOf)
1531 broadcastOp.emitWarning()
1532 << "Broadcast input layout must be a slice of result layout.";
1533 }
1534 // case 2: source and result have same rank
1535 if (rankDiff == 0) {
1536 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1537 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1538 broadcastUnitDimsSet.end());
1539 assert(sourceLayout.isEqualTo(
1540 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1541 "The sg_data for unit dimensions should be set as 1");
1542 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1543 }
1544
1545 sourceDistType =
1546 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1547 if (failed(sourceDistType)) {
1548 return rewriter.notifyMatchFailure(
1549 warpOp, "Failed to distribute the source vector type.");
1550 }
1551 sourceElemOrDistType = sourceDistType.value();
1552
1553 } else {
1554 // Case 3: source is a scalar type.
1555 if (sourceLayout) {
1556 return rewriter.notifyMatchFailure(
1557 warpOp, "Broadcast from scalar must not have a layout attribute.");
1558 }
1559 sourceElemOrDistType = broadcastOp.getSourceType();
1560 }
1561 FailureOr<VectorType> destDistType =
1562 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1563 if (failed(destDistType)) {
1564 return rewriter.notifyMatchFailure(
1565 warpOp, "Failed to distribute the dest vector type.");
1566 }
1567
1568 SmallVector<size_t> newRetIndices;
1570 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1571 newRetIndices);
1572
1573 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1574
1575 Value newBroadcast = distributedSource;
1576
1577 if (sourceElemOrDistType != destDistType.value()) {
1578 rewriter.setInsertionPointAfter(newWarpOp);
1579 newBroadcast =
1580 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1581 destDistType.value(), distributedSource);
1582 }
1583
1584 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
1585 return success();
1586 }
1587};
1588
1589/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1590/// `gpu.warp_execute_on_lane_0` region.
1591struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1592 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1593 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1594 PatternRewriter &rewriter) const override {
1595 OpOperand *yieldOperand =
1596 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1597 if (!yieldOperand)
1598 return failure();
1599 auto shapeCastOp =
1600 cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1601 unsigned operandNumber = yieldOperand->getOperandNumber();
1602 auto resultDistTy =
1603 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1604 xegpu::DistributeLayoutAttr sourceLayout =
1605 xegpu::getTemporaryLayout(shapeCastOp->getOpOperand(0));
1606 xegpu::DistributeLayoutAttr resultLayout =
1607 xegpu::getTemporaryLayout(dyn_cast<OpResult>(shapeCastOp.getResult()));
1608 if (!sourceLayout || !resultLayout)
1609 return rewriter.notifyMatchFailure(
1610 warpOp,
1611 "the source or result of shape_cast op lacks distribution layout");
1612
1613 FailureOr<VectorType> sourceDistTypeOrFailure =
1615 shapeCastOp.getSourceVectorType());
1616 if (failed(sourceDistTypeOrFailure))
1617 return rewriter.notifyMatchFailure(
1618 warpOp, "failed to get distributed vector type for source");
1619 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1620 // Create a new warp op that yields the source of the shape_cast op.
1621 SmallVector<size_t> newRetIndices;
1623 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1624 newRetIndices);
1625 rewriter.setInsertionPointAfter(newWarpOp);
1626 Value source = newWarpOp.getResult(newRetIndices[0]);
1627 // Create a new shape_cast op outside the warp op.
1628 Value newShapeCast = vector::ShapeCastOp::create(
1629 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1630 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1631 newShapeCast);
1632 return success();
1633 }
1634};
1635
1636// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1637// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1638// advanced cases where the distributed dimension is partially extracted and
1639// currently not supported by the generic vector distribution patterns.
1640struct VectorExtractStridedSliceDistribution
1642 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1643 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1644 PatternRewriter &rewriter) const override {
1645 OpOperand *operand =
1646 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1647 if (!operand)
1648 return failure();
1649 auto extractOp =
1650 cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1651 unsigned operandIdx = operand->getOperandNumber();
1652 auto distributedType =
1653 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1654 // Find the distributed dimensions.
1655 auto extractResultType = cast<VectorType>(operand->get().getType());
1656 auto distributedDims =
1657 getDistributedDims(extractResultType, distributedType);
1658 // Collect updated source type, sizes and offsets. They may be adjusted
1659 // later if the data is distributed to lanes (as opposed to being owned by
1660 // all lanes uniformly).
1661 VectorType updatedSourceType = extractOp.getSourceVectorType();
1662 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1663 extractOp.getSizes(), [](Attribute attr) { return attr; });
1664 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1665 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1666 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1667 extractOp.getStrides(), [](Attribute attr) { return attr; });
1668 // If the provided sizes, offsets, strides are less than the rank, pad them
1669 // with full sizes, zero offsets, and unit strides. This makes it easier to
1670 // adjust them later.
1671 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1672 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1673 updatedSizes.push_back(rewriter.getI64IntegerAttr(
1674 extractOp.getSourceVectorType().getDimSize(i)));
1675 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1676 updatedStrides.push_back(
1677 rewriter.getI64IntegerAttr(1)); // stride is always 1.
1678 }
1679 // If the result is distributed, it must be distributed in exactly one
1680 // dimension. In this case, we adjust the sourceDistType, distributedSizes
1681 // and distributedOffsets accordingly.
1682 if (distributedDims.size() > 0) {
1683 if (distributedDims.size() != 1)
1684 return rewriter.notifyMatchFailure(
1685 warpOp, "Source can not be distributed in multiple dimensions.");
1686 int64_t distributedDim = distributedDims[0];
1687 int sourceDistrDimSize =
1688 extractOp.getSourceVectorType().getShape()[distributedDim];
1689 auto sourceLayout = xegpu::getTemporaryLayout(extractOp->getOpOperand(0));
1690 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1691 return rewriter.notifyMatchFailure(
1692 warpOp, "the source of extract_strided_slice op lacks distribution "
1693 "layout");
1694 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1695 // Because only single dimension distribution is supported, lane layout
1696 // size at the distributed dim must be the subgroup size.
1697 int subgroupSize = sourceLaneLayout[distributedDim];
1698 // Check if the source size in the distributed dimension is a multiple of
1699 // subgroup size.
1700 if (sourceDistrDimSize % subgroupSize != 0)
1701 return rewriter.notifyMatchFailure(
1702 warpOp,
1703 "Source size along distributed dimension is not a multiple of "
1704 "subgroup size.");
1705 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1706 // We expect lane data to be all ones in this case.
1707 if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1708 return rewriter.notifyMatchFailure(
1709 warpOp, "Expecting unit lane data in source layout");
1710 // The offsets in the distributed dimention must be a multiple of subgroup
1711 // size.
1712 int64_t distrDimOffset =
1713 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1714 if (distrDimOffset % subgroupSize != 0)
1715 return rewriter.notifyMatchFailure(
1716 warpOp, "Offset along distributed dimension "
1717 "is not a multiple of subgroup size.");
1718 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1719 sourceLayout, extractOp.getSourceVectorType())
1720 .value();
1721 // Update the distributed sizes to match the distributed type.
1722 updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1723 distributedType.getDimSize(distributedDim));
1724 // Update the distributed offsets to match round robin distribution (i.e.
1725 // each lane owns data at `subgroupSize` stride given unit lane data).
1726 updatedOffsets[distributedDim] =
1727 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1728 }
1729 // Do the distribution by yielding the source of the extract op from
1730 // the warp op and creating a new extract op outside the warp op.
1731 SmallVector<size_t> newRetIndices;
1733 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1734 newRetIndices);
1735 rewriter.setInsertionPointAfter(newWarpOp);
1736 Value source = newWarpOp.getResult(newRetIndices[0]);
1737 // Create a new extract op outside the warp op.
1738 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1739 rewriter, extractOp.getLoc(), distributedType, source,
1740 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1741 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1742 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1743 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1744 return success();
1745 }
1746};
1747
1748/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1749/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1750/// advanced cases where the distributed dimension is partially inserted and
1751/// currently not supported by the generic vector distribution patterns.
1752struct VectorInsertStridedSliceDistribution
1754 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1755 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1756 PatternRewriter &rewriter) const override {
1757 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
1758 // Check if the InsertStridedSliceOp is the last op before yield op
1759 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1760 warpOp.getTerminator()->getPrevNode() == op;
1761 });
1762 if (!operand)
1763 return failure();
1764 unsigned int operandNumber = operand->getOperandNumber();
1765 auto insertOp =
1766 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1767 auto distributedType =
1768 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1769 // Find the distributed dimensions of the dest vector.
1770 auto insertResultType = cast<VectorType>(operand->get().getType());
1771 auto destDistributedDims =
1772 getDistributedDims(insertResultType, distributedType);
1773 // Collect updated offsets, source type and dest type. They may be adjusted
1774 // later if the data is distributed to lanes (as opposed to being owned by
1775 // all lanes uniformly).
1776 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1777 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1778 VectorType updatedSourceType = insertOp.getSourceVectorType();
1779 VectorType updatedDestType = insertOp.getDestVectorType();
1780 if (destDistributedDims.size() > 0) {
1781 // Only single dimension distribution is supported.
1782 if (destDistributedDims.size() != 1)
1783 return rewriter.notifyMatchFailure(
1784 warpOp,
1785 "Expecting source to be distributed in a single dimension.");
1786 int64_t destDistributedDim = destDistributedDims[0];
1787
1788 VectorType srcType = insertOp.getSourceVectorType();
1789 VectorType destType = insertOp.getDestVectorType();
1790 // Currently we require that both source (kD) and dest (nD) vectors are
1791 // distributed. This requires that distributedDim (d) is contained in the
1792 // last k dims of the dest vector (d >= n - k).
1793 int64_t sourceDistributedDim =
1794 destDistributedDim - (destType.getRank() - srcType.getRank());
1795 if (sourceDistributedDim < 0)
1796 return rewriter.notifyMatchFailure(
1797 insertOp,
1798 "distributed dimension must be in the last k (i.e. source "
1799 "rank) dims of dest vector");
1800 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1801 // Obtain the source and dest layouts.
1802 auto destLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(1));
1803 auto sourceLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(0));
1804 if (!destLayout || !sourceLayout ||
1805 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1806 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1807 return rewriter.notifyMatchFailure(
1808 warpOp, "the source or dest of insert_strided_slice op lacks "
1809 "distribution layout");
1810 // Because only single dimension distribution is supported, lane layout
1811 // size at the distributed dim must be the subgroup size.
1812 int subgroupSize =
1813 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1814 // We require that source and dest lane data are all ones to ensure
1815 // uniform round robin distribution.
1816 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1817 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1818 if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1819 !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1820 return rewriter.notifyMatchFailure(
1821 warpOp, "Expecting unit lane data in source and dest layouts");
1822 // Source distributed dim size must be multiples of subgroup size.
1823 if (srcDistrDimSize % subgroupSize != 0)
1824 return rewriter.notifyMatchFailure(
1825 warpOp, "Distributed dimension size in source is not a multiple of "
1826 "subgroup size.");
1827 // Offsets in the distributed dimension must be multiples of subgroup
1828 // size.
1829 int64_t destDistrDimOffset =
1830 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1831 if (destDistrDimOffset % subgroupSize != 0)
1832 return rewriter.notifyMatchFailure(
1833 warpOp,
1834 "Offset along distributed dimension in dest is not a multiple of "
1835 "subgroup size.");
1836 // Update the source and dest types based on their layouts.
1837 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1838 sourceLayout, insertOp.getSourceVectorType())
1839 .value();
1840 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1841 destLayout, insertOp.getDestVectorType())
1842 .value();
1843 // Update the distributed offsets to match round robin distribution (i.e.
1844 // each lane owns data at `subgroupSize` stride given unit lane data).
1845 updatedOffsets[destDistributedDim] =
1846 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1847 }
1848 // Do the distribution by yielding the source and dest of the insert op
1849 // from the warp op and creating a new insert op outside the warp op.
1850 SmallVector<size_t> newRetIndices;
1852 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1853 {updatedSourceType, updatedDestType}, newRetIndices);
1854 rewriter.setInsertionPointAfter(newWarpOp);
1855
1856 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1857 Value dest = newWarpOp.getResult(newRetIndices[1]);
1858 // Create a new insert op outside the warp op.
1859 Value newInsertOp = vector::InsertStridedSliceOp::create(
1860 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1861 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1862 insertOp.getStrides());
1863 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1864 newInsertOp);
1865 return success();
1866 }
1867};
1868
1869/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1870/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1871/// outside of the warp op.
1872struct MemrefExtractAlignedPointerAsIndexDistribution final
1874 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1875 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1876 PatternRewriter &rewriter) const override {
1877 OpOperand *operand = getWarpResult(
1878 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1879 if (!operand)
1880 return rewriter.notifyMatchFailure(
1881 warpOp,
1882 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1883 auto extractOp =
1884 operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1885 unsigned operandIdx = operand->getOperandNumber();
1886 SmallVector<size_t> newRetIndices;
1887 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1888 rewriter, warpOp, extractOp.getSource(),
1889 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1890 rewriter.setInsertionPointAfter(newWarpOp);
1891 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1892 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1893 newWarpOp.getResult(newRetIndices[0]));
1894 Value resultVal = newWarpOp.getResult(operandIdx);
1895 rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
1896 return success();
1897 }
1898};
1899
1900/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1901/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1902/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1903/// created outside of the warp op with distributed source vector type (computed
1904/// using assigned layout).
1905struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1906 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1907 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1908 PatternRewriter &rewriter) const override {
1909 OpOperand *operand =
1910 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1911 if (!operand)
1912 return rewriter.notifyMatchFailure(
1913 warpOp, "warp result is not a vector::BitCast op");
1914 auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1915 unsigned operandIdx = operand->getOperandNumber();
1916 VectorType distributedSourceType =
1918 xegpu::getTemporaryLayout(bitcastOp->getOpOperand(0)),
1919 bitcastOp.getSourceVectorType())
1920 .value_or(VectorType());
1921 if (!distributedSourceType)
1922 return rewriter.notifyMatchFailure(
1923 bitcastOp, "Failed to distribute the source vector type in "
1924 "vector::BitCast op");
1925 VectorType distributedResultType =
1926 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1927 SmallVector<size_t> newRetIndices;
1928 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1929 rewriter, warpOp, bitcastOp.getSource(),
1930 TypeRange{distributedSourceType}, newRetIndices);
1931 rewriter.setInsertionPointAfter(newWarpOp);
1932 auto newBitcastOp = vector::BitCastOp::create(
1933 rewriter, newWarpOp.getLoc(), distributedResultType,
1934 newWarpOp.getResult(newRetIndices[0]));
1935 Value distributedVal = newWarpOp.getResult(operandIdx);
1936 rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1937 return success();
1938 }
1939};
1940
1941/// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1942/// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1943/// supported. In most cases, transpose is a no op because it is entirely
1944/// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1945/// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1946/// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1947/// vector::TransposeOp outside of the warp op with distributed source vector
1948/// type (computed using assigned layout).
1949struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1950 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1951 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1952 PatternRewriter &rewriter) const override {
1953 OpOperand *operand =
1954 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1955 if (!operand)
1956 return rewriter.notifyMatchFailure(
1957 warpOp, "warp result is not a vector::Transpose op");
1958 auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1959 unsigned operandIdx = operand->getOperandNumber();
1960 xegpu::DistributeLayoutAttr sourceLayout =
1961 xegpu::getTemporaryLayout(transposeOp->getOpOperand(0));
1962 xegpu::DistributeLayoutAttr resultLayout =
1963 xegpu::getTemporaryLayout(transposeOp->getOpResult(0));
1964 if (!sourceLayout || !resultLayout)
1965 return rewriter.notifyMatchFailure(
1966 transposeOp,
1967 "the source or result vector of the transpose op lacks layout "
1968 "attribute");
1969 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1970 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1971 // Only 2D transposes are supported for now.
1972 // TODO: Support nD transposes.
1973 if (sourceRank != 2 || resultRank != 2)
1974 return rewriter.notifyMatchFailure(
1975 transposeOp, "the source or result vector of the transpose op "
1976 "does not have 2D layout");
1977 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1978 // Result layout must be a transpose of source layout.
1979 if (!resultLayout.isTransposeOf(sourceLayout, perm,
1980 xegpu::LayoutKind::Lane))
1981 return rewriter.notifyMatchFailure(
1982 transposeOp,
1983 "the source or result vector layouts must be 2D transposes of each "
1984 "other");
1985 FailureOr<VectorType> distributedSourceTypeOrFailure =
1987 transposeOp.getSourceVectorType());
1988 if (failed(distributedSourceTypeOrFailure))
1989 return rewriter.notifyMatchFailure(
1990 transposeOp, "Failed to distribute the source vector type in "
1991 "vector::Transpose op");
1992 SmallVector<size_t> newRetIndices;
1993 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1994 rewriter, warpOp, transposeOp.getVector(),
1995 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1996 rewriter.setInsertionPointAfter(newWarpOp);
1997 auto newTransposeOp = vector::TransposeOp::create(
1998 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1999 perm);
2000 Value distributedVal = newWarpOp.getResult(operandIdx);
2001 rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
2002 return success();
2003 }
2004};
2005
2006/// Distribute a vector::StepOp with the sliced result layout.
2007/// The sliced layout must have exactly 1 effective lane dimension.
2008/// We completely resolve the vector::StepOp by computing the lane_data-sized
2009/// subranges.
2010struct VectorStepSliceDistribution final : public gpu::WarpDistributionPattern {
2011 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2012 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2013 PatternRewriter &rewriter) const override {
2014 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2015 if (!operand)
2016 return rewriter.notifyMatchFailure(
2017 warpOp, "warp result is not a vector::StepOp op");
2018 auto stepOp = operand->get().getDefiningOp<vector::StepOp>();
2019 unsigned operandIdx = operand->getOperandNumber();
2020 xegpu::DistributeLayoutAttr resultLayout =
2021 xegpu::getTemporaryLayout(stepOp->getResult(0));
2022 if (!resultLayout)
2023 return rewriter.notifyMatchFailure(
2024 stepOp, "the result vector of the step op lacks layout "
2025 "attribute");
2026 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2027 if (!sliceLayout)
2028 return rewriter.notifyMatchFailure(
2029 stepOp, "the result layout must be a slice layout");
2030 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2031 return rewriter.notifyMatchFailure(
2032 stepOp, "expecting 1 dim in the effective result layout");
2033
2034 rewriter.setInsertionPointAfter(warpOp);
2035 auto loc = stepOp.getLoc();
2036 auto stepResultVecTy = stepOp.getResult().getType();
2037 Value distributedVal = warpOp.getResult(operandIdx);
2038 VectorType newVecTy = cast<VectorType>(distributedVal.getType());
2039
2040 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2041 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2042 if (failed(laneDataBlockCoords))
2043 return rewriter.notifyMatchFailure(
2044 stepOp, "failed to compute lane data block coordinates");
2045
2046 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2047 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2048 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
2049 newVecTy.getNumElements() / laneDataBlockLength);
2050 SmallVector<Value> stepVals;
2051 // For each lane_data block, reconstruct its sub-range
2052 // from the range of SG-level vector.step. Example: vector.step
2053 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
2054 // vector<16xindex>
2055 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
2056 // The blocks are round-robin distributed, so logical lane id 0
2057 // holds values [0,1, 8,9].
2058 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2059 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2060 stepVals.push_back(laneDataBlockStartCoord);
2061 for (int i = 1; i < laneDataBlockLength; ++i) {
2062 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
2063 stepVals.push_back(arith::AddIOp::create(
2064 rewriter, loc, laneDataBlockStartCoord, offset));
2065 }
2066 }
2067 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
2068 "Expecting the number of step values to match the number of "
2069 "elements in the vector");
2070 auto stepOpVal =
2071 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2072 rewriter.replaceAllUsesWith(distributedVal, stepOpVal);
2073 return success();
2074 }
2075};
2076
2077struct ConvertLayoutDistribution
2078 : public OpRewritePattern<xegpu::ConvertLayoutOp> {
2080
2081 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
2082 PatternRewriter &rewriter) const override {
2083 auto inputLayout = op.getInputLayoutAttr();
2084 auto targetLayout = op.getTargetLayoutAttr();
2085 Type valType = op.getResult().getType();
2086
2087 if (!inputLayout || !targetLayout)
2088 return rewriter.notifyMatchFailure(op, "missing layout attributes");
2089
2090 if (valType.isIntOrFloat()) {
2091 rewriter.replaceOp(op, op.getSource());
2092 return success();
2093 }
2094 auto resShape = cast<VectorType>(valType).getShape();
2095 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
2096 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
2097 xegpu::LayoutKind::Lane)) {
2098 return rewriter.notifyMatchFailure(
2099 op, "lowering incompatible convert_layout not yet supported");
2100 }
2101 rewriter.replaceOp(op, op.getSource());
2102 return success();
2103 }
2104};
2105
2106} // namespace
2107
2108namespace {
2109struct XeGPUSubgroupDistributePass final
2111 XeGPUSubgroupDistributePass> {
2112 void runOnOperation() override;
2113};
2114} // namespace
2115
2117 RewritePatternSet &patterns) {
2118 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2119 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2120 GpuBarrierDistribution, VectorMultiReductionDistribution,
2121 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2122 VectorBitcastDistribution, LoadMatrixDistribution,
2123 StoreMatrixDistribution, ConvertLayoutDistribution,
2124 MemrefExtractAlignedPointerAsIndexDistribution>(
2125 patterns.getContext(),
2126 /*pattern benefit=*/PatternHierarchy::Regular);
2127 // For following patterns, we need to override the regular vector distribution
2128 // patterns. Therefore, assign higher benefit.
2129 patterns
2130 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2131 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2132 VectorStepSliceDistribution, SinkUniformOps>(
2133 patterns.getContext(),
2134 /*pattern benefit=*/PatternHierarchy::AboveRegular);
2135}
2136
2138 RewritePatternSet &patterns) {
2139 patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
2140}
2141
2142void XeGPUSubgroupDistributePass::runOnOperation() {
2143 // Step 1: Attach layouts to op operands.
2144 // TODO: Following assumptions are made:
2145 // 1) It is assumed that there are no layout conflicts.
2146 // 2) Any existing layout attributes attached to the operands are ignored.
2147 Operation *op = getOperation();
2149 signalPassFailure();
2150 return;
2151 }
2152
2153 // Step 2: Move all operations of a GPU function inside
2154 // gpu.warp_execute_on_lane_0 operation.
2155 {
2156 RewritePatternSet patterns(&getContext());
2158
2159 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2160 signalPassFailure();
2161 return;
2162 }
2163 // At this point, we have moved the entire function body inside the
2164 // warpOp. Now move any scalar uniform code outside of the warpOp (like
2165 // GPU index ops, scalar constants, etc.). This will simplify the
2166 // later lowering and avoid custom patterns for these ops.
2167 getOperation()->walk([&](Operation *op) {
2168 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2169 vector::moveScalarUniformCode(warpOp);
2170 });
2171 }
2172 // Step 3: Apply subgroup to workitem distribution patterns.
2173 RewritePatternSet patterns(&getContext());
2175 // distributionFn is used by vector distribution patterns to determine the
2176 // distributed vector type for a given vector value. In XeGPU subgroup
2177 // distribution context, we compute this based on lane layout.
2178 auto distributionFn = [](Value val) {
2179 VectorType vecType = dyn_cast<VectorType>(val.getType());
2180 int64_t vecRank = vecType ? vecType.getRank() : 0;
2181 if (vecRank == 0)
2182 return AffineMap::get(val.getContext());
2183 // Get the layout of the vector type.
2184 xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
2185 // If no layout is specified, assume uniform case (no distribution).
2186 if (!layout)
2187 return AffineMap::get(val.getContext());
2188 // Expecting vector and layout rank to match.
2189 assert(layout.getRank() == vecRank &&
2190 "Expecting vector and layout rank to match");
2191 // A dimension is distributed only if layout suggests there are
2192 // multiple lanes assigned for this dimension and the shape can be evenly
2193 // distributed to those lanes.
2194 SmallVector<unsigned int> distributedDims;
2195 for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2196 if (v > 1 && vecType.getShape()[i] % v == 0)
2197 distributedDims.push_back(i);
2198 }
2199 return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
2200 val.getContext());
2201 };
2202 // TODO: shuffleFn is not used.
2203 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2204 int64_t warpSz) { return Value(); };
2205
2206 vector::populateDistributeReduction(
2207 patterns, xegpu::subgroupReduction,
2208 /*pattern benefit=*/PatternHierarchy::Regular);
2209
2210 vector::populatePropagateWarpVectorDistributionPatterns(
2211 patterns, distributionFn, shuffleFn,
2212 /*pattern benefit=*/PatternHierarchy::Regular);
2213 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2214 signalPassFailure();
2215 return;
2216 }
2217
2218 // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
2219 // due to tensor desc type mismatches created by using upstream distribution
2220 // patterns (scf.for). This cleanup should only be done if all the ops are
2221 // distributed successfully, if some ops are still not distributed and remains
2222 // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
2223 // breaking the IR.
2224 bool foundWarpOp = false;
2225 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2226 // Look for WarpOps that are not trivially dead.
2227 if (isOpTriviallyDead(warpOp))
2228 return WalkResult::advance();
2229 foundWarpOp = true;
2230 return WalkResult::interrupt();
2231 });
2232 if (foundWarpOp)
2233 return;
2234
2235 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2236 // We are only interested in UnrealizedConversionCastOps there were added
2237 // for resolving SIMT type mismatches.
2238 if (!op->getAttr(resolveSIMTTypeMismatch))
2239 return WalkResult::skip();
2240
2241 Value input = op.getOperand(0);
2242 Value output = op.getResult(0);
2243
2244 // Both input and output must have tensor descriptor types.
2245 xegpu::TensorDescType inputDescType =
2246 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
2247 xegpu::TensorDescType outputDescType =
2248 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
2249 assert(inputDescType && outputDescType &&
2250 "Unrealized conversion cast must have tensor descriptor types");
2251
2252 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
2253 // This occurs inside scf.for body to resolve the block argument type to
2254 // SIMT type.
2255 if (inputDescType.getLayout()) {
2256 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2257 if (argument) {
2258 argument.setType(output.getType());
2259 output.replaceAllUsesWith(argument);
2260 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2261 argument.getOwner()->getParentOp())) {
2262 auto result = loopOp.getTiedLoopResult(argument);
2263 result.setType(output.getType());
2264 }
2265 }
2266 }
2267
2268 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
2269 // conversions. This occurs at the yield op of scf.for body to go back
2270 // from SIMT type to original type.
2271 if (outputDescType.getLayout())
2272 output.replaceAllUsesWith(input);
2273
2274 if (op->use_empty())
2275 op->erase();
2276 return WalkResult::advance();
2277 });
2278
2279 xegpu::removeTemporaryLayoutAttrs(getOperation());
2280}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
operand_type_range getOperandTypes()
Definition Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0