MLIR 23.0.0git
XeGPUUtils.cpp
Go to the documentation of this file.
1//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utility methods for working with the XeGPU dialect.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/IR/ValueRange.h"
25#include "llvm/Support/Casting.h"
26#include "llvm/Support/FormatVariadic.h"
27#include <cstdint>
28#include <numeric>
29
30using namespace mlir;
31
32/// convert ArrayRef<ValueRange> into SmallVector<Value>
35 for (const auto &vals : values)
36 llvm::append_range(result, vals);
37 return result;
38}
39
40FailureOr<VectorType>
41mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
42 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
43 // It only works for subgroup level layout, which only has lane_layout
44 // and lane_data, and is to distribute a SIMD code into SIMT code.
45 if (!layout || !layout.isForSubgroup())
46 return failure();
47
48 SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
49 SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
50 auto tdescShape = tdescTy.getShape();
51 auto elementType = tdescTy.getElementType();
52
53 // compute sgSize by multiply elements of laneLayout
54 // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
55 // e.g. for 1D layout, sgSize = laneLayout[0]
56 int64_t sgSize = llvm::product_of(laneLayout);
57
58 // Check if the tensor descriptor shape is distributable.
59 int64_t tensorSize = 1;
60 for (auto [tdescDim, laneDim, laneDataDim] :
61 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
62 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
63 "tensor descriptor shape is not distributable");
64 tensorSize *= tdescDim;
65 }
66 // tensorSize must be adjusted for array_length.
67 tensorSize *= tdescTy.getArrayLength();
68
69 return VectorType::get({tensorSize / sgSize}, elementType);
70}
71
72FailureOr<VectorType>
73mlir::xegpu::getDistributedVectorType(VectorType originalType,
74 xegpu::LayoutAttr layout) {
75 int64_t rank = originalType.getRank();
76 // Distributed vector type is only supported for 1D, 2D and 3D vectors.
77 if (rank < 1 || rank > 3)
78 return failure();
79 ArrayRef<int64_t> shape = originalType.getShape();
80 // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
81 // of the 3D vector.
82 int arrayLength = 1;
83 if (rank == 3) {
84 arrayLength = shape[0];
85 shape = shape.drop_front();
86 }
87 auto helperTdescTy = xegpu::TensorDescType::get(
88 shape, originalType.getElementType(), arrayLength,
89 /*boundary_check=*/true,
90 /*memory_space=*/xegpu::MemorySpace::Global, layout);
91 return xegpu::getDistributedVectorType(helperTdescTy);
92}
93
94FailureOr<VectorType>
95xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
96 VectorType originalType) {
97 if (!layout)
98 return failure();
99 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
100 "Expecting a valid layout.");
101
102 int64_t vectorRank = originalType.getRank();
103 int64_t layoutRank = layout.getRank();
104 assert(vectorRank >= layoutRank && "Vector rank must be >= layout rank.");
105
106 // When the vector has more dimensions than the layout, only the trailing
107 // dimensions are distributed. Leading dimensions are preserved as-is.
108 int64_t offset = vectorRank - layoutRank;
109 ArrayRef<int64_t> fullShape = originalType.getShape();
110 SmallVector<int64_t> trailingShape(fullShape.begin() + offset,
111 fullShape.end());
112 auto distributedShapeOrFailure =
113 layout.computeDistributedShape(trailingShape);
114 if (failed(distributedShapeOrFailure))
115 return failure();
116
117 SmallVector<int64_t> resultShape(fullShape.begin(),
118 fullShape.begin() + offset);
119 resultShape.append(distributedShapeOrFailure->begin(),
120 distributedShapeOrFailure->end());
121 return VectorType::get(resultShape, originalType.getElementType());
122}
123
124std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
125 const StringRef prefix("layout_operand_");
126 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
127 return llvm::formatv("{0}{1}", prefix, idx).str();
128}
129
131 const StringRef prefix = "layout_result_";
132 return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
133}
134
135xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
136 if (!value)
137 return nullptr;
138
139 if (auto tdescTy =
140 dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
141 return tdescTy.getLayoutAttr();
142
143 if (auto result = dyn_cast<OpResult>(value)) {
144 Operation *defOp = result.getDefiningOp();
145 assert(defOp && "result must have a defining op");
146
147 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
148 auto layout = anchorOp.getAnchorLayout();
149 return layout;
150 }
151
152 std::string layoutName = getTemporaryLayoutName(result);
153 if (defOp->hasAttr(layoutName)) {
154 auto layout =
155 defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
156 return layout;
157 }
158 }
159
160 if (auto arg = dyn_cast<BlockArgument>(value)) {
161 auto *parentOp = arg.getOwner()->getParentOp();
162 if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
163 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
164 if (tiedInit)
165 return getDistributeLayoutAttr(tiedInit->get());
166 }
167 }
168
169 return nullptr;
170}
171xegpu::DistributeLayoutAttr
173 Operation *op = opr.getOwner();
174 unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
175
176 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
177 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
178 if (idx == 0) {
179 return dpasOp.getLayoutAAttr();
180 } else if (idx == 1) {
181 return dpasOp.getLayoutBAttr();
182 } else if (idx == 2) {
183 return dpasOp.getLayoutCdAttr();
184 }
185 }
186 if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
187 return convertOp.getInputLayoutAttr();
188 }
189 auto layout = anchorOp.getAnchorLayout();
190
191 if (idx == 0)
192 return layout;
193
194 // For StoreNdOp and StoreMatrixOp,
195 // the layout is valid for the first two operands: value and memref/tdesc.
196 if (isa<xegpu::StoreNdOp, xegpu::StoreMatrixOp>(op) && (idx < 2))
197 return layout;
198
199 if (isa<xegpu::StoreScatterOp>(op)) {
200 xegpu::StoreScatterOp store(op);
201 int chunkSize = store.getChunkSize().value_or(1);
202 if (layout && idx >= 2 && chunkSize > 1)
203 return layout.dropDims(llvm::to_vector(
204 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
205 return layout;
206 }
207 if (isa<xegpu::LoadGatherOp>(op)) {
208 xegpu::LoadGatherOp load(op);
209 int chunkSize = load.getChunkSize().value_or(1);
210 if (layout && idx >= 1 && chunkSize > 1)
211 return layout.dropDims(llvm::to_vector(
212 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
213 return layout;
214 }
215 }
216
217 std::string layoutName = xegpu::getTemporaryLayoutName(opr);
218 if (op->hasAttr(layoutName)) {
219 auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
220 return layout;
221 }
222
223 return nullptr;
224}
225
226// Returns the permanent layout attribute for the given result if it's
227// available on the defining op. Otherwise returns the provided layout.
228xegpu::DistributeLayoutAttr
229maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
230 const OpResult &result, mlir::Operation *owner,
231 const std::string &name) {
232 xegpu::DistributeLayoutAttr candidate = layout;
233
234 if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
235 if (auto perm = loadOp.getLayoutAttr())
236 candidate = perm;
237 }
238
239 return candidate;
240}
241
242// Returns the permanent layout attribute for the given operand if it's
243// available on the defining op. Otherwise returns the provided layout.
244xegpu::DistributeLayoutAttr
245maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
246 const OpOperand &operand, mlir::Operation *owner,
247 const std::string &name) {
248 xegpu::DistributeLayoutAttr candidate = layout;
249 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
250
251 if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
252 if (idx == 0) {
253 if (auto perm = storeOp.getLayoutAttr())
254 candidate = perm;
255 }
256 }
257
258 return candidate;
259}
260
261// TODO-LayoutRefactor: Remove this function after replacing use
262// with setTemporaryLayout or setAnchorLayout
264 const mlir::OpResult &result,
265 const mlir::xegpu::DistributeLayoutAttr layout) {
266 Operation *owner = result.getOwner();
267
268 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
269 if (anchorOp.getAnchorLayout() == layout)
270 return;
271 anchorOp.setAnchorLayout(layout);
272 return;
273 }
274
275 std::string name = xegpu::getTemporaryLayoutName(result);
276 if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
277 return;
278 }
279 if (layout) {
280 owner->setAttr(name, layout);
281 }
282}
283
284// TODO-LayoutRefactor: Remove this function after replacing use
285// with setTemporaryLayout or setAnchorLayout
287 const DistributeLayoutAttr layout) {
288 Operation *owner = operand.getOwner();
289 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
290
291 if (!layout) {
292 return;
293 }
294 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
295 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
296 if (idx == 0) {
297 return dpasOp.setLayoutAAttr(layout);
298 } else if (idx == 1) {
299 return dpasOp.setLayoutBAttr(layout);
300 } else if (idx == 2) {
301 return dpasOp.setLayoutCdAttr(layout);
302 }
303 }
304 if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
305 return convertOp.setInputLayoutAttr(layout);
306 }
307
308 // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
309 // the layout is valid for the first two operands: value and memref/tdesc.
310 // For other operations, the layout applies to the first operand only.
311 if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
312 owner)) {
313 if (idx < 2) {
314 anchorOp.setAnchorLayout(layout);
315 }
316 } else {
317 if (idx == 0) {
318 anchorOp.setAnchorLayout(layout);
319 }
320 }
321 }
322
323 std::string name = xegpu::getTemporaryLayoutName(operand);
324 if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
325 return;
326 }
327 if (layout) {
328 owner->setAttr(name, layout);
329 }
330}
331
332template <typename T, typename>
333xegpu::DistributeLayoutAttr
334xegpu::getTemporaryLayout(const T &operandOrResult) {
335 Operation *op = operandOrResult.getOwner();
336
337 std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
338 if (op->hasAttr(layoutName)) {
339 auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
340 return layout;
341 }
342
343 return nullptr;
344}
345
346template xegpu::DistributeLayoutAttr
348template xegpu::DistributeLayoutAttr
350
351template <typename T, typename>
352void xegpu::setTemporaryLayout(const T &operandOrResult,
353 const xegpu::DistributeLayoutAttr layout) {
354 Operation *owner = operandOrResult.getOwner();
355 std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
356 if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
357 return;
358 }
359 if (layout) {
360 owner->setAttr(name, layout);
361 }
362}
363
365 const mlir::OpResult &result,
366 const mlir::xegpu::DistributeLayoutAttr layout);
367
369 const mlir::OpOperand &operand,
370 const mlir::xegpu::DistributeLayoutAttr layout);
371
375 auto vecTy = dyn_cast<VectorType>(value.getType());
376 if (!vecTy)
377 return {value};
378
379 ArrayRef<int64_t> srcShape = vecTy.getShape();
380 if (!computeShapeRatio(srcShape, shape))
381 return {value};
382
383 int64_t srcShapeRank = srcShape.size();
384 int64_t targetShapeRank = shape.size();
385
386 SmallVector<int64_t> adjustedTargetShape(srcShape.size());
387 int64_t rankDiff = srcShapeRank - targetShapeRank;
388 std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
389 1);
390 llvm::copy(shape, adjustedTargetShape.begin() + rankDiff);
391
393 for (SmallVector<int64_t> offsets :
394 StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
395 SmallVector<int64_t> staticStrides(offsets.size(), 1);
396 Value slice = vector::ExtractStridedSliceOp::create(
397 builder, loc, value, offsets, adjustedTargetShape, staticStrides);
398
399 // Reshape to remove leading unit dims if needed
400 if (srcShapeRank > targetShapeRank) {
401 auto targetTy = VectorType::get(shape, vecTy.getElementType());
402 slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
403 }
404 result.push_back(slice);
405 }
406
407 return result;
408}
409
411 ValueRange values,
413 VectorType inputTy = dyn_cast<VectorType>(values[0].getType());
414 assert(llvm::all_of(values.getTypes(),
415 [&](Type type) { return type == inputTy; }) &&
416 "values must be of the same VectorType");
417
418 Type elemTy = inputTy.getElementType();
419 ArrayRef<int64_t> tileShape = inputTy.getShape();
420
421 VectorType resultTy = VectorType::get(shape, elemTy);
422 auto zeroAttr = builder.getZeroAttr(elemTy);
423 Value result = arith::ConstantOp::create(
424 builder, loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr));
425
426 for (auto [src, offsets] :
427 llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
428 SmallVector<int64_t> staticStrides(tileShape.size(), 1);
429 result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
430 offsets, staticStrides);
431 }
432 return result;
433}
434
436 Operation *op, TypeConverter converter) {
437 MLIRContext *context = op->getContext();
438
439 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
440 Location loc) -> Value {
441 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
442 .getResult(0);
443 };
444
445 { // convert VectorType to RankedTensorType for SCF Structural ops
446 TypeConverter converter;
447 converter.addConversion([](Type type) -> Type { return type; });
448 converter.addConversion([](VectorType type) -> Type {
449 return RankedTensorType::get(type.getShape(), type.getElementType());
450 });
451 converter.addSourceMaterialization(materializeCast);
452 converter.addTargetMaterialization(materializeCast);
453
454 mlir::ConversionTarget target(*context);
455 target.addLegalOp<UnrealizedConversionCastOp>();
456
457 mlir::RewritePatternSet patterns(context);
459 target);
460 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
461 }
462
463 { // propagate the layout attribute to RankedTensorType by checking
464 // BuiltInUnrealizedCastOps
465 // for VectorType to RankedTensorType cast.
466 op->walk([](UnrealizedConversionCastOp castOp) {
467 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
468 return WalkResult::skip();
469
470 Value input = castOp.getInputs()[0];
471 Value result = castOp.getResults()[0];
472 auto inputTy = dyn_cast<VectorType>(input.getType());
473 auto resultTy = dyn_cast<RankedTensorType>(result.getType());
474
475 // Only look at ops casting from VectorType to RankedTensorType
476 if (!inputTy || !resultTy)
477 return WalkResult::skip();
478
479 xegpu::DistributeLayoutAttr layout =
481 if (!layout)
482 return WalkResult::skip();
483
484 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
485 result.setType(newTy);
486
487 // update the arguments if user is a LoopLike op.
488 for (OpOperand &use : result.getUses()) {
489 if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
490 BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
491 arg.setType(newTy);
492 }
493 // whileOp has two regions, the BlockArgument of the after region
494 // is not exposed by LoopLikeOpInterface
495 if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
496 unsigned idx = use.getOperandNumber();
497 BlockArgument arg = whileOp.getAfterArguments()[idx];
498 arg.setType(newTy);
499 }
500 }
501 return WalkResult::advance();
502 });
503
504 // using yieldOp as anchor to update the result type of its ParentOp
505 op->walk([](scf::YieldOp yieldOp) {
506 Operation *parentOp = yieldOp->getParentOp();
507 for (OpResult r : parentOp->getOpResults()) {
508 unsigned idx = r.getResultNumber();
509 Type resultTy = r.getType();
510 Type yieldTy = yieldOp.getResults()[idx].getType();
511 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
512 r.setType(yieldTy);
513 }
514 });
515 }
516
517 { // perform the conversion from RankedTensorType to VectorType based on the
518 // DistributeLayoutAttr
519
520 // Handle the UnrealizedConversionCastOp introduced by the first step.
521 // For vector->RankedTensorType, it will simply forward the inputs.
522 // For RankedTensorType->vector, it will update the inputs with the
523 // one from the adaptor.
524 class UnrealizedConversionCastOpPattern
525 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
526 using OpConversionPattern<
527 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
528
529 mlir::LogicalResult
530 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
531 OneToNOpAdaptor adaptor,
532 ConversionPatternRewriter &rewriter) const override {
533 auto inputs = op.getOperands();
534 auto outputs = op.getOutputs();
535
536 if (inputs.size() != 1 || outputs.size() != 1)
537 return failure();
538
539 auto inputTy = inputs[0].getType();
540 auto outputTy = outputs[0].getType();
541
542 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
543 rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
544 return success();
545 }
546
547 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
548 SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
549 auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(),
550 outputTy, values);
551 rewriter.replaceOp(op, newOp);
552 return success();
553 }
554 return failure();
555 }
556 };
557
558 converter.addSourceMaterialization(materializeCast);
559 converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
560 ValueRange inputs, Location loc) {
561 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
562 .getResults();
563 });
564
565 mlir::ConversionTarget target(*context);
566 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
567 [](UnrealizedConversionCastOp op) {
568 auto isTensorTy = [](Type type) {
569 return isa<RankedTensorType>(type);
570 };
571 return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
572 llvm::none_of(op->getResultTypes(), isTensorTy);
573 });
574 mlir::RewritePatternSet patterns(context);
575 patterns.insert<UnrealizedConversionCastOpPattern>(context);
577 target);
578 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
579 }
580}
581
582std::optional<std::string> xegpu::getChipStr(Operation *op) {
583 auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
584
585 if (!gpuModuleOp)
586 return std::nullopt;
587
588 auto targetAttrs = gpuModuleOp.getTargets();
589 if (targetAttrs) {
590 for (auto &attr : *targetAttrs) {
591 auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
592 if (xevmAttr)
593 return xevmAttr.getChip().str();
594 }
595 }
596
597 return std::nullopt;
598}
599
600/// Generates element-wise addition ops of two arrays with same length.
602 Location loc,
605 assert(lhs.size() == rhs.size() && "lhs and rhs must have the same size");
607 for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
608 auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
609 auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
610 results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
611 }
612 return results;
613}
614
615/// Generates element-wise addition ops of two arrays with automatic alignment.
616/// When the input arrays have different sizes, the shorter array is
617/// right-aligned with the longer array, and the unmatched leading elements from
618/// the longer array are preserved unchanged. This is commonly used for offset
619/// computation where higher-dimensional offsets need to be added to
620/// lower-dimensional adjustments.
621///
622/// Example:
623/// lhs = [l1, l2, l3], rhs = [r1, r2]
624/// Result: [11, l2+r1, l3+r2]
629 // ensure a is longer than b
630 ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
631 ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
632 SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
633 a = a.slice(a.size() - b.size());
634 results.append(addElementwise(builder, loc, a, b));
635 return results;
636}
637
638template <typename T>
640 ArrayRef<T> candidateMultiples) {
641 static_assert(std::is_integral<T>::value, "T must be an integer type");
642 int largest = -1;
643 SmallVector<T> multiples = {1};
644 if (!candidateMultiples.empty())
645 multiples =
646 SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
647 for (T candidate : candidates) {
648 for (T multiple : multiples) {
649 int value = static_cast<int>(candidate * multiple);
650 if (value != 0 && dim % value == 0 && value > largest)
651 largest = value;
652 }
653 }
654 return largest;
655}
656
658 vector::CombiningKind kind, uint32_t size) {
659 // First reduce on a single thread to get per lane reduction value.
660 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
661 // Parallel reduction using butterfly shuffles.
662 for (uint64_t i = 1; i < size; i <<= 1) {
663 Value shuffled =
664 gpu::ShuffleOp::create(builder, loc, laneVal, i, /** width = **/ size,
665 /** mode = **/ gpu::ShuffleMode::XOR)
666 .getShuffleResult();
667 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
668 }
669 return laneVal;
670}
671
674 vector::CombiningKind kind,
675 int64_t reductionDim, Location loc,
676 PatternRewriter &rewriter) {
677 VectorType sourceType = src.getType();
678 int64_t sourceRank = sourceType.getRank();
679 // Expecting at least a 2D source vector. Leading dimensions (all except the
680 // last two) must be unit.
681 assert(sourceRank >= 2 && "expected at least a 2D source vector");
682 for (int64_t i = 0; i < sourceRank - 2; ++i)
683 assert(sourceType.getShape()[i] == 1 &&
684 "expected leading dimensions to be unit");
685 int64_t rowIdx = sourceRank - 2;
686 int64_t columnIdx = sourceRank - 1;
687 int64_t sourceH = sourceType.getShape()[rowIdx];
688 int64_t sourceW = sourceType.getShape()[columnIdx];
689 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
690 // Create a constant vector to hold the result of the reduction.
691 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
692 Value reductionResult = arith::ConstantOp::create(
693 rewriter, loc, acc.getType(),
694 DenseElementsAttr::get(acc.getType(), zeroAttr));
695 // TODO: Remove these get/setTemporaryLayout calls after we deprecate the old
696 // XeGPUSubgroupDistribute pass.
697 auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
698 auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
699 // Reduction result should have the same layout as the accumulator.
700 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
701 // For each slice of the source, extract the slice vector, do a reduction
702 // and, insert the reduced value back to the result vector.
703 int64_t accRank = acc.getType().getRank();
704 for (int i = 0; i < nSlices; ++i) {
705 // Build nD offsets, sizes, and strides. Leading unit dims get
706 // offset=0, size=1. The last two dims are set based on reductionDim.
707 SmallVector<int64_t> sliceOffsets(sourceRank, 0);
708 SmallVector<int64_t> sliceSizes(sourceRank, 1);
709 SmallVector<int64_t> strides(sourceRank, 1);
710 if (reductionDim == columnIdx) {
711 sliceOffsets[rowIdx] = i;
712 sliceSizes[columnIdx] = sourceW;
713 } else {
714 sliceOffsets[columnIdx] = i;
715 sliceSizes[rowIdx] = sourceH;
716 }
717
718 vector::ExtractStridedSliceOp extractOp =
719 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
720 sliceSizes, strides);
721 // Extract strided slice has the same layout as src.
722 xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
723
724 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
725
726 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
727 rewriter, loc,
728 VectorType::get({nSliceElements}, sourceType.getElementType()),
729 extractOp.getResult());
730
731 // Shape cast output has the same layout as the accumulator. Shape cast
732 // source has the same layout as the original reduction source.
733 xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
734 xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
735 // Extract and reduction results in scalars, so no result layout is needed.
736 // Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
737 // the reduction dim removed). Leading unit dims get index 0.
738 SmallVector<int64_t> accIdx(accRank, 0);
739 accIdx[accRank - 1] = i;
740 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
741 Value reduction = vector::ReductionOp::create(
742 rewriter, loc, kind, slice.getResult(), accExtract);
743 reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
744 reductionResult, accIdx);
745 // Insert op should have the same layout as the accumulator.
746 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
747 }
748 return reductionResult;
749}
750
753 vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize,
754 Location loc, PatternRewriter &rewriter) {
755 VectorType sourceType = src.getType();
756 int64_t sourceRank = sourceType.getRank();
757 // Expecting at least a 2D source vector. Leading dimensions (all except the
758 // last two) must be unit.
759 assert(sourceRank >= 2 && "expected at least a 2D source vector");
760 for (int64_t i = 0; i < sourceRank - 2; ++i)
761 assert(sourceType.getShape()[i] == 1 &&
762 "expected leading dimensions to be unit");
763 int64_t rowIdx = sourceRank - 2;
764 int64_t columnIdx = sourceRank - 1;
765 int64_t sourceH = sourceType.getShape()[rowIdx];
766 int64_t sourceW = sourceType.getShape()[columnIdx];
767
768 // Create a constant vector to hold the result of the reduction.
769 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
770 Value reductionResult = arith::ConstantOp::create(
771 rewriter, loc, acc.getType(),
772 DenseElementsAttr::get(acc.getType(), zeroAttr));
773
774 // nSlices is the number of reduction operations needed to reduce the entire
775 // source vector. For example, if reductionDim is the row dim, we are
776 // reducing across rows, and each slice is a column. So the number of slices
777 // is the number of columns, which is sourceW.
778 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
779
780 // For each slice of the source, extract the slice vector, do a reduction
781 // and, insert the reduced value back to the result vector.
782 int64_t accRank = acc.getType().getRank();
783 for (int i = 0; i < nSlices; ++i) {
784 // Build nD offsets, sizes, and strides. Leading unit dims get
785 // offset=0, size=1. The last two dims are set based on reductionDim.
786 SmallVector<int64_t> sliceOffsets(sourceRank, 0);
787 SmallVector<int64_t> sliceSizes(sourceRank, 1);
788 SmallVector<int64_t> strides(sourceRank, 1);
789 if (reductionDim == columnIdx) {
790 sliceOffsets[rowIdx] = i;
791 sliceSizes[columnIdx] = sourceW;
792 } else {
793 sliceOffsets[columnIdx] = i;
794 sliceSizes[rowIdx] = sourceH;
795 }
796
797 vector::ExtractStridedSliceOp extractOp =
798 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
799 sliceSizes, strides);
800 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
801 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
802 rewriter, loc,
803 VectorType::get({nSliceElements}, sourceType.getElementType()),
804 extractOp.getResult());
805
806 SmallVector<int64_t> accIdx(accRank, 0);
807 accIdx[accRank - 1] = i;
808 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
809 Value fullReduce =
810 xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize);
811 fullReduce =
812 vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract);
813 reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
814 reductionResult, accIdx);
815 }
816 return reductionResult;
817}
818
820 Type type,
821 vector::CombiningKind kind) {
822 auto vecTy = dyn_cast<VectorType>(type);
823 Type elemTy = vecTy ? vecTy.getElementType() : type;
824
825 // Helper to create either a splat vector or scalar constant from an attr.
826 auto makeConst = [&](Attribute scalarAttr) -> Value {
827 if (vecTy)
828 return arith::ConstantOp::create(
829 builder, loc, vecTy, DenseElementsAttr::get(vecTy, scalarAttr));
830 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(scalarAttr));
831 };
832
833 switch (kind) {
834 case vector::CombiningKind::ADD:
835 case vector::CombiningKind::XOR:
836 case vector::CombiningKind::OR:
837 case vector::CombiningKind::MAXUI:
838 return makeConst(builder.getZeroAttr(elemTy));
839
840 case vector::CombiningKind::MUL:
841 case vector::CombiningKind::AND:
842 return makeConst(builder.getOneAttr(elemTy));
843
844 case vector::CombiningKind::MINSI:
845 if (auto intTy = dyn_cast<IntegerType>(elemTy))
846 return makeConst(builder.getIntegerAttr(
847 elemTy, APInt::getSignedMaxValue(intTy.getWidth())));
848 return nullptr;
849
850 case vector::CombiningKind::MINUI:
851 if (auto intTy = dyn_cast<IntegerType>(elemTy))
852 return makeConst(
853 builder.getIntegerAttr(elemTy, APInt::getMaxValue(intTy.getWidth())));
854 return nullptr;
855
856 case vector::CombiningKind::MAXSI:
857 if (auto intTy = dyn_cast<IntegerType>(elemTy))
858 return makeConst(builder.getIntegerAttr(
859 elemTy, APInt::getSignedMinValue(intTy.getWidth())));
860 return nullptr;
861
862 case vector::CombiningKind::MINNUMF:
863 case vector::CombiningKind::MINIMUMF:
864 if (auto floatTy = dyn_cast<FloatType>(elemTy))
865 return makeConst(builder.getFloatAttr(
866 elemTy, APFloat::getInf(floatTy.getFloatSemantics())));
867 return nullptr;
868
869 case vector::CombiningKind::MAXNUMF:
870 case vector::CombiningKind::MAXIMUMF:
871 if (auto floatTy = dyn_cast<FloatType>(elemTy))
872 return makeConst(builder.getFloatAttr(
873 elemTy, APFloat::getInf(floatTy.getFloatSemantics(), true)));
874 return nullptr;
875 }
876 return nullptr;
877}
878
879/// Explicit instantiations
880template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
881 ArrayRef<int> candidateMultiples);
882template int
884 ArrayRef<unsigned> candidateMultiples);
885
886bool xegpu::requirePacked(const xegpu::DistributeLayoutAttr layout) {
887 if (!layout)
888 return false;
889 auto laneData = layout.getEffectiveLaneDataAsInt();
890 if (laneData.size() != 2)
891 return false;
892 return laneData[0] != 1;
893}
894
895bool xegpu::requireTranspose(const xegpu::DistributeLayoutAttr layout,
896 const xegpu::uArch::uArch *uArch) {
897 // Return false for unsupported targets.
898 // TODO: Add more support or move to target info.
899 if (uArch->getName().equals_insensitive("pvc") &&
900 uArch->getName().equals_insensitive("bmg"))
901 return false;
902 if (!layout)
903 return false;
904 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
905 if (laneLayout.size() != 2)
906 return false;
907 return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
908}
909
910// Check if dst shape is an expansion of src shape by inserting unit dimensions.
911// Returns true if all dimensions in src match corresponding dimensions in dst
912// (after skipping unit dimensions), and populates expandedUnitDims with the
913// indices of the unit dimensions in dst that were added (not present in src).
914// Example: src=[2,3], dst=[1,2,3,1] -> true, expandedUnitDims=[0,3]
916 SmallVector<int64_t> &expandedUnitDims) {
917 // All unit dimensions in dst that don't appear in src are the expanded
918 // unit dimensions
919 size_t srcIdx = 0;
920 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
921 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
922 srcIdx++;
923 else if (dst[dstIdx] == 1)
924 expandedUnitDims.push_back(dstIdx);
925 else
926 return false;
927 return srcIdx == src.size();
928}
929
930// Checks if dst shape is an expansion of src shape where each dimension in src
931// is split into one or more consecutive dimensions in dst whose product equals
932// the original dimension. Populates splitDimGroups with groups of dst indices
933// that correspond to each src dimension. Example: src=[6,4], dst=[2,3,2,2] ->
934// true
937 SmallVector<SmallVector<int64_t>> &splitDimGroups) {
938 // each dim in src can be mapped to one or more dims in dst whose product
939 // equals to the src dim
940 size_t srcIdx = 0;
941 int64_t accumulatedSize = 1;
942 SmallVector<int64_t> currentDstDims;
943
944 splitDimGroups.clear();
945 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
946 if (srcIdx >= src.size())
947 return false;
948 accumulatedSize *= dst[dstIdx];
949 currentDstDims.push_back(dstIdx);
950
951 if (accumulatedSize == src[srcIdx]) {
952 // Record the mapping: srcIdx -> currentDstDims
953 splitDimGroups.push_back(currentDstDims);
954 // move to next src dim
955 srcIdx++;
956 accumulatedSize = 1;
957 currentDstDims.clear();
958 } else if (accumulatedSize > src[srcIdx]) {
959 return false;
960 }
961 }
962 return srcIdx == src.size();
963}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
xegpu::DistributeLayoutAttr maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, const OpResult &result, mlir::Operation *owner, const std::string &name)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:346
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
bool hasAttrOfType(NameT &&name)
Definition Operation.h:601
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:586
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getOpResults()
Definition Operation.h:446
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual int getSubgroupSize() const =0
StringRef getName() const
Definition uArchBase.h:158