28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
41 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->
getAttr(
"sg_id_range")))
49static std::pair<SmallVector<int64_t>,
int>
51 xegpu::DistributeLayoutAttr layout) {
54 if (layout && layout.isForWorkgroup()) {
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
59 sgShape = *maybeDerivedSgData;
64 for (
size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(
shape[i], distUnit[i]);
68 return std::make_pair(sgShape, count);
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
86 if (origOffsets.empty())
90 xegpu::DistributeLayoutAttr layout;
91 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
92 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
93 layout = op.getLayoutAttr();
95 layout = op.getDescLayoutAttr();
99 if (!layout || !layout.isForWorkgroup())
103 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
106 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
108 int64_t startOfRange = sgIdRange.getStart().getInt();
109 int64_t endOfRange = sgIdRange.getEnd().getInt();
111 if (layout.getNumSubgroups() != endOfRange - startOfRange)
112 return rewriter.notifyMatchFailure(
113 op,
"sg_layout size must match the sg_id_range");
115 if (startOfRange > 0) {
116 Value startOfRangeVal =
118 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
125 auto maybeDescOffsets =
126 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
127 if (
failed(maybeDescOffsets))
132 for (
const auto &sgOffsets : *maybeDescOffsets) {
135 offsetsList.push_back(std::move(newOffsets));
187struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
188 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
191 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter)
const override {
193 SmallVector<SmallVector<OpFoldResult>> offsetsList;
194 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
197 MLIRContext *ctx = op.getContext();
198 xegpu::TensorDescType tdescTy = op.getType();
199 ArrayRef<int64_t> wgShape = tdescTy.getShape();
200 Type elemTy = tdescTy.getElementType();
201 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
202 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
204 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
205 layout.dropSgLayoutAndData());
207 SmallVector<Value> newOps;
208 for (
auto offsets : offsetsList) {
209 auto newOp = xegpu::CreateNdDescOp::create(
210 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
211 op.getMixedSizes(), op.getMixedStrides());
213 newOps.push_back(newOp);
215 rewriter.replaceOpWithMultiple(op, {newOps});
223struct WgToSgCreateNdOpNoOffset
224 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
225 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
228 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter)
const override {
232 if (!op.getMixedOffsets().empty())
235 Location loc = op.getLoc();
236 MLIRContext *ctx = op.getContext();
237 xegpu::TensorDescType tdescTy = op.getType();
238 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
239 if (!layout || !layout.isForWorkgroup())
242 Type elemTy = tdescTy.getElementType();
243 ArrayRef<int64_t> wgShape = tdescTy.getShape();
245 SmallVector<int64_t> sgShape;
247 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
248 xegpu::TensorDescType newTdescTy =
249 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
250 layout.dropSgLayoutAndData());
252 SmallVector<Value> newCreateNdOps(count);
253 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
254 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
255 op.getSource(), op.getMixedSizes(),
256 op.getMixedStrides());
259 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
265struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
266 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
268 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter)
const override {
270 if (!op.getMixedOffsets().empty())
273 SmallVector<Value> newLoadOps;
274 for (
auto src : adaptor.getTensorDesc()) {
275 xegpu::TensorDescType tdescTy =
276 dyn_cast<xegpu::TensorDescType>(src.getType());
277 ArrayRef<int64_t> srcShape = tdescTy.getShape();
278 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
279 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
280 src, op->getAttrs());
281 newLoadOps.push_back(newLoadOp);
283 rewriter.replaceOpWithMultiple(op, {newLoadOps});
284 return mlir::success();
291struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
292 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
294 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter)
const override {
296 if (!op.getMixedOffsets().empty())
299 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
300 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
301 op.getL2HintAttr(), op.getL3HintAttr());
303 rewriter.eraseOp(op);
310struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
311 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
313 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
316 SmallVector<SmallVector<OpFoldResult>> offsetsList;
317 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
320 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
322 layout = layout.dropSgLayoutAndData();
323 SmallVector<Value> newOps;
324 for (
auto [tdesc, offsets] :
325 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
326 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
327 VectorType newResTy =
328 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
329 auto newOp = xegpu::LoadNdOp::create(
330 rewriter, op.getLoc(), newResTy, tdesc, offsets,
331 nullptr,
nullptr, op.getL1HintAttr(),
332 op.getL2HintAttr(), op.getL3HintAttr(), layout);
333 newOps.push_back(newOp);
335 rewriter.replaceOpWithMultiple(op, {newOps});
343struct WgToSgStoreNdOpWithOffset
344 :
public OpConversionPattern<xegpu::StoreNdOp> {
345 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
347 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter)
const override {
349 SmallVector<SmallVector<OpFoldResult>> offsetsList;
350 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
353 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
355 layout = layout.dropSgLayoutAndData();
356 for (
auto [v, tdesc, offsets] :
357 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
358 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
359 op.getL1HintAttr(), op.getL2HintAttr(),
360 op.getL3HintAttr(), layout);
362 rewriter.eraseOp(op);
370struct WgToSgPrefetchNdOpWithOffset
371 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
372 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
374 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
375 ConversionPatternRewriter &rewriter)
const override {
376 SmallVector<SmallVector<OpFoldResult>> offsetsList;
377 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
380 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
382 layout = layout.dropSgLayoutAndData();
383 for (
auto [tdesc, offsets] :
384 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
385 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
386 op.getL1HintAttr(), op.getL2HintAttr(),
387 op.getL3HintAttr(), layout);
389 rewriter.eraseOp(op);
398struct WgToSgUpdateNdOffsetOp
399 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
400 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
402 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter)
const override {
404 llvm::SmallVector<Value> newUpdateTileOffsetOps;
405 for (
auto tDesc : adaptor.getTensorDesc()) {
406 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
407 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
408 op.getConstOffsets());
409 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
412 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
418struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
419 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
421 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter)
const override {
423 Location loc = op.getLoc();
424 VectorType resultTy = op.getResult().getType();
425 if (resultTy.getRank() != 2)
433 SmallVector<Value> newDpasOps;
434 for (
auto aVec : adaptor.getLhs()) {
435 for (
auto bVec : adaptor.getRhs()) {
437 llvm::SmallVector<Value> operands({aVec, bVec});
440 tmpC = adaptor.getAcc()[i++];
441 operands.push_back(tmpC);
444 ArrayRef<int64_t> aVecShape =
445 llvm::cast<VectorType>(aVec.getType()).getShape();
446 ArrayRef<int64_t> bVecShape =
447 llvm::cast<VectorType>(bVec.getType()).getShape();
448 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
449 resultTy.getElementType());
450 tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
452 originalLayout.dropSgLayoutAndData());
454 newDpasOps.push_back(tmpC);
457 rewriter.replaceOpWithMultiple(op, {newDpasOps});
463struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
464 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
466 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter)
const override {
469 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
470 if ((offsetSize != 0) || op.getConstOffsetsAttr())
473 for (
auto src : adaptor.getTensorDesc())
474 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(),
TypeRange(), src,
476 rewriter.eraseOp(op);
482struct WgToSgVectorBroadcastOp
483 :
public OpConversionPattern<vector::BroadcastOp> {
484 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
487 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
488 ConversionPatternRewriter &rewriter)
const override {
490 VectorType resultType = op.getResult().getType();
491 ArrayRef<int64_t> wgShape = resultType.getShape();
493 xegpu::DistributeLayoutAttr layout =
495 if (!layout || !layout.isForWorkgroup())
498 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
499 VectorType newResultType =
500 VectorType::get(sgShape, resultType.getElementType());
502 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
505 SmallVector<Value> newBroadcastOps;
506 for (
auto operand : adaptor.getOperands().front()) {
507 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
508 newResultType, operand);
510 layout.dropSgLayoutAndData());
512 newBroadcastOps.push_back(newBroadcast.getResult());
514 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
526 ConversionPatternRewriter &rewriter)
const override {
532 assert(resultType &&
"Expected result to be a VectorType");
536 xegpu::DistributeLayoutAttr layout =
538 if (!layout || !layout.isForWorkgroup())
543 size_t numVariants = operands.empty() ? 0 : operands.front().size();
545 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
546 return operandVec.size() != numVariants;
551 VectorType newResultType =
552 VectorType::get(sgShape, resultType.getElementType());
554 for (
size_t i = 0; i < numVariants; ++i) {
556 for (
auto &operandVec : operands)
557 opOperands.push_back(operandVec[i]);
560 state.addOperands(opOperands);
561 state.addTypes(newResultType);
566 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
567 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
568 !layout.getEffectiveInstDataAsInt().empty())
569 state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
571 state.addAttribute(attr.getName(), attr.getValue());
574 Operation *newOp = rewriter.create(state);
575 newResults.push_back(newOp->
getResult(0));
578 rewriter.replaceOpWithMultiple(op, {newResults});
609struct WgToSgConvertLayoutOp
610 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
611 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
613 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
614 ConversionPatternRewriter &rewriter)
const override {
616 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
617 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
619 if (!input || !
target || !input.isForWorkgroup() ||
621 return rewriter.notifyMatchFailure(
622 op,
"Input and target layouts must have subgroup layout");
633 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
634 inputOrder != targetOrder)
637 input = input.dropSgLayoutAndData();
640 SmallVector<Value> newOps(adaptor.getSource());
643 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
644 auto newOp = xegpu::ConvertLayoutOp::create(
645 rewriter, op.getLoc(), src.getType(), src, input,
target);
649 rewriter.replaceOpWithMultiple(op, {newOps});
685struct UnrealizedConversionCastOpPattern
686 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
687 using OpConversionPattern<
688 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
691 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
692 ConversionPatternRewriter &rewriter)
const override {
695 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
696 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
698 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
699 !llvm::all_equal(
ValueRange(inputs).getTypes()))
707 if (op.getNumOperands() == 1 &&
708 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
709 rewriter.replaceOp(op, inputs);
720 if (op.getNumResults() == 1 &&
722 rewriter.replaceOpWithMultiple(op, {inputs});
726 return mlir::failure();
731struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
732 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
735 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const override {
737 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
738 auto vecType = dyn_cast<VectorType>(op.getType());
739 if (!vecAttr || !vecType)
742 xegpu::DistributeLayoutAttr layout =
744 if (!layout || !layout.isForWorkgroup())
747 ArrayRef<int64_t> wgShape = vecType.getShape();
748 SmallVector<int64_t> sgShape;
750 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
752 auto newType = VectorType::get(sgShape, vecType.getElementType());
753 Location loc = op.getLoc();
754 auto eltType = vecType.getElementType();
756 auto setLayout = [&](Value val) {
758 layout.dropSgLayoutAndData());
761 if (vecAttr.isSplat()) {
763 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
765 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
766 setLayout(cstOp->getResult(0));
767 rewriter.replaceOp(op, cstOp);
769 }
else if (sgShape == wgShape) {
772 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
773 setLayout(newConstOp->getResult(0));
774 rewriter.replaceOp(op, newConstOp);
780 if (!eltType.isIndex())
781 return rewriter.notifyMatchFailure(
782 op,
"Unsupported element type for non-splat constant op.");
784 if (wgShape.size() > 2)
785 return rewriter.notifyMatchFailure(
786 op,
"Only 1D & 2D vector constant supported");
788 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
789 int64_t rowStride = 0, colStride = 0;
790 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
791 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
795 colStride = cast<IntegerAttr>(values[1]).getInt() -
796 cast<IntegerAttr>(values[0]).getInt();
799 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
800 cast<IntegerAttr>(values[0]).getInt();
803 for (int64_t r = 0; r < rows; ++r) {
804 for (int64_t c = 0; c < cols; ++c) {
805 int64_t idx = r * cols + c;
807 if (c > 0 && cols > 1) {
808 int64_t prevIdx = r * cols + (c - 1);
809 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
810 cast<IntegerAttr>(values[prevIdx]).getInt();
811 if (diff != colStride)
812 return rewriter.notifyMatchFailure(
813 op,
"Non-constant column stride in constant op.");
816 if (r > 0 && rows > 1) {
817 int64_t prevIdx = (r - 1) * cols + c;
818 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
819 cast<IntegerAttr>(values[prevIdx]).getInt();
820 if (diff != rowStride)
821 return rewriter.notifyMatchFailure(
822 op,
"Non-constant row stride in constant op.");
830 SmallVector<Attribute> baseTileValues;
831 int baseTileCols = sgShape[sgShape.size() - 1];
832 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
833 for (int64_t r = 0; r < baseTileRows; ++r) {
834 for (int64_t c = 0; c < baseTileCols; ++c) {
835 baseTileValues.push_back(values[r * cols + c]);
841 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
845 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
847 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
851 SmallVector<Value, 2> strideConsts;
852 strideConsts.push_back(
856 strideConsts.begin(),
859 SmallVector<Value> newConstOps;
860 for (
auto offsets : *sgOffsets) {
863 for (
size_t i = 0; i < strideConsts.size(); ++i) {
865 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
866 offsets[i], strideConsts[i]);
867 mulOffset = arith::AddIOp::create(
868 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
871 auto bcastOffset = vector::BroadcastOp::create(
872 rewriter, loc, baseConstVec.getType(), mulOffset);
874 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
875 setLayout(baseConstVec);
876 setLayout(bcastOffset);
877 setLayout(finalConst);
878 newConstOps.push_back(finalConst);
880 rewriter.replaceOpWithMultiple(op, {newConstOps});
888struct WgToSgLoadGatherOpWithOffset
889 :
public OpConversionPattern<xegpu::LoadGatherOp> {
890 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
892 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter)
const override {
895 if (!op.getOffsets())
898 Location loc = op.getLoc();
899 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
902 ArrayRef<int64_t> wgShape = resultType.getShape();
904 xegpu::DistributeLayoutAttr layout =
906 if (!layout || !layout.isForWorkgroup())
909 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
912 auto offsetsVecType =
913 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
915 dyn_cast<VectorType>(adaptor.getMask().front().getType());
916 if (!offsetsVecType || !maskVecType ||
917 offsetsVecType.getShape() != maskVecType.getShape()) {
918 return rewriter.notifyMatchFailure(op,
919 "offsets have not been distributed");
922 SmallVector<Value> newLoadOps;
924 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
925 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
926 for (
auto [offsets, mask] :
927 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
928 auto newLayout = layout.dropSgLayoutAndData();
929 auto newLoadOp = xegpu::LoadGatherOp::create(
930 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
931 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
934 newLoadOps.push_back(newLoadOp);
936 rewriter.replaceOpWithMultiple(op, {newLoadOps});
943struct WgToSgStoreScatterOpWithOffset
944 :
public OpConversionPattern<xegpu::StoreScatterOp> {
945 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
947 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
948 ConversionPatternRewriter &rewriter)
const override {
950 if (!op.getOffsets())
953 Location loc = op.getLoc();
954 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
958 xegpu::DistributeLayoutAttr layout =
960 if (!layout || !layout.isForWorkgroup())
964 auto offsetsVecType =
965 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
967 dyn_cast<VectorType>(adaptor.getMask().front().getType());
968 if (!offsetsVecType || !maskVecType ||
969 offsetsVecType.getShape() != maskVecType.getShape()) {
970 return rewriter.notifyMatchFailure(op,
971 "offsets have not been distributed");
974 auto chunkSizeOpt = op.getChunkSize();
975 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
976 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
977 for (
auto [val, offs, mask] : llvm::zip(
978 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
979 auto store = xegpu::StoreScatterOp::create(
980 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
981 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
982 layout.dropSgLayoutAndData());
984 for (OpOperand &operand : store->getOpOperands()) {
986 if (operand.getOperandNumber() == 1)
991 rewriter.eraseOp(op);
996struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
997 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
999 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1000 ConversionPatternRewriter &rewriter)
const override {
1002 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1003 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1006 ArrayRef<int64_t> wgShape = op.getDataShape();
1007 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1008 assert(valueTy &&
"the value type must be vector type!");
1009 Type elemTy = valueTy.getElementType();
1011 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1012 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1013 VectorType newResTy = VectorType::get(sgShape, elemTy);
1014 SmallVector<Value> newOps;
1015 for (
auto offsets : offsetsList) {
1016 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1017 op.getMemDesc(), offsets,
1018 layout.dropSgLayoutAndData());
1019 newOps.push_back(newOp);
1021 rewriter.replaceOpWithMultiple(op, {newOps});
1027struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1028 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1030 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter)
const override {
1033 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1034 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1037 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1038 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1039 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1040 offsets, layout.dropSgLayoutAndData());
1041 rewriter.eraseOp(op);
1047struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1048 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1050 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1051 ConversionPatternRewriter &rewriter)
const override {
1052 xegpu::DistributeLayoutAttr layout =
1054 if (!layout || !layout.isForWorkgroup())
1057 Location loc = op.getLoc();
1058 VectorType type = op.getResult().getType();
1059 auto wgShape = type.getShape();
1060 std::optional<SmallVector<int64_t>> sgShape =
1061 getSgShapeAndCount(wgShape, layout).first;
1066 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1068 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1072 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1073 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1074 SmallVector<Value> newOps;
1075 for (
auto offsets : *sgOffsets) {
1078 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1080 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1082 layout.dropSgLayoutAndData());
1084 layout.dropSgLayoutAndData());
1086 layout.dropSgLayoutAndData());
1087 newOps.push_back(finalSteps);
1090 rewriter.replaceOpWithMultiple(op, {newOps});
1096struct WgToSgVectorShapeCastOp
1097 :
public OpConversionPattern<vector::ShapeCastOp> {
1098 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1101 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1102 ConversionPatternRewriter &rewriter)
const override {
1104 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1108 ArrayRef<int64_t> wgShape = resultType.getShape();
1109 xegpu::DistributeLayoutAttr layout =
1111 if (!layout || !layout.isForWorkgroup())
1114 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1115 VectorType newResultType =
1116 VectorType::get(sgShape, resultType.getElementType());
1119 auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
1124 auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
1126 SmallVector<int64_t> srcNonUnit, dstNonUnit;
1127 for (int64_t d : src)
1129 srcNonUnit.push_back(d);
1130 for (int64_t d : dst)
1132 dstNonUnit.push_back(d);
1133 return srcNonUnit == dstNonUnit;
1136 if (!onlyUnitDims(srcType.getShape(), sgShape))
1141 int64_t sourceRank = srcType.getRank();
1142 int64_t resultRank = sgShape.size();
1143 xegpu::DistributeLayoutAttr sourceLayout =
1145 if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1147 if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1150 SmallVector<Value> newShapeCastOps;
1151 for (
auto src : adaptor.getSource()) {
1152 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1153 newResultType, src);
1155 layout.dropSgLayoutAndData());
1156 newShapeCastOps.push_back(newShapeCast.getResult());
1159 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1169struct WgToSgMultiDimReductionOp
1170 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1171 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1174 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1175 ConversionPatternRewriter &rewriter)
const override {
1176 VectorType srcType = op.getSourceVectorType();
1177 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1181 auto srcShape = srcType.getShape();
1182 xegpu::DistributeLayoutAttr layout =
1184 if (!layout || !layout.isForWorkgroup())
1187 auto reductionDims = llvm::to_vector(op.getReductionDims());
1189 SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1191 .getEffectiveSgLayoutAsInt();
1192 SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1194 .getEffectiveSgDataAsInt();
1198 for (int64_t dim : reductionDims) {
1199 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1200 return rewriter.notifyMatchFailure(
1202 "sgLayout in each reduced dimension must be 1 and sgData in the "
1203 "reduced dim must match srcShape in that dim");
1206 SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1208 VectorType newDstType =
1209 VectorType::get({sgShape}, dstType.getElementType());
1211 SmallVector<Value> newReductions;
1212 for (
auto sgSrc : adaptor.getSource()) {
1213 auto newOp = vector::MultiDimReductionOp::create(
1214 rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1215 adaptor.getAcc()[0], op.getReductionDims());
1217 layout.dropSgLayoutAndData());
1218 newReductions.push_back(newOp.
getResult());
1221 rewriter.replaceOpWithMultiple(op, {newReductions});
1227struct WgToSgVectorTransposeOp
1228 :
public OpConversionPattern<vector::TransposeOp> {
1229 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1232 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1233 ConversionPatternRewriter &rewriter)
const override {
1234 VectorType resultType = op.getResultVectorType();
1236 ArrayRef<int64_t> wgShape = resultType.getShape();
1237 xegpu::DistributeLayoutAttr layout =
1239 if (!layout || !layout.isForWorkgroup())
1242 xegpu::DistributeLayoutAttr sourceLayout =
1244 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1247 SmallVector<int64_t> sourceSgLayout =
1248 sourceLayout.getEffectiveSgLayoutAsInt();
1249 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1253 if (!sourceOrder || !resultOrder) {
1254 return rewriter.notifyMatchFailure(
1255 op,
"Both source and result must have order attributes");
1258 ArrayRef<int64_t> permutation = op.getPermutation();
1259 size_t permutationSize = permutation.size();
1260 if (sourceSgLayout.size() != permutationSize ||
1261 resultSgLayout.size() != permutationSize) {
1262 return rewriter.notifyMatchFailure(
1263 op,
"Layouts and permutation must have the same rank");
1268 if (!layout.isTransposeOf(sourceLayout, permutation))
1269 return rewriter.notifyMatchFailure(
1270 op,
"Result layout is not a valid transpose of source layout "
1271 "according to permutation");
1273 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1274 VectorType newResultType =
1275 VectorType::get(sgShape, resultType.getElementType());
1276 SmallVector<Value> newTransposeOps;
1277 for (
auto src : adaptor.getVector()) {
1278 auto newTranspose = vector::TransposeOp::create(
1279 rewriter, op.getLoc(), newResultType, src, permutation);
1281 layout.dropSgLayoutAndData());
1282 newTransposeOps.push_back(newTranspose.getResult());
1285 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1291template <
typename MaskOpType>
1292struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1293 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1295 LogicalResult matchAndRewrite(
1297 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1298 ConversionPatternRewriter &rewriter)
const override {
1299 xegpu::DistributeLayoutAttr layout =
1301 if (!layout || !layout.isForWorkgroup())
1304 Location loc = op.getLoc();
1305 VectorType type = op.getResult().getType();
1306 auto wgShape = type.getShape();
1308 SmallVector<Value> wgMaskDimSizes;
1309 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1310 for (int64_t maskSize : op.getMaskDimSizes()) {
1311 wgMaskDimSizes.push_back(
1314 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1315 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1319 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1321 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1325 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1326 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1330 SmallVector<Value> newCreateMaskOps;
1331 for (
auto offsetSet : *sgOffsets) {
1332 SmallVector<Value> maskOperands;
1334 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1337 Value offset = offsetSet[i];
1338 Value adjustedMaskSize =
1339 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1342 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1344 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1345 maskOperands.push_back(sgMaskSize);
1348 auto newCreateMaskOp =
1349 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1351 layout.dropSgLayoutAndData());
1352 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1355 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1360using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1361using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1368 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1369 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1370 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1371 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1372 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1373 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1374 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1375 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1376 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1377 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1384struct XeGPUWgToSgDistributePass
1386 void runOnOperation()
override;
1390void XeGPUWgToSgDistributePass::runOnOperation() {
1392 SmallVector<Operation *> existingCastOps;
1393 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1394 existingCastOps.push_back(castOp.getOperation());
1404 TypeConverter converter;
1405 converter.addConversion([&](Type type) -> Type {
return type; });
1406 converter.addConversion(
1407 [&](RankedTensorType type,
1408 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1409 Type elemTy = type.getElementType();
1410 ArrayRef<int64_t> shape = type.getShape();
1413 SmallVector<int64_t> subShape;
1414 std::tie(subShape, count) = getSgShapeAndCount(
1416 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1418 auto newTy = VectorType::get(subShape, elemTy);
1419 result.append(count, newTy);
1431 ConversionTarget
target(*ctx);
1432 TypeConverter converter;
1433 converter.addConversion([&](Type type) -> Type {
return type; });
1434 converter.addConversion(
1435 [&](xegpu::TensorDescType type,
1436 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1437 Type elemTy = type.getElementType();
1438 ArrayRef<int64_t> shape = type.getShape();
1441 SmallVector<int64_t> subShape;
1442 xegpu::LayoutAttr layout = type.getLayoutAttr();
1443 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1446 layout = layout.dropSgLayoutAndData();
1448 auto newTy = xegpu::TensorDescType::get(
1449 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1450 result.append(count, newTy);
1454 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1455 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1456 return createOp.getType();
1457 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1458 return loadOp.getTensorDescType();
1459 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1460 return storeOp.getTensorDescType();
1461 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1463 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1464 return prefetchOp.getTensorDescType();
1465 return xegpu::TensorDescType();
1468 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1469 return !layout || !layout.isForWorkgroup();
1472 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1473 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1474 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1475 auto tdescTy = getTensorDescType(op);
1476 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1477 return isLegal(layout);
1480 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1482 return isLegal(layout);
1485 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1486 [=](xegpu::LoadMatrixOp op) ->
bool {
1487 return isLegal(op.getLayoutAttr());
1490 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1491 [=](xegpu::StoreMatrixOp op) ->
bool {
1492 return isLegal(op.getLayoutAttr());
1495 target.addDynamicallyLegalOp<arith::ConstantOp>(
1496 [=](arith::ConstantOp op) ->
bool {
1497 auto vecType = dyn_cast<VectorType>(op.getType());
1502 return isLegal(layout);
1505 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1506 vector::TransposeOp, vector::BroadcastOp,
1507 vector::MultiDimReductionOp,
1508 vector::ConstantMaskOp, vector::CreateMaskOp>(
1509 [=](Operation *op) ->
bool {
1512 return isLegal(layout);
1515 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1516 [=](xegpu::LoadGatherOp op) ->
bool {
1518 return isLegal(layout);
1521 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1522 [=](xegpu::StoreScatterOp op) ->
bool {
1524 return isLegal(layout);
1527 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1528 [=](xegpu::ConvertLayoutOp op) ->
bool {
1529 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1532 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1533 [=](Operation *op) -> std::optional<bool> {
1538 VectorType resultType =
1539 dyn_cast<VectorType>(op->getResult(0).getType());
1545 for (Value operand : op->getOperands()) {
1546 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1547 if (!operandType || operandType.getShape() != resultType.getShape()) {
1552 xegpu::DistributeLayoutAttr layout =
1554 return isLegal(layout);
1557 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1558 [=](UnrealizedConversionCastOp op) {
1559 return llvm::is_contained(existingCastOps, op.getOperation());
1562 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1568 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
1569 return signalPassFailure();
1576 getOperation()->walk([](Operation *op) {
1579 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
1581 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1582 if (
auto newLayout = layout.dropSgLayoutAndData())
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.