MLIR 23.0.0git
XeGPUUtils.cpp
Go to the documentation of this file.
1//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utility methods for working with the XeGPU dialect.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/IR/ValueRange.h"
25#include "llvm/Support/Casting.h"
26#include "llvm/Support/FormatVariadic.h"
27#include <cstdint>
28#include <numeric>
29
30using namespace mlir;
31
32/// convert ArrayRef<ValueRange> into SmallVector<Value>
35 for (const auto &vals : values)
36 llvm::append_range(result, vals);
37 return result;
38}
39
40FailureOr<VectorType>
41mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
42 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
43 // It only works for subgroup level layout, which only has lane_layout
44 // and lane_data, and is to distribute a SIMD code into SIMT code.
45 if (!layout || !layout.isForSubgroup())
46 return failure();
47
48 SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
49 SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
50 auto tdescShape = tdescTy.getShape();
51 auto elementType = tdescTy.getElementType();
52
53 // compute sgSize by multiply elements of laneLayout
54 // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
55 // e.g. for 1D layout, sgSize = laneLayout[0]
56 int64_t sgSize = llvm::product_of(laneLayout);
57
58 // Check if the tensor descriptor shape is distributable.
59 int64_t tensorSize = 1;
60 for (auto [tdescDim, laneDim, laneDataDim] :
61 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
62 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
63 "tensor descriptor shape is not distributable");
64 tensorSize *= tdescDim;
65 }
66 // tensorSize must be adjusted for array_length.
67 tensorSize *= tdescTy.getArrayLength();
68
69 return VectorType::get({tensorSize / sgSize}, elementType);
70}
71
72FailureOr<VectorType>
73mlir::xegpu::getDistributedVectorType(VectorType originalType,
74 xegpu::LayoutAttr layout) {
75 int64_t rank = originalType.getRank();
76 // Distributed vector type is only supported for 1D, 2D and 3D vectors.
77 if (rank < 1 || rank > 3)
78 return failure();
79 ArrayRef<int64_t> shape = originalType.getShape();
80 // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
81 // of the 3D vector.
82 int arrayLength = 1;
83 if (rank == 3) {
84 arrayLength = shape[0];
85 shape = shape.drop_front();
86 }
87 auto helperTdescTy = xegpu::TensorDescType::get(
88 shape, originalType.getElementType(), arrayLength,
89 /*boundary_check=*/true,
90 /*memory_space=*/xegpu::MemorySpace::Global, layout);
91 return xegpu::getDistributedVectorType(helperTdescTy);
92}
93
94FailureOr<VectorType>
95xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
96 VectorType originalType) {
97 if (!layout)
98 return failure();
99 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
100 "Expecting a valid layout.");
101
102 int64_t vectorRank = originalType.getRank();
103 int64_t layoutRank = layout.getRank();
104 assert(vectorRank >= layoutRank && "Vector rank must be >= layout rank.");
105
106 // When the vector has more dimensions than the layout, only the trailing
107 // dimensions are distributed. Leading dimensions are preserved as-is.
108 int64_t offset = vectorRank - layoutRank;
109 ArrayRef<int64_t> fullShape = originalType.getShape();
110 SmallVector<int64_t> trailingShape(fullShape.begin() + offset,
111 fullShape.end());
112 auto distributedShapeOrFailure =
113 layout.computeDistributedShape(trailingShape);
114 if (failed(distributedShapeOrFailure))
115 return failure();
116
117 SmallVector<int64_t> resultShape(fullShape.begin(),
118 fullShape.begin() + offset);
119 resultShape.append(distributedShapeOrFailure->begin(),
120 distributedShapeOrFailure->end());
121 return VectorType::get(resultShape, originalType.getElementType());
122}
123
124std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
125 const StringRef prefix("layout_operand_");
126 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
127 return llvm::formatv("{0}{1}", prefix, idx).str();
128}
129
131 const StringRef prefix = "layout_result_";
132 return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
133}
134
135xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
136 if (!value)
137 return nullptr;
138
139 if (auto result = dyn_cast<OpResult>(value)) {
140 Operation *defOp = result.getDefiningOp();
141 assert(defOp && "result must have a defining op");
142
143 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
144 auto layout = anchorOp.getAnchorLayout();
145 return layout;
146 }
147
148 std::string layoutName = getTemporaryLayoutName(result);
149 if (defOp->hasAttr(layoutName)) {
150 auto layout =
151 defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
152 return layout;
153 }
154 }
155
156 if (auto arg = dyn_cast<BlockArgument>(value)) {
157 auto *parentOp = arg.getOwner()->getParentOp();
158 if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
159 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
160 if (tiedInit)
161 return getTemporaryLayout(*tiedInit);
162 }
163 }
164
165 if (auto tdescTy =
166 dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
167 return tdescTy.getLayoutAttr();
168
169 return nullptr;
170}
171xegpu::DistributeLayoutAttr
173 Operation *op = opr.getOwner();
174 unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
175
176 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
177 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
178 if (idx == 0) {
179 return dpasOp.getLayoutAAttr();
180 } else if (idx == 1) {
181 return dpasOp.getLayoutBAttr();
182 } else if (idx == 2) {
183 return dpasOp.getLayoutCdAttr();
184 }
185 }
186 if (auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
187 // DpasMxOp has operands: a, b, optional acc, optional scale_a, optional
188 // scale_b
189 unsigned currentIdx = 0;
190
191 if (idx == currentIdx++)
192 return dpasMxOp.getLayoutAAttr();
193
194 if (idx == currentIdx++)
195 return dpasMxOp.getLayoutBAttr();
196
197 if (dpasMxOp.getAcc())
198 if (idx == currentIdx++)
199 return dpasMxOp.getLayoutCdAttr();
200
201 if (dpasMxOp.getScaleA())
202 if (idx == currentIdx++)
203 return dpasMxOp.getLayoutAScaleAttr();
204
205 if (dpasMxOp.getScaleB())
206 if (idx == currentIdx++)
207 return dpasMxOp.getLayoutBScaleAttr();
208
209 return nullptr;
210 }
211 if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
212 return convertOp.getInputLayoutAttr();
213 }
214 auto layout = anchorOp.getAnchorLayout();
215
216 if (idx == 0)
217 return layout;
218
219 // For StoreNdOp and StoreMatrixOp,
220 // the layout is valid for the first two operands: value and memref/tdesc.
221 if (isa<xegpu::StoreNdOp, xegpu::StoreMatrixOp>(op) && (idx < 2))
222 return layout;
223
224 if (isa<xegpu::StoreScatterOp>(op)) {
225 xegpu::StoreScatterOp store(op);
226 int chunkSize = store.getChunkSize().value_or(1);
227 if (layout && idx >= 2 && chunkSize > 1)
228 return layout.dropDims(llvm::to_vector(
229 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
230 return layout;
231 }
232 if (isa<xegpu::LoadGatherOp>(op)) {
233 xegpu::LoadGatherOp load(op);
234 int chunkSize = load.getChunkSize().value_or(1);
235 if (layout && idx >= 1 && chunkSize > 1)
236 return layout.dropDims(llvm::to_vector(
237 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
238 return layout;
239 }
240 }
241
242 std::string layoutName = xegpu::getTemporaryLayoutName(opr);
243 if (op->hasAttr(layoutName)) {
244 auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
245 return layout;
246 }
247
248 return nullptr;
249}
250
251// Returns the permanent layout attribute for the given result if it's
252// available on the defining op. Otherwise returns the provided layout.
253xegpu::DistributeLayoutAttr
254maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
255 const OpResult &result, mlir::Operation *owner,
256 const std::string &name) {
257 xegpu::DistributeLayoutAttr candidate = layout;
258
259 if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
260 if (auto perm = loadOp.getLayoutAttr())
261 candidate = perm;
262 }
263
264 return candidate;
265}
266
267// Returns the permanent layout attribute for the given operand if it's
268// available on the defining op. Otherwise returns the provided layout.
269xegpu::DistributeLayoutAttr
270maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
271 const OpOperand &operand, mlir::Operation *owner,
272 const std::string &name) {
273 xegpu::DistributeLayoutAttr candidate = layout;
274 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
275
276 if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
277 if (idx == 0) {
278 if (auto perm = storeOp.getLayoutAttr())
279 candidate = perm;
280 }
281 }
282
283 return candidate;
284}
285
286// TODO-LayoutRefactor: Remove this function after replacing use
287// with setTemporaryLayout or setAnchorLayout
289 const mlir::OpResult &result,
290 const mlir::xegpu::DistributeLayoutAttr layout) {
291 Operation *owner = result.getOwner();
292
293 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
294 if (anchorOp.getAnchorLayout() == layout)
295 return;
296 anchorOp.setAnchorLayout(layout);
297 return;
298 }
299
300 std::string name = xegpu::getTemporaryLayoutName(result);
301 if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
302 return;
303 }
304 if (layout) {
305 owner->setAttr(name, layout);
306 }
307}
308
309// TODO-LayoutRefactor: Remove this function after replacing use
310// with setTemporaryLayout or setAnchorLayout
312 const DistributeLayoutAttr layout) {
313 Operation *owner = operand.getOwner();
314 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
315
316 if (!layout) {
317 return;
318 }
319 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
320 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
321 if (idx == 0) {
322 return dpasOp.setLayoutAAttr(layout);
323 } else if (idx == 1) {
324 return dpasOp.setLayoutBAttr(layout);
325 } else if (idx == 2) {
326 return dpasOp.setLayoutCdAttr(layout);
327 }
328 }
329 if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
330 return convertOp.setInputLayoutAttr(layout);
331 }
332
333 // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
334 // the layout is valid for the first two operands: value and memref/tdesc.
335 // For other operations, the layout applies to the first operand only.
336 if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
337 owner)) {
338 if (idx < 2) {
339 anchorOp.setAnchorLayout(layout);
340 }
341 } else {
342 if (idx == 0) {
343 anchorOp.setAnchorLayout(layout);
344 }
345 }
346 }
347
348 std::string name = xegpu::getTemporaryLayoutName(operand);
349 if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
350 return;
351 }
352 if (layout) {
353 owner->setAttr(name, layout);
354 }
355}
356
357template <typename T, typename>
358xegpu::DistributeLayoutAttr
359xegpu::getTemporaryLayout(const T &operandOrResult) {
360 Operation *op = operandOrResult.getOwner();
361
362 std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
363 if (op->hasAttr(layoutName)) {
364 auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
365 return layout;
366 }
367
368 return nullptr;
369}
370
371template xegpu::DistributeLayoutAttr
373template xegpu::DistributeLayoutAttr
375
376template <typename T, typename>
377void xegpu::setTemporaryLayout(const T &operandOrResult,
378 const xegpu::DistributeLayoutAttr layout) {
379 Operation *owner = operandOrResult.getOwner();
380 std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
381 if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
382 return;
383 }
384 if (layout) {
385 owner->setAttr(name, layout);
386 }
387}
388
390 const mlir::OpResult &result,
391 const mlir::xegpu::DistributeLayoutAttr layout);
392
394 const mlir::OpOperand &operand,
395 const mlir::xegpu::DistributeLayoutAttr layout);
396
400 auto vecTy = dyn_cast<VectorType>(value.getType());
401 if (!vecTy)
402 return {value};
403
404 ArrayRef<int64_t> srcShape = vecTy.getShape();
405 if (!computeShapeRatio(srcShape, shape))
406 return {value};
407
408 int64_t srcShapeRank = srcShape.size();
409 int64_t targetShapeRank = shape.size();
410
411 SmallVector<int64_t> adjustedTargetShape(srcShape.size());
412 int64_t rankDiff = srcShapeRank - targetShapeRank;
413 std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
414 1);
415 llvm::copy(shape, adjustedTargetShape.begin() + rankDiff);
416
418 for (SmallVector<int64_t> offsets :
419 StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
420 SmallVector<int64_t> staticStrides(offsets.size(), 1);
421 Value slice = vector::ExtractStridedSliceOp::create(
422 builder, loc, value, offsets, adjustedTargetShape, staticStrides);
423
424 // Reshape to remove leading unit dims if needed
425 if (srcShapeRank > targetShapeRank) {
426 auto targetTy = VectorType::get(shape, vecTy.getElementType());
427 slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
428 }
429 result.push_back(slice);
430 }
431
432 return result;
433}
434
436 ValueRange values,
438 VectorType inputTy = dyn_cast<VectorType>(values[0].getType());
439 assert(llvm::all_of(values.getTypes(),
440 [&](Type type) { return type == inputTy; }) &&
441 "values must be of the same VectorType");
442
443 Type elemTy = inputTy.getElementType();
444 ArrayRef<int64_t> tileShape = inputTy.getShape();
445
446 VectorType resultTy = VectorType::get(shape, elemTy);
447 auto zeroAttr = builder.getZeroAttr(elemTy);
448 Value result = arith::ConstantOp::create(
449 builder, loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr));
450
451 for (auto [src, offsets] :
452 llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
453 SmallVector<int64_t> staticStrides(tileShape.size(), 1);
454 result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
455 offsets, staticStrides);
456 }
457 return result;
458}
459
461 Operation *op, TypeConverter converter) {
462 MLIRContext *context = op->getContext();
463
464 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
465 Location loc) -> Value {
466 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
467 .getResult(0);
468 };
469
470 { // convert VectorType to RankedTensorType for SCF Structural ops
471 TypeConverter converter;
472 converter.addConversion([](Type type) -> Type { return type; });
473 converter.addConversion([](VectorType type) -> Type {
474 return RankedTensorType::get(type.getShape(), type.getElementType());
475 });
476 converter.addSourceMaterialization(materializeCast);
477 converter.addTargetMaterialization(materializeCast);
478
479 mlir::ConversionTarget target(*context);
480 target.addLegalOp<UnrealizedConversionCastOp>();
481
482 mlir::RewritePatternSet patterns(context);
484 target);
485 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
486 }
487
488 { // propagate the layout attribute to RankedTensorType by checking
489 // BuiltInUnrealizedCastOps
490 // for VectorType to RankedTensorType cast.
491 op->walk([](UnrealizedConversionCastOp castOp) {
492 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
493 return WalkResult::skip();
494
495 Value input = castOp.getInputs()[0];
496 Value result = castOp.getResults()[0];
497 auto inputTy = dyn_cast<VectorType>(input.getType());
498 auto resultTy = dyn_cast<RankedTensorType>(result.getType());
499
500 // Only look at ops casting from VectorType to RankedTensorType
501 if (!inputTy || !resultTy)
502 return WalkResult::skip();
503
504 xegpu::DistributeLayoutAttr layout =
506 if (!layout)
507 return WalkResult::skip();
508
509 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
510 result.setType(newTy);
511
512 // update the arguments if user is a LoopLike op.
513 for (OpOperand &use : result.getUses()) {
514 if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
515 BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
516 arg.setType(newTy);
517 }
518 // whileOp has two regions, the BlockArgument of the after region
519 // is not exposed by LoopLikeOpInterface
520 if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
521 unsigned idx = use.getOperandNumber();
522 BlockArgument arg = whileOp.getAfterArguments()[idx];
523 arg.setType(newTy);
524 }
525 }
526 return WalkResult::advance();
527 });
528
529 // using yieldOp as anchor to update the result type of its ParentOp
530 op->walk([](scf::YieldOp yieldOp) {
531 Operation *parentOp = yieldOp->getParentOp();
532 for (OpResult r : parentOp->getOpResults()) {
533 unsigned idx = r.getResultNumber();
534 Type resultTy = r.getType();
535 Type yieldTy = yieldOp.getResults()[idx].getType();
536 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
537 r.setType(yieldTy);
538 }
539 });
540 }
541
542 { // perform the conversion from RankedTensorType to VectorType based on the
543 // DistributeLayoutAttr
544
545 // Handle the UnrealizedConversionCastOp introduced by the first step.
546 // For vector->RankedTensorType, it will simply forward the inputs.
547 // For RankedTensorType->vector, it will update the inputs with the
548 // one from the adaptor.
549 class UnrealizedConversionCastOpPattern
550 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
551 using OpConversionPattern<
552 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
553
554 mlir::LogicalResult
555 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
556 OneToNOpAdaptor adaptor,
557 ConversionPatternRewriter &rewriter) const override {
558 auto inputs = op.getOperands();
559 auto outputs = op.getOutputs();
560
561 if (inputs.size() != 1 || outputs.size() != 1)
562 return failure();
563
564 auto inputTy = inputs[0].getType();
565 auto outputTy = outputs[0].getType();
566
567 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
568 rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
569 return success();
570 }
571
572 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
573 SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
574 auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(),
575 outputTy, values);
576 rewriter.replaceOp(op, newOp);
577 return success();
578 }
579 return failure();
580 }
581 };
582
583 converter.addSourceMaterialization(materializeCast);
584 converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
585 ValueRange inputs, Location loc) {
586 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
587 .getResults();
588 });
589
590 mlir::ConversionTarget target(*context);
591 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
592 [](UnrealizedConversionCastOp op) {
593 auto isTensorTy = [](Type type) {
594 return isa<RankedTensorType>(type);
595 };
596 return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
597 llvm::none_of(op->getResultTypes(), isTensorTy);
598 });
599 mlir::RewritePatternSet patterns(context);
600 patterns.insert<UnrealizedConversionCastOpPattern>(context);
602 target);
603 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
604 }
605}
606
607std::optional<std::string> xegpu::getChipStr(Operation *op) {
608 auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
609
610 if (!gpuModuleOp)
611 return std::nullopt;
612
613 auto targetAttrs = gpuModuleOp.getTargets();
614 if (targetAttrs) {
615 for (auto &attr : *targetAttrs) {
616 auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
617 if (xevmAttr)
618 return xevmAttr.getChip().str();
619 }
620 }
621
622 return std::nullopt;
623}
624
625/// Generates element-wise addition ops of two arrays with same length.
627 Location loc,
630 assert(lhs.size() == rhs.size() && "lhs and rhs must have the same size");
632 for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
633 auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
634 auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
635 results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
636 }
637 return results;
638}
639
640/// Generates element-wise addition ops of two arrays with automatic alignment.
641/// When the input arrays have different sizes, the shorter array is
642/// right-aligned with the longer array, and the unmatched leading elements from
643/// the longer array are preserved unchanged. This is commonly used for offset
644/// computation where higher-dimensional offsets need to be added to
645/// lower-dimensional adjustments.
646///
647/// Example:
648/// lhs = [l1, l2, l3], rhs = [r1, r2]
649/// Result: [11, l2+r1, l3+r2]
654 // ensure a is longer than b
655 ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
656 ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
657 SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
658 a = a.slice(a.size() - b.size());
659 results.append(addElementwise(builder, loc, a, b));
660 return results;
661}
662
663template <typename T>
665 ArrayRef<T> candidateMultiples) {
666 static_assert(std::is_integral<T>::value, "T must be an integer type");
667 int largest = -1;
668 SmallVector<T> multiples = {1};
669 if (!candidateMultiples.empty())
670 multiples =
671 SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
672 for (T candidate : candidates) {
673 for (T multiple : multiples) {
674 int value = static_cast<int>(candidate * multiple);
675 if (value != 0 && dim % value == 0 && value > largest)
676 largest = value;
677 }
678 }
679 return largest;
680}
681
683 vector::CombiningKind kind, uint32_t size) {
684 // First reduce on a single thread to get per lane reduction value.
685 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
686 // Parallel reduction using butterfly shuffles.
687 for (uint64_t i = 1; i < size; i <<= 1) {
688 Value shuffled =
689 gpu::ShuffleOp::create(builder, loc, laneVal, i, /** width = **/ size,
690 /** mode = **/ gpu::ShuffleMode::XOR)
691 .getShuffleResult();
692 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
693 }
694 return laneVal;
695}
696
699 vector::CombiningKind kind,
700 int64_t reductionDim, Location loc,
701 PatternRewriter &rewriter) {
702 VectorType sourceType = src.getType();
703 int64_t sourceRank = sourceType.getRank();
704 // Expecting at least a 2D source vector. Leading dimensions (all except the
705 // last two) must be unit.
706 assert(sourceRank >= 2 && "expected at least a 2D source vector");
707 for (int64_t i = 0; i < sourceRank - 2; ++i)
708 assert(sourceType.getShape()[i] == 1 &&
709 "expected leading dimensions to be unit");
710 int64_t rowIdx = sourceRank - 2;
711 int64_t columnIdx = sourceRank - 1;
712 int64_t sourceH = sourceType.getShape()[rowIdx];
713 int64_t sourceW = sourceType.getShape()[columnIdx];
714 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
715 // Create a constant vector to hold the result of the reduction.
716 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
717 Value reductionResult = arith::ConstantOp::create(
718 rewriter, loc, acc.getType(),
719 DenseElementsAttr::get(acc.getType(), zeroAttr));
720 auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
721 auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
722 // Reduction result should have the same layout as the accumulator.
723 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
724 // For each slice of the source, extract the slice vector, do a reduction
725 // and, insert the reduced value back to the result vector.
726 int64_t accRank = acc.getType().getRank();
727 for (int i = 0; i < nSlices; ++i) {
728 // Build nD offsets, sizes, and strides. Leading unit dims get
729 // offset=0, size=1. The last two dims are set based on reductionDim.
730 SmallVector<int64_t> sliceOffsets(sourceRank, 0);
731 SmallVector<int64_t> sliceSizes(sourceRank, 1);
732 SmallVector<int64_t> strides(sourceRank, 1);
733 if (reductionDim == columnIdx) {
734 sliceOffsets[rowIdx] = i;
735 sliceSizes[columnIdx] = sourceW;
736 } else {
737 sliceOffsets[columnIdx] = i;
738 sliceSizes[rowIdx] = sourceH;
739 }
740
741 vector::ExtractStridedSliceOp extractOp =
742 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
743 sliceSizes, strides);
744 // Extract strided slice has the same layout as src.
745 xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
746
747 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
748
749 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
750 rewriter, loc,
751 VectorType::get({nSliceElements}, sourceType.getElementType()),
752 extractOp.getResult());
753
754 // Shape cast output has the same layout as the accumulator. Shape cast
755 // source has the same layout as the original reduction source.
756 xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
757 xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
758 // Extract and reduction results in scalars, so no result layout is needed.
759 // Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
760 // the reduction dim removed). Leading unit dims get index 0.
761 SmallVector<int64_t> accIdx(accRank, 0);
762 accIdx[accRank - 1] = i;
763 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
764 Value reduction = vector::ReductionOp::create(
765 rewriter, loc, kind, slice.getResult(), accExtract);
766 reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
767 reductionResult, accIdx);
768 // Insert op should have the same layout as the accumulator.
769 xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
770 }
771 return reductionResult;
772}
773
776 vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize,
777 Location loc, PatternRewriter &rewriter) {
778 VectorType sourceType = src.getType();
779 int64_t sourceRank = sourceType.getRank();
780 // Expecting at least a 2D source vector. Leading dimensions (all except the
781 // last two) must be unit.
782 assert(sourceRank >= 2 && "expected at least a 2D source vector");
783 for (int64_t i = 0; i < sourceRank - 2; ++i)
784 assert(sourceType.getShape()[i] == 1 &&
785 "expected leading dimensions to be unit");
786 int64_t rowIdx = sourceRank - 2;
787 int64_t columnIdx = sourceRank - 1;
788 int64_t sourceH = sourceType.getShape()[rowIdx];
789 int64_t sourceW = sourceType.getShape()[columnIdx];
790
791 // Create a constant vector to hold the result of the reduction.
792 TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
793 Value reductionResult = arith::ConstantOp::create(
794 rewriter, loc, acc.getType(),
795 DenseElementsAttr::get(acc.getType(), zeroAttr));
796
797 // nSlices is the number of reduction operations needed to reduce the entire
798 // source vector. For example, if reductionDim is the row dim, we are
799 // reducing across rows, and each slice is a column. So the number of slices
800 // is the number of columns, which is sourceW.
801 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
802
803 // For each slice of the source, extract the slice vector, do a reduction
804 // and, insert the reduced value back to the result vector.
805 int64_t accRank = acc.getType().getRank();
806 for (int i = 0; i < nSlices; ++i) {
807 // Build nD offsets, sizes, and strides. Leading unit dims get
808 // offset=0, size=1. The last two dims are set based on reductionDim.
809 SmallVector<int64_t> sliceOffsets(sourceRank, 0);
810 SmallVector<int64_t> sliceSizes(sourceRank, 1);
811 SmallVector<int64_t> strides(sourceRank, 1);
812 if (reductionDim == columnIdx) {
813 sliceOffsets[rowIdx] = i;
814 sliceSizes[columnIdx] = sourceW;
815 } else {
816 sliceOffsets[columnIdx] = i;
817 sliceSizes[rowIdx] = sourceH;
818 }
819
820 vector::ExtractStridedSliceOp extractOp =
821 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
822 sliceSizes, strides);
823 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
824 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
825 rewriter, loc,
826 VectorType::get({nSliceElements}, sourceType.getElementType()),
827 extractOp.getResult());
828
829 SmallVector<int64_t> accIdx(accRank, 0);
830 accIdx[accRank - 1] = i;
831 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
832 Value fullReduce =
833 xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize);
834 fullReduce =
835 vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract);
836 reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
837 reductionResult, accIdx);
838 }
839 return reductionResult;
840}
841
843 Type type,
844 vector::CombiningKind kind) {
845 auto vecTy = dyn_cast<VectorType>(type);
846 Type elemTy = vecTy ? vecTy.getElementType() : type;
847
848 // Helper to create either a splat vector or scalar constant from an attr.
849 auto makeConst = [&](Attribute scalarAttr) -> Value {
850 if (vecTy)
851 return arith::ConstantOp::create(
852 builder, loc, vecTy, DenseElementsAttr::get(vecTy, scalarAttr));
853 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(scalarAttr));
854 };
855
856 switch (kind) {
857 case vector::CombiningKind::ADD:
858 case vector::CombiningKind::XOR:
859 case vector::CombiningKind::OR:
860 case vector::CombiningKind::MAXUI:
861 return makeConst(builder.getZeroAttr(elemTy));
862
863 case vector::CombiningKind::MUL:
864 case vector::CombiningKind::AND:
865 return makeConst(builder.getOneAttr(elemTy));
866
867 case vector::CombiningKind::MINSI:
868 if (auto intTy = dyn_cast<IntegerType>(elemTy))
869 return makeConst(builder.getIntegerAttr(
870 elemTy, APInt::getSignedMaxValue(intTy.getWidth())));
871 return nullptr;
872
873 case vector::CombiningKind::MINUI:
874 if (auto intTy = dyn_cast<IntegerType>(elemTy))
875 return makeConst(
876 builder.getIntegerAttr(elemTy, APInt::getMaxValue(intTy.getWidth())));
877 return nullptr;
878
879 case vector::CombiningKind::MAXSI:
880 if (auto intTy = dyn_cast<IntegerType>(elemTy))
881 return makeConst(builder.getIntegerAttr(
882 elemTy, APInt::getSignedMinValue(intTy.getWidth())));
883 return nullptr;
884
885 case vector::CombiningKind::MINNUMF:
886 case vector::CombiningKind::MINIMUMF:
887 if (auto floatTy = dyn_cast<FloatType>(elemTy))
888 return makeConst(builder.getFloatAttr(
889 elemTy, APFloat::getInf(floatTy.getFloatSemantics())));
890 return nullptr;
891
892 case vector::CombiningKind::MAXNUMF:
893 case vector::CombiningKind::MAXIMUMF:
894 if (auto floatTy = dyn_cast<FloatType>(elemTy))
895 return makeConst(builder.getFloatAttr(
896 elemTy, APFloat::getInf(floatTy.getFloatSemantics(), true)));
897 return nullptr;
898 }
899 return nullptr;
900}
901
902/// Explicit instantiations
903template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
904 ArrayRef<int> candidateMultiples);
905template int
907 ArrayRef<unsigned> candidateMultiples);
908
909bool xegpu::requirePacked(const xegpu::DistributeLayoutAttr layout) {
910 if (!layout)
911 return false;
912 auto laneData = layout.getEffectiveLaneDataAsInt();
913 if (laneData.size() != 2)
914 return false;
915 return laneData[0] != 1;
916}
917
918bool xegpu::requireTranspose(const xegpu::DistributeLayoutAttr layout,
919 const xegpu::uArch::uArch *uArch) {
920 // Return false for unsupported targets.
921 // TODO: Add more support or move to target info.
922 if (uArch->getName().equals_insensitive("pvc") &&
923 uArch->getName().equals_insensitive("bmg") &&
924 uArch->getName().equals_insensitive("cri"))
925 return false;
926 if (!layout)
927 return false;
928 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
929 if (laneLayout.size() != 2)
930 return false;
931 return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
932}
933
934// Check if dst shape is an expansion of src shape by inserting unit dimensions.
935// Returns true if all dimensions in src match corresponding dimensions in dst
936// (after skipping unit dimensions), and populates expandedUnitDims with the
937// indices of the unit dimensions in dst that were added (not present in src).
938// Example: src=[2,3], dst=[1,2,3,1] -> true, expandedUnitDims=[0,3]
940 SmallVector<int64_t> &expandedUnitDims) {
941 // All unit dimensions in dst that don't appear in src are the expanded
942 // unit dimensions
943 size_t srcIdx = 0;
944 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
945 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
946 srcIdx++;
947 else if (dst[dstIdx] == 1)
948 expandedUnitDims.push_back(dstIdx);
949 else
950 return false;
951 return srcIdx == src.size();
952}
953
954// Checks if dst shape is an expansion of src shape where each dimension in src
955// is split into one or more consecutive dimensions in dst whose product equals
956// the original dimension. Populates splitDimGroups with groups of dst indices
957// that correspond to each src dimension. Example: src=[6,4], dst=[2,3,2,2] ->
958// true
961 SmallVector<SmallVector<int64_t>> &splitDimGroups) {
962 // each dim in src can be mapped to one or more dims in dst whose product
963 // equals to the src dim
964 size_t srcIdx = 0;
965 int64_t accumulatedSize = 1;
966 SmallVector<int64_t> currentDstDims;
967
968 splitDimGroups.clear();
969 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
970 if (srcIdx >= src.size())
971 return false;
972 accumulatedSize *= dst[dstIdx];
973 currentDstDims.push_back(dstIdx);
974
975 if (accumulatedSize == src[srcIdx]) {
976 // Also collect trailing unit dims in destination, if any.
977 // Leading unit dims were implicitly collected.
978 if (srcIdx == src.size() - 1) {
979 while (++dstIdx < dst.size() && dst[dstIdx] == 1)
980 currentDstDims.push_back(dstIdx);
981 }
982 // Record the mapping: srcIdx -> currentDstDims
983 splitDimGroups.push_back(currentDstDims);
984 // move to next src dim
985 srcIdx++;
986 accumulatedSize = 1;
987 currentDstDims.clear();
988 } else if (accumulatedSize > src[srcIdx]) {
989 return false;
990 }
991 }
992 return srcIdx == src.size();
993}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
xegpu::DistributeLayoutAttr maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, const OpResult &result, mlir::Operation *owner, const std::string &name)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:259
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:347
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:575
bool hasAttrOfType(NameT &&name)
Definition Operation.h:600
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:585
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:607
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
result_range getOpResults()
Definition Operation.h:445
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual int getSubgroupSize() const =0
StringRef getName() const
Definition uArchBase.h:163