28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
41 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->
getAttr(
"sg_id_range")))
49static std::pair<SmallVector<int64_t>,
int>
51 xegpu::DistributeLayoutAttr layout) {
54 if (layout && layout.isForWorkgroup()) {
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
59 sgShape = *maybeDerivedSgData;
64 for (
size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(
shape[i], distUnit[i]);
68 return std::make_pair(sgShape, count);
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
86 if (origOffsets.empty())
90 xegpu::DistributeLayoutAttr layout;
91 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
92 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
93 layout = op.getLayoutAttr();
95 layout = op.getDescLayoutAttr();
99 if (!layout || !layout.isForWorkgroup())
103 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
106 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
108 int64_t startOfRange = sgIdRange.getStart().getInt();
109 int64_t endOfRange = sgIdRange.getEnd().getInt();
111 if (layout.getNumSubgroups() != endOfRange - startOfRange)
112 return rewriter.notifyMatchFailure(
113 op,
"sg_layout size must match the sg_id_range");
115 if (startOfRange > 0) {
116 Value startOfRangeVal =
118 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
125 auto maybeDescOffsets =
126 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
127 if (
failed(maybeDescOffsets))
132 for (
const auto &sgOffsets : *maybeDescOffsets) {
135 offsetsList.push_back(std::move(newOffsets));
187struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
188 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
191 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter)
const override {
193 SmallVector<SmallVector<OpFoldResult>> offsetsList;
194 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
197 MLIRContext *ctx = op.getContext();
198 xegpu::TensorDescType tdescTy = op.getType();
199 ArrayRef<int64_t> wgShape = tdescTy.getShape();
200 Type elemTy = tdescTy.getElementType();
201 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
202 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
204 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
205 layout.dropSgLayoutAndData());
207 SmallVector<Value> newOps;
208 for (
auto offsets : offsetsList) {
209 auto newOp = xegpu::CreateNdDescOp::create(
210 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
211 op.getMixedSizes(), op.getMixedStrides());
213 newOps.push_back(newOp);
215 rewriter.replaceOpWithMultiple(op, {newOps});
223struct WgToSgCreateNdOpNoOffset
224 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
225 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
228 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter)
const override {
232 if (!op.getMixedOffsets().empty())
235 Location loc = op.getLoc();
236 MLIRContext *ctx = op.getContext();
237 xegpu::TensorDescType tdescTy = op.getType();
238 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
239 if (!layout || !layout.isForWorkgroup())
242 Type elemTy = tdescTy.getElementType();
243 ArrayRef<int64_t> wgShape = tdescTy.getShape();
245 SmallVector<int64_t> sgShape;
247 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
248 xegpu::TensorDescType newTdescTy =
249 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
250 layout.dropSgLayoutAndData());
252 SmallVector<Value> newCreateNdOps(count);
253 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
254 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
255 op.getSource(), op.getMixedSizes(),
256 op.getMixedStrides());
259 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
265struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
266 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
268 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter)
const override {
270 if (!op.getMixedOffsets().empty())
273 SmallVector<Value> newLoadOps;
274 for (
auto src : adaptor.getTensorDesc()) {
275 xegpu::TensorDescType tdescTy =
276 dyn_cast<xegpu::TensorDescType>(src.getType());
277 ArrayRef<int64_t> srcShape = tdescTy.getShape();
278 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
279 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
280 src, op->getAttrs());
281 newLoadOps.push_back(newLoadOp);
283 rewriter.replaceOpWithMultiple(op, {newLoadOps});
284 return mlir::success();
291struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
292 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
294 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter)
const override {
296 if (!op.getMixedOffsets().empty())
299 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
300 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
301 op.getL2HintAttr(), op.getL3HintAttr());
303 rewriter.eraseOp(op);
310struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
311 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
313 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
316 SmallVector<SmallVector<OpFoldResult>> offsetsList;
317 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
320 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
322 layout = layout.dropSgLayoutAndData();
323 SmallVector<Value> newOps;
324 for (
auto [tdesc, offsets] :
325 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
326 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
327 VectorType newResTy =
328 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
329 auto newOp = xegpu::LoadNdOp::create(
330 rewriter, op.getLoc(), newResTy, tdesc, offsets,
331 nullptr,
nullptr, op.getL1HintAttr(),
332 op.getL2HintAttr(), op.getL3HintAttr(), layout);
333 newOps.push_back(newOp);
335 rewriter.replaceOpWithMultiple(op, {newOps});
343struct WgToSgStoreNdOpWithOffset
344 :
public OpConversionPattern<xegpu::StoreNdOp> {
345 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
347 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter)
const override {
349 SmallVector<SmallVector<OpFoldResult>> offsetsList;
350 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
353 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
355 layout = layout.dropSgLayoutAndData();
356 for (
auto [v, tdesc, offsets] :
357 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
358 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
359 op.getL1HintAttr(), op.getL2HintAttr(),
360 op.getL3HintAttr(), layout);
362 rewriter.eraseOp(op);
370struct WgToSgPrefetchNdOpWithOffset
371 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
372 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
374 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
375 ConversionPatternRewriter &rewriter)
const override {
376 SmallVector<SmallVector<OpFoldResult>> offsetsList;
377 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
380 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
382 layout = layout.dropSgLayoutAndData();
383 for (
auto [tdesc, offsets] :
384 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
385 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
386 op.getL1HintAttr(), op.getL2HintAttr(),
387 op.getL3HintAttr(), layout);
389 rewriter.eraseOp(op);
398struct WgToSgUpdateNdOffsetOp
399 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
400 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
402 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter)
const override {
404 llvm::SmallVector<Value> newUpdateTileOffsetOps;
405 for (
auto tDesc : adaptor.getTensorDesc()) {
406 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
407 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
408 op.getConstOffsets());
409 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
412 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
418struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
419 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
421 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter)
const override {
423 Location loc = op.getLoc();
424 VectorType resultTy = op.getResult().getType();
425 if (resultTy.getRank() != 2)
428 auto layoutCd = op.getLayoutCdAttr();
429 auto layoutA = op.getLayoutAAttr();
430 auto layoutB = op.getLayoutBAttr();
431 if (!layoutCd || !layoutA || !layoutB)
434 SmallVector<Value> newDpasOps;
435 for (
auto aVec : adaptor.getLhs()) {
436 for (
auto bVec : adaptor.getRhs()) {
438 llvm::SmallVector<Value> operands({aVec, bVec});
441 tmpC = adaptor.getAcc()[i++];
442 operands.push_back(tmpC);
445 ArrayRef<int64_t> aVecShape =
446 llvm::cast<VectorType>(aVec.getType()).getShape();
447 ArrayRef<int64_t> bVecShape =
448 llvm::cast<VectorType>(bVec.getType()).getShape();
449 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
450 resultTy.getElementType());
451 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
452 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
453 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
454 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
456 newDpasOps.push_back(newDpasOp);
459 rewriter.replaceOpWithMultiple(op, {newDpasOps});
465struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
466 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
468 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter)
const override {
471 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
472 if ((offsetSize != 0) || op.getConstOffsetsAttr())
475 for (
auto src : adaptor.getTensorDesc())
476 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(),
TypeRange(), src,
478 rewriter.eraseOp(op);
484struct WgToSgVectorBroadcastOp
485 :
public OpConversionPattern<vector::BroadcastOp> {
486 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
489 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter)
const override {
492 VectorType resultType = op.getResult().getType();
493 ArrayRef<int64_t> wgShape = resultType.getShape();
495 xegpu::DistributeLayoutAttr layout =
497 if (!layout || !layout.isForWorkgroup())
500 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
501 VectorType newResultType =
502 VectorType::get(sgShape, resultType.getElementType());
504 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
507 SmallVector<Value> newBroadcastOps;
508 for (
auto operand : adaptor.getOperands().front()) {
509 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
510 newResultType, operand);
512 layout.dropSgLayoutAndData());
514 newBroadcastOps.push_back(newBroadcast.getResult());
516 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
528 ConversionPatternRewriter &rewriter)
const override {
534 assert(resultType &&
"Expected result to be a VectorType");
538 xegpu::DistributeLayoutAttr layout =
540 if (!layout || !layout.isForWorkgroup())
545 size_t numVariants = operands.empty() ? 0 : operands.front().size();
547 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
548 return operandVec.size() != numVariants;
553 VectorType newResultType =
554 VectorType::get(sgShape, resultType.getElementType());
556 for (
size_t i = 0; i < numVariants; ++i) {
558 for (
auto &operandVec : operands)
559 opOperands.push_back(operandVec[i]);
568 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
569 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
570 !layout.getEffectiveInstDataAsInt().empty())
571 state.
addAttribute(attr.getName(), layout.dropSgLayoutAndData());
576 Operation *newOp = rewriter.create(state);
577 newResults.push_back(newOp->
getResult(0));
580 rewriter.replaceOpWithMultiple(op, {newResults});
611struct WgToSgConvertLayoutOp
612 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
613 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
615 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter)
const override {
618 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
619 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
621 if (!input || !
target || !input.isForWorkgroup() ||
623 return rewriter.notifyMatchFailure(
624 op,
"Input and target layouts must have subgroup layout");
635 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
636 inputOrder != targetOrder)
639 input = input.dropSgLayoutAndData();
642 SmallVector<Value> newOps(adaptor.getSource());
645 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
646 auto newOp = xegpu::ConvertLayoutOp::create(
647 rewriter, op.getLoc(), src.getType(), src, input,
target);
651 rewriter.replaceOpWithMultiple(op, {newOps});
687struct UnrealizedConversionCastOpPattern
688 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
689 using OpConversionPattern<
690 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
693 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
694 ConversionPatternRewriter &rewriter)
const override {
697 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
698 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
700 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
701 !llvm::all_equal(
ValueRange(inputs).getTypes()))
709 if (op.getNumOperands() == 1 &&
710 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
711 rewriter.replaceOp(op, inputs);
722 if (op.getNumResults() == 1 &&
724 rewriter.replaceOpWithMultiple(op, {inputs});
728 return mlir::failure();
733struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
734 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
737 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
738 ConversionPatternRewriter &rewriter)
const override {
739 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
740 auto vecType = dyn_cast<VectorType>(op.getType());
741 if (!vecAttr || !vecType)
744 xegpu::DistributeLayoutAttr layout =
746 if (!layout || !layout.isForWorkgroup())
749 ArrayRef<int64_t> wgShape = vecType.getShape();
750 SmallVector<int64_t> sgShape;
752 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
754 auto newType = VectorType::get(sgShape, vecType.getElementType());
755 Location loc = op.getLoc();
756 auto eltType = vecType.getElementType();
758 auto setLayout = [&](Value val) {
760 layout.dropSgLayoutAndData());
763 if (vecAttr.isSplat()) {
765 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
767 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
768 setLayout(cstOp->getResult(0));
769 rewriter.replaceOp(op, cstOp);
771 }
else if (sgShape == wgShape) {
774 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
775 setLayout(newConstOp->getResult(0));
776 rewriter.replaceOp(op, newConstOp);
782 if (!eltType.isIndex())
783 return rewriter.notifyMatchFailure(
784 op,
"Unsupported element type for non-splat constant op.");
786 if (wgShape.size() > 2)
787 return rewriter.notifyMatchFailure(
788 op,
"Only 1D & 2D vector constant supported");
790 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
791 int64_t rowStride = 0, colStride = 0;
792 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
793 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
797 colStride = cast<IntegerAttr>(values[1]).getInt() -
798 cast<IntegerAttr>(values[0]).getInt();
801 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
802 cast<IntegerAttr>(values[0]).getInt();
805 for (int64_t r = 0; r < rows; ++r) {
806 for (int64_t c = 0; c < cols; ++c) {
807 int64_t idx = r * cols + c;
809 if (c > 0 && cols > 1) {
810 int64_t prevIdx = r * cols + (c - 1);
811 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
812 cast<IntegerAttr>(values[prevIdx]).getInt();
813 if (diff != colStride)
814 return rewriter.notifyMatchFailure(
815 op,
"Non-constant column stride in constant op.");
818 if (r > 0 && rows > 1) {
819 int64_t prevIdx = (r - 1) * cols + c;
820 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
821 cast<IntegerAttr>(values[prevIdx]).getInt();
822 if (diff != rowStride)
823 return rewriter.notifyMatchFailure(
824 op,
"Non-constant row stride in constant op.");
832 SmallVector<Attribute> baseTileValues;
833 int baseTileCols = sgShape[sgShape.size() - 1];
834 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
835 for (int64_t r = 0; r < baseTileRows; ++r) {
836 for (int64_t c = 0; c < baseTileCols; ++c) {
837 baseTileValues.push_back(values[r * cols + c]);
843 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
847 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
849 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
853 SmallVector<Value, 2> strideConsts;
854 strideConsts.push_back(
858 strideConsts.begin(),
861 SmallVector<Value> newConstOps;
862 for (
auto offsets : *sgOffsets) {
865 for (
size_t i = 0; i < strideConsts.size(); ++i) {
867 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
868 offsets[i], strideConsts[i]);
869 mulOffset = arith::AddIOp::create(
870 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
873 auto bcastOffset = vector::BroadcastOp::create(
874 rewriter, loc, baseConstVec.getType(), mulOffset);
876 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
877 setLayout(baseConstVec);
878 setLayout(bcastOffset);
879 setLayout(finalConst);
880 newConstOps.push_back(finalConst);
882 rewriter.replaceOpWithMultiple(op, {newConstOps});
890struct WgToSgLoadGatherOpWithOffset
891 :
public OpConversionPattern<xegpu::LoadGatherOp> {
892 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
894 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
895 ConversionPatternRewriter &rewriter)
const override {
897 if (!op.getOffsets())
900 Location loc = op.getLoc();
901 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
904 ArrayRef<int64_t> wgShape = resultType.getShape();
906 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
908 if (!layout || !layout.isForWorkgroup())
911 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
914 auto offsetsVecType =
915 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
917 dyn_cast<VectorType>(adaptor.getMask().front().getType());
918 if (!offsetsVecType || !maskVecType ||
919 offsetsVecType.getShape() != maskVecType.getShape()) {
920 return rewriter.notifyMatchFailure(op,
921 "offsets have not been distributed");
924 SmallVector<Value> newLoadOps;
926 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
927 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
928 for (
auto [offsets, mask] :
929 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
930 auto newLayout = layout.dropSgLayoutAndData();
931 auto newLoadOp = xegpu::LoadGatherOp::create(
932 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
933 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
935 newLoadOp.setAnchorLayout(newLayout);
936 newLoadOps.push_back(newLoadOp);
938 rewriter.replaceOpWithMultiple(op, {newLoadOps});
945struct WgToSgStoreScatterOpWithOffset
946 :
public OpConversionPattern<xegpu::StoreScatterOp> {
947 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
949 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter)
const override {
952 if (!op.getOffsets())
955 Location loc = op.getLoc();
956 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
960 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
962 if (!layout || !layout.isForWorkgroup())
966 auto offsetsVecType =
967 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
969 dyn_cast<VectorType>(adaptor.getMask().front().getType());
970 if (!offsetsVecType || !maskVecType ||
971 offsetsVecType.getShape() != maskVecType.getShape()) {
972 return rewriter.notifyMatchFailure(op,
973 "offsets have not been distributed");
976 auto chunkSizeOpt = op.getChunkSize();
977 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
978 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
979 for (
auto [val, offs, mask] : llvm::zip(
980 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
981 auto store = xegpu::StoreScatterOp::create(
982 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
983 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
984 layout.dropSgLayoutAndData());
986 for (OpOperand &operand : store->getOpOperands()) {
988 if (operand.getOperandNumber() == 1)
993 rewriter.eraseOp(op);
998struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
999 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1001 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1002 ConversionPatternRewriter &rewriter)
const override {
1004 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1005 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1008 ArrayRef<int64_t> wgShape = op.getDataShape();
1009 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1010 assert(valueTy &&
"the value type must be vector type!");
1011 Type elemTy = valueTy.getElementType();
1013 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1014 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1015 VectorType newResTy = VectorType::get(sgShape, elemTy);
1016 SmallVector<Value> newOps;
1017 for (
auto offsets : offsetsList) {
1018 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1019 op.getMemDesc(), offsets,
1020 layout.dropSgLayoutAndData());
1021 newOps.push_back(newOp);
1023 rewriter.replaceOpWithMultiple(op, {newOps});
1029struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1030 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1032 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1033 ConversionPatternRewriter &rewriter)
const override {
1035 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1036 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1039 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1040 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1041 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1042 offsets, layout.dropSgLayoutAndData());
1043 rewriter.eraseOp(op);
1049struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1050 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1052 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1053 ConversionPatternRewriter &rewriter)
const override {
1054 xegpu::DistributeLayoutAttr layout =
1056 if (!layout || !layout.isForWorkgroup())
1059 Location loc = op.getLoc();
1060 VectorType type = op.getResult().getType();
1061 auto wgShape = type.getShape();
1062 std::optional<SmallVector<int64_t>> sgShape =
1063 getSgShapeAndCount(wgShape, layout).first;
1068 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1070 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1074 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1075 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1076 SmallVector<Value> newOps;
1077 for (
auto offsets : *sgOffsets) {
1080 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1082 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1084 layout.dropSgLayoutAndData());
1086 layout.dropSgLayoutAndData());
1088 layout.dropSgLayoutAndData());
1089 newOps.push_back(finalSteps);
1092 rewriter.replaceOpWithMultiple(op, {newOps});
1098struct WgToSgVectorShapeCastOp
1099 :
public OpConversionPattern<vector::ShapeCastOp> {
1100 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1103 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter)
const override {
1106 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1110 ArrayRef<int64_t> wgShape = resultType.getShape();
1111 xegpu::DistributeLayoutAttr layout =
1113 if (!layout || !layout.isForWorkgroup())
1118 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1122 ArrayRef<int64_t> srcShape = srcType.getShape();
1123 llvm::SetVector<int64_t> expandedUnitDims;
1127 auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
1128 ArrayRef<int64_t> dst) ->
bool {
1132 for (
size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
1133 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
1135 else if (dst[dstIdx] == 1)
1136 expandedUnitDims.insert(dstIdx);
1139 return srcIdx == src.size();
1142 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1144 if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
1145 xegpu::DistributeLayoutAttr sourceLayout =
1148 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1149 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1150 return isa<vector::BroadcastOp>(user);
1154 if (!usedByBroadcastOp(op))
1155 return rewriter.notifyMatchFailure(
1156 op,
"ShapeCast ops that expand unit dimensions and are used by "
1157 "non-broadcast operations are not supported.");
1159 if (!sourceLayout.isSliceOf(layout))
1160 return rewriter.notifyMatchFailure(
1161 op,
"The ShapeCast op only expands dimensions, the result layout "
1162 "must be a slice of the input layout, or vice versa.");
1163 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1164 layoutToDistribute =
1165 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1168 SmallVector<int64_t> sgShape =
1169 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1170 VectorType newResultType =
1171 VectorType::get(sgShape, resultType.getElementType());
1173 SmallVector<Value> newShapeCastOps;
1174 for (
auto src : adaptor.getSource()) {
1175 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1176 newResultType, src);
1178 layout.dropSgLayoutAndData());
1179 newShapeCastOps.push_back(newShapeCast.getResult());
1182 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1192struct WgToSgMultiDimReductionOp
1193 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1194 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1197 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1198 ConversionPatternRewriter &rewriter)
const override {
1199 VectorType srcType = op.getSourceVectorType();
1200 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1204 auto srcShape = srcType.getShape();
1205 xegpu::DistributeLayoutAttr layout =
1207 if (!layout || !layout.isForWorkgroup())
1210 auto reductionDims = llvm::to_vector(op.getReductionDims());
1212 SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1214 .getEffectiveSgLayoutAsInt();
1215 SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1217 .getEffectiveSgDataAsInt();
1221 for (int64_t dim : reductionDims) {
1222 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1223 return rewriter.notifyMatchFailure(
1225 "sgLayout in each reduced dimension must be 1 and sgData in the "
1226 "reduced dim must match srcShape in that dim");
1229 SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1231 VectorType newDstType =
1232 VectorType::get({sgShape}, dstType.getElementType());
1234 SmallVector<Value> newReductions;
1235 for (
auto sgSrc : adaptor.getSource()) {
1236 auto newOp = vector::MultiDimReductionOp::create(
1237 rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1238 adaptor.getAcc()[0], op.getReductionDims());
1240 layout.dropSgLayoutAndData());
1241 newReductions.push_back(newOp.
getResult());
1244 rewriter.replaceOpWithMultiple(op, {newReductions});
1250struct WgToSgVectorTransposeOp
1251 :
public OpConversionPattern<vector::TransposeOp> {
1252 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1255 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter)
const override {
1257 VectorType resultType = op.getResultVectorType();
1259 ArrayRef<int64_t> wgShape = resultType.getShape();
1260 xegpu::DistributeLayoutAttr layout =
1262 if (!layout || !layout.isForWorkgroup())
1265 xegpu::DistributeLayoutAttr sourceLayout =
1267 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1270 SmallVector<int64_t> sourceSgLayout =
1271 sourceLayout.getEffectiveSgLayoutAsInt();
1272 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1276 if (!sourceOrder || !resultOrder) {
1277 return rewriter.notifyMatchFailure(
1278 op,
"Both source and result must have order attributes");
1281 ArrayRef<int64_t> permutation = op.getPermutation();
1282 size_t permutationSize = permutation.size();
1283 if (sourceSgLayout.size() != permutationSize ||
1284 resultSgLayout.size() != permutationSize) {
1285 return rewriter.notifyMatchFailure(
1286 op,
"Layouts and permutation must have the same rank");
1291 if (!layout.isTransposeOf(sourceLayout, permutation))
1292 return rewriter.notifyMatchFailure(
1293 op,
"Result layout is not a valid transpose of source layout "
1294 "according to permutation");
1296 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1297 VectorType newResultType =
1298 VectorType::get(sgShape, resultType.getElementType());
1299 SmallVector<Value> newTransposeOps;
1300 for (
auto src : adaptor.getVector()) {
1301 auto newTranspose = vector::TransposeOp::create(
1302 rewriter, op.getLoc(), newResultType, src, permutation);
1304 layout.dropSgLayoutAndData());
1305 newTransposeOps.push_back(newTranspose.getResult());
1308 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1314template <
typename MaskOpType>
1315struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1316 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1318 LogicalResult matchAndRewrite(
1320 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1321 ConversionPatternRewriter &rewriter)
const override {
1322 xegpu::DistributeLayoutAttr layout =
1324 if (!layout || !layout.isForWorkgroup())
1327 Location loc = op.getLoc();
1328 VectorType type = op.getResult().getType();
1329 auto wgShape = type.getShape();
1331 SmallVector<Value> wgMaskDimSizes;
1332 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1333 for (int64_t maskSize : op.getMaskDimSizes()) {
1334 wgMaskDimSizes.push_back(
1337 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1338 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1342 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1344 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1348 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1349 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1353 SmallVector<Value> newCreateMaskOps;
1354 for (
auto offsetSet : *sgOffsets) {
1355 SmallVector<Value> maskOperands;
1357 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1360 Value offset = offsetSet[i];
1361 Value adjustedMaskSize =
1362 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1365 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1367 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1368 maskOperands.push_back(sgMaskSize);
1371 auto newCreateMaskOp =
1372 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1374 layout.dropSgLayoutAndData());
1375 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1378 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1383using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1384using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1391 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1392 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1393 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1394 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1395 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1396 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1397 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1398 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1399 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1400 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1407struct XeGPUWgToSgDistributePass
1409 void runOnOperation()
override;
1413void XeGPUWgToSgDistributePass::runOnOperation() {
1423 SmallVector<Operation *> existingCastOps;
1424 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1425 existingCastOps.push_back(castOp.getOperation());
1435 TypeConverter converter;
1436 converter.addConversion([&](Type type) -> Type {
return type; });
1437 converter.addConversion(
1438 [&](RankedTensorType type,
1439 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1440 Type elemTy = type.getElementType();
1441 ArrayRef<int64_t> shape = type.getShape();
1444 SmallVector<int64_t> subShape;
1445 std::tie(subShape, count) = getSgShapeAndCount(
1447 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1449 auto newTy = VectorType::get(subShape, elemTy);
1450 result.append(count, newTy);
1462 ConversionTarget
target(*ctx);
1463 TypeConverter converter;
1464 converter.addConversion([&](Type type) -> Type {
return type; });
1465 converter.addConversion(
1466 [&](xegpu::TensorDescType type,
1467 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1468 Type elemTy = type.getElementType();
1469 ArrayRef<int64_t> shape = type.getShape();
1472 SmallVector<int64_t> subShape;
1473 xegpu::LayoutAttr layout = type.getLayoutAttr();
1474 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1477 layout = layout.dropSgLayoutAndData();
1479 auto newTy = xegpu::TensorDescType::get(
1480 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1481 result.append(count, newTy);
1485 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1486 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1487 return createOp.getType();
1488 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1489 return loadOp.getTensorDescType();
1490 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1491 return storeOp.getTensorDescType();
1492 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1494 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1495 return prefetchOp.getTensorDescType();
1496 return xegpu::TensorDescType();
1499 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1500 return !layout || !layout.isForWorkgroup();
1503 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1504 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1505 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1506 auto tdescTy = getTensorDescType(op);
1507 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1508 return isLegal(layout);
1511 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1512 auto layout = op.getLayoutCdAttr();
1513 return isLegal(layout);
1516 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1517 [=](xegpu::LoadMatrixOp op) ->
bool {
1518 return isLegal(op.getLayoutAttr());
1521 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1522 [=](xegpu::StoreMatrixOp op) ->
bool {
1523 return isLegal(op.getLayoutAttr());
1526 target.addDynamicallyLegalOp<arith::ConstantOp>(
1527 [=](arith::ConstantOp op) ->
bool {
1528 auto vecType = dyn_cast<VectorType>(op.getType());
1534 return isLegal(layout);
1537 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1538 vector::TransposeOp, vector::BroadcastOp,
1539 vector::MultiDimReductionOp,
1540 vector::ConstantMaskOp, vector::CreateMaskOp>(
1541 [=](Operation *op) ->
bool {
1545 return isLegal(layout);
1548 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1549 [=](xegpu::LoadGatherOp op) ->
bool {
1550 auto layout = op.getLayoutAttr();
1551 return isLegal(layout);
1554 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1555 [=](xegpu::StoreScatterOp op) ->
bool {
1556 auto layout = op.getLayoutAttr();
1557 return isLegal(layout);
1560 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1561 [=](xegpu::ConvertLayoutOp op) ->
bool {
1562 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1565 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1566 [=](Operation *op) -> std::optional<bool> {
1571 VectorType resultType =
1572 dyn_cast<VectorType>(op->getResult(0).getType());
1578 for (Value operand : op->getOperands()) {
1579 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1580 if (!operandType || operandType.getShape() != resultType.getShape()) {
1585 xegpu::DistributeLayoutAttr layout =
1587 return isLegal(layout);
1590 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1591 [=](UnrealizedConversionCastOp op) {
1592 return llvm::is_contained(existingCastOps, op.getOperation());
1595 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1601 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
1602 return signalPassFailure();
1609 getOperation()->walk([](Operation *op) {
1612 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
1614 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1615 if (
auto newLayout = layout.dropSgLayoutAndData())
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)