28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
40#define DEBUG_TYPE "shard-ops"
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
50 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(
int64_t val) : val(val) {}
52 int64_t value()
const {
return val; }
53 operator int64_t()
const {
return val; }
54 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
63 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
64 return DimensionSize::dynamic();
66 return lhs.value() /
rhs.value();
70 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
71 return DimensionSize::dynamic();
73 return lhs.value() *
rhs.value();
81 auto dyn = dynamics.begin();
82 Type i64 =
b.getI64Type();
85 assert((i64 == type ||
b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (
auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
91 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
103struct ShardInlinerinterface :
public DialectInlinerInterface {
104 using DialectInlinerInterface::DialectInlinerInterface;
112 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
122void ShardDialect::initialize() {
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
135 addInterface<ShardInlinerinterface>();
141 return arith::ConstantOp::materialize(builder, value, type, loc);
151 shard::GridOp grid =
getGridOrNull(op, gridSymbol, symbolTable);
153 return op->
emitError() <<
"Undefined required grid symbol \""
160template <
typename It>
165 It next = std::next(begin);
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
181 if (!
isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) <<
"Grid axes contains duplicate elements.";
186 for (
auto axis : axes) {
187 if (axis >= rank || axis < 0) {
189 <<
"0-based grid axis index " << axis
190 <<
" is out of bounds. The referenced grid \"" << grid.getSymName()
191 <<
"\" is of rank " << rank <<
".";
198template <
typename Op>
199static FailureOr<GridOp>
212template <
typename InShape,
typename GridShape,
typename SplitAxes,
214static void shardShape(
const InShape &inShape,
const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
219 if (inShape.empty()) {
220 assert(outShape.empty());
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
230 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
238 uint64_t numShards = 0;
239 for (
auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
242 for (
size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
249 pos += numShards + 1;
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
255 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
261 if (!haloSizes.empty()) {
264 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
273 outShape[tensorAxis] = ShapedType::kDynamic;
283 using Dim = std::decay_t<
decltype(
shape.getDimSize(0))>;
288 return shape.clone(resShapeArr);
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
303 ShardOp &newShardOp) {
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
311 newShardOp = shardOp;
318 ShardingOp::create(builder, operandValue.
getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.
getLoc(), operandValue,
324 newShardOp, [operandOp, operandValue](
OpOperand &use) {
325 return use.
getOwner() == operandOp && use.
get() == operandValue;
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
332 auto newShardOp2 = ShardOp::create(builder, operandValue.
getLoc(), newShardOp,
333 newShardOp.getSharding(),
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
343 for (
auto &use :
result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
346 for (
auto &[operandValue, operandOp] : uses) {
348 builder, newShardOp);
358 bool isBlockArg = !operandSrcOp;
360 [[maybe_unused]]
auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
364 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
380 ShardingOp::create(builder, operand.
get().
getLoc(), sharding);
382 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
386 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
387 return use.
getOwner() == operandOp && use.
get() == operandValue;
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
409LogicalResult GridOp::verify() {
413 return emitOpError(
"rank of grid is expected to be a positive integer");
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError(
"dimension size of a grid is expected to be "
418 "non-negative or dynamic");
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() <<
"Unexpected number of results " << getResult().size()
442 <<
". Expected " << expectedResultsCount <<
".";
455 build(odsBuilder, odsState,
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
469void GridShapeOp::getAsmResultNames(
471 setNameFn(getResults()[0],
"grid_shape");
483 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), splitAxes),
493 GridAxesArrayAttr::get(
b.getContext(), splitAxes),
499void ShardingOp::build(
509 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), splitAxes),
517 build(
b, odsState, ShardingType::get(
b.getContext()), from.
getGridAttr(),
529LogicalResult ShardingOp::verify() {
530 llvm::SmallSet<GridAxis, 4> visitedAxes;
535 return emitError() <<
"grid axis is expected to be non-negative";
536 if (!visitedAxes.insert(axis).second)
537 return emitError() <<
"grid axis duplicated";
542 for (
auto subAxes : getSplitAxes().getAxes()) {
544 if (
failed(checkGridAxis(subAxesArray)))
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (
auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() <<
"halo sizes must be specified for all split axes.";
567void ShardingOp::getAsmResultNames(
569 setNameFn(getResult(),
"sharding");
577 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
578 !getStaticShardedDimsOffsets().empty()) {
579 return emitError() <<
"sharded dims offsets are not allowed for "
580 "device grids with dynamic shape.";
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto gridShape = grid.value().getShape();
586 assert(ShapedType::isStaticShape(gridShape));
588 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (
auto i : innerSplitAxes.asArrayRef()) {
592 numShards += gridShape[i];
594 for (
int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() <<
"sharded dims offsets has wrong size.";
598 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
601 <<
"sharded dims offsets must be non-decreasing.";
603 off = shardedDimsOffsets[pos + i];
606 pos += numShards + 1;
621 using OpRewritePattern<ShardingOp>::OpRewritePattern;
623 LogicalResult matchAndRewrite(ShardingOp op,
624 PatternRewriter &
b)
const override {
628 op.getDynamicShardedDimsOffsets(),
b);
635 auto staticHalos = decomposedHalos.first;
636 auto dynamicHalos = decomposedHalos.second;
638 auto staticOffs = decomposedOffs.first;
639 auto dynamicOffs = decomposedOffs.second;
641 if (dynamicHalos.empty() && !staticHalos.empty()) {
642 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
653 if (dynamicOffs.empty() && !staticOffs.empty()) {
654 assert(staticOffs.size() >= 2);
655 auto diff = staticOffs[1] - staticOffs[0];
656 bool allSame = staticOffs.size() > 2;
657 for (
auto i = 2u; i < staticOffs.size(); ++i) {
658 if (staticOffs[i] - staticOffs[i - 1] != diff) {
673 b.modifyOpInPlace(op, [&]() {
674 op.setStaticHaloSizes(staticHalos);
675 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
676 op.setStaticShardedDimsOffsets(staticOffs);
677 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
686 results.
add<NormalizeSharding>(context);
698 auto minSize = std::min(
getSplitAxes().size(),
rhs.getSplitAxes().size());
699 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
701 llvm::make_range(
rhs.getSplitAxes().begin(),
702 rhs.getSplitAxes().begin() + minSize))) {
706 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
707 std::mem_fn(&GridAxesAttr::empty)) &&
708 llvm::all_of(llvm::drop_begin(
rhs.getSplitAxes(), minSize),
709 std::mem_fn(&GridAxesAttr::empty));
717 if (
rhs.getStaticShardedDimsOffsets().size() !=
720 rhs.getStaticShardedDimsOffsets())) {
723 if (
rhs.getDynamicShardedDimsOffsets().size() !=
726 rhs.getDynamicShardedDimsOffsets())) {
758 os <<
"Sharding<grid=" << sharding.
getGrid() <<
", split_axes=[";
761 llvm::interleaveComma(axes.asArrayRef(), os);
766 os <<
", halo_sizes=[";
771 os <<
", sharded_dims_offsets=[";
782 auto shardingOp =
rhs.getDefiningOp<ShardingOp>();
783 assert(shardingOp &&
"expected sharding op");
784 auto splitAxes = shardingOp.getSplitAxes().getAxes();
786 if (splitAxes.empty()) {
787 *
this =
Sharding(shardingOp.getGridAttr());
791 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
792 shardingOp.getStaticShardedDimsOffsets(),
804 if (splitAxes.empty()) {
808 res.split_axes.resize(splitAxes.size());
809 for (
auto [i, axis] : llvm::enumerate(splitAxes)) {
813 auto clone = [](
const auto src,
auto &dst) {
814 dst.resize(src.size());
815 llvm::copy(src, dst.begin());
818 clone(staticHaloSizes, res.static_halo_sizes);
819 clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
820 clone(dynamicHaloSizes, res.dynamic_halo_sizes);
821 clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
830void ShardShapeOp::getAsmResultNames(
832 setNameFn(getResult()[0],
"shard_shape");
841 build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
849void ShardOp::getAsmResultNames(
851 setNameFn(getResult(),
"sharding_annotated");
857class FoldDuplicateShardOp final :
public OpRewritePattern<ShardOp> {
859 using OpRewritePattern<ShardOp>::OpRewritePattern;
861 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &
b)
const override {
864 Value value = op.getSrc();
870 for (
auto &use : value.
getUses()) {
871 if (use.getOwner() != op.getOperation()) {
872 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
873 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
878 Sharding currentSharding(op.getSharding());
879 Sharding otherSharding(otherOp.getSharding());
880 if (currentSharding == otherSharding) {
881 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
882 b.eraseOp(op.getOperation());
886 op, [&]() { op.getSrcMutable().assign(otherOp.getResult()); });
897void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
898 mlir::MLIRContext *context) {
899 results.
add<FoldDuplicateShardOp>(context);
907ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
916 size_t expectedResultsCount =
917 getAxes().empty() ? grid->getRank() : getAxes().size();
918 if (getResult().size() != expectedResultsCount) {
919 return emitError() <<
"Unexpected number of results " << getResult().size()
920 <<
". Expected " << expectedResultsCount <<
".";
926void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
928 build(odsBuilder, odsState,
929 SmallVector<Type>(grid.getRank(), odsBuilder.
getIndexType()),
930 grid.getSymName(), ArrayRef<GridAxis>());
933void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
934 StringRef grid, ArrayRef<GridAxis> axes) {
935 build(odsBuilder, odsState,
936 SmallVector<Type>(axes.size(), odsBuilder.
getIndexType()), grid,
940void ProcessMultiIndexOp::getAsmResultNames(
942 setNameFn(getResults()[0],
"proc_linear_idx");
950ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
958void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
959 OperationState &odsState, GridOp grid) {
960 build(odsBuilder, odsState, grid.getSymName());
963void ProcessLinearIndexOp::getAsmResultNames(
965 setNameFn(getResult(),
"proc_linear_idx");
973NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
981void NeighborsLinearIndicesOp::getAsmResultNames(
983 setNameFn(getNeighborDown(),
"down_linear_idx");
984 setNameFn(getNeighborUp(),
"up_linear_idx");
993template <
typename Op>
994struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
995 using OpRewritePattern<
Op>::OpRewritePattern;
996 LogicalResult matchAndRewrite(Op op,
997 PatternRewriter &rewriter)
const override {
998 auto gridAxes = op.getGridAxes();
999 if (!gridAxes.empty()) {
1002 if (op.getInput().getType() != op.getResult().getType()) {
1019 if (device.size() != gridAxes.size()) {
1020 return emitError(loc) <<
"In-group device \"" << deviceName
1021 <<
"\" has unexpected multi-index size "
1022 << device.size() <<
". Expected " << gridAxes.size()
1026 for (
size_t i = 0; i < device.size(); ++i) {
1027 if (ShapedType::isStatic(device[i]) &&
1028 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1029 gridShape[gridAxes[i]] <= device[i]) {
1031 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1032 << deviceName <<
"\"."
1033 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1034 << (gridShape[gridAxes[i]] - 1) <<
"].";
1044 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1045 return emitError(loc) <<
"Dimension size mismatch for result axis "
1046 << resultAxis <<
". Expected "
1047 << (ShapedType::isDynamic(expectedDimSize)
1049 : Twine(expectedDimSize))
1050 <<
", but got " << resultDimSize <<
".";
1059 auto resultRank = cast<ShapedType>(
result.getType()).getRank();
1060 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1062 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1063 << resultRank <<
").";
1066 ShapedType operandType = cast<ShapedType>(operand.
getType());
1067 ShapedType resultType = cast<ShapedType>(
result.getType());
1068 auto deviceGroupSize =
1070 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1071 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1072 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1073 auto expectedResultDimSize =
1074 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1076 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1086 ShapedType operandType = cast<ShapedType>(operand.
getType());
1087 ShapedType resultType = cast<ShapedType>(
result.getType());
1088 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1089 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1091 result.getLoc(), operandType.getDimSize(axis),
1092 resultType.getDimSize(axis), axis))) {
1098 if (splitAxis == concatAxis) {
1102 auto deviceGroupSize =
1104 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1105 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1106 DimensionSize expectedResultConcatDimSize =
1107 operandConcatDimSize * deviceGroupSize;
1108 DimensionSize expectedResultSplitDimSize =
1109 operandSplitDimSize / deviceGroupSize;
1110 if (!expectedResultSplitDimSize.isDynamic() &&
1112 expectedResultSplitDimSize = DimensionSize::dynamic();
1115 result.getLoc(), expectedResultConcatDimSize.value(),
1116 resultType.getDimSize(concatAxis), concatAxis))) {
1120 result.getLoc(), expectedResultSplitDimSize.value(),
1121 resultType.getDimSize(splitAxis), splitAxis))) {
1131 ShapedType operandType = cast<ShapedType>(operand.
getType());
1132 ShapedType resultType = cast<ShapedType>(
result.getType());
1133 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1134 if (axis != tensorAxis) {
1136 result.getLoc(), operandType.getDimSize(axis),
1137 resultType.getDimSize(axis), axis))) {
1143 auto deviceGroupSize =
1145 auto operandScatterDimSize =
1146 DimensionSize(operandType.getDimSize(tensorAxis));
1147 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1148 int64_t(operandScatterDimSize) %
int64_t(deviceGroupSize) != 0) {
1150 <<
"Operand dimension size " <<
int64_t(operandScatterDimSize)
1151 <<
" is not divisible by collective device group size "
1152 <<
int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1155 DimensionSize expectedResultTensorDimSize =
1156 operandScatterDimSize / deviceGroupSize;
1158 result.getLoc(), expectedResultTensorDimSize.value(),
1159 resultType.getDimSize(tensorAxis), tensorAxis))) {
1169 RankedTensorType operandRankedTensorType =
1170 cast<RankedTensorType>(operandType);
1171 DimensionSize operandSliceAxisSize =
1172 operandRankedTensorType.getShape()[sliceAxis];
1174 llvm::to_vector(operandRankedTensorType.getShape());
1176 resultShape[sliceAxis] =
1177 operandSliceAxisSize /
1179 return operandRankedTensorType.clone(resultShape);
1187AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1192 auto gatherAxis = getGatherAxis().getSExtValue();
1194 gatherAxis, getGridAxes(),
1195 grid.value().getShape());
1198void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1199 MLIRContext *context) {
1200 patterns.
add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1203void AllGatherOp::getAsmResultNames(
1205 setNameFn(getResult(),
"all_gather");
1213AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1217void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1218 MLIRContext *context) {
1219 patterns.
add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1222void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1223 Value input, StringRef grid,
1224 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1225 build(odsBuilder, odsState, input.
getType(), grid, gridAxes, input,
1229void AllReduceOp::getAsmResultNames(
1231 setNameFn(getResult(),
"all_reduce");
1238LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1244 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1245 grid.value().getShape());
1248void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1249 MLIRContext *context) {
1250 patterns.
add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1253void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1254 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1255 int64_t sliceAxis) {
1257 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1261void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1262 Type resultType, Value input, StringRef grid,
1263 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1264 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1265 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1268void AllSliceOp::getAsmResultNames(
1270 setNameFn(getResult(),
"all_slice");
1277LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1284 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1285 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1288void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1289 MLIRContext *context) {
1290 patterns.
add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1293void AllToAllOp::getAsmResultNames(
1295 setNameFn(getResult(),
"all_to_all");
1303BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1309 getRootDynamic(), getGridAxes(),
1310 grid.value().getShape()))) {
1317void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1318 MLIRContext *context) {
1319 patterns.
add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1322void BroadcastOp::getAsmResultNames(
1324 setNameFn(getResult(),
"broadcast");
1331LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1337 getRootDynamic(), getGridAxes(),
1338 grid.value().getShape()))) {
1342 auto gatherAxis = getGatherAxis().getSExtValue();
1345 grid.value().getShape());
1348void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1349 MLIRContext *context) {
1350 patterns.
add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1353void GatherOp::getAsmResultNames(
1355 setNameFn(getResult(),
"gather");
1362LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1369 getSource().value(), getSourceDynamic(),
1370 getGridAxes(), grid.value().getShape()))) {
1376void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1377 MLIRContext *context) {
1378 patterns.
add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1381void RecvOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1382 setNameFn(getResult(),
"recv");
1389LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1395 getRootDynamic(), getGridAxes(),
1396 grid.value().getShape()))) {
1403void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1404 MLIRContext *context) {
1405 patterns.
add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1408void ReduceOp::getAsmResultNames(
1410 setNameFn(getResult(),
"reduce");
1418ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1425 getOperand(), getResult(), getScatterDim().getSExtValue(), getGridAxes(),
1426 grid.value().getShape());
1429void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1430 MLIRContext *context) {
1431 patterns.
add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1434void ReduceScatterOp::getAsmResultNames(
1436 setNameFn(getResult(),
"reduce_scatter");
1443LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1449 getRootDynamic(), getGridAxes(),
1450 grid.value().getShape()))) {
1454 auto scatterDim = getScatterDim().getSExtValue();
1456 scatterDim, getGridAxes(),
1457 grid.value().getShape());
1460void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1461 MLIRContext *context) {
1462 patterns.
add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1465void ScatterOp::getAsmResultNames(
1467 setNameFn(getResult(),
"scatter");
1474LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1480 getDestination(), getDestinationDynamic(),
1481 getGridAxes(), grid.value().getShape()))) {
1487void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1488 MLIRContext *context) {
1489 patterns.
add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1492void SendOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1493 setNameFn(getResult(),
"send");
1500LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1506 auto gridAxes = getGridAxes();
1507 auto shiftAxis = getShiftAxis().getZExtValue();
1508 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1509 return emitError() <<
"Invalid shift axis " << shiftAxis
1510 <<
". It must be one of the grouping grid axes.";
1516void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1517 MLIRContext *context) {
1522void ShiftOp::getAsmResultNames(
1524 setNameFn(getResult(),
"shift");
1532UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1545#define GET_OP_CLASSES
1546#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1548#define GET_ATTRDEF_CLASSES
1549#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1551#define GET_TYPEDEF_CLASSES
1552#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1554#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const Sharding &sharding)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
DenseI16ArrayAttr GridAxesAttr
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.