MLIR 23.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
27#include "mlir/IR/Operation.h"
29#include "mlir/IR/TypeRange.h"
30#include "mlir/IR/Value.h"
31#include "mlir/IR/Visitors.h"
33#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53
54static const char *const resolveSIMTTypeMismatch =
55 "resolve_simt_type_mismatch"; // Attribute name for identifying
56 // UnrelizedConversionCastOp added to resolve
57 // SIMT type mismatches.
58
59namespace {
60
61//===----------------------------------------------------------------------===//
62// SIMT Distribution Patterns
63//===----------------------------------------------------------------------===//
64
65/// In certain cases, we may need to favor XeGPU specific distribution patterns
66/// over generic vector distribution patterns. In such cases, we can assign
67/// priorities to patterns.
68enum PatternHierarchy : unsigned { Regular = 1, AboveRegular = 2 };
69
70/// Helper function to resolve types if the distributed type out of
71/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
72/// Example 1:
73/// distributed type: vector<8x1xf32>
74/// expected type: vector<8xf32>
75/// resolved using,
76/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
77/// Example 2:
78/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
79/// expected type: xegpu.tensor_desc<8x16xf32>
80/// resolved using,
81/// %0 = unrealized_conversion_cast %1 :
82/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
83/// xegpu.tensor_desc<8x16xf32>
84template <typename T>
85static Value resolveDistributedTy(Value orig, T expected,
86 PatternRewriter &rewriter) {
87 // If orig and expected types are the same, return orig.
88 if (orig.getType() == expected)
89 return orig;
90 // If orig is a vector type, create a shape cast op to reconcile the types.
91 if (isa<VectorType>(orig.getType())) {
92 auto castOp =
93 vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
94 return castOp.getResult();
95 }
96 // If orig is a tensor descriptor type, create an unrealized conversion cast
97 // op to reconcile the types.
98 if (isa<xegpu::TensorDescType>(orig.getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
100 expected, orig);
101 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
102 return castOp.getResult(0);
103 }
104 llvm_unreachable("Unsupported type for reconciliation");
105 return orig;
106}
107
108/// Given a vector type and its distributed vector type, return the list of
109/// dimensions that are distributed.
110static SmallVector<int64_t> getDistributedDims(VectorType originalType,
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
114 SmallVector<int64_t> distributedDims;
115 for (int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
118 }
119 }
120 return distributedDims;
121}
122
123/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
124/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
125/// contained within a WarpExecuteOnLane0Op.
126/// Example:
127///
128/// ```
129/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
130/// ...
131/// ...
132/// gpu.return %result: vector<8x16xf32>
133/// }
134/// ```
135/// To
136/// ```
137/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
138/// %laneid = gpu.lane_id : index
139/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
140/// ...
141/// ...
142/// gpu.yield %result: vector<8x16xf32>
143/// }
144/// return %0
145/// }
146struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
147 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
148 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
149 PatternRewriter &rewriter) const override {
150 auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
151 if (!uArch)
152 return rewriter.notifyMatchFailure(
153 gpuFuncOp, "Subgroup distribution requires target attribute attached "
154 "to set the warp size");
155 if (!gpuFuncOp.getBody().hasOneBlock())
156 return rewriter.notifyMatchFailure(
157 gpuFuncOp, "expected gpu.func to have a single block");
158
159 // If the function only contains a single void return, skip.
160 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
161 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
162 }))
163 return failure();
164 // If the function already moved inside a warp_execute_on_lane0, skip.
165 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
166 return isa<gpu::WarpExecuteOnLane0Op>(op);
167 }))
168 return failure();
169 gpu::ReturnOp origReturnOp = dyn_cast_if_present<gpu::ReturnOp>(
170 gpuFuncOp.getBlocks().back().getTerminator());
171 if (!origReturnOp)
172 return rewriter.notifyMatchFailure(
173 gpuFuncOp, "expected gpu.func terminator to be gpu.return");
174 // Create a new function with the same signature and same attributes.
175 SmallVector<Type> workgroupAttributionsTypes =
176 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
177 [](BlockArgument arg) { return arg.getType(); });
178 SmallVector<Type> privateAttributionsTypes =
179 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
180 [](BlockArgument arg) { return arg.getType(); });
181 auto newGpuFunc = gpu::GPUFuncOp::create(
182 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
183 gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
184 privateAttributionsTypes);
185 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
186 // Create a WarpExecuteOnLane0Op with same arguments and results as the
187 // original gpuFuncOp.
188 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
189 auto laneId = gpu::LaneIdOp::create(
190 rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
191 /** upperBound = **/ mlir::IntegerAttr());
192 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
193 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
194 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
195 uArch->getSubgroupSize(), newGpuFunc.getArguments(),
196 newGpuFunc.getArgumentTypes());
197 Block &warpBodyBlock = warpOp.getBodyRegion().front();
198 // Replace the ReturnOp of the original gpu function with a YieldOp.
199 rewriter.setInsertionPointAfter(origReturnOp);
200 gpu::YieldOp::create(rewriter, origReturnOp.getLoc(),
201 origReturnOp.getOperands());
202 rewriter.eraseOp(origReturnOp);
203 // Move the original function body to the WarpExecuteOnLane0Op body.
204 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
205 warpOp.getBodyRegion().begin());
206 rewriter.eraseBlock(&warpBodyBlock);
207 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
208 rewriter.setInsertionPointAfter(warpOp);
209 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
210 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
211 return success();
212 }
213};
214
215/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
216/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
217/// still contain the original op that will not be used by the yield op (and
218/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
219/// arguments. Tensor descriptor shape is not distributed because it is a
220/// uniform value across all work items within the subgroup. However, the
221/// layout information is dropped in the new tensor descriptor type.
222///
223/// Example:
224///
225/// ```
226/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
227/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
228/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
229/// ...
230/// %td = xegpu.create_nd_tdesc %arg0
231/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
232/// vector.yield %td
233/// }
234/// ```
235/// To
236/// ```
237/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
238/// ...
239/// %dead = xegpu.create_nd_tdesc %arg0
240/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
241/// vector.yield %arg0, %dead
242/// }
243/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
244/// -> !xegpu.tensor_desc<4x8xf32>
245///
246/// ```
247struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
248 using gpu::WarpDistributionPattern::WarpDistributionPattern;
249 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
250 PatternRewriter &rewriter) const override {
251 OpOperand *operand =
252 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
253 if (!operand)
254 return rewriter.notifyMatchFailure(
255 warpOp, "warp result is not a xegpu::CreateNdDesc op");
256 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
257 unsigned operandIdx = operand->getOperandNumber();
258
259 xegpu::DistributeLayoutAttr layout = descOp.getType().getLayoutAttr();
260 if (!layout)
261 return rewriter.notifyMatchFailure(
262 descOp, "the tensor descriptor lacks layout attribute");
263 // CreateNdOp must not have offsets.
264 if (descOp.getMixedOffsets().size())
265 return rewriter.notifyMatchFailure(
266 descOp, "xegpu::CreateNdDescOp must not have offsets");
267
268 SmallVector<size_t> newRetIndices;
269 rewriter.setInsertionPoint(warpOp);
270 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
271 rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
272 /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
273
274 SmallVector<Value> newDescOperands = llvm::map_to_vector(
275 newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
276 rewriter.setInsertionPointAfter(newWarpOp);
277 xegpu::TensorDescType distributedTensorDescTy =
278 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
279 // does not contain layout info.
280 Value newDescOp = xegpu::CreateNdDescOp::create(
281 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
282 descOp->getAttrs());
283
284 Value distributedVal = newWarpOp.getResult(operandIdx);
285 // Resolve the distributed type to the expected type.
286 newDescOp =
287 resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
288 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
289 return success();
290 }
291};
292
293/// Distribute a store_nd op at the end of enclosing
294/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
295/// through the warp op interface they would be propagated as returned values.
296/// Source vector is distributed based on lane layout. Appropriate cast ops are
297/// inserted if the distributed types does not match expected xegpu SIMT types.
298///
299/// Example:
300///
301/// ```
302/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
303/// gpu.warp_execute_on_lane_0(%laneid) -> () {
304/// ...
305/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
306/// !xegpu.tensor_desc<4x8xf32, #layout0>
307/// }
308/// ```
309/// To
310/// ```
311/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
312/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
313/// ...
314/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
315/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
316/// }
317/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
318/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
319/// #layout0>
320/// -> !xegpu.tensor_desc<4x8xf32>
321/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
322/// !xegpu.tensor_desc<4x8xf32>
323///
324/// ```
325struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
326 using gpu::WarpDistributionPattern::WarpDistributionPattern;
327 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
328 PatternRewriter &rewriter) const override {
329 gpu::YieldOp yield = warpOp.getTerminator();
330 Operation *lastNode = yield->getPrevNode();
331 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
332 if (!storeOp)
333 return failure();
334
335 SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
336 // Expecting offsets to be present.
337 if (offsets.empty())
338 return rewriter.notifyMatchFailure(storeOp,
339 "the store op must have offsets");
340 SmallVector<Value> offsetsAsValues =
341 vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
342 SmallVector<Type> offsetTypes = llvm::map_to_vector(
343 offsetsAsValues, [](Value v) { return v.getType(); });
344 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
345 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
346 if (!layout)
347 return rewriter.notifyMatchFailure(
348 storeOp, "the source tensor descriptor lacks layout attribute");
349
350 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
351 xegpu::getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
352 if (failed(distributedTypeByWarpOpOrFailure))
353 return rewriter.notifyMatchFailure(storeOp,
354 "Failed to distribute the type");
355 VectorType distributedTypeByWarpOp =
356 distributedTypeByWarpOpOrFailure.value();
357
358 SmallVector<size_t> newRetIndices;
359 SmallVector<Value> newYieldedValues = {storeOp.getValue(),
360 storeOp.getTensorDesc()};
361 SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
362 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
363 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
364 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
365 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
366 // Create a new store op outside the warp op with the distributed vector
367 // type. Tensor descriptor is not distributed.
368 rewriter.setInsertionPointAfter(newWarpOp);
369 SmallVector<Value> newStoreOperands;
370
371 // For the value operand, there can be a mismatch between the vector type
372 // distributed by the warp op and (xegpu-specific) distributed type
373 // supported by the store op. Type mismatch must be resolved using
374 // appropriate cast op.
375 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
376 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
377 if (failed(storeNdDistributedValueTyOrFailure))
378 return rewriter.notifyMatchFailure(
379 storeOp, "Failed to get distributed vector type for the store op");
380 newStoreOperands.push_back(resolveDistributedTy(
381 newWarpOp.getResult(newRetIndices[0]),
382 storeNdDistributedValueTyOrFailure.value(), rewriter));
383 // For the tensor descriptor operand, the layout attribute is dropped after
384 // distribution. Types needs to be resolved in this case also.
385 xegpu::TensorDescType distributedTensorDescTy =
386 storeOp.getTensorDescType().dropLayouts();
387 newStoreOperands.push_back(
388 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
389 distributedTensorDescTy, rewriter));
390 // Collect offsets.
391 for (size_t i = 2; i < newRetIndices.size(); ++i)
392 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
393
394 auto newStoreOp =
395 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
396 newStoreOperands, storeOp->getAttrs());
397 xegpu::removeLayoutAttrs(newStoreOp);
398 rewriter.eraseOp(storeOp);
399 return success();
400 }
401};
402
403/// Distribute a load_nd op feeding into vector.yield op for the enclosing
404/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
405/// The warp op will still contain the original op that will not be used by
406/// the yield op (and should be cleaned up later). The yield op will
407/// bypass the load's arguments. Only the loaded vector is distributed
408/// according to lane layout and, tensor descriptor types is not
409/// distributed. Appropriate cast ops are inserted if the distributed types does
410/// not match expected xegpu SIMT types.
411///
412/// Example:
413///
414/// ```
415/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
416/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
417/// (vector<4x1xf32>) {
418/// ...
419/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
420/// ->
421/// vector<4x8xf32>
422/// gpu.yield %ld
423/// }
424/// ```
425/// To
426/// ```
427/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
428/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
429/// ...
430/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
431/// vector<4x8xf32> gpu.yield %dead, %arg0
432/// }
433/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
434/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
435/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
436/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
437///
438/// ```
439struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
440 using gpu::WarpDistributionPattern::WarpDistributionPattern;
441 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
442 PatternRewriter &rewriter) const override {
443 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
444 if (!isa<xegpu::LoadNdOp>(op))
445 return false;
446 // Make sure the same load op is the last operation in the warp op body.
447 // This ensure that load op is not sinked earlier violating any barrier
448 // synchronizations.
449 gpu::YieldOp yield = warpOp.getTerminator();
450 return yield->getPrevNode() == op;
451 });
452
453 if (!operand)
454 return rewriter.notifyMatchFailure(
455 warpOp, "warp result is not a xegpu::LoadNd op");
456
457 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
458 auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
459 if (!uArch)
460 return rewriter.notifyMatchFailure(
461 loadOp, "xegpu::LoadNdOp require target attribute attached to "
462 "determine transpose "
463 "requirement");
464 // Chip information is required to decide if the layout requires transpose
465 // effect.
466 // Expecting offsets to be present.
467 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
468 if (offsets.empty())
469 return rewriter.notifyMatchFailure(loadOp,
470 "the load op must have offsets");
471 SmallVector<Value> offsetsAsValues =
472 vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
473 SmallVector<Type> offsetTypes = llvm::map_to_vector(
474 offsetsAsValues, [](Value v) { return v.getType(); });
475
476 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
477 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
478 if (!layout)
479 return rewriter.notifyMatchFailure(
480 loadOp, "the source tensor descriptor lacks layout attribute");
481
482 unsigned operandIdx = operand->getOperandNumber();
483 VectorType distributedTypeByWarpOp =
484 cast<VectorType>(warpOp.getResult(operandIdx).getType());
485
486 SmallVector<size_t> newRetIndices;
487 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
488 SmallVector<Type> newYieldedTypes = {tensorDescTy};
489 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
490 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
491 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
492 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
493
494 // Create a new load op outside the warp op with the distributed vector
495 // type.
496 rewriter.setInsertionPointAfter(newWarpOp);
497 FailureOr<VectorType> loadNdDistValueTyOrFailure =
498 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
499 if (failed(loadNdDistValueTyOrFailure))
500 return rewriter.notifyMatchFailure(
501 loadOp, "Failed to get distributed vector type for the load op");
502 xegpu::TensorDescType distributedTensorDescTy =
503 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
504 // descriptor type does not
505 // contain layout info.
506 SmallVector<Value> newLoadOperands{
507 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
508 distributedTensorDescTy, rewriter)};
509 // Collect offsets.
510 for (size_t i = 1; i < newRetIndices.size(); ++i)
511 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
512 auto newLoadOp = xegpu::LoadNdOp::create(
513 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
514 newLoadOperands, loadOp->getAttrs());
515 xegpu::removeLayoutAttrs(newLoadOp);
516 // Set the packed attribute if the layout requires it.
517 newLoadOp.setPacked(xegpu::requirePacked(layout));
518 // Set the transpose attribute if the layout requires it.
519 if (xegpu::requireTranspose(layout, uArch))
520 newLoadOp.setTranspose(
521 DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
522 Value distributedVal = newWarpOp.getResult(operandIdx);
523 // There can be a conflict between the vector type distributed by the
524 // warp op and (xegpu-specific) distributed type supported by the load
525 // op. Resolve these mismatches by inserting a cast.
526 Value tyResolvedVal = resolveDistributedTy(
527 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
528 rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
529 return success();
530 }
531};
532
533/// Distribute a dpas op feeding into vector.yield op for the enclosing
534/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
535/// The warp op will still contain the original op that will not be used by
536/// the yield op (and should be cleaned up later). The yield op will
537/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
538/// distributed types does not match expected xegpu SIMT types.
539/// Example:
540/// ```
541/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
542/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
543/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
544/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
545/// (vector<8x1xf32>) {
546/// ...
547/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
548/// vector<8x16xf32>
549/// gpu.yield %dpas
550/// }
551/// ```
552/// To
553/// ```
554/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
555/// vector<8x1xf16>, vector<16x1xf16>) {
556/// ...
557/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
558/// -> vector<8x16xf32>
559/// gpu.yield %dead, %arg0, %arg1
560/// }
561/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
562/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
563/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
564/// vector<8xf32>
565/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
566/// ```
567struct DpasDistribution final : public gpu::WarpDistributionPattern {
568 using gpu::WarpDistributionPattern::WarpDistributionPattern;
569 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
570 PatternRewriter &rewriter) const override {
571 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
572 if (!operand)
573 return rewriter.notifyMatchFailure(warpOp,
574 "warp result is not a xegpu::Dpas op");
575
576 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
577 unsigned operandIdx = operand->getOperandNumber();
578
579 xegpu::LayoutAttr layoutA =
580 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
581 xegpu::LayoutAttr layoutB =
582 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
583 xegpu::LayoutAttr layoutOut =
584 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
585
586 if (!layoutA || !layoutB || !layoutOut)
587 return rewriter.notifyMatchFailure(
588 dpasOp,
589 "the xegpu::Dpas op lacks layout attribute for A, B or output");
590
591 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
592 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
593 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
594 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
595 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
596 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
597
598 if (failed(distLhsTypeByWarpOpOrFailure) ||
599 failed(distRhsTypeByWarpOpOrFailure) ||
600 failed(distResultTypeByWarpOpOrFailure))
601 return rewriter.notifyMatchFailure(
602 dpasOp,
603 "Failed to distribute the A, B or output types in xegpu::Dpas op");
604
605 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
606 dpasOp.getRhs()};
607 llvm::SmallVector<Type, 3> newYieldTypes{
608 distLhsTypeByWarpOpOrFailure.value(),
609 distRhsTypeByWarpOpOrFailure.value()};
610 // Dpas acc operand is optional.
611 if (dpasOp.getAcc()) {
612 newYieldValues.push_back(dpasOp.getAcc());
613 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
614 }
615 // Create a new warp op without the dpas.
616 SmallVector<size_t> newRetIndices;
617 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
618 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
619
620 FailureOr<VectorType> expectedDistLhsTyOrFailure =
621 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
622 FailureOr<VectorType> expectedDistRhsTyOrFailure =
623 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
624 FailureOr<VectorType> expectedDistResultTyOrFailure =
625 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
626
627 if (failed(expectedDistLhsTyOrFailure) ||
628 failed(expectedDistRhsTyOrFailure) ||
629 failed(expectedDistResultTyOrFailure))
630 return rewriter.notifyMatchFailure(
631 dpasOp,
632 "Failed to get distributed vector type for the dpas operands.");
633 // Create a new dpas op outside the warp op.
634 rewriter.setInsertionPointAfter(newWarpOp);
635 SmallVector<Value> newDpasOperands;
636 SmallVector<VectorType> newDpasOperandExpectedTypes;
637
638 // Resolve the distributed types with the original types.
639 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
640 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
641 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
642 if (dpasOp.getAcc())
643 newDpasOperandExpectedTypes.push_back(distributedResultTy);
644
645 for (unsigned i = 0; i < newRetIndices.size(); i++) {
646 newDpasOperands.push_back(
647 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
648 newDpasOperandExpectedTypes[i], rewriter));
649 }
650 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
651 distributedResultTy, newDpasOperands,
652 dpasOp->getAttrs());
653 xegpu::removeLayoutAttrs(newDpasOp);
654 Value distributedVal = newWarpOp.getResult(operandIdx);
655 // Resolve the output type.
656 Value typeResolved =
657 resolveDistributedTy(newDpasOp.getResult(),
658 distResultTypeByWarpOpOrFailure.value(), rewriter);
659 rewriter.replaceAllUsesWith(distributedVal, typeResolved);
660 return success();
661 }
662};
663
664/// Distribute a prefetch_nd op at the end of enclosing
665/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
666/// through the warp op interface they would be propagated as returned values.
667/// Tensor descriptor shape is not distributed because it is a uniform value
668/// across all work items within the subgroup. Appropriate cast ops are inserted
669/// if the distributed types does not match expected xegpu SIMT types.
670///
671/// Example:
672///
673/// ```
674/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
675/// gpu.warp_execute_on_lane_0(%laneid) -> () {
676/// ...
677/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
678/// }
679/// ```
680/// To
681/// ```
682/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
683/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
684/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
685/// index
686/// }
687/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
688/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
689/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
690///
691/// ```
692struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
693 using gpu::WarpDistributionPattern::WarpDistributionPattern;
694 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
695 PatternRewriter &rewriter) const override {
696 gpu::YieldOp yield = warpOp.getTerminator();
697 Operation *lastNode = yield->getPrevNode();
698 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
699 if (!prefetchOp)
700 return failure();
701
702 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
703 // PrefetchNdOp must have offsets.
704 if (offsets.empty())
705 return rewriter.notifyMatchFailure(prefetchOp,
706 "the prefetch op must have offsets");
707 SmallVector<Value> offsetsAsValues =
708 vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
709 SmallVector<Type> offsetTypes = llvm::map_to_vector(
710 offsetsAsValues, [](Value v) { return v.getType(); });
711
712 xegpu::DistributeLayoutAttr layout =
713 prefetchOp.getTensorDescType().getLayoutAttr();
714 if (!layout)
715 return rewriter.notifyMatchFailure(
716 prefetchOp, "the source tensor descriptor lacks layout attribute");
717
718 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
719 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
720 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
721 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
722 SmallVector<size_t> newRetIndices;
723 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
724 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
725 // Create a new prefetch op outside the warp op with updated tensor
726 // descriptor type. Source tensor descriptor require type resolution.
727 xegpu::TensorDescType newTensorDescTy =
728 prefetchOp.getTensorDescType().dropLayouts();
729 rewriter.setInsertionPointAfter(newWarpOp);
730 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
731 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
732 // Collect offsets.
733 for (size_t i = 1; i < newRetIndices.size(); ++i)
734 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
735 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
736 rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
737 prefetchOp->getAttrs());
738 xegpu::removeLayoutAttrs(newPrefetchOp);
739 rewriter.eraseOp(prefetchOp);
740 return success();
741 }
742};
743
744/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
745/// region. This will simply move the barrier op outside of the warp op.
746struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
747 using gpu::WarpDistributionPattern::WarpDistributionPattern;
748 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
749 PatternRewriter &rewriter) const override {
750 gpu::YieldOp yield = warpOp.getTerminator();
751 Operation *lastNode = yield->getPrevNode();
752 // The last node must be a gpu::BarrierOp.
753 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
754 if (!barrierOp)
755 return failure();
756 // Move the barrier op outside of the warp op.
757 rewriter.setInsertionPointAfter(warpOp);
758 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
759 barrierOp->getResultTypes(),
760 barrierOp->getOperands(), barrierOp->getAttrs());
761 rewriter.eraseOp(barrierOp);
762 return success();
763 }
764};
765
766/// Distribute a scattered store op. The offsets argument is required.
767/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
768/// The layouts are fixed and implicit: one offset/mask per lane.
769/// The pass changes the offset/mask vector shapes to a
770/// single-element vector, **it is assumed that their producer will also be
771/// distributed**. The payload vector also has a fixed distribution:
772/// no chunk size -> vector of one element.
773/// chunk size -> vector of the innermost dimension of the SG-payload.
774/// Example 1 (no chunk size):
775/// %mask = producer_op : vector<16xi1>
776/// %offset = producer_op : vector<16xindex>
777/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
778/// memref<256xf16>, vector<16xindex>, vector<16xi1>
779/// To
780/// %mask = producer_op : vector<1xi1>
781/// %offset = producer_op : vector<1xindex>
782/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
783/// memref<256xf16>, vector<1xindex>, vector<1xi1>
784/// Example 2 (chunk size, same mask and offsets):
785/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
786/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
787/// To
788/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
789/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
790///
791/// Note that the store distribution pattern also handles leading unit
792/// dimensions in the payload, mask and offsets vectors. In this case the store
793/// distribution will only change the dimensions corresponding to the SG
794/// distribution and keep the leading unit dimensions unchanged.
795/// For example, a store with payload vector<1x16xf16> with lane layout [1, 16 ]
796/// will be distributed as vector<1x1xf16>. Shapecast ops are inserted for the
797/// offset/mask/payload when necessary so that the distributed store is workign
798/// on 1D shape vector to match the HW capability.
799struct StoreDistribution final : public gpu::WarpDistributionPattern {
800 using gpu::WarpDistributionPattern::WarpDistributionPattern;
801 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
802 PatternRewriter &rewriter) const override {
803 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
804 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
805 if (!storeScatterOp)
806 return failure();
807 Value offsets = storeScatterOp.getOffsets();
808 if (!isa<VectorType>(offsets.getType()))
809 return rewriter.notifyMatchFailure(
810 storeScatterOp, "Store op must have a vector of offsets argument");
811 VectorType offsetsTy = cast<VectorType>(offsets.getType());
812 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
813 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
814
815 // Add handling for leading unit dimensions support
816 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
817 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
818
819 // Check that all leading dimensions are unit dimensions
820 for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
821 if (storeVecTy.getShape()[i] != 1) {
822 return rewriter.notifyMatchFailure(
823 storeScatterOp, "Only unit dimensions allowed for the leading "
824 "dimensions of the store vector!");
825 }
826 }
827
828 auto layoutPayload = storeScatterOp.getLayoutAttr();
829 auto layoutOffsets =
830 xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
831 auto layoutMask = layoutOffsets;
832
833 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
834 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
835 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
836 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
837 FailureOr<VectorType> distMaskByWarpOpOrFailure =
838 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
839 if (failed(distStoreVecByWarpOpOrFailure) ||
840 failed(distOffsetsByWarpOpOrFailure) ||
841 failed(distMaskByWarpOpOrFailure)) {
842 return rewriter.notifyMatchFailure(
843 storeScatterOp,
844 "Some vector operands have no layouts, using defaults instead.");
845 }
846
847 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
848 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
849 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
850
851 SmallVector<size_t> newRetIndices;
852 SmallVector<Value> operands = storeScatterOp->getOperands();
853 SmallVector<Type> operandTypesToYield = {
854 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
855
856 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
857 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
858
859 rewriter.setInsertionPointAfter(newWarpOp);
860
861 // Distributed store payload type is always 1D without leading unit dims
862 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
863 distPayloadTy.getElementType());
864
865 VectorType distOffsetsTy1D = VectorType::get(
866 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
867 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
868 distMaskTy.getElementType());
869
870 // Resolve distributed types to 1D for SIMT execution
871 Value distPayloadVal = resolveDistributedTy(
872 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
873 Value distOffsetVal = resolveDistributedTy(
874 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
875 Value distMaskVal = resolveDistributedTy(
876 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
877
878 SmallVector<Value> newStoreScatterOpOperands = {
879 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
880 distMaskVal};
881
882 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
883 rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
884 storeScatterOp->getAttrs());
886 rewriter.eraseOp(storeScatterOp);
887 return success();
888 }
889};
890
891static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
892 PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
893 Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
894 SmallVector<Value> newCoods;
895 auto maybeCoords =
896 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
897 if (failed(maybeCoords))
898 return {};
899 assert(maybeCoords.value().size() == 1 &&
900 "Expected one set of distributed offsets");
902 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
903 getAsOpFoldResult(origOffsets));
904 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
905 return newCoods;
906}
907
908/// Pattern for distributing xegpu::LoadMatrixOp.
909struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
910 using gpu::WarpDistributionPattern::WarpDistributionPattern;
911 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
912 PatternRewriter &rewriter) const override {
913 gpu::YieldOp yield = warpOp.getTerminator();
914 Operation *lastNode = yield->getPrevNode();
915 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
916 if (!matrixOp)
917 return failure();
918
919 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
920 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
921 });
922 if (!producedByLastLoad)
923 return rewriter.notifyMatchFailure(
924 warpOp, "The last op is not xegpu::LoadMatrixOp");
925 const int operandIdx = producedByLastLoad->getOperandNumber();
926
927 VectorType sgPayloadTy =
928 dyn_cast<VectorType>(matrixOp.getResult().getType());
929 VectorType warpResultTy =
930 cast<VectorType>(warpOp.getResult(operandIdx).getType());
931 if (!sgPayloadTy)
932 return rewriter.notifyMatchFailure(
933 matrixOp, "the matrix op payload must be a vector type");
934
935 auto loc = matrixOp.getLoc();
936 auto offsets = matrixOp.getMixedOffsets();
937 if (offsets.empty())
938 return rewriter.notifyMatchFailure(matrixOp,
939 "the load op must have offsets");
940 SmallVector<Value> offsetsAsValues =
941 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
942
943 auto layout = matrixOp.getLayoutAttr();
944 if (!layout)
945 return rewriter.notifyMatchFailure(
946 matrixOp, "the matrix operation lacks layout attribute");
947
948 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
949 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
950 if (failed(distPayloadByWarpOpOrFailure))
951 return rewriter.notifyMatchFailure(
952 matrixOp, "Failed to distribute matrix op payload based on layout.");
953
954 SmallVector<Value> operands = {matrixOp.getMemDesc()};
955 const unsigned offsetsStartIdx = operands.size();
956 operands.append(offsetsAsValues);
957
958 SmallVector<Type> operandTypes =
959 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
960
961 SmallVector<size_t> newRetIndices;
962 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
963 rewriter, warpOp, operands, operandTypes, newRetIndices);
964 SmallVector<Value> newOperands = llvm::map_to_vector(
965 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
966
967 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
968 ShapedType::kDynamic);
969 DenseI64ArrayAttr newConstOffsetsAttr =
970 rewriter.getDenseI64ArrayAttr(newConstOffsets);
971 ValueRange currentOffsets =
972 ValueRange(newOperands).drop_front(offsetsStartIdx);
973
974 SmallVector<Value> newCoords = currentOffsets;
975 rewriter.setInsertionPointAfter(newWarpOp);
976
977 if (!matrixOp.getSubgroupBlockIoAttr()) {
978 newCoords = computeDistributedCoordinatesForMatrixOp(
979 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
980 currentOffsets);
981 }
982 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
983 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
984 newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
985 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
986 // Resolve the output type and replace all uses.
987 rewriter.replaceAllUsesWith(
988 newWarpOp.getResult(operandIdx),
989 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
990 return success();
991 }
992};
993
994/// Pattern for distributing xegpu::StoreMatrixOp.
995struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
996 using gpu::WarpDistributionPattern::WarpDistributionPattern;
997 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
998 PatternRewriter &rewriter) const override {
999 gpu::YieldOp yield = warpOp.getTerminator();
1000 Operation *lastNode = yield->getPrevNode();
1001 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1002 if (!matrixOp)
1003 return failure();
1004
1005 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1006 if (!sgPayloadTy)
1007 return rewriter.notifyMatchFailure(
1008 matrixOp, "the matrix op payload must be a vector type");
1009
1010 auto loc = matrixOp.getLoc();
1011 auto offsets = matrixOp.getMixedOffsets();
1012 if (offsets.empty())
1013 return rewriter.notifyMatchFailure(matrixOp,
1014 "the store op must have offsets");
1015 SmallVector<Value> offsetsAsValues =
1016 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1017
1018 auto layout = matrixOp.getLayoutAttr();
1019 if (!layout)
1020 return rewriter.notifyMatchFailure(
1021 matrixOp, "the matrix operation lacks layout attribute");
1022
1023 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1024 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1025 if (failed(distPayloadByWarpOpOrFailure))
1026 return rewriter.notifyMatchFailure(
1027 matrixOp, "Failed to distribute matrix op payload based on layout.");
1028
1029 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1030 const unsigned offsetsStartIdx = operands.size();
1031 operands.append(offsetsAsValues);
1032
1033 SmallVector<Type> operandTypes =
1034 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
1035 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1036
1037 SmallVector<size_t> newRetIndices;
1038 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1039 rewriter, warpOp, operands, operandTypes, newRetIndices);
1040 SmallVector<Value> newOperands = llvm::map_to_vector(
1041 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1042
1043 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1044 ShapedType::kDynamic);
1045 DenseI64ArrayAttr newConstOffsetsAttr =
1046 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1047 ValueRange currentOffsets =
1048 ValueRange(newOperands).drop_front(offsetsStartIdx);
1049
1050 SmallVector<Value> newCoords = currentOffsets;
1051 rewriter.setInsertionPointAfter(newWarpOp);
1052
1053 if (!matrixOp.getSubgroupBlockIoAttr()) {
1054 newCoords = computeDistributedCoordinatesForMatrixOp(
1055 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1056 currentOffsets);
1057 }
1058
1059 xegpu::StoreMatrixOp::create(
1060 rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1061 ValueRange(newCoords), newConstOffsetsAttr,
1062 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1063 rewriter.eraseOp(matrixOp);
1064 return success();
1065 }
1066};
1067
1068/// Distribute a scattered load op. The logic and requirements are the same as
1069/// for the scattered store distribution. The warpOp's payload vector is
1070/// expected to be distributed by the load's result consumer.
1071/// Example 1 (no chunk size):
1072/// %mask = producer_op : vector<16xi1>
1073/// %offset = producer_op : vector<16xindex>
1074/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1075/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
1076/// To
1077/// %mask = producer_op : vector<1xi1>
1078/// %offset = producer_op : vector<1xindex>
1079/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1080/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
1081/// Example 2 (chunk size, same mask and offsets):
1082/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1083/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
1084/// To
1085/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1086/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
1087///
1088/// Note that the load distribution pattern also handles leading unit dimensions
1089/// in the payload, mask, and offsets vector.The load distribution will only
1090/// change the dimensions corresponding to the SG distribution and keep the
1091/// leading unit dimensions unchanged. For example, a load with result type
1092/// vector<1x16xf16> with lane layout [1, 16 ] will be distributed
1093/// as result type vector<1x1xf16>. Shapecast ops are inserted for the
1094/// offset/mask/payload when necessary so that the distributed load is workign
1095/// on 1D shape vector to match the HW capability.
1096struct LoadDistribution final : public gpu::WarpDistributionPattern {
1097 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1098 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1099 PatternRewriter &rewriter) const override {
1100 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1101 // Check if the yield operand that was produced by the *last* scattered
1102 // load op to avoid sinking it before barriers (maintain memory order).
1103 return isa<xegpu::LoadGatherOp>(op) &&
1104 warpOp.getTerminator()->getPrevNode() == op;
1105 });
1106 if (!producedByLastLoad)
1107 return rewriter.notifyMatchFailure(
1108 warpOp, "The last op is not xegpu::LoadGatherOp");
1109
1110 auto loadGatherOp =
1111 producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
1112 Value offsets = loadGatherOp.getOffsets();
1113 if (!isa<VectorType>(offsets.getType()) ||
1114 !isa<VectorType>(loadGatherOp.getMask().getType()))
1115 return rewriter.notifyMatchFailure(
1116 loadGatherOp,
1117 "Load op must have vector arguments for offsets and mask");
1118 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1119 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1120 VectorType resultVecTy =
1121 cast<VectorType>(loadGatherOp.getResult().getType());
1122 // add handling leading unit dimensions support
1123 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1124 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1125 for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1126 if (resultVecTy.getShape()[i] != 1) {
1127 return rewriter.notifyMatchFailure(
1128 loadGatherOp, "Only unit dimensions allowed for the leading "
1129 "dimensions of the load vector!");
1130 }
1131 }
1132
1133 auto layoutPayload = loadGatherOp.getLayoutAttr();
1134 auto layoutOffsets =
1135 xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
1136 auto layoutMask = layoutOffsets;
1137
1138 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1139 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1140 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1141 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1142 if (failed(distOffsetsByWarpOpOrFailure) ||
1143 failed(distMaskByWarpOpOrFailure)) {
1144 return rewriter.notifyMatchFailure(
1145 loadGatherOp,
1146 "Some vector operands have no layouts, using defaults instead.");
1147 }
1148
1149 SmallVector<size_t> newRetIndices;
1150 SmallVector<Value> operands = loadGatherOp->getOperands();
1151
1152 const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1153 VectorType distResultTy =
1154 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1155 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1156 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1157
1158 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1159 distOffsetsTy, distMaskTy};
1160
1161 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1162 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1163
1164 rewriter.setInsertionPointAfter(newWarpOp);
1165
1166 // Distributed load op will always be 1D.
1167 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1168 distResultTy.getElementType());
1169
1170 VectorType distOffsetsTy1D =
1171 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1172 distOffsetsByWarpOpOrFailure.value().getElementType());
1173 VectorType distMaskTy1D =
1174 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1175 distMaskByWarpOpOrFailure.value().getElementType());
1176
1177 Value distOffsetVal = resolveDistributedTy(
1178 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1179 Value distmaskVal = resolveDistributedTy(
1180 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1181
1182 SmallVector<Value> newLoadGatherOperands = {
1183 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1184
1185 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1186 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1187 loadGatherOp->getAttrs());
1189 Value distributedVal = newWarpOp.getResult(operandIdx);
1190 // Resolve the output type and replace all uses.
1191 rewriter.replaceAllUsesWith(
1192 distributedVal,
1193 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1194 return success();
1195 }
1196};
1197
1198// Sink SG-uniform ops. An op is uniform if none
1199// of its operands/results has a distribution layout attribute.
1200// Non-uniform vectors are handled by dedicated patterns.
1201// This pattern must have a higher priority than vector dialect distribution
1202// patterns, because a distributable shape may be logically intended as
1203// uniform (i.e., no layout), so we want to omit its distribution.
1204struct SinkUniformOps final : public gpu::WarpDistributionPattern {
1205 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1206 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1207 PatternRewriter &rewriter) const override {
1208 // Take the last op
1209 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1210 // Any ops with nested regions must be handled carefully in dedicated
1211 // patterns.
1212 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
1213 return failure();
1214 int operandIdx = -1;
1215 if (warpRegionPreYieldOp->getNumResults()) {
1216 OpOperand *operand = getWarpResult(
1217 warpOp, [&](Operation *op) { return warpRegionPreYieldOp == op; });
1218 if (!operand)
1219 return failure();
1220 operandIdx = operand->getOperandNumber();
1221 if (warpRegionPreYieldOp->getResult(0).getType() !=
1222 warpOp.getResult(operandIdx).getType())
1223 return rewriter.notifyMatchFailure(warpOp,
1224 "The op result is not uniform.");
1225 }
1226
1227 // The op must have no layout-based operands or results.
1228 bool uniformValuesOnly =
1229 llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
1230 return !xegpu::getDistributeLayoutAttr(v);
1231 });
1232 uniformValuesOnly &=
1233 llvm::all_of(warpRegionPreYieldOp->getOpOperands(), [](OpOperand &opr) {
1234 return !xegpu::getDistributeLayoutAttr(opr);
1235 });
1236 if (!uniformValuesOnly)
1237 return rewriter.notifyMatchFailure(warpOp,
1238 "Some values are not uniform.");
1239 SmallVector<size_t> newRetIndices;
1240 SmallVector<Value> operands =
1241 llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
1242 SmallVector<Type> operandTypes =
1243 llvm::to_vector_of<Type>(warpRegionPreYieldOp->getOperandTypes());
1244 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1245 rewriter, warpOp, operands, operandTypes, newRetIndices);
1246
1247 rewriter.setInsertionPointAfter(newWarpOp);
1248 IRMapping operandMapper;
1249 for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1250 operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),
1251 newWarpOp->getResult(newOperandIdx));
1252 Operation *clonedOp = rewriter.clone(*warpRegionPreYieldOp, operandMapper);
1253 if (!clonedOp->getNumResults())
1254 rewriter.eraseOp(warpRegionPreYieldOp);
1255 else {
1256 assert(operandIdx != -1 && "Expected a warp result for the operation");
1257 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx),
1258 clonedOp->getResult(0));
1259 }
1260 return success();
1261 }
1262};
1263
1264/// This patterns distribute the `vector.multi_reduction` operation across
1265/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1266/// layouts for the source and accumulator vectors,
1267/// * If the reduction dimension is distributed across lanes, the reduction is
1268/// non-lane-local and the reduction is done using warp shuffles. Here we
1269/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1270/// the warp op body.
1271/// * If the reduction dimension is not distributed across lanes, the reduction
1272/// is lane-local. In this case, we yield the source and accumulator vectors
1273/// from the warp op and perform the lane-local reduction outside the warp op
1274/// using a sequence of ReductionOps.
1275/// Example 1 (Reduction is lane-local):
1276/// ```
1277/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1278/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1279/// %acc = "some_def"() : () -> (vector<32xf32>)
1280/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1281/// vector<32xf32> gpu.yield %1 : vector<32xf32>
1282/// }
1283/// ```
1284/// is lowered to:
1285/// ```
1286/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1287/// vector<1xf32>) {
1288/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1289/// %acc = "some_def"() : () -> (vector<32xf32>)
1290/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1291/// }
1292/// %c = arith.constant dense<0.0> : vector<1xf32>
1293/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1294/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1295/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1296/// ```
1297/// Example 2 (Reduction is non-lane-local):
1298/// ```
1299/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1300/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1301/// %acc = "some_def"() : () -> (vector<2xf32>)
1302/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1303/// vector<2xf32>
1304/// gpu.yield %1 : vector<2xf32>
1305/// }
1306/// ```
1307/// is lowered to:
1308/// ```
1309/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1310/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1311/// %acc = "some_def"() : () -> (vector<2xf32>)
1312/// %1 = arith.constant dense<0.0> : vector<2xf32>
1313/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1314/// %3 = ("warp.reduction %2") : f32
1315/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1316/// ... repeat for row 1
1317/// gpu.yield %1 : vector<2xf32>
1318/// }
1319struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1320 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1321 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1322 PatternRewriter &rewriter) const override {
1323 OpOperand *yieldOperand =
1324 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1325 if (!yieldOperand)
1326 return failure();
1327 auto reductionOp =
1328 cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1329 unsigned operandIdx = yieldOperand->getOperandNumber();
1330 VectorType sourceType = reductionOp.getSourceVectorType();
1331 int64_t sourceRank = sourceType.getRank();
1332 // Need at least a 2D source vector.
1333 if (sourceRank < 2)
1334 return rewriter.notifyMatchFailure(warpOp,
1335 "Only 2D+ reductions are supported.");
1336 // Leading dimensions (first rank-2) must be unit (size 1).
1337 for (int64_t i = 0; i < sourceRank - 2; ++i) {
1338 if (sourceType.getShape()[i] != 1)
1339 return rewriter.notifyMatchFailure(
1340 warpOp, "Only unit dimensions allowed for the leading dimensions.");
1341 }
1342 // Effective dimension indices (last 2 dims of the source).
1343 int64_t rowIdx = sourceRank - 2;
1344 int64_t columnIdx = sourceRank - 1;
1345 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1346 if (reductionDims.size() != 1)
1347 return rewriter.notifyMatchFailure(warpOp,
1348 "Only 1 reduction dim is supported.");
1349 int64_t reductionDim = reductionDims[0];
1350 // The reduction dim must be among the last 2 dims.
1351 if (reductionDim != rowIdx && reductionDim != columnIdx)
1352 return rewriter.notifyMatchFailure(
1353 warpOp, "Reduction dim must be among the last 2 dimensions.");
1354 VectorType distributedResultType =
1355 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1356 VectorType resultType = cast<VectorType>(reductionOp.getType());
1357 xegpu::DistributeLayoutAttr sourceLayout =
1358 xegpu::getTemporaryLayout(reductionOp->getOpOperand(0));
1359
1360 FailureOr<VectorType> sourceDistTypeOrFailure =
1361 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1362 if (failed(sourceDistTypeOrFailure))
1363 return rewriter.notifyMatchFailure(
1364 warpOp, "Failed to distribute the source vector type.");
1365 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1366 // Only single dimension distribution among the last 2 dims is supported.
1367 bool rowDistributed =
1368 sourceDistType.getShape()[rowIdx] != sourceType.getShape()[rowIdx];
1369 bool columnDistributed = sourceDistType.getShape()[columnIdx] !=
1370 sourceType.getShape()[columnIdx];
1371 if (rowDistributed && columnDistributed)
1372 return rewriter.notifyMatchFailure(
1373 warpOp, "Expecting source to be distributed in a single dimension.");
1374 int64_t sourceDistDim =
1375 rowDistributed ? rowIdx : (columnDistributed ? columnIdx : -1);
1376 if (sourceDistDim == -1)
1377 return rewriter.notifyMatchFailure(
1378 warpOp, "Expecting a distributed source vector.");
1379 bool resultDistributed =
1380 distributedResultType.getNumElements() < resultType.getNumElements();
1381 // If the lane owns all the data required for reduction (i.e. reduction is
1382 // fully parallel accross lanes), then each lane owns part of the result
1383 // (i.e. result is distributed). If the reduction require cross-lane
1384 // shuffling, then the result is shared among all lanes (broadcasted).
1385 // Therefore we expect following cases:
1386 //
1387 // | Source vector | Reduction dim | Result vector |
1388 // |----------------------|----------------|----------------|
1389 // | dim-0 distributed | 0 | broadcasted |
1390 // | dim-0 distributed | 1 | distributed |
1391 // | dim-1 distributed | 0 | distributed |
1392 // | dim-1 distributed | 1 | broadcasted |
1393
1394 bool isReductionLaneLocal =
1395 (sourceDistDim == rowIdx && reductionDim == columnIdx) ||
1396 (sourceDistDim == columnIdx && reductionDim == rowIdx);
1397 if (isReductionLaneLocal && !resultDistributed)
1398 return rewriter.notifyMatchFailure(
1399 warpOp, "Expecting a distributed result for lane-local reduction.");
1400
1401 if (!isReductionLaneLocal && resultDistributed)
1402 return rewriter.notifyMatchFailure(
1403 warpOp,
1404 "Expecting a broadcasted result for non-lane-local reduction.");
1405
1406 // Handle lane-local reduction case. In this case we fully distribute the
1407 // reduction result.
1408 if (isReductionLaneLocal) {
1409 // Yield the source and acc vectors from the WarpOp.
1410 SmallVector<size_t> newRetIndices;
1411 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1412 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1413 {sourceDistType, distributedResultType}, newRetIndices);
1414 rewriter.setInsertionPointAfter(newWarpOp);
1416 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1417 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1418 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1419 // Replace the warp op result with the final result.
1420 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1421 return success();
1422 }
1423 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1424 // of multiple ReductionOps. Actual distribution is done by the
1425 // WarpOpReduction pattern.
1426 rewriter.setInsertionPointAfter(reductionOp);
1428 cast<TypedValue<VectorType>>(reductionOp.getSource()),
1429 cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1430 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1431 // Replace the warp op result with the final result.
1432 rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1433 return success();
1434 }
1435};
1436
1437/// This pattern distributes the `vector.broadcast` operation across lanes in a
1438/// warp. The pattern supports three use cases:
1439///
1440/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1441/// vector
1442/// must have a slice layout of the result. If the distributed source and
1443/// target vector types are identical, this lowers to a no-op; otherwise, it
1444/// remains a broadcast but operates on distributed vectors.
1445///
1446/// 2) Broadcast a same-rank vector with identical layouts for source and
1447/// target:
1448/// The source vector must have unit dimensions, and lane_data must be unit
1449/// size for those unit dims. This always lowers to a no-op.
1450///
1451/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
1452/// scalar to distributed result type.
1453///
1454/// Example 1 (lowering to a broadcast with distributed types):
1455/// ```
1456/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1457/// %0 = "some_def"() {layout_result_0 =
1458/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1459/// dims = [0]> } : () -> (vector<32xf32>)
1460/// %2 = vector.broadcast %0 {layout_result_0 =
1461/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
1462/// : vector<32xf32> to vector<8x32xf32>
1463/// gpu.yield %1 : vector<8x32xf32>
1464/// }
1465/// ```
1466/// is lowered to:
1467/// ```
1468/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1469/// %0 = "some_def"() {layout_result_0 =
1470/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1471/// dims = [0]> } : () -> (vector<32xf32>)
1472/// gpu.yield %0 : vector<32xf32>
1473/// }
1474/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
1475///
1476/// Example 2 (no-op):
1477/// ```
1478/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
1479/// %0 = "some_def"() {layout_result_0 =
1480/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1481/// dims = [1]> } : () -> (vector<8xf32>)
1482/// %1 = vector.shape_cast %0
1483/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1484/// 1]>}: vector<8xf32> to vector<8x1xf32>
1485/// %2 = vector.broadcast %1
1486/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1487/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
1488/// gpu.yield %1 : vector<8x32xf32>
1489/// }
1490/// ```
1491/// is lowered to:
1492/// ```
1493/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1494/// %0 = "some_def"() {layout_result_0 =
1495/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1496/// dims = [1]> } : () -> (vector<8xf32>)
1497/// %1 = vector.shape_cast %0
1498/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1499/// 1]>}: vector<8xf32> to vector<8x1xf32>
1500/// gpu.yield %1 : vector<8x1xf32>
1501/// }
1502/// // The broadcast is implicit through layout transformation (no-op)
1503/// "some_use"(%r#0)
1504/// ```
1505struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
1506 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1507 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1508 PatternRewriter &rewriter) const override {
1509 OpOperand *yieldOperand =
1510 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1511 if (!yieldOperand)
1512 return failure();
1513 auto broadcastOp =
1514 cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
1515 unsigned operandIdx = yieldOperand->getOperandNumber();
1516
1517 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1518 VectorType destType =
1519 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1520
1521 xegpu::DistributeLayoutAttr sourceLayout =
1522 xegpu::getTemporaryLayout(broadcastOp->getOpOperand(0));
1523 xegpu::DistributeLayoutAttr resultLayout =
1524 xegpu::getTemporaryLayout(dyn_cast<OpResult>(broadcastOp.getResult()));
1525
1526 FailureOr<VectorType> sourceDistType;
1527 Type sourceElemOrDistType;
1528 if (sourceType) {
1529
1530 // Case 1 and 2: source is a vector type.
1531 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1532 if (rankDiff > 0) {
1533 // Case 1: source is lower-rank than result.
1534 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1535 if (!isSliceOf)
1536 broadcastOp.emitWarning()
1537 << "Broadcast input layout must be a slice of result layout.";
1538 }
1539 // case 2: source and result have same rank
1540 if (rankDiff == 0) {
1541 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1542 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1543 broadcastUnitDimsSet.end());
1544 assert(sourceLayout.isEqualTo(
1545 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1546 "The sg_data for unit dimensions should be set as 1");
1547 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1548 }
1549
1550 sourceDistType =
1551 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1552 if (failed(sourceDistType)) {
1553 return rewriter.notifyMatchFailure(
1554 warpOp, "Failed to distribute the source vector type.");
1555 }
1556 sourceElemOrDistType = sourceDistType.value();
1557
1558 } else {
1559 // Case 3: source is a scalar type.
1560 if (sourceLayout) {
1561 return rewriter.notifyMatchFailure(
1562 warpOp, "Broadcast from scalar must not have a layout attribute.");
1563 }
1564 sourceElemOrDistType = broadcastOp.getSourceType();
1565 }
1566 FailureOr<VectorType> destDistType =
1567 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1568 if (failed(destDistType)) {
1569 return rewriter.notifyMatchFailure(
1570 warpOp, "Failed to distribute the dest vector type.");
1571 }
1572
1573 SmallVector<size_t> newRetIndices;
1575 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1576 newRetIndices);
1577
1578 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1579
1580 Value newBroadcast = distributedSource;
1581
1582 if (sourceElemOrDistType != destDistType.value()) {
1583 rewriter.setInsertionPointAfter(newWarpOp);
1584 newBroadcast =
1585 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1586 destDistType.value(), distributedSource);
1587 }
1588
1589 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
1590 return success();
1591 }
1592};
1593
1594/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1595/// `gpu.warp_execute_on_lane_0` region.
1596struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1597 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1598 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1599 PatternRewriter &rewriter) const override {
1600 OpOperand *yieldOperand =
1601 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1602 if (!yieldOperand)
1603 return failure();
1604 auto shapeCastOp =
1605 cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1606 unsigned operandNumber = yieldOperand->getOperandNumber();
1607 auto resultDistTy =
1608 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1609 xegpu::DistributeLayoutAttr sourceLayout =
1610 xegpu::getTemporaryLayout(shapeCastOp->getOpOperand(0));
1611 xegpu::DistributeLayoutAttr resultLayout =
1612 xegpu::getTemporaryLayout(dyn_cast<OpResult>(shapeCastOp.getResult()));
1613 if (!sourceLayout || !resultLayout)
1614 return rewriter.notifyMatchFailure(
1615 warpOp,
1616 "the source or result of shape_cast op lacks distribution layout");
1617
1618 FailureOr<VectorType> sourceDistTypeOrFailure =
1619 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1620 shapeCastOp.getSourceVectorType());
1621 if (failed(sourceDistTypeOrFailure))
1622 return rewriter.notifyMatchFailure(
1623 warpOp, "failed to get distributed vector type for source");
1624 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1625 // Create a new warp op that yields the source of the shape_cast op.
1626 SmallVector<size_t> newRetIndices;
1628 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1629 newRetIndices);
1630 rewriter.setInsertionPointAfter(newWarpOp);
1631 Value source = newWarpOp.getResult(newRetIndices[0]);
1632 // Create a new shape_cast op outside the warp op.
1633 Value newShapeCast = vector::ShapeCastOp::create(
1634 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1635 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1636 newShapeCast);
1637 return success();
1638 }
1639};
1640
1641// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1642// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1643// advanced cases where the distributed dimension is partially extracted and
1644// currently not supported by the generic vector distribution patterns.
1645struct VectorExtractStridedSliceDistribution
1647 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1648 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1649 PatternRewriter &rewriter) const override {
1650 OpOperand *operand =
1651 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1652 if (!operand)
1653 return failure();
1654 auto extractOp =
1655 cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1656 unsigned operandIdx = operand->getOperandNumber();
1657 auto distributedType =
1658 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1659 // Find the distributed dimensions.
1660 auto extractResultType = cast<VectorType>(operand->get().getType());
1661 auto distributedDims =
1662 getDistributedDims(extractResultType, distributedType);
1663 // Collect updated source type, sizes and offsets. They may be adjusted
1664 // later if the data is distributed to lanes (as opposed to being owned by
1665 // all lanes uniformly).
1666 VectorType updatedSourceType = extractOp.getSourceVectorType();
1667 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1668 extractOp.getSizes(), [](Attribute attr) { return attr; });
1669 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1670 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1671 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1672 extractOp.getStrides(), [](Attribute attr) { return attr; });
1673 // If the provided sizes, offsets, strides are less than the rank, pad them
1674 // with full sizes, zero offsets, and unit strides. This makes it easier to
1675 // adjust them later.
1676 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1677 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1678 updatedSizes.push_back(rewriter.getI64IntegerAttr(
1679 extractOp.getSourceVectorType().getDimSize(i)));
1680 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1681 updatedStrides.push_back(
1682 rewriter.getI64IntegerAttr(1)); // stride is always 1.
1683 }
1684 // If the result is distributed, it must be distributed in exactly one
1685 // dimension. In this case, we adjust the sourceDistType, distributedSizes
1686 // and distributedOffsets accordingly.
1687 if (distributedDims.size() > 0) {
1688 if (distributedDims.size() != 1)
1689 return rewriter.notifyMatchFailure(
1690 warpOp, "Source can not be distributed in multiple dimensions.");
1691 int64_t distributedDim = distributedDims[0];
1692 int sourceDistrDimSize =
1693 extractOp.getSourceVectorType().getShape()[distributedDim];
1694 auto sourceLayout = xegpu::getTemporaryLayout(extractOp->getOpOperand(0));
1695 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1696 return rewriter.notifyMatchFailure(
1697 warpOp, "the source of extract_strided_slice op lacks distribution "
1698 "layout");
1699 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1700 // Because only single dimension distribution is supported, lane layout
1701 // size at the distributed dim must be the subgroup size.
1702 int subgroupSize = sourceLaneLayout[distributedDim];
1703 // Check if the source size in the distributed dimension is a multiple of
1704 // subgroup size.
1705 if (sourceDistrDimSize % subgroupSize != 0)
1706 return rewriter.notifyMatchFailure(
1707 warpOp,
1708 "Source size along distributed dimension is not a multiple of "
1709 "subgroup size.");
1710 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1711 // We expect lane data to be all ones in this case.
1712 if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1713 return rewriter.notifyMatchFailure(
1714 warpOp, "Expecting unit lane data in source layout");
1715 // The offsets in the distributed dimention must be a multiple of subgroup
1716 // size.
1717 int64_t distrDimOffset =
1718 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1719 if (distrDimOffset % subgroupSize != 0)
1720 return rewriter.notifyMatchFailure(
1721 warpOp, "Offset along distributed dimension "
1722 "is not a multiple of subgroup size.");
1723 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1724 sourceLayout, extractOp.getSourceVectorType())
1725 .value();
1726 // Update the distributed sizes to match the distributed type.
1727 updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1728 distributedType.getDimSize(distributedDim));
1729 // Update the distributed offsets to match round robin distribution (i.e.
1730 // each lane owns data at `subgroupSize` stride given unit lane data).
1731 updatedOffsets[distributedDim] =
1732 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1733 }
1734 // Do the distribution by yielding the source of the extract op from
1735 // the warp op and creating a new extract op outside the warp op.
1736 SmallVector<size_t> newRetIndices;
1738 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1739 newRetIndices);
1740 rewriter.setInsertionPointAfter(newWarpOp);
1741 Value source = newWarpOp.getResult(newRetIndices[0]);
1742 // Create a new extract op outside the warp op.
1743 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1744 rewriter, extractOp.getLoc(), distributedType, source,
1745 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1746 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1747 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1748 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1749 return success();
1750 }
1751};
1752
1753/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1754/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1755/// advanced cases where the distributed dimension is partially inserted and
1756/// currently not supported by the generic vector distribution patterns.
1757struct VectorInsertStridedSliceDistribution
1759 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1760 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1761 PatternRewriter &rewriter) const override {
1762 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
1763 // Check if the InsertStridedSliceOp is the last op before yield op
1764 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1765 warpOp.getTerminator()->getPrevNode() == op;
1766 });
1767 if (!operand)
1768 return failure();
1769 unsigned int operandNumber = operand->getOperandNumber();
1770 auto insertOp =
1771 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1772 auto distributedType =
1773 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1774 // Find the distributed dimensions of the dest vector.
1775 auto insertResultType = cast<VectorType>(operand->get().getType());
1776 auto destDistributedDims =
1777 getDistributedDims(insertResultType, distributedType);
1778 // Collect updated offsets, source type and dest type. They may be adjusted
1779 // later if the data is distributed to lanes (as opposed to being owned by
1780 // all lanes uniformly).
1781 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1782 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1783 VectorType updatedSourceType = insertOp.getSourceVectorType();
1784 VectorType updatedDestType = insertOp.getDestVectorType();
1785 if (destDistributedDims.size() > 0) {
1786 // Only single dimension distribution is supported.
1787 if (destDistributedDims.size() != 1)
1788 return rewriter.notifyMatchFailure(
1789 warpOp,
1790 "Expecting source to be distributed in a single dimension.");
1791 int64_t destDistributedDim = destDistributedDims[0];
1792
1793 VectorType srcType = insertOp.getSourceVectorType();
1794 VectorType destType = insertOp.getDestVectorType();
1795 // Currently we require that both source (kD) and dest (nD) vectors are
1796 // distributed. This requires that distributedDim (d) is contained in the
1797 // last k dims of the dest vector (d >= n - k).
1798 int64_t sourceDistributedDim =
1799 destDistributedDim - (destType.getRank() - srcType.getRank());
1800 if (sourceDistributedDim < 0)
1801 return rewriter.notifyMatchFailure(
1802 insertOp,
1803 "distributed dimension must be in the last k (i.e. source "
1804 "rank) dims of dest vector");
1805 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1806 // Obtain the source and dest layouts.
1807 auto destLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(1));
1808 auto sourceLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(0));
1809 if (!destLayout || !sourceLayout ||
1810 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1811 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1812 return rewriter.notifyMatchFailure(
1813 warpOp, "the source or dest of insert_strided_slice op lacks "
1814 "distribution layout");
1815 // Because only single dimension distribution is supported, lane layout
1816 // size at the distributed dim must be the subgroup size.
1817 int subgroupSize =
1818 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1819 // We require that source and dest lane data are all ones to ensure
1820 // uniform round robin distribution.
1821 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1822 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1823 if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1824 !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1825 return rewriter.notifyMatchFailure(
1826 warpOp, "Expecting unit lane data in source and dest layouts");
1827 // Source distributed dim size must be multiples of subgroup size.
1828 if (srcDistrDimSize % subgroupSize != 0)
1829 return rewriter.notifyMatchFailure(
1830 warpOp, "Distributed dimension size in source is not a multiple of "
1831 "subgroup size.");
1832 // Offsets in the distributed dimension must be multiples of subgroup
1833 // size.
1834 int64_t destDistrDimOffset =
1835 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1836 if (destDistrDimOffset % subgroupSize != 0)
1837 return rewriter.notifyMatchFailure(
1838 warpOp,
1839 "Offset along distributed dimension in dest is not a multiple of "
1840 "subgroup size.");
1841 // Update the source and dest types based on their layouts.
1842 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1843 sourceLayout, insertOp.getSourceVectorType())
1844 .value();
1845 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1846 destLayout, insertOp.getDestVectorType())
1847 .value();
1848 // Update the distributed offsets to match round robin distribution (i.e.
1849 // each lane owns data at `subgroupSize` stride given unit lane data).
1850 updatedOffsets[destDistributedDim] =
1851 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1852 }
1853 // Do the distribution by yielding the source and dest of the insert op
1854 // from the warp op and creating a new insert op outside the warp op.
1855 SmallVector<size_t> newRetIndices;
1857 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1858 {updatedSourceType, updatedDestType}, newRetIndices);
1859 rewriter.setInsertionPointAfter(newWarpOp);
1860
1861 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1862 Value dest = newWarpOp.getResult(newRetIndices[1]);
1863 // Create a new insert op outside the warp op.
1864 Value newInsertOp = vector::InsertStridedSliceOp::create(
1865 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1866 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1867 insertOp.getStrides());
1868 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1869 newInsertOp);
1870 return success();
1871 }
1872};
1873
1874/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1875/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1876/// outside of the warp op.
1877struct MemrefExtractAlignedPointerAsIndexDistribution final
1879 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1880 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1881 PatternRewriter &rewriter) const override {
1882 OpOperand *operand = getWarpResult(
1883 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1884 if (!operand)
1885 return rewriter.notifyMatchFailure(
1886 warpOp,
1887 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1888 auto extractOp =
1889 operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1890 unsigned operandIdx = operand->getOperandNumber();
1891 SmallVector<size_t> newRetIndices;
1892 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1893 rewriter, warpOp, extractOp.getSource(),
1894 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1895 rewriter.setInsertionPointAfter(newWarpOp);
1896 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1897 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1898 newWarpOp.getResult(newRetIndices[0]));
1899 Value resultVal = newWarpOp.getResult(operandIdx);
1900 rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
1901 return success();
1902 }
1903};
1904
1905/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1906/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1907/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1908/// created outside of the warp op with distributed source vector type (computed
1909/// using assigned layout).
1910struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1911 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1912 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1913 PatternRewriter &rewriter) const override {
1914 OpOperand *operand =
1915 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1916 if (!operand)
1917 return rewriter.notifyMatchFailure(
1918 warpOp, "warp result is not a vector::BitCast op");
1919 auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1920 unsigned operandIdx = operand->getOperandNumber();
1921 VectorType distributedSourceType =
1922 getDistVecTypeBasedOnLaneLayout(
1923 xegpu::getTemporaryLayout(bitcastOp->getOpOperand(0)),
1924 bitcastOp.getSourceVectorType())
1925 .value_or(VectorType());
1926 if (!distributedSourceType)
1927 return rewriter.notifyMatchFailure(
1928 bitcastOp, "Failed to distribute the source vector type in "
1929 "vector::BitCast op");
1930 VectorType distributedResultType =
1931 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1932 SmallVector<size_t> newRetIndices;
1933 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1934 rewriter, warpOp, bitcastOp.getSource(),
1935 TypeRange{distributedSourceType}, newRetIndices);
1936 rewriter.setInsertionPointAfter(newWarpOp);
1937 auto newBitcastOp = vector::BitCastOp::create(
1938 rewriter, newWarpOp.getLoc(), distributedResultType,
1939 newWarpOp.getResult(newRetIndices[0]));
1940 Value distributedVal = newWarpOp.getResult(operandIdx);
1941 rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1942 return success();
1943 }
1944};
1945
1946/// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1947/// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1948/// supported. In most cases, transpose is a no op because it is entirely
1949/// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1950/// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1951/// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1952/// vector::TransposeOp outside of the warp op with distributed source vector
1953/// type (computed using assigned layout).
1954struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1955 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1956 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1957 PatternRewriter &rewriter) const override {
1958 OpOperand *operand =
1959 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1960 if (!operand)
1961 return rewriter.notifyMatchFailure(
1962 warpOp, "warp result is not a vector::Transpose op");
1963 auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1964 unsigned operandIdx = operand->getOperandNumber();
1965 xegpu::DistributeLayoutAttr sourceLayout =
1966 xegpu::getTemporaryLayout(transposeOp->getOpOperand(0));
1967 xegpu::DistributeLayoutAttr resultLayout =
1968 xegpu::getTemporaryLayout(transposeOp->getOpResult(0));
1969 if (!sourceLayout || !resultLayout)
1970 return rewriter.notifyMatchFailure(
1971 transposeOp,
1972 "the source or result vector of the transpose op lacks layout "
1973 "attribute");
1974 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1975 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1976 // Only 2D transposes are supported for now.
1977 // TODO: Support nD transposes.
1978 if (sourceRank != 2 || resultRank != 2)
1979 return rewriter.notifyMatchFailure(
1980 transposeOp, "the source or result vector of the transpose op "
1981 "does not have 2D layout");
1982 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1983 // Result layout must be a transpose of source layout.
1984 if (!resultLayout.isTransposeOf(sourceLayout, perm,
1986 return rewriter.notifyMatchFailure(
1987 transposeOp,
1988 "the source or result vector layouts must be 2D transposes of each "
1989 "other");
1990 FailureOr<VectorType> distributedSourceTypeOrFailure =
1991 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1992 transposeOp.getSourceVectorType());
1993 if (failed(distributedSourceTypeOrFailure))
1994 return rewriter.notifyMatchFailure(
1995 transposeOp, "Failed to distribute the source vector type in "
1996 "vector::Transpose op");
1997 SmallVector<size_t> newRetIndices;
1998 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1999 rewriter, warpOp, transposeOp.getVector(),
2000 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2001 rewriter.setInsertionPointAfter(newWarpOp);
2002 auto newTransposeOp = vector::TransposeOp::create(
2003 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2004 perm);
2005 Value distributedVal = newWarpOp.getResult(operandIdx);
2006 rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
2007 return success();
2008 }
2009};
2010
2011/// Distribute a vector::StepOp with the sliced result layout.
2012/// The sliced layout must have exactly 1 effective lane dimension.
2013/// We completely resolve the vector::StepOp by computing the lane_data-sized
2014/// subranges.
2015struct VectorStepSliceDistribution final : public gpu::WarpDistributionPattern {
2016 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2017 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2018 PatternRewriter &rewriter) const override {
2019 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2020 if (!operand)
2021 return rewriter.notifyMatchFailure(
2022 warpOp, "warp result is not a vector::StepOp op");
2023 auto stepOp = operand->get().getDefiningOp<vector::StepOp>();
2024 unsigned operandIdx = operand->getOperandNumber();
2025 xegpu::DistributeLayoutAttr resultLayout =
2026 xegpu::getTemporaryLayout(stepOp->getResult(0));
2027 if (!resultLayout)
2028 return rewriter.notifyMatchFailure(
2029 stepOp, "the result vector of the step op lacks layout "
2030 "attribute");
2031 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2032 if (!sliceLayout)
2033 return rewriter.notifyMatchFailure(
2034 stepOp, "the result layout must be a slice layout");
2035 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2036 return rewriter.notifyMatchFailure(
2037 stepOp, "expecting 1 dim in the effective result layout");
2038
2039 rewriter.setInsertionPointAfter(warpOp);
2040 auto loc = stepOp.getLoc();
2041 auto stepResultVecTy = stepOp.getResult().getType();
2042 Value distributedVal = warpOp.getResult(operandIdx);
2043 VectorType newVecTy = cast<VectorType>(distributedVal.getType());
2044
2045 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2046 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2047 if (failed(laneDataBlockCoords))
2048 return rewriter.notifyMatchFailure(
2049 stepOp, "failed to compute lane data block coordinates");
2050
2051 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2052 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2053 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
2054 newVecTy.getNumElements() / laneDataBlockLength);
2055 SmallVector<Value> stepVals;
2056 // For each lane_data block, reconstruct its sub-range
2057 // from the range of SG-level vector.step. Example: vector.step
2058 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
2059 // vector<16xindex>
2060 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
2061 // The blocks are round-robin distributed, so logical lane id 0
2062 // holds values [0,1, 8,9].
2063 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2064 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2065 stepVals.push_back(laneDataBlockStartCoord);
2066 for (int i = 1; i < laneDataBlockLength; ++i) {
2067 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
2068 stepVals.push_back(arith::AddIOp::create(
2069 rewriter, loc, laneDataBlockStartCoord, offset));
2070 }
2071 }
2072 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
2073 "Expecting the number of step values to match the number of "
2074 "elements in the vector");
2075 auto stepOpVal =
2076 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2077 rewriter.replaceAllUsesWith(distributedVal, stepOpVal);
2078 return success();
2079 }
2080};
2081
2082struct ConvertLayoutDistribution
2083 : public OpRewritePattern<xegpu::ConvertLayoutOp> {
2085
2086 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
2087 PatternRewriter &rewriter) const override {
2088 auto inputLayout = op.getInputLayoutAttr();
2089 auto targetLayout = op.getTargetLayoutAttr();
2090 Type valType = op.getResult().getType();
2091
2092 if (!inputLayout || !targetLayout)
2093 return rewriter.notifyMatchFailure(op, "missing layout attributes");
2094
2095 if (valType.isIntOrFloat()) {
2096 rewriter.replaceOp(op, op.getSource());
2097 return success();
2098 }
2099 auto resShape = cast<VectorType>(valType).getShape();
2100 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
2101 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
2103 return rewriter.notifyMatchFailure(
2104 op, "lowering incompatible convert_layout not yet supported");
2105 }
2106 rewriter.replaceOp(op, op.getSource());
2107 return success();
2108 }
2109};
2110
2111} // namespace
2112
2113namespace {
2114struct XeGPUSubgroupDistributePass final
2115 : public xegpu::impl::XeGPUSubgroupDistributeBase<
2116 XeGPUSubgroupDistributePass> {
2117 void runOnOperation() override;
2118};
2119} // namespace
2120
2122 RewritePatternSet &patterns) {
2123 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2124 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2125 GpuBarrierDistribution, VectorMultiReductionDistribution,
2126 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2127 VectorBitcastDistribution, LoadMatrixDistribution,
2128 StoreMatrixDistribution, ConvertLayoutDistribution,
2129 MemrefExtractAlignedPointerAsIndexDistribution>(
2130 patterns.getContext(),
2131 /*pattern benefit=*/PatternHierarchy::Regular);
2132 // For following patterns, we need to override the regular vector distribution
2133 // patterns. Therefore, assign higher benefit.
2134 patterns
2135 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2136 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2137 VectorStepSliceDistribution, SinkUniformOps>(
2138 patterns.getContext(),
2139 /*pattern benefit=*/PatternHierarchy::AboveRegular);
2140}
2141
2143 RewritePatternSet &patterns) {
2144 patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
2145}
2146
2147void XeGPUSubgroupDistributePass::runOnOperation() {
2148 // Step 1: Attach layouts to op operands.
2149 // TODO: Following assumptions are made:
2150 // 1) It is assumed that there are no layout conflicts.
2151 // 2) Any existing layout attributes attached to the operands are ignored.
2152 Operation *op = getOperation();
2154 signalPassFailure();
2155 return;
2156 }
2157
2158 // Step 2: Move all operations of a GPU function inside
2159 // gpu.warp_execute_on_lane_0 operation.
2160 {
2161 RewritePatternSet patterns(&getContext());
2163
2164 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2165 signalPassFailure();
2166 return;
2167 }
2168 // At this point, we have moved the entire function body inside the
2169 // warpOp. Now move any scalar uniform code outside of the warpOp (like
2170 // GPU index ops, scalar constants, etc.). This will simplify the
2171 // later lowering and avoid custom patterns for these ops.
2172 getOperation()->walk([&](Operation *op) {
2173 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2174 vector::moveScalarUniformCode(warpOp);
2175 });
2176 }
2177 // Step 3: Apply subgroup to workitem distribution patterns.
2178 RewritePatternSet patterns(&getContext());
2180 // distributionFn is used by vector distribution patterns to determine the
2181 // distributed vector type for a given vector value. In XeGPU subgroup
2182 // distribution context, we compute this based on lane layout.
2183 auto distributionFn = [](Value val) {
2184 VectorType vecType = dyn_cast<VectorType>(val.getType());
2185 int64_t vecRank = vecType ? vecType.getRank() : 0;
2186 if (vecRank == 0)
2187 return AffineMap::get(val.getContext());
2188 // Get the layout of the vector type.
2189 xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
2190 // If no layout is specified, assume uniform case (no distribution).
2191 if (!layout)
2192 return AffineMap::get(val.getContext());
2193 // Expecting vector and layout rank to match.
2194 assert(layout.getRank() == vecRank &&
2195 "Expecting vector and layout rank to match");
2196 // A dimension is distributed only if layout suggests there are
2197 // multiple lanes assigned for this dimension and the shape can be evenly
2198 // distributed to those lanes.
2199 SmallVector<unsigned int> distributedDims;
2200 for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2201 if (v > 1 && vecType.getShape()[i] % v == 0)
2202 distributedDims.push_back(i);
2203 }
2204 return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
2205 val.getContext());
2206 };
2207 // TODO: shuffleFn is not used.
2208 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2209 int64_t warpSz) { return Value(); };
2210
2211 vector::populateDistributeReduction(
2212 patterns, xegpu::subgroupReduction,
2213 /*pattern benefit=*/PatternHierarchy::Regular);
2214
2215 vector::populatePropagateWarpVectorDistributionPatterns(
2216 patterns, distributionFn, shuffleFn,
2217 /*pattern benefit=*/PatternHierarchy::Regular);
2218 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2219 signalPassFailure();
2220 return;
2221 }
2222
2223 // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
2224 // due to tensor desc type mismatches created by using upstream distribution
2225 // patterns (scf.for). This cleanup should only be done if all the ops are
2226 // distributed successfully, if some ops are still not distributed and remains
2227 // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
2228 // breaking the IR.
2229 bool foundWarpOp = false;
2230 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2231 // Look for WarpOps that are not trivially dead.
2232 if (isOpTriviallyDead(warpOp))
2233 return WalkResult::advance();
2234 foundWarpOp = true;
2235 return WalkResult::interrupt();
2236 });
2237 if (foundWarpOp)
2238 return;
2239
2240 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2241 // We are only interested in UnrealizedConversionCastOps there were added
2242 // for resolving SIMT type mismatches.
2243 if (!op->getAttr(resolveSIMTTypeMismatch))
2244 return WalkResult::skip();
2245
2246 Value input = op.getOperand(0);
2247 Value output = op.getResult(0);
2248
2249 // Both input and output must have tensor descriptor types.
2250 xegpu::TensorDescType inputDescType =
2251 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
2252 xegpu::TensorDescType outputDescType =
2253 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
2254 assert(inputDescType && outputDescType &&
2255 "Unrealized conversion cast must have tensor descriptor types");
2256
2257 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
2258 // This occurs inside scf.for body to resolve the block argument type to
2259 // SIMT type.
2260 if (inputDescType.getLayout()) {
2261 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2262 if (argument) {
2263 argument.setType(output.getType());
2264 output.replaceAllUsesWith(argument);
2265 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2266 argument.getOwner()->getParentOp())) {
2267 auto result = loopOp.getTiedLoopResult(argument);
2268 result.setType(output.getType());
2269 }
2270 }
2271 }
2272
2273 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
2274 // conversions. This occurs at the yield op of scf.for body to go back
2275 // from SIMT type to original type.
2276 if (outputDescType.getLayout())
2277 output.replaceAllUsesWith(input);
2278
2279 if (op->use_empty())
2280 op->erase();
2281 return WalkResult::advance();
2282 });
2283}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
operand_type_range getOperandTypes()
Definition Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0