MLIR 23.0.0git
XeGPUUtils.cpp
Go to the documentation of this file.
1//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utility methods for working with the XeGPU dialect.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/IR/ValueRange.h"
25#include "llvm/Support/Casting.h"
26#include "llvm/Support/FormatVariadic.h"
27#include <cstdint>
28#include <numeric>
29
30using namespace mlir;
31
32/// convert ArrayRef<ValueRange> into SmallVector<Value>
35 for (const auto &vals : values)
36 llvm::append_range(result, vals);
37 return result;
38}
39
40FailureOr<VectorType>
41mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
42 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
43 // It only works for subgroup level layout, which only has lane_layout
44 // and lane_data, and is to distribute a SIMD code into SIMT code.
45 if (!layout || !layout.isForSubgroup())
46 return failure();
47
48 SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
49 SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
50 auto tdescShape = tdescTy.getShape();
51 auto elementType = tdescTy.getElementType();
52
53 // compute sgSize by multiply elements of laneLayout
54 // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
55 // e.g. for 1D layout, sgSize = laneLayout[0]
56 int64_t sgSize = llvm::product_of(laneLayout);
57
58 // Check if the tensor descriptor shape is distributable.
59 int64_t tensorSize = 1;
60 for (auto [tdescDim, laneDim, laneDataDim] :
61 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
62 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
63 "tensor descriptor shape is not distributable");
64 tensorSize *= tdescDim;
65 }
66 // tensorSize must be adjusted for array_length.
67 tensorSize *= tdescTy.getArrayLength();
68
69 return VectorType::get({tensorSize / sgSize}, elementType);
70}
71
72FailureOr<VectorType>
73mlir::xegpu::getDistributedVectorType(VectorType originalType,
74 xegpu::LayoutAttr layout) {
75 int64_t rank = originalType.getRank();
76 // Distributed vector type is only supported for 1D, 2D and 3D vectors.
77 if (rank < 1 || rank > 3)
78 return failure();
79 ArrayRef<int64_t> shape = originalType.getShape();
80 // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
81 // of the 3D vector.
82 int arrayLength = 1;
83 if (rank == 3) {
84 arrayLength = shape[0];
85 shape = shape.drop_front();
86 }
87 auto helperTdescTy = xegpu::TensorDescType::get(
88 shape, originalType.getElementType(), arrayLength,
89 /*boundary_check=*/true,
90 /*memory_space=*/xegpu::MemorySpace::Global, layout);
91 return xegpu::getDistributedVectorType(helperTdescTy);
92}
93
94FailureOr<VectorType>
95xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
96 VectorType originalType) {
97 if (!layout)
98 return failure();
99 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
100 "Expecting a valid layout.");
101
102 int64_t vectorRank = originalType.getRank();
103 int64_t layoutRank = layout.getRank();
104 assert(vectorRank >= layoutRank && "Vector rank must be >= layout rank.");
105
106 // When the vector has more dimensions than the layout, only the trailing
107 // dimensions are distributed. Leading dimensions are preserved as-is.
108 int64_t offset = vectorRank - layoutRank;
109 ArrayRef<int64_t> fullShape = originalType.getShape();
110 SmallVector<int64_t> trailingShape(fullShape.begin() + offset,
111 fullShape.end());
112 auto distributedShapeOrFailure =
113 layout.computeDistributedShape(trailingShape);
114 if (failed(distributedShapeOrFailure))
115 return failure();
116
117 SmallVector<int64_t> resultShape(fullShape.begin(),
118 fullShape.begin() + offset);
119 resultShape.append(distributedShapeOrFailure->begin(),
120 distributedShapeOrFailure->end());
121 return VectorType::get(resultShape, originalType.getElementType());
122}
123
124std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
125 const StringRef prefix("layout_operand_");
126 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
127 return llvm::formatv("{0}{1}", prefix, idx).str();
128}
129
131 const StringRef prefix = "layout_result_";
132 return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
133}
134
135xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
136 if (!value)
137 return nullptr;
138
139 if (auto result = dyn_cast<OpResult>(value)) {
140 Operation *defOp = result.getDefiningOp();
141 assert(defOp && "result must have a defining op");
142
143 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
144 auto layout = anchorOp.getAnchorLayout();
145 return layout;
146 }
147
148 std::string layoutName = getTemporaryLayoutName(result);
149 if (defOp->hasAttr(layoutName)) {
150 auto layout =
151 defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
152 return layout;
153 }
154 }
155
156 if (auto arg = dyn_cast<BlockArgument>(value)) {
157 auto *parentOp = arg.getOwner()->getParentOp();
158 if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
159 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
160 if (tiedInit)
161 return getTemporaryLayout(*tiedInit);
162 }
163 }
164
165 if (auto tdescTy =
166 dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
167 return tdescTy.getLayoutAttr();
168
169 return nullptr;
170}
171xegpu::DistributeLayoutAttr
173 Operation *op = opr.getOwner();
174 unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
175
176 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
177 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
178 if (idx == 0) {
179 return dpasOp.getLayoutAAttr();
180 } else if (idx == 1) {
181 return dpasOp.getLayoutBAttr();
182 } else if (idx == 2) {
183 return dpasOp.getLayoutCdAttr();
184 }
185 }
186 if (auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
187 // DpasMxOp has operands: a, b, optional acc, optional scale_a, optional
188 // scale_b
189 unsigned currentIdx = 0;
190
191 if (idx == currentIdx++)
192 return dpasMxOp.getLayoutAAttr();
193
194 if (idx == currentIdx++)
195 return dpasMxOp.getLayoutBAttr();
196
197 if (dpasMxOp.getAcc())
198 if (idx == currentIdx++)
199 return dpasMxOp.getLayoutCdAttr();
200
201 if (dpasMxOp.getScaleA())
202 if (idx == currentIdx++)
203 return dpasMxOp.getLayoutAScaleAttr();
204
205 if (dpasMxOp.getScaleB())
206 if (idx == currentIdx++)
207 return dpasMxOp.getLayoutBScaleAttr();
208
209 return nullptr;
210 }
211 if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
212 return convertOp.getInputLayoutAttr();
213 }
214 auto layout = anchorOp.getAnchorLayout();
215
216 if (idx == 0)
217 return layout;
218
219 // For StoreNdOp and StoreMatrixOp,
220 // the layout is valid for the first two operands: value and memref/tdesc.
221 if (isa<xegpu::StoreNdOp, xegpu::StoreMatrixOp>(op) && (idx < 2))
222 return layout;
223
224 if (isa<xegpu::StoreScatterOp>(op)) {
225 xegpu::StoreScatterOp store(op);
226 int chunkSize = store.getChunkSize().value_or(1);
227 if (layout && idx >= 2 && chunkSize > 1)
228 return layout.dropDims(llvm::to_vector(
229 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
230 return layout;
231 }
232 if (isa<xegpu::LoadGatherOp>(op)) {
233 xegpu::LoadGatherOp load(op);
234 int chunkSize = load.getChunkSize().value_or(1);
235 if (layout && idx >= 1 && chunkSize > 1)
236 return layout.dropDims(llvm::to_vector(
237 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
238 return layout;
239 }
240 }
241
242 std::string layoutName = xegpu::getTemporaryLayoutName(opr);
243 if (op->hasAttr(layoutName)) {
244 auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
245 return layout;
246 }
247
248 return nullptr;
249}
250
251// Returns the permanent layout attribute for the given result if it's
252// available on the defining op. Otherwise returns the provided layout.
253xegpu::DistributeLayoutAttr
254maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
255 const OpResult &result, mlir::Operation *owner,
256 const std::string &name) {
257 xegpu::DistributeLayoutAttr candidate = layout;
258
259 if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
260 if (auto perm = loadOp.getLayoutAttr())
261 candidate = perm;
262 }
263
264 return candidate;
265}
266
267// Returns the permanent layout attribute for the given operand if it's
268// available on the defining op. Otherwise returns the provided layout.
269xegpu::DistributeLayoutAttr
270maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
271 const OpOperand &operand, mlir::Operation *owner,
272 const std::string &name) {
273 xegpu::DistributeLayoutAttr candidate = layout;
274 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
275
276 if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
277 if (idx == 0) {
278 if (auto perm = storeOp.getLayoutAttr())
279 candidate = perm;
280 }
281 }
282
283 return candidate;
284}
285
286// TODO-LayoutRefactor: Remove this function after replacing use
287// with setTemporaryLayout or setAnchorLayout
289 const mlir::OpResult &result,
290 const mlir::xegpu::DistributeLayoutAttr layout) {
291 Operation *owner = result.getOwner();
292
293 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
294 if (anchorOp.getAnchorLayout() == layout)
295 return;
296 anchorOp.setAnchorLayout(layout);
297 return;
298 }
299
300 std::string name = xegpu::getTemporaryLayoutName(result);
301 if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
302 return;
303 }
304 if (layout) {
305 owner->setAttr(name, layout);
306 }
307}
308
309// TODO-LayoutRefactor: Remove this function after replacing use
310// with setTemporaryLayout or setAnchorLayout
312 const DistributeLayoutAttr layout) {
313 Operation *owner = operand.getOwner();
314 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
315
316 if (!layout) {
317 return;
318 }
319 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
320 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
321 if (idx == 0) {
322 return dpasOp.setLayoutAAttr(layout);
323 } else if (idx == 1) {
324 return dpasOp.setLayoutBAttr(layout);
325 } else if (idx == 2) {
326 return dpasOp.setLayoutCdAttr(layout);
327 }
328 }
329 if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
330 return convertOp.setInputLayoutAttr(layout);
331 }
332
333 // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
334 // the layout is valid for the first two operands: value and memref/tdesc.
335 // For other operations, the layout applies to the first operand only.
336 if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
337 owner)) {
338 if (idx < 2) {
339 anchorOp.setAnchorLayout(layout);
340 }
341 } else {
342 if (idx == 0) {
343 anchorOp.setAnchorLayout(layout);
344 }
345 }
346 }
347
348 std::string name = xegpu::getTemporaryLayoutName(operand);
349 if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
350 return;
351 }
352 if (layout) {
353 owner->setAttr(name, layout);
354 }
355}
356
357template <typename T, typename>
358xegpu::DistributeLayoutAttr
359xegpu::getTemporaryLayout(const T &operandOrResult) {
360 Operation *op = operandOrResult.getOwner();
361
362 std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
363 if (op->hasAttr(layoutName)) {
364 auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
365 return layout;
366 }
367
368 return nullptr;
369}
370
371template xegpu::DistributeLayoutAttr
373template xegpu::DistributeLayoutAttr
375
376template <typename T, typename>
377void xegpu::setTemporaryLayout(const T &operandOrResult,
378 const xegpu::DistributeLayoutAttr layout) {
379 Operation *owner = operandOrResult.getOwner();
380 std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
381 if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
382 return;
383 }
384 if (layout) {
385 owner->setAttr(name, layout);
386 }
387}
388
390 const mlir::OpResult &result,
391 const mlir::xegpu::DistributeLayoutAttr layout);
392
394 const mlir::OpOperand &operand,
395 const mlir::xegpu::DistributeLayoutAttr layout);
396
400 auto vecTy = dyn_cast<VectorType>(value.getType());
401 if (!vecTy)
402 return {value};
403
404 ArrayRef<int64_t> srcShape = vecTy.getShape();
405 if (!computeShapeRatio(srcShape, shape))
406 return {value};
407
408 int64_t srcShapeRank = srcShape.size();
409 int64_t targetShapeRank = shape.size();
410
411 SmallVector<int64_t> adjustedTargetShape(srcShape.size());
412 int64_t rankDiff = srcShapeRank - targetShapeRank;
413 std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
414 1);
415 llvm::copy(shape, adjustedTargetShape.begin() + rankDiff);
416
418 for (SmallVector<int64_t> offsets :
419 StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
420 SmallVector<int64_t> staticStrides(offsets.size(), 1);
421 Value slice = vector::ExtractStridedSliceOp::create(
422 builder, loc, value, offsets, adjustedTargetShape, staticStrides);
423
424 // Reshape to remove leading unit dims if needed
425 if (srcShapeRank > targetShapeRank) {
426 auto targetTy = VectorType::get(shape, vecTy.getElementType());
427 slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
428 }
429 result.push_back(slice);
430 }
431
432 return result;
433}
434
436 ValueRange values,
438 VectorType inputTy = dyn_cast<VectorType>(values[0].getType());
439 assert(llvm::all_of(values.getTypes(),
440 [&](Type type) { return type == inputTy; }) &&
441 "values must be of the same VectorType");
442
443 Type elemTy = inputTy.getElementType();
444 ArrayRef<int64_t> tileShape = inputTy.getShape();
445
446 VectorType resultTy = VectorType::get(shape, elemTy);
447 auto zeroAttr = builder.getZeroAttr(elemTy);
448 Value result = arith::ConstantOp::create(
449 builder, loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr));
450
451 for (auto [src, offsets] :
452 llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
453 SmallVector<int64_t> staticStrides(tileShape.size(), 1);
454 result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
455 offsets, staticStrides);
456 }
457 return result;
458}
459
461 Operation *op, TypeConverter converter) {
462 MLIRContext *context = op->getContext();
463
464 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
465 Location loc) -> Value {
466 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
467 .getResult(0);
468 };
469
470 { // convert VectorType to RankedTensorType for SCF Structural ops
471 TypeConverter converter;
472 converter.addConversion([](Type type) -> Type { return type; });
473 converter.addConversion([](VectorType type) -> Type {
474 return RankedTensorType::get(type.getShape(), type.getElementType());
475 });
476 converter.addSourceMaterialization(materializeCast);
477 converter.addTargetMaterialization(materializeCast);
478
479 mlir::ConversionTarget target(*context);
480 target.addLegalOp<UnrealizedConversionCastOp>();
481
482 mlir::RewritePatternSet patterns(context);
484 target);
485 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
486 }
487
488 { // propagate the layout attribute to RankedTensorType by checking
489 // BuiltInUnrealizedCastOps
490 // for VectorType to RankedTensorType cast.
491 op->walk([](UnrealizedConversionCastOp castOp) {
492 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
493 return WalkResult::skip();
494
495 Value input = castOp.getInputs()[0];
496 Value result = castOp.getResults()[0];
497 auto inputTy = dyn_cast<VectorType>(input.getType());
498 auto resultTy = dyn_cast<RankedTensorType>(result.getType());
499
500 // Only look at ops casting from VectorType to RankedTensorType
501 if (!inputTy || !resultTy)
502 return WalkResult::skip();
503
504 xegpu::DistributeLayoutAttr layout =
506 if (!layout)
507 return WalkResult::skip();
508
509 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
510 result.setType(newTy);
511
512 // update the arguments if user is a LoopLike op.
513 for (OpOperand &use : result.getUses()) {
514 if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
515 BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
516 arg.setType(newTy);
517 }
518 // whileOp has two regions, the BlockArgument of the after region
519 // is not exposed by LoopLikeOpInterface
520 if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
521 unsigned idx = use.getOperandNumber();
522 BlockArgument arg = whileOp.getAfterArguments()[idx];
523 arg.setType(newTy);
524 }
525 }
526 return WalkResult::advance();
527 });
528
529 // using yieldOp as anchor to update the result type of its ParentOp
530 op->walk([](scf::YieldOp yieldOp) {
531 Operation *parentOp = yieldOp->getParentOp();
532 for (OpResult r : parentOp->getOpResults()) {
533 unsigned idx = r.getResultNumber();
534 Type resultTy = r.getType();
535 Type yieldTy = yieldOp.getResults()[idx].getType();
536 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
537 r.setType(yieldTy);
538 }
539 });
540 }
541
542 { // perform the conversion from RankedTensorType to VectorType based on the
543 // DistributeLayoutAttr
544
545 // Handle the UnrealizedConversionCastOp introduced by the first step.
546 // For vector->RankedTensorType, it will simply forward the inputs.
547 // For RankedTensorType->vector, it will update the inputs with the
548 // one from the adaptor.
549 class UnrealizedConversionCastOpPattern
550 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
551 using OpConversionPattern<
552 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
553
554 mlir::LogicalResult
555 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
556 OneToNOpAdaptor adaptor,
557 ConversionPatternRewriter &rewriter) const override {
558 auto inputs = op.getOperands();
559 auto outputs = op.getOutputs();
560
561 if (inputs.size() != 1 || outputs.size() != 1)
562 return failure();
563
564 auto inputTy = inputs[0].getType();
565 auto outputTy = outputs[0].getType();
566
567 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
568 rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
569 return success();
570 }
571
572 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
573 SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
574 auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(),
575 outputTy, values);
576 rewriter.replaceOp(op, newOp);
577 return success();
578 }
579 return failure();
580 }
581 };
582
583 converter.addSourceMaterialization(materializeCast);
584 converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
585 ValueRange inputs, Location loc) {
586 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
587 .getResults();
588 });
589
590 mlir::ConversionTarget target(*context);
591 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
592 [](UnrealizedConversionCastOp op) {
593 auto isTensorTy = [](Type type) {
594 return isa<RankedTensorType>(type);
595 };
596 return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
597 llvm::none_of(op->getResultTypes(), isTensorTy);
598 });
599 mlir::RewritePatternSet patterns(context);
600 patterns.insert<UnrealizedConversionCastOpPattern>(context);
602 target);
603 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
604 }
605}
606
607std::optional<std::string> xegpu::getChipStr(Operation *op) {
608 auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
609
610 if (!gpuModuleOp)
611 return std::nullopt;
612
613 auto targetAttrs = gpuModuleOp.getTargets();
614 if (targetAttrs) {
615 for (auto &attr : *targetAttrs) {
616 auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
617 if (xevmAttr)
618 return xevmAttr.getChip().str();
619 }
620 }
621
622 return std::nullopt;
623}
624
625/// Generates element-wise addition ops of two arrays with same length.
627 Location loc,
630 assert(lhs.size() == rhs.size() && "lhs and rhs must have the same size");
632 for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
633 auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
634 auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
635 results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
636 }
637 return results;
638}
639
640/// Generates element-wise addition ops of two arrays with automatic alignment.
641/// When the input arrays have different sizes, the shorter array is
642/// right-aligned with the longer array, and the unmatched leading elements from
643/// the longer array are preserved unchanged. This is commonly used for offset
644/// computation where higher-dimensional offsets need to be added to
645/// lower-dimensional adjustments.
646///
647/// Example:
648/// lhs = [l1, l2, l3], rhs = [r1, r2]
649/// Result: [11, l2+r1, l3+r2]
654 // ensure a is longer than b
655 ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
656 ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
657 SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
658 a = a.slice(a.size() - b.size());
659 results.append(addElementwise(builder, loc, a, b));
660 return results;
661}
662
663template <typename T>
665 ArrayRef<T> candidateMultiples) {
666 static_assert(std::is_integral<T>::value, "T must be an integer type");
667 int largest = -1;
668 SmallVector<T> multiples = {1};
669 if (!candidateMultiples.empty())
670 multiples =
671 SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
672 for (T candidate : candidates) {
673 for (T multiple : multiples) {
674 int value = static_cast<int>(candidate * multiple);
675 if (value != 0 && dim % value == 0 && value > largest)
676 largest = value;
677 }
678 }
679 return largest;
680}
681
683 vector::CombiningKind kind, uint32_t size) {
684 // First reduce on a single thread to get per lane reduction value.
685 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
686 // Parallel reduction using butterfly shuffles.
687 for (uint64_t i = 1; i < size; i <<= 1) {
688 Value shuffled =
689 gpu::ShuffleOp::create(builder, loc, laneVal, i, /** width = **/ size,
690 /** mode = **/ gpu::ShuffleMode::XOR)
691 .getShuffleResult();
692 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
693 }
694 return laneVal;
695}
696
699 vector::CombiningKind kind,
700 int64_t reductionDim, Location loc,
701 PatternRewriter &rewriter) {
702 VectorType sourceType = src.getType();
703 int64_t sourceRank = sourceType.getRank();
704 // Expecting at least a 2D source vector. Leading dimensions (all except the
705 // last two) must be unit.
706 assert(sourceRank >= 2 && "expected at least a 2D source vector");
707 for (int64_t i = 0; i < sourceRank - 2; ++i)
708 assert(sourceType.getShape()[i] == 1 &&
709 "expected leading dimensions to be unit");
710 int64_t rowIdx = sourceRank - 2;
711 int64_t columnIdx = sourceRank - 1;
712 int64_t sourceH = sourceType.getShape()[rowIdx];
713 int64_t sourceW = sourceType.getShape()[columnIdx];
714 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
715 // Create a constant vector to hold the result of the reduction.
716 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
717 Value reductionResult = arith::ConstantOp::create(
718 rewriter, loc, acc.getType(),
719 DenseElementsAttr::get(acc.getType(), zeroAttr));
720 // TODO: Remove these get/setTemporaryLayout calls after we deprecate the old
721 // XeGPUSubgroupDistribute pass.
722 auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
723 auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
724 // Reduction result should have the same layout as the accumulator.
725 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
726 // For each slice of the source, extract the slice vector, do a reduction
727 // and, insert the reduced value back to the result vector.
728 int64_t accRank = acc.getType().getRank();
729 for (int i = 0; i < nSlices; ++i) {
730 // Build nD offsets, sizes, and strides. Leading unit dims get
731 // offset=0, size=1. The last two dims are set based on reductionDim.
732 SmallVector<int64_t> sliceOffsets(sourceRank, 0);
733 SmallVector<int64_t> sliceSizes(sourceRank, 1);
734 SmallVector<int64_t> strides(sourceRank, 1);
735 if (reductionDim == columnIdx) {
736 sliceOffsets[rowIdx] = i;
737 sliceSizes[columnIdx] = sourceW;
738 } else {
739 sliceOffsets[columnIdx] = i;
740 sliceSizes[rowIdx] = sourceH;
741 }
742
743 vector::ExtractStridedSliceOp extractOp =
744 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
745 sliceSizes, strides);
746 // Extract strided slice has the same layout as src.
747 xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
748
749 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
750
751 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
752 rewriter, loc,
753 VectorType::get({nSliceElements}, sourceType.getElementType()),
754 extractOp.getResult());
755
756 // Shape cast output has the same layout as the accumulator. Shape cast
757 // source has the same layout as the original reduction source.
758 xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
759 xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
760 // Extract and reduction results in scalars, so no result layout is needed.
761 // Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
762 // the reduction dim removed). Leading unit dims get index 0.
763 SmallVector<int64_t> accIdx(accRank, 0);
764 accIdx[accRank - 1] = i;
765 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
766 Value reduction = vector::ReductionOp::create(
767 rewriter, loc, kind, slice.getResult(), accExtract);
768 reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
769 reductionResult, accIdx);
770 // Insert op should have the same layout as the accumulator.
771 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
772 }
773 return reductionResult;
774}
775
778 vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize,
779 Location loc, PatternRewriter &rewriter) {
780 VectorType sourceType = src.getType();
781 int64_t sourceRank = sourceType.getRank();
782 // Expecting at least a 2D source vector. Leading dimensions (all except the
783 // last two) must be unit.
784 assert(sourceRank >= 2 && "expected at least a 2D source vector");
785 for (int64_t i = 0; i < sourceRank - 2; ++i)
786 assert(sourceType.getShape()[i] == 1 &&
787 "expected leading dimensions to be unit");
788 int64_t rowIdx = sourceRank - 2;
789 int64_t columnIdx = sourceRank - 1;
790 int64_t sourceH = sourceType.getShape()[rowIdx];
791 int64_t sourceW = sourceType.getShape()[columnIdx];
792
793 // Create a constant vector to hold the result of the reduction.
794 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
795 Value reductionResult = arith::ConstantOp::create(
796 rewriter, loc, acc.getType(),
797 DenseElementsAttr::get(acc.getType(), zeroAttr));
798
799 // nSlices is the number of reduction operations needed to reduce the entire
800 // source vector. For example, if reductionDim is the row dim, we are
801 // reducing across rows, and each slice is a column. So the number of slices
802 // is the number of columns, which is sourceW.
803 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
804
805 // For each slice of the source, extract the slice vector, do a reduction
806 // and, insert the reduced value back to the result vector.
807 int64_t accRank = acc.getType().getRank();
808 for (int i = 0; i < nSlices; ++i) {
809 // Build nD offsets, sizes, and strides. Leading unit dims get
810 // offset=0, size=1. The last two dims are set based on reductionDim.
811 SmallVector<int64_t> sliceOffsets(sourceRank, 0);
812 SmallVector<int64_t> sliceSizes(sourceRank, 1);
813 SmallVector<int64_t> strides(sourceRank, 1);
814 if (reductionDim == columnIdx) {
815 sliceOffsets[rowIdx] = i;
816 sliceSizes[columnIdx] = sourceW;
817 } else {
818 sliceOffsets[columnIdx] = i;
819 sliceSizes[rowIdx] = sourceH;
820 }
821
822 vector::ExtractStridedSliceOp extractOp =
823 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
824 sliceSizes, strides);
825 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
826 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
827 rewriter, loc,
828 VectorType::get({nSliceElements}, sourceType.getElementType()),
829 extractOp.getResult());
830
831 SmallVector<int64_t> accIdx(accRank, 0);
832 accIdx[accRank - 1] = i;
833 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
834 Value fullReduce =
835 xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize);
836 fullReduce =
837 vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract);
838 reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
839 reductionResult, accIdx);
840 }
841 return reductionResult;
842}
843
845 Type type,
846 vector::CombiningKind kind) {
847 auto vecTy = dyn_cast<VectorType>(type);
848 Type elemTy = vecTy ? vecTy.getElementType() : type;
849
850 // Helper to create either a splat vector or scalar constant from an attr.
851 auto makeConst = [&](Attribute scalarAttr) -> Value {
852 if (vecTy)
853 return arith::ConstantOp::create(
854 builder, loc, vecTy, DenseElementsAttr::get(vecTy, scalarAttr));
855 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(scalarAttr));
856 };
857
858 switch (kind) {
859 case vector::CombiningKind::ADD:
860 case vector::CombiningKind::XOR:
861 case vector::CombiningKind::OR:
862 case vector::CombiningKind::MAXUI:
863 return makeConst(builder.getZeroAttr(elemTy));
864
865 case vector::CombiningKind::MUL:
866 case vector::CombiningKind::AND:
867 return makeConst(builder.getOneAttr(elemTy));
868
869 case vector::CombiningKind::MINSI:
870 if (auto intTy = dyn_cast<IntegerType>(elemTy))
871 return makeConst(builder.getIntegerAttr(
872 elemTy, APInt::getSignedMaxValue(intTy.getWidth())));
873 return nullptr;
874
875 case vector::CombiningKind::MINUI:
876 if (auto intTy = dyn_cast<IntegerType>(elemTy))
877 return makeConst(
878 builder.getIntegerAttr(elemTy, APInt::getMaxValue(intTy.getWidth())));
879 return nullptr;
880
881 case vector::CombiningKind::MAXSI:
882 if (auto intTy = dyn_cast<IntegerType>(elemTy))
883 return makeConst(builder.getIntegerAttr(
884 elemTy, APInt::getSignedMinValue(intTy.getWidth())));
885 return nullptr;
886
887 case vector::CombiningKind::MINNUMF:
888 case vector::CombiningKind::MINIMUMF:
889 if (auto floatTy = dyn_cast<FloatType>(elemTy))
890 return makeConst(builder.getFloatAttr(
891 elemTy, APFloat::getInf(floatTy.getFloatSemantics())));
892 return nullptr;
893
894 case vector::CombiningKind::MAXNUMF:
895 case vector::CombiningKind::MAXIMUMF:
896 if (auto floatTy = dyn_cast<FloatType>(elemTy))
897 return makeConst(builder.getFloatAttr(
898 elemTy, APFloat::getInf(floatTy.getFloatSemantics(), true)));
899 return nullptr;
900 }
901 return nullptr;
902}
903
904/// Explicit instantiations
905template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
906 ArrayRef<int> candidateMultiples);
907template int
909 ArrayRef<unsigned> candidateMultiples);
910
911bool xegpu::requirePacked(const xegpu::DistributeLayoutAttr layout) {
912 if (!layout)
913 return false;
914 auto laneData = layout.getEffectiveLaneDataAsInt();
915 if (laneData.size() != 2)
916 return false;
917 return laneData[0] != 1;
918}
919
920bool xegpu::requireTranspose(const xegpu::DistributeLayoutAttr layout,
921 const xegpu::uArch::uArch *uArch) {
922 // Return false for unsupported targets.
923 // TODO: Add more support or move to target info.
924 if (uArch->getName().equals_insensitive("pvc") &&
925 uArch->getName().equals_insensitive("bmg") &&
926 uArch->getName().equals_insensitive("cri"))
927 return false;
928 if (!layout)
929 return false;
930 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
931 if (laneLayout.size() != 2)
932 return false;
933 return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
934}
935
936// Check if dst shape is an expansion of src shape by inserting unit dimensions.
937// Returns true if all dimensions in src match corresponding dimensions in dst
938// (after skipping unit dimensions), and populates expandedUnitDims with the
939// indices of the unit dimensions in dst that were added (not present in src).
940// Example: src=[2,3], dst=[1,2,3,1] -> true, expandedUnitDims=[0,3]
942 SmallVector<int64_t> &expandedUnitDims) {
943 // All unit dimensions in dst that don't appear in src are the expanded
944 // unit dimensions
945 size_t srcIdx = 0;
946 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
947 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
948 srcIdx++;
949 else if (dst[dstIdx] == 1)
950 expandedUnitDims.push_back(dstIdx);
951 else
952 return false;
953 return srcIdx == src.size();
954}
955
956// Checks if dst shape is an expansion of src shape where each dimension in src
957// is split into one or more consecutive dimensions in dst whose product equals
958// the original dimension. Populates splitDimGroups with groups of dst indices
959// that correspond to each src dimension. Example: src=[6,4], dst=[2,3,2,2] ->
960// true
963 SmallVector<SmallVector<int64_t>> &splitDimGroups) {
964 // each dim in src can be mapped to one or more dims in dst whose product
965 // equals to the src dim
966 size_t srcIdx = 0;
967 int64_t accumulatedSize = 1;
968 SmallVector<int64_t> currentDstDims;
969
970 splitDimGroups.clear();
971 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
972 if (srcIdx >= src.size())
973 return false;
974 accumulatedSize *= dst[dstIdx];
975 currentDstDims.push_back(dstIdx);
976
977 if (accumulatedSize == src[srcIdx]) {
978 // Record the mapping: srcIdx -> currentDstDims
979 splitDimGroups.push_back(currentDstDims);
980 // move to next src dim
981 srcIdx++;
982 accumulatedSize = 1;
983 currentDstDims.clear();
984 } else if (accumulatedSize > src[srcIdx]) {
985 return false;
986 }
987 }
988 return srcIdx == src.size();
989}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
xegpu::DistributeLayoutAttr maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, const OpResult &result, mlir::Operation *owner, const std::string &name)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:346
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
bool hasAttrOfType(NameT &&name)
Definition Operation.h:601
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:586
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getOpResults()
Definition Operation.h:446
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual int getSubgroupSize() const =0
StringRef getName() const
Definition uArchBase.h:163