MLIR 23.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
27#include "mlir/IR/Operation.h"
29#include "mlir/IR/TypeRange.h"
30#include "mlir/IR/Value.h"
31#include "mlir/IR/Visitors.h"
33#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53
54static const char *const resolveSIMTTypeMismatch =
55 "resolve_simt_type_mismatch"; // Attribute name for identifying
56 // UnrelizedConversionCastOp added to resolve
57 // SIMT type mismatches.
58
59namespace {
60
61//===----------------------------------------------------------------------===//
62// SIMT Distribution Patterns
63//===----------------------------------------------------------------------===//
64
65/// In certain cases, we may need to favor XeGPU specific distribution patterns
66/// over generic vector distribution patterns. In such cases, we can assign
67/// priorities to patterns.
68enum PatternHierarchy : unsigned { Regular = 1, AboveRegular = 2 };
69
70/// Helper function to resolve types if the distributed type out of
71/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
72/// Example 1:
73/// distributed type: vector<8x1xf32>
74/// expected type: vector<8xf32>
75/// resolved using,
76/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
77/// Example 2:
78/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
79/// expected type: xegpu.tensor_desc<8x16xf32>
80/// resolved using,
81/// %0 = unrealized_conversion_cast %1 :
82/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
83/// xegpu.tensor_desc<8x16xf32>
84template <typename T>
85static Value resolveDistributedTy(Value orig, T expected,
86 PatternRewriter &rewriter) {
87 // If orig and expected types are the same, return orig.
88 if (orig.getType() == expected)
89 return orig;
90 // If orig is a vector type, create a shape cast op to reconcile the types.
91 if (isa<VectorType>(orig.getType())) {
92 auto castOp =
93 vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
94 return castOp.getResult();
95 }
96 // If orig is a tensor descriptor type, create an unrealized conversion cast
97 // op to reconcile the types.
98 if (isa<xegpu::TensorDescType>(orig.getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
100 expected, orig);
101 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
102 return castOp.getResult(0);
103 }
104 llvm_unreachable("Unsupported type for reconciliation");
105 return orig;
106}
107
108/// Given a vector type and its distributed vector type, return the list of
109/// dimensions that are distributed.
110static SmallVector<int64_t> getDistributedDims(VectorType originalType,
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
114 SmallVector<int64_t> distributedDims;
115 for (int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
118 }
119 }
120 return distributedDims;
121}
122
123/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
124/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
125/// contained within a WarpExecuteOnLane0Op.
126/// Example:
127///
128/// ```
129/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
130/// ...
131/// ...
132/// gpu.return %result: vector<8x16xf32>
133/// }
134/// ```
135/// To
136/// ```
137/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
138/// %laneid = gpu.lane_id : index
139/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
140/// ...
141/// ...
142/// gpu.yield %result: vector<8x16xf32>
143/// }
144/// return %0
145/// }
146struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
147 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
148 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
149 PatternRewriter &rewriter) const override {
150 auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
151 if (!uArch)
152 return rewriter.notifyMatchFailure(
153 gpuFuncOp, "Subgroup distribution requires target attribute attached "
154 "to set the warp size");
155 // If the function only contains a single void return, skip.
156 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
157 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
158 }))
159 return failure();
160 // If the function already moved inside a warp_execute_on_lane0, skip.
161 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
162 return isa<gpu::WarpExecuteOnLane0Op>(op);
163 }))
164 return failure();
165 // Create a new function with the same signature and same attributes.
166 SmallVector<Type> workgroupAttributionsTypes =
167 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
168 [](BlockArgument arg) { return arg.getType(); });
169 SmallVector<Type> privateAttributionsTypes =
170 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
171 [](BlockArgument arg) { return arg.getType(); });
172 auto newGpuFunc = gpu::GPUFuncOp::create(
173 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
174 gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
175 privateAttributionsTypes);
176 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
177 // Create a WarpExecuteOnLane0Op with same arguments and results as the
178 // original gpuFuncOp.
179 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
180 auto laneId = gpu::LaneIdOp::create(
181 rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
182 /** upperBound = **/ mlir::IntegerAttr());
183 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
184 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
185 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
186 uArch->getSubgroupSize(), newGpuFunc.getArguments(),
187 newGpuFunc.getArgumentTypes());
188 Block &warpBodyBlock = warpOp.getBodyRegion().front();
189 // Replace the ReturnOp of the original gpu function with a YieldOp.
190 auto origRetunOp =
191 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
192 rewriter.setInsertionPointAfter(origRetunOp);
193 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
194 origRetunOp.getOperands());
195 rewriter.eraseOp(origRetunOp);
196 // Move the original function body to the WarpExecuteOnLane0Op body.
197 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
198 warpOp.getBodyRegion().begin());
199 rewriter.eraseBlock(&warpBodyBlock);
200 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
201 rewriter.setInsertionPointAfter(warpOp);
202 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
203 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
204 return success();
205 }
206};
207
208/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
209/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
210/// still contain the original op that will not be used by the yield op (and
211/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
212/// arguments. Tensor descriptor shape is not distributed because it is a
213/// uniform value across all work items within the subgroup. However, the
214/// layout information is dropped in the new tensor descriptor type.
215///
216/// Example:
217///
218/// ```
219/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
220/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
221/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
222/// ...
223/// %td = xegpu.create_nd_tdesc %arg0
224/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
225/// vector.yield %td
226/// }
227/// ```
228/// To
229/// ```
230/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
231/// ...
232/// %dead = xegpu.create_nd_tdesc %arg0
233/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
234/// vector.yield %arg0, %dead
235/// }
236/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
237/// -> !xegpu.tensor_desc<4x8xf32>
238///
239/// ```
240struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
241 using gpu::WarpDistributionPattern::WarpDistributionPattern;
242 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
243 PatternRewriter &rewriter) const override {
244 OpOperand *operand =
245 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
246 if (!operand)
247 return rewriter.notifyMatchFailure(
248 warpOp, "warp result is not a xegpu::CreateNdDesc op");
249 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
250 unsigned operandIdx = operand->getOperandNumber();
251
252 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
253 if (!layout)
254 return rewriter.notifyMatchFailure(
255 descOp, "the tensor descriptor lacks layout attribute");
256 // CreateNdOp must not have offsets.
257 if (descOp.getMixedOffsets().size())
258 return rewriter.notifyMatchFailure(
259 descOp, "xegpu::CreateNdDescOp must not have offsets");
260
261 SmallVector<size_t> newRetIndices;
262 rewriter.setInsertionPoint(warpOp);
263 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
264 rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
265 /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
266
267 SmallVector<Value> newDescOperands = llvm::map_to_vector(
268 newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
269 rewriter.setInsertionPointAfter(newWarpOp);
270 xegpu::TensorDescType distributedTensorDescTy =
271 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
272 // does not contain layout info.
273 Value newDescOp = xegpu::CreateNdDescOp::create(
274 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
275 descOp->getAttrs());
276
277 Value distributedVal = newWarpOp.getResult(operandIdx);
278 // Resolve the distributed type to the expected type.
279 newDescOp =
280 resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
281 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
282 return success();
283 }
284};
285
286/// Distribute a store_nd op at the end of enclosing
287/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
288/// through the warp op interface they would be propagated as returned values.
289/// Source vector is distributed based on lane layout. Appropriate cast ops are
290/// inserted if the distributed types does not match expected xegpu SIMT types.
291///
292/// Example:
293///
294/// ```
295/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
296/// gpu.warp_execute_on_lane_0(%laneid) -> () {
297/// ...
298/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
299/// !xegpu.tensor_desc<4x8xf32, #layout0>
300/// }
301/// ```
302/// To
303/// ```
304/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
305/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
306/// ...
307/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
308/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
309/// }
310/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
311/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
312/// #layout0>
313/// -> !xegpu.tensor_desc<4x8xf32>
314/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
315/// !xegpu.tensor_desc<4x8xf32>
316///
317/// ```
318struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
319 using gpu::WarpDistributionPattern::WarpDistributionPattern;
320 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
321 PatternRewriter &rewriter) const override {
322 gpu::YieldOp yield = warpOp.getTerminator();
323 Operation *lastNode = yield->getPrevNode();
324 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
325 if (!storeOp)
326 return failure();
327
328 SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
329 // Expecting offsets to be present.
330 if (offsets.empty())
331 return rewriter.notifyMatchFailure(storeOp,
332 "the store op must have offsets");
333 SmallVector<Value> offsetsAsValues =
334 vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
335 SmallVector<Type> offsetTypes = llvm::map_to_vector(
336 offsetsAsValues, [](Value v) { return v.getType(); });
337 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
338 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
339 if (!layout)
340 return rewriter.notifyMatchFailure(
341 storeOp, "the source tensor descriptor lacks layout attribute");
342
343 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
344 xegpu::getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
345 if (failed(distributedTypeByWarpOpOrFailure))
346 return rewriter.notifyMatchFailure(storeOp,
347 "Failed to distribute the type");
348 VectorType distributedTypeByWarpOp =
349 distributedTypeByWarpOpOrFailure.value();
350
351 SmallVector<size_t> newRetIndices;
352 SmallVector<Value> newYieldedValues = {storeOp.getValue(),
353 storeOp.getTensorDesc()};
354 SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
355 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
356 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
357 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
358 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
359 // Create a new store op outside the warp op with the distributed vector
360 // type. Tensor descriptor is not distributed.
361 rewriter.setInsertionPointAfter(newWarpOp);
362 SmallVector<Value> newStoreOperands;
363
364 // For the value operand, there can be a mismatch between the vector type
365 // distributed by the warp op and (xegpu-specific) distributed type
366 // supported by the store op. Type mismatch must be resolved using
367 // appropriate cast op.
368 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
369 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
370 if (failed(storeNdDistributedValueTyOrFailure))
371 return rewriter.notifyMatchFailure(
372 storeOp, "Failed to get distributed vector type for the store op");
373 newStoreOperands.push_back(resolveDistributedTy(
374 newWarpOp.getResult(newRetIndices[0]),
375 storeNdDistributedValueTyOrFailure.value(), rewriter));
376 // For the tensor descriptor operand, the layout attribute is dropped after
377 // distribution. Types needs to be resolved in this case also.
378 xegpu::TensorDescType distributedTensorDescTy =
379 storeOp.getTensorDescType().dropLayouts();
380 newStoreOperands.push_back(
381 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
382 distributedTensorDescTy, rewriter));
383 // Collect offsets.
384 for (size_t i = 2; i < newRetIndices.size(); ++i)
385 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
386
387 auto newStoreOp =
388 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
389 newStoreOperands, storeOp->getAttrs());
391 rewriter.eraseOp(storeOp);
392 return success();
393 }
394};
396/// Distribute a load_nd op feeding into vector.yield op for the enclosing
397/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
398/// The warp op will still contain the original op that will not be used by
399/// the yield op (and should be cleaned up later). The yield op will
400/// bypass the load's arguments. Only the loaded vector is distributed
401/// according to lane layout and, tensor descriptor types is not
402/// distributed. Appropriate cast ops are inserted if the distributed types does
403/// not match expected xegpu SIMT types.
404///
405/// Example:
406///
407/// ```
408/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
409/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
410/// (vector<4x1xf32>) {
411/// ...
412/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
413/// ->
414/// vector<4x8xf32>
415/// gpu.yield %ld
416/// }
417/// ```
418/// To
419/// ```
420/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
421/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
422/// ...
423/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
424/// vector<4x8xf32> gpu.yield %dead, %arg0
425/// }
426/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
427/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
428/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
429/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
430///
431/// ```
432struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
433 using gpu::WarpDistributionPattern::WarpDistributionPattern;
434 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
435 PatternRewriter &rewriter) const override {
436 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
437 if (!isa<xegpu::LoadNdOp>(op))
438 return false;
439 // Make sure the same load op is the last operation in the warp op body.
440 // This ensure that load op is not sinked earlier violating any barrier
441 // synchronizations.
442 gpu::YieldOp yield = warpOp.getTerminator();
443 return yield->getPrevNode() == op;
444 });
445
446 if (!operand)
447 return rewriter.notifyMatchFailure(
448 warpOp, "warp result is not a xegpu::LoadNd op");
449
450 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
451 auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
452 if (!uArch)
453 return rewriter.notifyMatchFailure(
454 loadOp, "xegpu::LoadNdOp require target attribute attached to "
455 "determine transpose "
456 "requirement");
457 // Chip information is required to decide if the layout requires transpose
458 // effect.
459 // Expecting offsets to be present.
460 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
461 if (offsets.empty())
462 return rewriter.notifyMatchFailure(loadOp,
463 "the load op must have offsets");
464 SmallVector<Value> offsetsAsValues =
465 vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
466 SmallVector<Type> offsetTypes = llvm::map_to_vector(
467 offsetsAsValues, [](Value v) { return v.getType(); });
468
469 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
470 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
471 if (!layout)
472 return rewriter.notifyMatchFailure(
473 loadOp, "the source tensor descriptor lacks layout attribute");
474
475 unsigned operandIdx = operand->getOperandNumber();
476 VectorType distributedTypeByWarpOp =
477 cast<VectorType>(warpOp.getResult(operandIdx).getType());
478
479 SmallVector<size_t> newRetIndices;
480 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
481 SmallVector<Type> newYieldedTypes = {tensorDescTy};
482 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
483 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
484 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
485 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
486
487 // Create a new load op outside the warp op with the distributed vector
488 // type.
489 rewriter.setInsertionPointAfter(newWarpOp);
490 FailureOr<VectorType> loadNdDistValueTyOrFailure =
491 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
492 if (failed(loadNdDistValueTyOrFailure))
493 return rewriter.notifyMatchFailure(
494 loadOp, "Failed to get distributed vector type for the load op");
495 xegpu::TensorDescType distributedTensorDescTy =
496 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
497 // descriptor type does not
498 // contain layout info.
499 SmallVector<Value> newLoadOperands{
500 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
501 distributedTensorDescTy, rewriter)};
502 // Collect offsets.
503 for (size_t i = 1; i < newRetIndices.size(); ++i)
504 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
505 auto newLoadOp = xegpu::LoadNdOp::create(
506 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
507 newLoadOperands, loadOp->getAttrs());
508 xegpu::removeLayoutAttrs(newLoadOp);
509 // Set the packed attribute if the layout requires it.
510 newLoadOp.setPacked(xegpu::requirePacked(layout));
511 // Set the transpose attribute if the layout requires it.
512 if (xegpu::requireTranspose(layout, uArch))
513 newLoadOp.setTranspose(
514 DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
515 Value distributedVal = newWarpOp.getResult(operandIdx);
516 // There can be a conflict between the vector type distributed by the
517 // warp op and (xegpu-specific) distributed type supported by the load
518 // op. Resolve these mismatches by inserting a cast.
519 Value tyResolvedVal = resolveDistributedTy(
520 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
521 rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
522 return success();
523 }
524};
525
526/// Distribute a dpas op feeding into vector.yield op for the enclosing
527/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
528/// The warp op will still contain the original op that will not be used by
529/// the yield op (and should be cleaned up later). The yield op will
530/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
531/// distributed types does not match expected xegpu SIMT types.
532/// Example:
533/// ```
534/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
535/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
536/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
537/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
538/// (vector<8x1xf32>) {
539/// ...
540/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
541/// vector<8x16xf32>
542/// gpu.yield %dpas
543/// }
544/// ```
545/// To
546/// ```
547/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
548/// vector<8x1xf16>, vector<16x1xf16>) {
549/// ...
550/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
551/// -> vector<8x16xf32>
552/// gpu.yield %dead, %arg0, %arg1
553/// }
554/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
555/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
556/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
557/// vector<8xf32>
558/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
559/// ```
560struct DpasDistribution final : public gpu::WarpDistributionPattern {
561 using gpu::WarpDistributionPattern::WarpDistributionPattern;
562 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
563 PatternRewriter &rewriter) const override {
564 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
565 if (!operand)
566 return rewriter.notifyMatchFailure(warpOp,
567 "warp result is not a xegpu::Dpas op");
568
569 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
570 unsigned operandIdx = operand->getOperandNumber();
571
572 xegpu::LayoutAttr layoutA =
573 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
574 xegpu::LayoutAttr layoutB =
575 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
576 xegpu::LayoutAttr layoutOut =
577 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
578
579 if (!layoutA || !layoutB || !layoutOut)
580 return rewriter.notifyMatchFailure(
581 dpasOp,
582 "the xegpu::Dpas op lacks layout attribute for A, B or output");
583
584 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
585 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
586 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
587 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
588 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
589 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
590
591 if (failed(distLhsTypeByWarpOpOrFailure) ||
592 failed(distRhsTypeByWarpOpOrFailure) ||
593 failed(distResultTypeByWarpOpOrFailure))
594 return rewriter.notifyMatchFailure(
595 dpasOp,
596 "Failed to distribute the A, B or output types in xegpu::Dpas op");
597
598 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
599 dpasOp.getRhs()};
600 llvm::SmallVector<Type, 3> newYieldTypes{
601 distLhsTypeByWarpOpOrFailure.value(),
602 distRhsTypeByWarpOpOrFailure.value()};
603 // Dpas acc operand is optional.
604 if (dpasOp.getAcc()) {
605 newYieldValues.push_back(dpasOp.getAcc());
606 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
607 }
608 // Create a new warp op without the dpas.
609 SmallVector<size_t> newRetIndices;
610 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
611 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
612
613 FailureOr<VectorType> expectedDistLhsTyOrFailure =
614 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
615 FailureOr<VectorType> expectedDistRhsTyOrFailure =
616 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
617 FailureOr<VectorType> expectedDistResultTyOrFailure =
618 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
619
620 if (failed(expectedDistLhsTyOrFailure) ||
621 failed(expectedDistRhsTyOrFailure) ||
622 failed(expectedDistResultTyOrFailure))
623 return rewriter.notifyMatchFailure(
624 dpasOp,
625 "Failed to get distributed vector type for the dpas operands.");
626 // Create a new dpas op outside the warp op.
627 rewriter.setInsertionPointAfter(newWarpOp);
628 SmallVector<Value> newDpasOperands;
629 SmallVector<VectorType> newDpasOperandExpectedTypes;
630
631 // Resolve the distributed types with the original types.
632 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
633 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
634 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
635 if (dpasOp.getAcc())
636 newDpasOperandExpectedTypes.push_back(distributedResultTy);
637
638 for (unsigned i = 0; i < newRetIndices.size(); i++) {
639 newDpasOperands.push_back(
640 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
641 newDpasOperandExpectedTypes[i], rewriter));
642 }
643 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
644 distributedResultTy, newDpasOperands,
645 dpasOp->getAttrs());
646 xegpu::removeLayoutAttrs(newDpasOp);
647 Value distributedVal = newWarpOp.getResult(operandIdx);
648 // Resolve the output type.
649 Value typeResolved =
650 resolveDistributedTy(newDpasOp.getResult(),
651 distResultTypeByWarpOpOrFailure.value(), rewriter);
652 rewriter.replaceAllUsesWith(distributedVal, typeResolved);
653 return success();
654 }
655};
656
657/// Distribute a prefetch_nd op at the end of enclosing
658/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
659/// through the warp op interface they would be propagated as returned values.
660/// Tensor descriptor shape is not distributed because it is a uniform value
661/// across all work items within the subgroup. Appropriate cast ops are inserted
662/// if the distributed types does not match expected xegpu SIMT types.
663///
664/// Example:
665///
666/// ```
667/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
668/// gpu.warp_execute_on_lane_0(%laneid) -> () {
669/// ...
670/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
671/// }
672/// ```
673/// To
674/// ```
675/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
676/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
677/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
678/// index
679/// }
680/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
681/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
682/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
683///
684/// ```
685struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
686 using gpu::WarpDistributionPattern::WarpDistributionPattern;
687 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
688 PatternRewriter &rewriter) const override {
689 gpu::YieldOp yield = warpOp.getTerminator();
690 Operation *lastNode = yield->getPrevNode();
691 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
692 if (!prefetchOp)
693 return failure();
694
695 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
696 // PrefetchNdOp must have offsets.
697 if (offsets.empty())
698 return rewriter.notifyMatchFailure(prefetchOp,
699 "the prefetch op must have offsets");
700 SmallVector<Value> offsetsAsValues =
701 vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
702 SmallVector<Type> offsetTypes = llvm::map_to_vector(
703 offsetsAsValues, [](Value v) { return v.getType(); });
704
705 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
706 if (!layout)
707 return rewriter.notifyMatchFailure(
708 prefetchOp, "the source tensor descriptor lacks layout attribute");
709
710 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
711 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
712 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
713 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
714 SmallVector<size_t> newRetIndices;
715 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
716 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
717 // Create a new prefetch op outside the warp op with updated tensor
718 // descriptor type. Source tensor descriptor require type resolution.
719 xegpu::TensorDescType newTensorDescTy =
720 prefetchOp.getTensorDescType().dropLayouts();
721 rewriter.setInsertionPointAfter(newWarpOp);
722 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
723 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
724 // Collect offsets.
725 for (size_t i = 1; i < newRetIndices.size(); ++i)
726 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
727 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
728 rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
729 prefetchOp->getAttrs());
730 xegpu::removeLayoutAttrs(newPrefetchOp);
731 rewriter.eraseOp(prefetchOp);
732 return success();
733 }
734};
735
736/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
737/// region. This will simply move the barrier op outside of the warp op.
738struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
739 using gpu::WarpDistributionPattern::WarpDistributionPattern;
740 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
741 PatternRewriter &rewriter) const override {
742 gpu::YieldOp yield = warpOp.getTerminator();
743 Operation *lastNode = yield->getPrevNode();
744 // The last node must be a gpu::BarrierOp.
745 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
746 if (!barrierOp)
747 return failure();
748 // Move the barrier op outside of the warp op.
749 rewriter.setInsertionPointAfter(warpOp);
750 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
751 barrierOp->getResultTypes(),
752 barrierOp->getOperands(), barrierOp->getAttrs());
753 rewriter.eraseOp(barrierOp);
754 return success();
755 }
756};
757
758/// Distribute a scattered store op. The offsets argument is required.
759/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
760/// The layouts are fixed and implicit: one offset/mask per lane.
761/// The pass changes the offset/mask vector shapes to a
762/// single-element vector, **it is assumed that their producer will also be
763/// distributed**. The payload vector also has a fixed distribution:
764/// no chunk size -> vector of one element.
765/// chunk size -> vector of the innermost dimension of the SG-payload.
766/// Example 1 (no chunk size):
767/// %mask = producer_op : vector<16xi1>
768/// %offset = producer_op : vector<16xindex>
769/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
770/// memref<256xf16>, vector<16xindex>, vector<16xi1>
771/// To
772/// %mask = producer_op : vector<1xi1>
773/// %offset = producer_op : vector<1xindex>
774/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
775/// memref<256xf16>, vector<1xindex>, vector<1xi1>
776/// Example 2 (chunk size, same mask and offsets):
777/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
778/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
779/// To
780/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
781/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
782///
783/// Note that the store distribution pattern also handles leading unit
784/// dimensions in the payload, mask and offsets vectors. In this case the store
785/// distribution will only change the dimensions corresponding to the SG
786/// distribution and keep the leading unit dimensions unchanged.
787/// For example, a store with payload vector<1x16xf16> with lane layout [1, 16 ]
788/// will be distributed as vector<1x1xf16>. Shapecast ops are inserted for the
789/// offset/mask/payload when necessary so that the distributed store is workign
790/// on 1D shape vector to match the HW capability.
791struct StoreDistribution final : public gpu::WarpDistributionPattern {
792 using gpu::WarpDistributionPattern::WarpDistributionPattern;
793 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
794 PatternRewriter &rewriter) const override {
795 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
796 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
797 if (!storeScatterOp)
798 return failure();
799 auto offsets = storeScatterOp.getOffsets();
800 if (!offsets || !isa<VectorType>(offsets.getType()))
801 return rewriter.notifyMatchFailure(
802 storeScatterOp, "Store op must have a vector of offsets argument");
803 VectorType offsetsTy = cast<VectorType>(offsets.getType());
804 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
805 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
806
807 // Add handling for leading unit dimensions support
808 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
809 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
810
811 // Check that all leading dimensions are unit dimensions
812 for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
813 if (storeVecTy.getShape()[i] != 1) {
814 return rewriter.notifyMatchFailure(
815 storeScatterOp, "Only unit dimensions allowed for the leading "
816 "dimensions of the store vector!");
817 }
818 }
819
820 auto layoutPayload =
821 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
822 auto layoutOffsets =
823 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
824 auto layoutMask =
825 xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
826
827 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
828 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
829 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
830 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
831 FailureOr<VectorType> distMaskByWarpOpOrFailure =
832 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
833 if (failed(distStoreVecByWarpOpOrFailure) ||
834 failed(distOffsetsByWarpOpOrFailure) ||
835 failed(distMaskByWarpOpOrFailure)) {
836 return rewriter.notifyMatchFailure(
837 storeScatterOp,
838 "Some vector operands have no layouts, using defaults instead.");
839 }
840
841 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
842 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
843 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
844
845 SmallVector<size_t> newRetIndices;
846 SmallVector<Value> operands = storeScatterOp->getOperands();
847 SmallVector<Type> operandTypesToYield = {
848 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
849
850 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
851 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
852
853 rewriter.setInsertionPointAfter(newWarpOp);
854
855 // Distributed store payload type is always 1D without leading unit dims
856 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
857 distPayloadTy.getElementType());
858
859 VectorType distOffsetsTy1D = VectorType::get(
860 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
861 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
862 distMaskTy.getElementType());
863
864 // Resolve distributed types to 1D for SIMT execution
865 Value distPayloadVal = resolveDistributedTy(
866 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
867 Value distOffsetVal = resolveDistributedTy(
868 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
869 Value distMaskVal = resolveDistributedTy(
870 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
871
872 SmallVector<Value> newStoreScatterOpOperands = {
873 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
874 distMaskVal};
875
876 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
877 rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
878 storeScatterOp->getAttrs());
880 rewriter.eraseOp(storeScatterOp);
881 return success();
882 }
883};
884
885static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
886 PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
887 Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
888 SmallVector<Value> newCoods;
889 auto maybeCoords =
890 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
891 if (failed(maybeCoords))
892 return {};
893 assert(maybeCoords.value().size() == 1 &&
894 "Expected one set of distributed offsets");
896 rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
897 getAsOpFoldResult(origOffsets));
898 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
899 return newCoods;
900}
901
902/// Pattern for distributing xegpu::LoadMatrixOp.
903struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
904 using gpu::WarpDistributionPattern::WarpDistributionPattern;
905 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
906 PatternRewriter &rewriter) const override {
907 gpu::YieldOp yield = warpOp.getTerminator();
908 Operation *lastNode = yield->getPrevNode();
909 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
910 if (!matrixOp)
911 return failure();
912
913 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
914 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
915 });
916 if (!producedByLastLoad)
917 return rewriter.notifyMatchFailure(
918 warpOp, "The last op is not xegpu::LoadMatrixOp");
919 const int operandIdx = producedByLastLoad->getOperandNumber();
920
921 VectorType sgPayloadTy =
922 dyn_cast<VectorType>(matrixOp.getResult().getType());
923 VectorType warpResultTy =
924 cast<VectorType>(warpOp.getResult(operandIdx).getType());
925 if (!sgPayloadTy)
926 return rewriter.notifyMatchFailure(
927 matrixOp, "the matrix op payload must be a vector type");
928
929 auto loc = matrixOp.getLoc();
930 auto offsets = matrixOp.getMixedOffsets();
931 if (offsets.empty())
932 return rewriter.notifyMatchFailure(matrixOp,
933 "the load op must have offsets");
934 SmallVector<Value> offsetsAsValues =
935 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
936
937 auto layout = matrixOp.getLayoutAttr();
938 if (!layout)
939 return rewriter.notifyMatchFailure(
940 matrixOp, "the matrix operation lacks layout attribute");
941
942 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
943 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
944 if (failed(distPayloadByWarpOpOrFailure))
945 return rewriter.notifyMatchFailure(
946 matrixOp, "Failed to distribute matrix op payload based on layout.");
947
948 SmallVector<Value> operands = {matrixOp.getMemDesc()};
949 const unsigned offsetsStartIdx = operands.size();
950 operands.append(offsetsAsValues);
951
952 SmallVector<Type> operandTypes =
953 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
954
955 SmallVector<size_t> newRetIndices;
956 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
957 rewriter, warpOp, operands, operandTypes, newRetIndices);
958 SmallVector<Value> newOperands = llvm::map_to_vector(
959 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
960
961 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
962 ShapedType::kDynamic);
963 DenseI64ArrayAttr newConstOffsetsAttr =
964 rewriter.getDenseI64ArrayAttr(newConstOffsets);
965 ValueRange currentOffsets =
966 ValueRange(newOperands).drop_front(offsetsStartIdx);
967
968 SmallVector<Value> newCoords = currentOffsets;
969 rewriter.setInsertionPointAfter(newWarpOp);
970
971 if (!matrixOp.getSubgroupBlockIoAttr()) {
972 newCoords = computeDistributedCoordinatesForMatrixOp(
973 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
974 currentOffsets);
975 }
976 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
977 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
978 newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
979 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
980 // Resolve the output type and replace all uses.
981 rewriter.replaceAllUsesWith(
982 newWarpOp.getResult(operandIdx),
983 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
984 return success();
985 }
986};
987
988/// Pattern for distributing xegpu::StoreMatrixOp.
989struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
990 using gpu::WarpDistributionPattern::WarpDistributionPattern;
991 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
992 PatternRewriter &rewriter) const override {
993 gpu::YieldOp yield = warpOp.getTerminator();
994 Operation *lastNode = yield->getPrevNode();
995 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
996 if (!matrixOp)
997 return failure();
998
999 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1000 if (!sgPayloadTy)
1001 return rewriter.notifyMatchFailure(
1002 matrixOp, "the matrix op payload must be a vector type");
1003
1004 auto loc = matrixOp.getLoc();
1005 auto offsets = matrixOp.getMixedOffsets();
1006 if (offsets.empty())
1007 return rewriter.notifyMatchFailure(matrixOp,
1008 "the store op must have offsets");
1009 SmallVector<Value> offsetsAsValues =
1010 vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1011
1012 auto layout = matrixOp.getLayoutAttr();
1013 if (!layout)
1014 return rewriter.notifyMatchFailure(
1015 matrixOp, "the matrix operation lacks layout attribute");
1016
1017 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1018 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1019 if (failed(distPayloadByWarpOpOrFailure))
1020 return rewriter.notifyMatchFailure(
1021 matrixOp, "Failed to distribute matrix op payload based on layout.");
1022
1023 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1024 const unsigned offsetsStartIdx = operands.size();
1025 operands.append(offsetsAsValues);
1026
1027 SmallVector<Type> operandTypes =
1028 llvm::map_to_vector(operands, [](Value v) { return v.getType(); });
1029 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1030
1031 SmallVector<size_t> newRetIndices;
1032 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1033 rewriter, warpOp, operands, operandTypes, newRetIndices);
1034 SmallVector<Value> newOperands = llvm::map_to_vector(
1035 newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1036
1037 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1038 ShapedType::kDynamic);
1039 DenseI64ArrayAttr newConstOffsetsAttr =
1040 rewriter.getDenseI64ArrayAttr(newConstOffsets);
1041 ValueRange currentOffsets =
1042 ValueRange(newOperands).drop_front(offsetsStartIdx);
1043
1044 SmallVector<Value> newCoords = currentOffsets;
1045 rewriter.setInsertionPointAfter(newWarpOp);
1046
1047 if (!matrixOp.getSubgroupBlockIoAttr()) {
1048 newCoords = computeDistributedCoordinatesForMatrixOp(
1049 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1050 currentOffsets);
1051 }
1052
1053 xegpu::StoreMatrixOp::create(
1054 rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1055 ValueRange(newCoords), newConstOffsetsAttr,
1056 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1057 rewriter.eraseOp(matrixOp);
1058 return success();
1059 }
1060};
1061
1062/// Distribute a scattered load op. The logic and requirements are the same as
1063/// for the scattered store distribution. The warpOp's payload vector is
1064/// expected to be distributed by the load's result consumer.
1065/// Example 1 (no chunk size):
1066/// %mask = producer_op : vector<16xi1>
1067/// %offset = producer_op : vector<16xindex>
1068/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1069/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
1070/// To
1071/// %mask = producer_op : vector<1xi1>
1072/// %offset = producer_op : vector<1xindex>
1073/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
1074/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
1075/// Example 2 (chunk size, same mask and offsets):
1076/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1077/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
1078/// To
1079/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
1080/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
1081///
1082/// Note that the load distribution pattern also handles leading unit dimensions
1083/// in the payload, mask, and offsets vector.The load distribution will only
1084/// change the dimensions corresponding to the SG distribution and keep the
1085/// leading unit dimensions unchanged. For example, a load with result type
1086/// vector<1x16xf16> with lane layout [1, 16 ] will be distributed
1087/// as result type vector<1x1xf16>. Shapecast ops are inserted for the
1088/// offset/mask/payload when necessary so that the distributed load is workign
1089/// on 1D shape vector to match the HW capability.
1090struct LoadDistribution final : public gpu::WarpDistributionPattern {
1091 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1092 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1093 PatternRewriter &rewriter) const override {
1094 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1095 // Check if the yield operand that was produced by the *last* scattered
1096 // load op to avoid sinking it before barriers (maintain memory order).
1097 return isa<xegpu::LoadGatherOp>(op) &&
1098 warpOp.getTerminator()->getPrevNode() == op;
1099 });
1100 if (!producedByLastLoad)
1101 return rewriter.notifyMatchFailure(
1102 warpOp, "The last op is not xegpu::LoadGatherOp");
1103
1104 auto loadGatherOp =
1105 producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
1106 auto offsets = loadGatherOp.getOffsets();
1107 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1108 !isa<VectorType>(loadGatherOp.getMask().getType()))
1109 return rewriter.notifyMatchFailure(
1110 loadGatherOp,
1111 "Load op must have a vector arguments for offsets and mask");
1112 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1113 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1114 VectorType resultVecTy =
1115 cast<VectorType>(loadGatherOp.getResult().getType());
1116 // add handling leading unit dimensions support
1117 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1118 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1119 for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1120 if (resultVecTy.getShape()[i] != 1) {
1121 return rewriter.notifyMatchFailure(
1122 loadGatherOp, "Only unit dimensions allowed for the leading "
1123 "dimensions of the load vector!");
1124 }
1125 }
1126
1127 auto layoutOffsets =
1128 xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
1129 auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
1130
1131 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1132 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1133 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1134 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1135 if (failed(distOffsetsByWarpOpOrFailure) ||
1136 failed(distMaskByWarpOpOrFailure)) {
1137 return rewriter.notifyMatchFailure(
1138 loadGatherOp,
1139 "Some vector operands have no layouts, using defaults instead.");
1140 }
1141
1142 SmallVector<size_t> newRetIndices;
1143 SmallVector<Value> operands = loadGatherOp->getOperands();
1144
1145 const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1146 VectorType distResultTy =
1147 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1148 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1149 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1150
1151 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1152 distOffsetsTy, distMaskTy};
1153
1154 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1155 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1156
1157 rewriter.setInsertionPointAfter(newWarpOp);
1158
1159 // Distributed load op will always be 1D.
1160 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1161 distResultTy.getElementType());
1162
1163 VectorType distOffsetsTy1D =
1164 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1165 distOffsetsByWarpOpOrFailure.value().getElementType());
1166 VectorType distMaskTy1D =
1167 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1168 distMaskByWarpOpOrFailure.value().getElementType());
1169
1170 Value distOffsetVal = resolveDistributedTy(
1171 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1172 Value distmaskVal = resolveDistributedTy(
1173 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1174
1175 SmallVector<Value> newLoadGatherOperands = {
1176 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1177
1178 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1179 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1180 loadGatherOp->getAttrs());
1182 Value distributedVal = newWarpOp.getResult(operandIdx);
1183 // Resolve the output type and replace all uses.
1184 rewriter.replaceAllUsesWith(
1185 distributedVal,
1186 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1187 return success();
1188 }
1189};
1190
1191// Sink SG-uniform ops. An op is uniform if none
1192// of its operands/results has a distribution layout attribute.
1193// Non-uniform vectors are handled by dedicated patterns.
1194// This pattern must have a higher priority than vector dialect distribution
1195// patterns, because a distributable shape may be logically intended as
1196// uniform (i.e., no layout), so we want to omit its distribution.
1197struct SinkUniformOps final : public gpu::WarpDistributionPattern {
1198 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1199 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1200 PatternRewriter &rewriter) const override {
1201 // Take the last op
1202 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1203 // Any ops with nested regions must be handled carefully in dedicated
1204 // patterns.
1205 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
1206 return failure();
1207 int operandIdx = -1;
1208 if (warpRegionPreYieldOp->getNumResults()) {
1209 OpOperand *operand = getWarpResult(
1210 warpOp, [&](Operation *op) { return warpRegionPreYieldOp == op; });
1211 if (!operand)
1212 return failure();
1213 operandIdx = operand->getOperandNumber();
1214 if (warpRegionPreYieldOp->getResult(0).getType() !=
1215 warpOp.getResult(operandIdx).getType())
1216 return rewriter.notifyMatchFailure(warpOp,
1217 "The op result is not uniform.");
1218 }
1219
1220 // The op must have no layout-based operands or results.
1221 bool uniformValuesOnly =
1222 llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
1223 return !xegpu::getDistributeLayoutAttr(v);
1224 });
1225 uniformValuesOnly &=
1226 llvm::all_of(warpRegionPreYieldOp->getOpOperands(), [](OpOperand &opr) {
1227 return !xegpu::getDistributeLayoutAttr(opr);
1228 });
1229 if (!uniformValuesOnly)
1230 return rewriter.notifyMatchFailure(warpOp,
1231 "Some values are not uniform.");
1232 SmallVector<size_t> newRetIndices;
1233 SmallVector<Value> operands =
1234 llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
1235 SmallVector<Type> operandTypes =
1236 llvm::to_vector_of<Type>(warpRegionPreYieldOp->getOperandTypes());
1237 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1238 rewriter, warpOp, operands, operandTypes, newRetIndices);
1239
1240 rewriter.setInsertionPointAfter(newWarpOp);
1241 IRMapping operandMapper;
1242 for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1243 operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),
1244 newWarpOp->getResult(newOperandIdx));
1245 Operation *clonedOp = rewriter.clone(*warpRegionPreYieldOp, operandMapper);
1246 if (!clonedOp->getNumResults())
1247 rewriter.eraseOp(warpRegionPreYieldOp);
1248 else {
1249 assert(operandIdx != -1 && "Expected a warp result for the operation");
1250 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx),
1251 clonedOp->getResult(0));
1252 }
1253 return success();
1254 }
1255};
1256
1257/// This patterns distribute the `vector.multi_reduction` operation across
1258/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1259/// layouts for the source and accumulator vectors,
1260/// * If the reduction dimension is distributed across lanes, the reduction is
1261/// non-lane-local and the reduction is done using warp shuffles. Here we
1262/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1263/// the warp op body.
1264/// * If the reduction dimension is not distributed across lanes, the reduction
1265/// is lane-local. In this case, we yield the source and accumulator vectors
1266/// from the warp op and perform the lane-local reduction outside the warp op
1267/// using a sequence of ReductionOps.
1268/// Example 1 (Reduction is lane-local):
1269/// ```
1270/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1271/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1272/// %acc = "some_def"() : () -> (vector<32xf32>)
1273/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1274/// vector<32xf32> gpu.yield %1 : vector<32xf32>
1275/// }
1276/// ```
1277/// is lowered to:
1278/// ```
1279/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1280/// vector<1xf32>) {
1281/// %0 = "some_def"() : () -> (vector<16x32xf32>)
1282/// %acc = "some_def"() : () -> (vector<32xf32>)
1283/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1284/// }
1285/// %c = arith.constant dense<0.0> : vector<1xf32>
1286/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1287/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1288/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1289/// ```
1290/// Example 2 (Reduction is non-lane-local):
1291/// ```
1292/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1293/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1294/// %acc = "some_def"() : () -> (vector<2xf32>)
1295/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1296/// vector<2xf32>
1297/// gpu.yield %1 : vector<2xf32>
1298/// }
1299/// ```
1300/// is lowered to:
1301/// ```
1302/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1303/// %0 = "some_def"() : () -> (vector<2x32xf32>)
1304/// %acc = "some_def"() : () -> (vector<2xf32>)
1305/// %1 = arith.constant dense<0.0> : vector<2xf32>
1306/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1307/// %3 = ("warp.reduction %2") : f32
1308/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1309/// ... repeat for row 1
1310/// gpu.yield %1 : vector<2xf32>
1311/// }
1312struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1313 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1314 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1315 PatternRewriter &rewriter) const override {
1316 OpOperand *yieldOperand =
1317 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1318 if (!yieldOperand)
1319 return failure();
1320 auto reductionOp =
1321 cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1322 unsigned operandIdx = yieldOperand->getOperandNumber();
1323 VectorType sourceType = reductionOp.getSourceVectorType();
1324 // Only 2D vectors are supported.
1325 if (sourceType.getRank() != 2)
1326 return rewriter.notifyMatchFailure(warpOp,
1327 "Only 2D reductions are supported.");
1328 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1329 // Only 1 reduction dimension supported. This also ensures that the result
1330 // is vector type.
1331 if (reductionDims.size() != 1)
1332 return rewriter.notifyMatchFailure(
1333 warpOp, "Only 1 reduction dimension is supported.");
1334 int64_t reductionDim = reductionDims[0];
1335 VectorType distributedResultType =
1336 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1337 VectorType resultType = cast<VectorType>(reductionOp.getType());
1338 xegpu::DistributeLayoutAttr sourceLayout =
1339 xegpu::getTemporaryLayout(reductionOp->getOpOperand(0));
1340
1341 FailureOr<VectorType> sourceDistTypeOrFailure =
1342 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1343 if (failed(sourceDistTypeOrFailure))
1344 return rewriter.notifyMatchFailure(
1345 warpOp, "Failed to distribute the source vector type.");
1346 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1347 // Only single dimension distribution is supported.
1348 bool dim0Distributed =
1349 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1350 bool dim1Distributed =
1351 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1352 if (dim0Distributed && dim1Distributed)
1353 return rewriter.notifyMatchFailure(
1354 warpOp, "Expecting source to be distributed in a single dimension.");
1355 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1356 if (sourceDistDim == -1)
1357 return rewriter.notifyMatchFailure(
1358 warpOp, "Expecting a distributed source vector.");
1359 bool resultDistributed =
1360 distributedResultType.getNumElements() < resultType.getNumElements();
1361 // If the lane owns all the data required for reduction (i.e. reduction is
1362 // fully parallel accross lanes), then each lane owns part of the result
1363 // (i.e. result is distributed). If the reduction require cross-lane
1364 // shuffling, then the result is shared among all lanes (broadcasted).
1365 // Therefore we expect following cases:
1366 //
1367 // | Source vector | Reduction dim | Result vector |
1368 // |----------------------|----------------|----------------|
1369 // | dim-0 distributed | 0 | broadcasted |
1370 // | dim-0 distributed | 1 | distributed |
1371 // | dim-1 distributed | 0 | distributed |
1372 // | dim-1 distributed | 1 | broadcasted |
1373
1374 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1375 (sourceDistDim == 1 && reductionDim == 0);
1376 if (isReductionLaneLocal && !resultDistributed)
1377 return rewriter.notifyMatchFailure(
1378 warpOp, "Expecting a distributed result for lane-local reduction.");
1379
1380 if (!isReductionLaneLocal && resultDistributed)
1381 return rewriter.notifyMatchFailure(
1382 warpOp,
1383 "Expecting a broadcasted result for non-lane-local reduction.");
1384
1385 // Handle lane-local reduction case. In this case we fully distribute the
1386 // reduction result.
1387 if (isReductionLaneLocal) {
1388 // Yield the source and acc vectors from the WarpOp.
1389 SmallVector<size_t> newRetIndices;
1390 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1391 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1392 {sourceDistType, distributedResultType}, newRetIndices);
1393 rewriter.setInsertionPointAfter(newWarpOp);
1395 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1396 cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1397 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1398 // Replace the warp op result with the final result.
1399 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1400 return success();
1401 }
1402 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1403 // of multiple ReductionOps. Actual distribution is done by the
1404 // WarpOpReduction pattern.
1405 rewriter.setInsertionPointAfter(reductionOp);
1407 cast<TypedValue<VectorType>>(reductionOp.getSource()),
1408 cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1409 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1410 // Replace the warp op result with the final result.
1411 rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1412 return success();
1413 }
1414};
1415
1416/// This pattern distributes the `vector.broadcast` operation across lanes in a
1417/// warp. The pattern supports three use cases:
1418///
1419/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
1420/// vector
1421/// must have a slice layout of the result. If the distributed source and
1422/// target vector types are identical, this lowers to a no-op; otherwise, it
1423/// remains a broadcast but operates on distributed vectors.
1424///
1425/// 2) Broadcast a same-rank vector with identical layouts for source and
1426/// target:
1427/// The source vector must have unit dimensions, and lane_data must be unit
1428/// size for those unit dims. This always lowers to a no-op.
1429///
1430/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
1431/// scalar to distributed result type.
1432///
1433/// Example 1 (lowering to a broadcast with distributed types):
1434/// ```
1435/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1436/// %0 = "some_def"() {layout_result_0 =
1437/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1438/// dims = [0]> } : () -> (vector<32xf32>)
1439/// %2 = vector.broadcast %0 {layout_result_0 =
1440/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
1441/// : vector<32xf32> to vector<8x32xf32>
1442/// gpu.yield %1 : vector<8x32xf32>
1443/// }
1444/// ```
1445/// is lowered to:
1446/// ```
1447/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1448/// %0 = "some_def"() {layout_result_0 =
1449/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1450/// dims = [0]> } : () -> (vector<32xf32>)
1451/// gpu.yield %0 : vector<32xf32>
1452/// }
1453/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
1454///
1455/// Example 2 (no-op):
1456/// ```
1457/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
1458/// %0 = "some_def"() {layout_result_0 =
1459/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1460/// dims = [1]> } : () -> (vector<8xf32>)
1461/// %1 = vector.shape_cast %0
1462/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1463/// 1]>}: vector<8xf32> to vector<8x1xf32>
1464/// %2 = vector.broadcast %1
1465/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1466/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
1467/// gpu.yield %1 : vector<8x32xf32>
1468/// }
1469/// ```
1470/// is lowered to:
1471/// ```
1472/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
1473/// %0 = "some_def"() {layout_result_0 =
1474/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
1475/// dims = [1]> } : () -> (vector<8xf32>)
1476/// %1 = vector.shape_cast %0
1477/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
1478/// 1]>}: vector<8xf32> to vector<8x1xf32>
1479/// gpu.yield %1 : vector<8x1xf32>
1480/// }
1481/// // The broadcast is implicit through layout transformation (no-op)
1482/// "some_use"(%r#0)
1483/// ```
1484struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
1485 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1486 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1487 PatternRewriter &rewriter) const override {
1488 OpOperand *yieldOperand =
1489 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1490 if (!yieldOperand)
1491 return failure();
1492 auto broadcastOp =
1493 cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
1494 unsigned operandIdx = yieldOperand->getOperandNumber();
1495
1496 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1497 VectorType destType =
1498 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1499
1500 xegpu::DistributeLayoutAttr sourceLayout =
1501 xegpu::getTemporaryLayout(broadcastOp->getOpOperand(0));
1502 xegpu::DistributeLayoutAttr resultLayout =
1503 xegpu::getTemporaryLayout(dyn_cast<OpResult>(broadcastOp.getResult()));
1504
1505 FailureOr<VectorType> sourceDistType;
1506 Type sourceElemOrDistType;
1507 if (sourceType) {
1508
1509 // Case 1 and 2: source is a vector type.
1510 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1511 if (rankDiff > 0) {
1512 // Case 1: source is lower-rank than result.
1513 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1514 if (!isSliceOf)
1515 broadcastOp.emitWarning()
1516 << "Broadcast input layout must be a slice of result layout.";
1517 }
1518 // case 2: source and result have same rank
1519 if (rankDiff == 0) {
1520 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1521 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1522 broadcastUnitDimsSet.end());
1523 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1524 if (!isEqualTo)
1525 return rewriter.notifyMatchFailure(
1526 warpOp, "For same-rank broadcast, source must be identical to "
1527 "adjusted result layouts with unit dims.");
1528 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1529 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1530 }
1531
1532 sourceDistType =
1533 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1534 if (failed(sourceDistType)) {
1535 return rewriter.notifyMatchFailure(
1536 warpOp, "Failed to distribute the source vector type.");
1537 }
1538 sourceElemOrDistType = sourceDistType.value();
1539
1540 } else {
1541 // Case 3: source is a scalar type.
1542 if (sourceLayout) {
1543 return rewriter.notifyMatchFailure(
1544 warpOp, "Broadcast from scalar must not have a layout attribute.");
1545 }
1546 sourceElemOrDistType = broadcastOp.getSourceType();
1547 }
1548 FailureOr<VectorType> destDistType =
1549 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1550 if (failed(destDistType)) {
1551 return rewriter.notifyMatchFailure(
1552 warpOp, "Failed to distribute the dest vector type.");
1553 }
1554
1555 SmallVector<size_t> newRetIndices;
1557 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1558 newRetIndices);
1559
1560 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1561
1562 Value newBroadcast = distributedSource;
1563
1564 if (sourceElemOrDistType != destDistType.value()) {
1565 rewriter.setInsertionPointAfter(newWarpOp);
1566 newBroadcast =
1567 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1568 destDistType.value(), distributedSource);
1569 }
1570
1571 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
1572 return success();
1573 }
1574};
1575
1576/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1577/// `gpu.warp_execute_on_lane_0` region.
1578struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1579 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1580 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1581 PatternRewriter &rewriter) const override {
1582 OpOperand *yieldOperand =
1583 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1584 if (!yieldOperand)
1585 return failure();
1586 auto shapeCastOp =
1587 cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1588 unsigned operandNumber = yieldOperand->getOperandNumber();
1589 auto resultDistTy =
1590 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1591 xegpu::DistributeLayoutAttr sourceLayout =
1592 xegpu::getTemporaryLayout(shapeCastOp->getOpOperand(0));
1593 xegpu::DistributeLayoutAttr resultLayout =
1594 xegpu::getTemporaryLayout(dyn_cast<OpResult>(shapeCastOp.getResult()));
1595 if (!sourceLayout || !resultLayout)
1596 return rewriter.notifyMatchFailure(
1597 warpOp,
1598 "the source or result of shape_cast op lacks distribution layout");
1599
1600 FailureOr<VectorType> sourceDistTypeOrFailure =
1602 shapeCastOp.getSourceVectorType());
1603 if (failed(sourceDistTypeOrFailure))
1604 return rewriter.notifyMatchFailure(
1605 warpOp, "failed to get distributed vector type for source");
1606 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1607 // Create a new warp op that yields the source of the shape_cast op.
1608 SmallVector<size_t> newRetIndices;
1610 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1611 newRetIndices);
1612 rewriter.setInsertionPointAfter(newWarpOp);
1613 Value source = newWarpOp.getResult(newRetIndices[0]);
1614 // Create a new shape_cast op outside the warp op.
1615 Value newShapeCast = vector::ShapeCastOp::create(
1616 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1617 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1618 newShapeCast);
1619 return success();
1620 }
1621};
1622
1623// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1624// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1625// advanced cases where the distributed dimension is partially extracted and
1626// currently not supported by the generic vector distribution patterns.
1627struct VectorExtractStridedSliceDistribution
1629 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1630 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1631 PatternRewriter &rewriter) const override {
1632 OpOperand *operand =
1633 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1634 if (!operand)
1635 return failure();
1636 auto extractOp =
1637 cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1638 unsigned operandIdx = operand->getOperandNumber();
1639 auto distributedType =
1640 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1641 // Find the distributed dimensions.
1642 auto extractResultType = cast<VectorType>(operand->get().getType());
1643 auto distributedDims =
1644 getDistributedDims(extractResultType, distributedType);
1645 // Collect updated source type, sizes and offsets. They may be adjusted
1646 // later if the data is distributed to lanes (as opposed to being owned by
1647 // all lanes uniformly).
1648 VectorType updatedSourceType = extractOp.getSourceVectorType();
1649 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1650 extractOp.getSizes(), [](Attribute attr) { return attr; });
1651 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1652 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1653 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1654 extractOp.getStrides(), [](Attribute attr) { return attr; });
1655 // If the provided sizes, offsets, strides are less than the rank, pad them
1656 // with full sizes, zero offsets, and unit strides. This makes it easier to
1657 // adjust them later.
1658 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1659 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1660 updatedSizes.push_back(rewriter.getI64IntegerAttr(
1661 extractOp.getSourceVectorType().getDimSize(i)));
1662 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1663 updatedStrides.push_back(
1664 rewriter.getI64IntegerAttr(1)); // stride is always 1.
1665 }
1666 // If the result is distributed, it must be distributed in exactly one
1667 // dimension. In this case, we adjust the sourceDistType, distributedSizes
1668 // and distributedOffsets accordingly.
1669 if (distributedDims.size() > 0) {
1670 if (distributedDims.size() != 1)
1671 return rewriter.notifyMatchFailure(
1672 warpOp, "Source can not be distributed in multiple dimensions.");
1673 int64_t distributedDim = distributedDims[0];
1674 int sourceDistrDimSize =
1675 extractOp.getSourceVectorType().getShape()[distributedDim];
1676 auto sourceLayout = xegpu::getTemporaryLayout(extractOp->getOpOperand(0));
1677 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1678 return rewriter.notifyMatchFailure(
1679 warpOp, "the source of extract_strided_slice op lacks distribution "
1680 "layout");
1681 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1682 // Because only single dimension distribution is supported, lane layout
1683 // size at the distributed dim must be the subgroup size.
1684 int subgroupSize = sourceLaneLayout[distributedDim];
1685 // Check if the source size in the distributed dimension is a multiple of
1686 // subgroup size.
1687 if (sourceDistrDimSize % subgroupSize != 0)
1688 return rewriter.notifyMatchFailure(
1689 warpOp,
1690 "Source size along distributed dimension is not a multiple of "
1691 "subgroup size.");
1692 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1693 // We expect lane data to be all ones in this case.
1694 if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1695 return rewriter.notifyMatchFailure(
1696 warpOp, "Expecting unit lane data in source layout");
1697 // The offsets in the distributed dimention must be a multiple of subgroup
1698 // size.
1699 int64_t distrDimOffset =
1700 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1701 if (distrDimOffset % subgroupSize != 0)
1702 return rewriter.notifyMatchFailure(
1703 warpOp, "Offset along distributed dimension "
1704 "is not a multiple of subgroup size.");
1705 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1706 sourceLayout, extractOp.getSourceVectorType())
1707 .value();
1708 // Update the distributed sizes to match the distributed type.
1709 updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1710 distributedType.getDimSize(distributedDim));
1711 // Update the distributed offsets to match round robin distribution (i.e.
1712 // each lane owns data at `subgroupSize` stride given unit lane data).
1713 updatedOffsets[distributedDim] =
1714 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1715 }
1716 // Do the distribution by yielding the source of the extract op from
1717 // the warp op and creating a new extract op outside the warp op.
1718 SmallVector<size_t> newRetIndices;
1720 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1721 newRetIndices);
1722 rewriter.setInsertionPointAfter(newWarpOp);
1723 Value source = newWarpOp.getResult(newRetIndices[0]);
1724 // Create a new extract op outside the warp op.
1725 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1726 rewriter, extractOp.getLoc(), distributedType, source,
1727 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1728 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1729 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1730 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1731 return success();
1732 }
1733};
1734
1735/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1736/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1737/// advanced cases where the distributed dimension is partially inserted and
1738/// currently not supported by the generic vector distribution patterns.
1739struct VectorInsertStridedSliceDistribution
1741 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1742 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1743 PatternRewriter &rewriter) const override {
1744 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
1745 // Check if the InsertStridedSliceOp is the last op before yield op
1746 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1747 warpOp.getTerminator()->getPrevNode() == op;
1748 });
1749 if (!operand)
1750 return failure();
1751 unsigned int operandNumber = operand->getOperandNumber();
1752 auto insertOp =
1753 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1754 auto distributedType =
1755 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1756 // Find the distributed dimensions of the dest vector.
1757 auto insertResultType = cast<VectorType>(operand->get().getType());
1758 auto destDistributedDims =
1759 getDistributedDims(insertResultType, distributedType);
1760 // Collect updated offsets, source type and dest type. They may be adjusted
1761 // later if the data is distributed to lanes (as opposed to being owned by
1762 // all lanes uniformly).
1763 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1764 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1765 VectorType updatedSourceType = insertOp.getSourceVectorType();
1766 VectorType updatedDestType = insertOp.getDestVectorType();
1767 if (destDistributedDims.size() > 0) {
1768 // Only single dimension distribution is supported.
1769 if (destDistributedDims.size() != 1)
1770 return rewriter.notifyMatchFailure(
1771 warpOp,
1772 "Expecting source to be distributed in a single dimension.");
1773 int64_t destDistributedDim = destDistributedDims[0];
1774
1775 VectorType srcType = insertOp.getSourceVectorType();
1776 VectorType destType = insertOp.getDestVectorType();
1777 // Currently we require that both source (kD) and dest (nD) vectors are
1778 // distributed. This requires that distributedDim (d) is contained in the
1779 // last k dims of the dest vector (d >= n - k).
1780 int64_t sourceDistributedDim =
1781 destDistributedDim - (destType.getRank() - srcType.getRank());
1782 if (sourceDistributedDim < 0)
1783 return rewriter.notifyMatchFailure(
1784 insertOp,
1785 "distributed dimension must be in the last k (i.e. source "
1786 "rank) dims of dest vector");
1787 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1788 // Obtain the source and dest layouts.
1789 auto destLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(1));
1790 auto sourceLayout = xegpu::getTemporaryLayout(insertOp->getOpOperand(0));
1791 if (!destLayout || !sourceLayout ||
1792 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1793 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1794 return rewriter.notifyMatchFailure(
1795 warpOp, "the source or dest of insert_strided_slice op lacks "
1796 "distribution layout");
1797 // Because only single dimension distribution is supported, lane layout
1798 // size at the distributed dim must be the subgroup size.
1799 int subgroupSize =
1800 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1801 // We require that source and dest lane data are all ones to ensure
1802 // uniform round robin distribution.
1803 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1804 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1805 if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1806 !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1807 return rewriter.notifyMatchFailure(
1808 warpOp, "Expecting unit lane data in source and dest layouts");
1809 // Source distributed dim size must be multiples of subgroup size.
1810 if (srcDistrDimSize % subgroupSize != 0)
1811 return rewriter.notifyMatchFailure(
1812 warpOp, "Distributed dimension size in source is not a multiple of "
1813 "subgroup size.");
1814 // Offsets in the distributed dimension must be multiples of subgroup
1815 // size.
1816 int64_t destDistrDimOffset =
1817 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1818 if (destDistrDimOffset % subgroupSize != 0)
1819 return rewriter.notifyMatchFailure(
1820 warpOp,
1821 "Offset along distributed dimension in dest is not a multiple of "
1822 "subgroup size.");
1823 // Update the source and dest types based on their layouts.
1824 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1825 sourceLayout, insertOp.getSourceVectorType())
1826 .value();
1827 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1828 destLayout, insertOp.getDestVectorType())
1829 .value();
1830 // Update the distributed offsets to match round robin distribution (i.e.
1831 // each lane owns data at `subgroupSize` stride given unit lane data).
1832 updatedOffsets[destDistributedDim] =
1833 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1834 }
1835 // Do the distribution by yielding the source and dest of the insert op
1836 // from the warp op and creating a new insert op outside the warp op.
1837 SmallVector<size_t> newRetIndices;
1839 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1840 {updatedSourceType, updatedDestType}, newRetIndices);
1841 rewriter.setInsertionPointAfter(newWarpOp);
1842
1843 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1844 Value dest = newWarpOp.getResult(newRetIndices[1]);
1845 // Create a new insert op outside the warp op.
1846 Value newInsertOp = vector::InsertStridedSliceOp::create(
1847 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1848 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1849 insertOp.getStrides());
1850 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1851 newInsertOp);
1852 return success();
1853 }
1854};
1855
1856/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1857/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1858/// outside of the warp op.
1859struct MemrefExtractAlignedPointerAsIndexDistribution final
1861 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1862 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1863 PatternRewriter &rewriter) const override {
1864 OpOperand *operand = getWarpResult(
1865 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1866 if (!operand)
1867 return rewriter.notifyMatchFailure(
1868 warpOp,
1869 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1870 auto extractOp =
1871 operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1872 unsigned operandIdx = operand->getOperandNumber();
1873 SmallVector<size_t> newRetIndices;
1874 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1875 rewriter, warpOp, extractOp.getSource(),
1876 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1877 rewriter.setInsertionPointAfter(newWarpOp);
1878 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1879 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1880 newWarpOp.getResult(newRetIndices[0]));
1881 Value resultVal = newWarpOp.getResult(operandIdx);
1882 rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
1883 return success();
1884 }
1885};
1886
1887/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1888/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1889/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1890/// created outside of the warp op with distributed source vector type (computed
1891/// using assigned layout).
1892struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1893 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1894 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1895 PatternRewriter &rewriter) const override {
1896 OpOperand *operand =
1897 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1898 if (!operand)
1899 return rewriter.notifyMatchFailure(
1900 warpOp, "warp result is not a vector::BitCast op");
1901 auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1902 unsigned operandIdx = operand->getOperandNumber();
1903 VectorType distributedSourceType =
1905 xegpu::getTemporaryLayout(bitcastOp->getOpOperand(0)),
1906 bitcastOp.getSourceVectorType())
1907 .value_or(VectorType());
1908 if (!distributedSourceType)
1909 return rewriter.notifyMatchFailure(
1910 bitcastOp, "Failed to distribute the source vector type in "
1911 "vector::BitCast op");
1912 VectorType distributedResultType =
1913 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1914 SmallVector<size_t> newRetIndices;
1915 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1916 rewriter, warpOp, bitcastOp.getSource(),
1917 TypeRange{distributedSourceType}, newRetIndices);
1918 rewriter.setInsertionPointAfter(newWarpOp);
1919 auto newBitcastOp = vector::BitCastOp::create(
1920 rewriter, newWarpOp.getLoc(), distributedResultType,
1921 newWarpOp.getResult(newRetIndices[0]));
1922 Value distributedVal = newWarpOp.getResult(operandIdx);
1923 rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1924 return success();
1925 }
1926};
1927
1928/// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1929/// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1930/// supported. In most cases, transpose is a no op because it is entirely
1931/// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1932/// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1933/// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1934/// vector::TransposeOp outside of the warp op with distributed source vector
1935/// type (computed using assigned layout).
1936struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1937 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1938 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1939 PatternRewriter &rewriter) const override {
1940 OpOperand *operand =
1941 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1942 if (!operand)
1943 return rewriter.notifyMatchFailure(
1944 warpOp, "warp result is not a vector::Transpose op");
1945 auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1946 unsigned operandIdx = operand->getOperandNumber();
1947 xegpu::DistributeLayoutAttr sourceLayout =
1948 xegpu::getTemporaryLayout(transposeOp->getOpOperand(0));
1949 xegpu::DistributeLayoutAttr resultLayout =
1950 xegpu::getTemporaryLayout(transposeOp->getOpResult(0));
1951 if (!sourceLayout || !resultLayout)
1952 return rewriter.notifyMatchFailure(
1953 transposeOp,
1954 "the source or result vector of the transpose op lacks layout "
1955 "attribute");
1956 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1957 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1958 // Only 2D transposes are supported for now.
1959 // TODO: Support nD transposes.
1960 if (sourceRank != 2 || resultRank != 2)
1961 return rewriter.notifyMatchFailure(
1962 transposeOp, "the source or result vector of the transpose op "
1963 "does not have 2D layout");
1964 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1965 // Result layout must be a transpose of source layout.
1966 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1967 return rewriter.notifyMatchFailure(
1968 transposeOp,
1969 "the source or result vector layouts must be 2D transposes of each "
1970 "other");
1971 FailureOr<VectorType> distributedSourceTypeOrFailure =
1973 transposeOp.getSourceVectorType());
1974 if (failed(distributedSourceTypeOrFailure))
1975 return rewriter.notifyMatchFailure(
1976 transposeOp, "Failed to distribute the source vector type in "
1977 "vector::Transpose op");
1978 SmallVector<size_t> newRetIndices;
1979 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1980 rewriter, warpOp, transposeOp.getVector(),
1981 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1982 rewriter.setInsertionPointAfter(newWarpOp);
1983 auto newTransposeOp = vector::TransposeOp::create(
1984 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1985 perm);
1986 Value distributedVal = newWarpOp.getResult(operandIdx);
1987 rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
1988 return success();
1989 }
1990};
1991
1992/// Distribute a vector::StepOp with the sliced result layout.
1993/// The sliced layout must have exactly 1 effective lane dimension.
1994/// We completely resolve the vector::StepOp by computing the lane_data-sized
1995/// subranges.
1996struct VectorStepSliceDistribution final : public gpu::WarpDistributionPattern {
1997 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1998 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1999 PatternRewriter &rewriter) const override {
2000 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2001 if (!operand)
2002 return rewriter.notifyMatchFailure(
2003 warpOp, "warp result is not a vector::StepOp op");
2004 auto stepOp = operand->get().getDefiningOp<vector::StepOp>();
2005 unsigned operandIdx = operand->getOperandNumber();
2006 xegpu::DistributeLayoutAttr resultLayout =
2007 xegpu::getTemporaryLayout(stepOp->getResult(0));
2008 if (!resultLayout)
2009 return rewriter.notifyMatchFailure(
2010 stepOp, "the result vector of the step op lacks layout "
2011 "attribute");
2012 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2013 if (!sliceLayout)
2014 return rewriter.notifyMatchFailure(
2015 stepOp, "the result layout must be a slice layout");
2016 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2017 return rewriter.notifyMatchFailure(
2018 stepOp, "expecting 1 dim in the effective result layout");
2019
2020 rewriter.setInsertionPointAfter(warpOp);
2021 auto loc = stepOp.getLoc();
2022 auto stepResultVecTy = stepOp.getResult().getType();
2023 Value distributedVal = warpOp.getResult(operandIdx);
2024 VectorType newVecTy = cast<VectorType>(distributedVal.getType());
2025
2026 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2027 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2028 if (failed(laneDataBlockCoords))
2029 return rewriter.notifyMatchFailure(
2030 stepOp, "failed to compute lane data block coordinates");
2031
2032 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2033 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2034 assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
2035 newVecTy.getNumElements() / laneDataBlockLength);
2036 SmallVector<Value> stepVals;
2037 // For each lane_data block, reconstruct its sub-range
2038 // from the range of SG-level vector.step. Example: vector.step
2039 // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
2040 // vector<16xindex>
2041 // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
2042 // The blocks are round-robin distributed, so logical lane id 0
2043 // holds values [0,1, 8,9].
2044 for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2045 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2046 stepVals.push_back(laneDataBlockStartCoord);
2047 for (int i = 1; i < laneDataBlockLength; ++i) {
2048 auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
2049 stepVals.push_back(arith::AddIOp::create(
2050 rewriter, loc, laneDataBlockStartCoord, offset));
2051 }
2052 }
2053 assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
2054 "Expecting the number of step values to match the number of "
2055 "elements in the vector");
2056 auto stepOpVal =
2057 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2058 rewriter.replaceAllUsesWith(distributedVal, stepOpVal);
2059 return success();
2060 }
2061};
2062
2063struct ConvertLayoutDistribution
2064 : public OpRewritePattern<xegpu::ConvertLayoutOp> {
2066
2067 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
2068 PatternRewriter &rewriter) const override {
2069 auto inputLayout = op.getInputLayoutAttr();
2070 auto targetLayout = op.getTargetLayoutAttr();
2071
2072 if (!inputLayout || !targetLayout)
2073 return rewriter.notifyMatchFailure(op, "missing layout attributes");
2074
2075 if (!inputLayout.isCompatibleWith(targetLayout, xegpu::LayoutKind::Lane)) {
2076 return rewriter.notifyMatchFailure(
2077 op, "lowering incompatible convert_layout not yet supported");
2078 }
2079 rewriter.replaceOp(op, op.getSource());
2080 return success();
2081 }
2082};
2083
2084} // namespace
2085
2086namespace {
2087struct XeGPUSubgroupDistributePass final
2089 XeGPUSubgroupDistributePass> {
2090 void runOnOperation() override;
2091};
2092} // namespace
2093
2095 RewritePatternSet &patterns) {
2096 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2097 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2098 GpuBarrierDistribution, VectorMultiReductionDistribution,
2099 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2100 VectorBitcastDistribution, LoadMatrixDistribution,
2101 StoreMatrixDistribution, ConvertLayoutDistribution,
2102 MemrefExtractAlignedPointerAsIndexDistribution>(
2103 patterns.getContext(),
2104 /*pattern benefit=*/PatternHierarchy::Regular);
2105 // For following patterns, we need to override the regular vector distribution
2106 // patterns. Therefore, assign higher benefit.
2107 patterns
2108 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2109 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2110 VectorStepSliceDistribution, SinkUniformOps>(
2111 patterns.getContext(),
2112 /*pattern benefit=*/PatternHierarchy::AboveRegular);
2113}
2114
2116 RewritePatternSet &patterns) {
2117 patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
2118}
2119
2120void XeGPUSubgroupDistributePass::runOnOperation() {
2121 // Step 1: Attach layouts to op operands.
2122 // TODO: Following assumptions are made:
2123 // 1) It is assumed that there are no layout conflicts.
2124 // 2) Any existing layout attributes attached to the operands are ignored.
2125 Operation *op = getOperation();
2127 signalPassFailure();
2128 return;
2129 }
2130
2131 // Step 2: Move all operations of a GPU function inside
2132 // gpu.warp_execute_on_lane_0 operation.
2133 {
2134 RewritePatternSet patterns(&getContext());
2136
2137 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2138 signalPassFailure();
2139 return;
2140 }
2141 // At this point, we have moved the entire function body inside the
2142 // warpOp. Now move any scalar uniform code outside of the warpOp (like
2143 // GPU index ops, scalar constants, etc.). This will simplify the
2144 // later lowering and avoid custom patterns for these ops.
2145 getOperation()->walk([&](Operation *op) {
2146 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2147 vector::moveScalarUniformCode(warpOp);
2148 });
2149 }
2150 // Step 3: Apply subgroup to workitem distribution patterns.
2151 RewritePatternSet patterns(&getContext());
2153 // distributionFn is used by vector distribution patterns to determine the
2154 // distributed vector type for a given vector value. In XeGPU subgroup
2155 // distribution context, we compute this based on lane layout.
2156 auto distributionFn = [](Value val) {
2157 VectorType vecType = dyn_cast<VectorType>(val.getType());
2158 int64_t vecRank = vecType ? vecType.getRank() : 0;
2159 if (vecRank == 0)
2160 return AffineMap::get(val.getContext());
2161 // Get the layout of the vector type.
2162 xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
2163 // If no layout is specified, assume uniform case (no distribution).
2164 if (!layout)
2165 return AffineMap::get(val.getContext());
2166 // Expecting vector and layout rank to match.
2167 assert(layout.getRank() == vecRank &&
2168 "Expecting vector and layout rank to match");
2169 // A dimension is distributed only if layout suggests there are
2170 // multiple lanes assigned for this dimension and the shape can be evenly
2171 // distributed to those lanes.
2172 SmallVector<unsigned int> distributedDims;
2173 for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2174 if (v > 1 && vecType.getShape()[i] % v == 0)
2175 distributedDims.push_back(i);
2176 }
2177 return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
2178 val.getContext());
2179 };
2180 // TODO: shuffleFn is not used.
2181 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2182 int64_t warpSz) { return Value(); };
2183
2184 vector::populateDistributeReduction(
2185 patterns, xegpu::subgroupReduction,
2186 /*pattern benefit=*/PatternHierarchy::Regular);
2187
2188 vector::populatePropagateWarpVectorDistributionPatterns(
2189 patterns, distributionFn, shuffleFn,
2190 /*pattern benefit=*/PatternHierarchy::Regular);
2191 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
2192 signalPassFailure();
2193 return;
2194 }
2195
2196 // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
2197 // due to tensor desc type mismatches created by using upstream distribution
2198 // patterns (scf.for). This cleanup should only be done if all the ops are
2199 // distributed successfully, if some ops are still not distributed and remains
2200 // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
2201 // breaking the IR.
2202 bool foundWarpOp = false;
2203 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2204 // Look for WarpOps that are not trivially dead.
2205 if (isOpTriviallyDead(warpOp))
2206 return WalkResult::advance();
2207 foundWarpOp = true;
2208 return WalkResult::interrupt();
2209 });
2210 if (foundWarpOp)
2211 return;
2212
2213 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2214 // We are only interested in UnrealizedConversionCastOps there were added
2215 // for resolving SIMT type mismatches.
2216 if (!op->getAttr(resolveSIMTTypeMismatch))
2217 return WalkResult::skip();
2218
2219 Value input = op.getOperand(0);
2220 Value output = op.getResult(0);
2221
2222 // Both input and output must have tensor descriptor types.
2223 xegpu::TensorDescType inputDescType =
2224 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
2225 xegpu::TensorDescType outputDescType =
2226 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
2227 assert(inputDescType && outputDescType &&
2228 "Unrealized conversion cast must have tensor descriptor types");
2229
2230 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
2231 // This occurs inside scf.for body to resolve the block argument type to
2232 // SIMT type.
2233 if (inputDescType.getLayout()) {
2234 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2235 if (argument) {
2236 argument.setType(output.getType());
2237 output.replaceAllUsesWith(argument);
2238 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2239 argument.getOwner()->getParentOp())) {
2240 auto result = loopOp.getTiedLoopResult(argument);
2241 result.setType(output.getType());
2242 }
2243 }
2244 }
2245
2246 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
2247 // conversions. This occurs at the yield op of scf.for body to go back
2248 // from SIMT type to original type.
2249 if (outputDescType.getLayout())
2250 output.replaceAllUsesWith(input);
2251
2252 if (op->use_empty())
2253 op->erase();
2254 return WalkResult::advance();
2255 });
2256}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:358
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:415
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:391
operand_type_range getOperandTypes()
Definition Operation.h:405
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:386
result_range getResults()
Definition Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:412
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0