28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
40#define DEBUG_TYPE "shard-ops"
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
50 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(
int64_t val) : val(val) {}
52 int64_t value()
const {
return val; }
53 operator int64_t()
const {
return val; }
54 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
63 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
64 return DimensionSize::dynamic();
66 return lhs.value() /
rhs.value();
70 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
71 return DimensionSize::dynamic();
73 return lhs.value() *
rhs.value();
81 auto dyn = dynamics.begin();
82 Type i64 =
b.getI64Type();
85 assert((i64 == type ||
b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (
auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
91 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
112 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
122void ShardDialect::initialize() {
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
135 addInterface<ShardInlinerinterface>();
141 return arith::ConstantOp::materialize(builder, value, type, loc);
151 shard::GridOp grid =
getGridOrNull(op, gridSymbol, symbolTable);
153 return op->
emitError() <<
"Undefined required grid symbol \""
160template <
typename It>
165 It next = std::next(begin);
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
181 if (!
isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) <<
"Grid axes contains duplicate elements.";
186 for (
auto axis : axes) {
187 if (axis >= rank || axis < 0) {
189 <<
"0-based grid axis index " << axis
190 <<
" is out of bounds. The referenced grid \"" << grid.getSymName()
191 <<
"\" is of rank " << rank <<
".";
198template <
typename Op>
199static FailureOr<GridOp>
212template <
typename InShape,
typename GridShape,
typename SplitAxes,
214static void shardShape(
const InShape &inShape,
const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
219 if (inShape.empty()) {
220 assert(outShape.empty());
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
230 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
238 uint64_t numShards = 0;
239 for (
auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
242 for (
size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
249 pos += numShards + 1;
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
255 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
261 if (!haloSizes.empty()) {
264 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
273 outShape[tensorAxis] = ShapedType::kDynamic;
283 using Dim = std::decay_t<
decltype(
shape.getDimSize(0))>;
288 return shape.clone(resShapeArr);
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
303 ShardOp &newShardOp) {
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
311 newShardOp = shardOp;
318 ShardingOp::create(builder, operandValue.
getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.
getLoc(), operandValue,
324 newShardOp, [operandOp, operandValue](
OpOperand &use) {
325 return use.
getOwner() == operandOp && use.
get() == operandValue;
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
332 auto newShardOp2 = ShardOp::create(builder, operandValue.
getLoc(), newShardOp,
333 newShardOp.getSharding(),
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
343 for (
auto &use :
result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
346 for (
auto &[operandValue, operandOp] : uses) {
348 builder, newShardOp);
358 bool isBlockArg = !operandSrcOp;
360 [[maybe_unused]]
auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
364 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
380 ShardingOp::create(builder, operand.
get().
getLoc(), sharding);
382 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
386 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
387 return use.
getOwner() == operandOp && use.
get() == operandValue;
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
409LogicalResult GridOp::verify() {
413 return emitOpError(
"rank of grid is expected to be a positive integer");
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError(
"dimension size of a grid is expected to be "
418 "non-negative or dynamic");
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() <<
"Unexpected number of results " << getResult().size()
442 <<
". Expected " << expectedResultsCount <<
".";
455 build(odsBuilder, odsState,
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
469void GridShapeOp::getAsmResultNames(
471 setNameFn(getResults()[0],
"grid_shape");
484 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), split_axes),
494 GridAxesArrayAttr::get(
b.getContext(), split_axes),
500void ShardingOp::build(
510 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), split_axes),
518 build(
b, odsState, ShardingType::get(
b.getContext()), from.
getGridAttr(),
530LogicalResult ShardingOp::verify() {
531 llvm::SmallSet<GridAxis, 4> visitedAxes;
536 return emitError() <<
"grid axis is expected to be non-negative";
537 if (!visitedAxes.insert(axis).second)
538 return emitError() <<
"grid axis duplicated";
543 for (
auto subAxes : getSplitAxes().getAxes()) {
545 if (
failed(checkGridAxis(subAxesArray)))
549 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
553 if (!getStaticHaloSizes().empty()) {
554 auto numSplitAxes = getSplitAxes().getAxes().size();
555 for (
auto splitAxis : getSplitAxes().getAxes()) {
556 if (splitAxis.empty()) {
560 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561 return emitError() <<
"halo sizes must be specified for all split axes.";
568void ShardingOp::getAsmResultNames(
570 setNameFn(getResult(),
"sharding");
578 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
579 !getStaticShardedDimsOffsets().empty()) {
580 return emitError() <<
"sharded dims offsets are not allowed for "
581 "device grids with dynamic shape.";
584 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585 if (!shardedDimsOffsets.empty()) {
586 auto gridShape = grid.value().getShape();
587 assert(ShapedType::isStaticShape(gridShape));
589 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
590 if (!innerSplitAxes.empty()) {
591 int64_t numShards = 0, off = 0;
592 for (
auto i : innerSplitAxes.asArrayRef()) {
593 numShards += gridShape[i];
595 for (
int64_t i = 0; i <= numShards; ++i) {
596 if (shardedDimsOffsets.size() <= pos + i) {
597 return emitError() <<
"sharded dims offsets has wrong size.";
599 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
600 if (shardedDimsOffsets[pos + i] < off) {
602 <<
"sharded dims offsets must be non-decreasing.";
604 off = shardedDimsOffsets[pos + i];
607 pos += numShards + 1;
622 using OpRewritePattern<ShardingOp>::OpRewritePattern;
624 LogicalResult matchAndRewrite(ShardingOp op,
625 PatternRewriter &
b)
const override {
629 op.getDynamicShardedDimsOffsets(),
b);
638 if (dynamicHalos.empty() && !staticHalos.empty()) {
639 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
650 if (dynamicOffs.empty() && !staticOffs.empty()) {
651 assert(staticOffs.size() >= 2);
652 auto diff = staticOffs[1] - staticOffs[0];
653 bool all_same = staticOffs.size() > 2;
654 for (
auto i = 2u; i < staticOffs.size(); ++i) {
655 if (staticOffs[i] - staticOffs[i - 1] != diff) {
670 op.setStaticHaloSizes(staticHalos);
671 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672 op.setStaticShardedDimsOffsets(staticOffs);
673 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
682 results.
add<NormalizeSharding>(context);
694 auto minSize = std::min(
getSplitAxes().size(),
rhs.getSplitAxes().size());
695 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
697 llvm::make_range(
rhs.getSplitAxes().begin(),
698 rhs.getSplitAxes().begin() + minSize))) {
702 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
703 std::mem_fn(&GridAxesAttr::empty)) &&
704 llvm::all_of(llvm::drop_begin(
rhs.getSplitAxes(), minSize),
705 std::mem_fn(&GridAxesAttr::empty));
713 if (
rhs.getStaticShardedDimsOffsets().size() !=
716 rhs.getStaticShardedDimsOffsets())) {
719 if (
rhs.getDynamicShardedDimsOffsets().size() !=
722 rhs.getDynamicShardedDimsOffsets())) {
755 auto shardingOp =
rhs.getDefiningOp<ShardingOp>();
756 assert(shardingOp &&
"expected sharding op");
757 auto splitAxes = shardingOp.getSplitAxes().getAxes();
759 if (splitAxes.empty()) {
760 *
this =
Sharding(shardingOp.getGridAttr());
764 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
765 shardingOp.getStaticShardedDimsOffsets(),
777 if (split_axes_.empty()) {
781 res.split_axes.resize(split_axes_.size());
782 for (
auto [i, axis] : llvm::enumerate(split_axes_)) {
787 auto clone = [](
const auto src,
auto &dst) {
788 dst.resize(src.size());
789 llvm::copy(src, dst.begin());
792 clone(static_halo_sizes_, res.static_halo_sizes);
793 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
794 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
795 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
804void ShardShapeOp::getAsmResultNames(
806 setNameFn(getResult()[0],
"shard_shape");
815 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
823void ShardOp::getAsmResultNames(
825 setNameFn(getResult(),
"sharding_annotated");
831class FoldDuplicateShardOp final :
public OpRewritePattern<ShardOp> {
833 using OpRewritePattern<ShardOp>::OpRewritePattern;
835 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &
b)
const override {
838 Value value = op.getSrc();
844 for (
auto &use : value.
getUses()) {
845 if (use.getOwner() != op.getOperation()) {
846 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
847 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
852 Sharding currentSharding(op.getSharding());
853 Sharding otherSharding(otherOp.getSharding());
854 if (currentSharding == otherSharding) {
855 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
856 b.eraseOp(op.getOperation());
859 op.getSrcMutable().assign(otherOp.getResult());
870void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
871 mlir::MLIRContext *context) {
872 results.
add<FoldDuplicateShardOp>(context);
880ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
889 size_t expectedResultsCount =
890 getAxes().empty() ? grid->getRank() : getAxes().size();
891 if (getResult().size() != expectedResultsCount) {
892 return emitError() <<
"Unexpected number of results " << getResult().size()
893 <<
". Expected " << expectedResultsCount <<
".";
899void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
901 build(odsBuilder, odsState,
902 SmallVector<Type>(grid.getRank(), odsBuilder.
getIndexType()),
903 grid.getSymName(), ArrayRef<GridAxis>());
906void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
907 StringRef grid, ArrayRef<GridAxis> axes) {
908 build(odsBuilder, odsState,
909 SmallVector<Type>(axes.size(), odsBuilder.
getIndexType()), grid,
913void ProcessMultiIndexOp::getAsmResultNames(
915 setNameFn(getResults()[0],
"proc_linear_idx");
923ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
931void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
932 OperationState &odsState, GridOp grid) {
933 build(odsBuilder, odsState, grid.getSymName());
936void ProcessLinearIndexOp::getAsmResultNames(
938 setNameFn(getResult(),
"proc_linear_idx");
946NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
954void NeighborsLinearIndicesOp::getAsmResultNames(
956 setNameFn(getNeighborDown(),
"down_linear_idx");
957 setNameFn(getNeighborUp(),
"up_linear_idx");
966template <
typename Op>
967struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
968 using OpRewritePattern<
Op>::OpRewritePattern;
969 LogicalResult matchAndRewrite(Op op,
970 PatternRewriter &rewriter)
const override {
971 auto gridAxes = op.getGridAxes();
972 if (!gridAxes.empty()) {
975 if (op.getInput().getType() != op.getResult().getType()) {
992 if (device.size() != gridAxes.size()) {
993 return emitError(loc) <<
"In-group device \"" << deviceName
994 <<
"\" has unexpected multi-index size "
995 << device.size() <<
". Expected " << gridAxes.size()
999 for (
size_t i = 0; i < device.size(); ++i) {
1000 if (ShapedType::isStatic(device[i]) &&
1001 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1002 gridShape[gridAxes[i]] <= device[i]) {
1004 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1005 << deviceName <<
"\"."
1006 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1007 << (gridShape[gridAxes[i]] - 1) <<
"].";
1017 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1018 return emitError(loc) <<
"Dimension size mismatch for result axis "
1019 << resultAxis <<
". Expected "
1020 << (ShapedType::isDynamic(expectedDimSize)
1022 : Twine(expectedDimSize))
1023 <<
", but got " << resultDimSize <<
".";
1032 auto resultRank = cast<ShapedType>(
result.getType()).getRank();
1033 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1035 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1036 << resultRank <<
").";
1039 ShapedType operandType = cast<ShapedType>(operand.
getType());
1040 ShapedType resultType = cast<ShapedType>(
result.getType());
1041 auto deviceGroupSize =
1043 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1044 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1045 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1046 auto expectedResultDimSize =
1047 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1049 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1059 ShapedType operandType = cast<ShapedType>(operand.
getType());
1060 ShapedType resultType = cast<ShapedType>(
result.getType());
1061 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1062 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1064 result.getLoc(), operandType.getDimSize(axis),
1065 resultType.getDimSize(axis), axis))) {
1071 if (splitAxis == concatAxis) {
1075 auto deviceGroupSize =
1077 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1078 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1079 DimensionSize expectedResultConcatDimSize =
1080 operandConcatDimSize * deviceGroupSize;
1081 DimensionSize expectedResultSplitDimSize =
1082 operandSplitDimSize / deviceGroupSize;
1083 if (!expectedResultSplitDimSize.isDynamic() &&
1085 expectedResultSplitDimSize = DimensionSize::dynamic();
1088 result.getLoc(), expectedResultConcatDimSize.value(),
1089 resultType.getDimSize(concatAxis), concatAxis))) {
1093 result.getLoc(), expectedResultSplitDimSize.value(),
1094 resultType.getDimSize(splitAxis), splitAxis))) {
1104 ShapedType operandType = cast<ShapedType>(operand.
getType());
1105 ShapedType resultType = cast<ShapedType>(
result.getType());
1106 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1107 if (axis != tensorAxis) {
1109 result.getLoc(), operandType.getDimSize(axis),
1110 resultType.getDimSize(axis), axis))) {
1116 auto deviceGroupSize =
1118 auto operandScatterDimSize =
1119 DimensionSize(operandType.getDimSize(tensorAxis));
1120 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1121 int64_t(operandScatterDimSize) %
int64_t(deviceGroupSize) != 0) {
1123 <<
"Operand dimension size " <<
int64_t(operandScatterDimSize)
1124 <<
" is not divisible by collective device group size "
1125 <<
int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1128 DimensionSize expectedResultTensorDimSize =
1129 operandScatterDimSize / deviceGroupSize;
1131 result.getLoc(), expectedResultTensorDimSize.value(),
1132 resultType.getDimSize(tensorAxis), tensorAxis))) {
1142 RankedTensorType operandRankedTensorType =
1143 cast<RankedTensorType>(operandType);
1144 DimensionSize operandSliceAxisSize =
1145 operandRankedTensorType.getShape()[sliceAxis];
1147 llvm::to_vector(operandRankedTensorType.getShape());
1149 resultShape[sliceAxis] =
1150 operandSliceAxisSize /
1152 return operandRankedTensorType.clone(resultShape);
1160AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1165 auto gatherAxis = getGatherAxis().getSExtValue();
1167 gatherAxis, getGridAxes(),
1168 grid.value().getShape());
1171void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1172 MLIRContext *context) {
1173 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1176void AllGatherOp::getAsmResultNames(
1178 setNameFn(getResult(),
"all_gather");
1186AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1190void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1191 MLIRContext *context) {
1192 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1195void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1196 Value input, StringRef grid,
1197 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1198 build(odsBuilder, odsState, input.
getType(), grid, gridAxes, input,
1202void AllReduceOp::getAsmResultNames(
1204 setNameFn(getResult(),
"all_reduce");
1211LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1217 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1218 grid.value().getShape());
1221void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1222 MLIRContext *context) {
1223 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1226void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1227 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1228 int64_t sliceAxis) {
1230 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1234void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1235 Type resultType, Value input, StringRef grid,
1236 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1237 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1238 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1241void AllSliceOp::getAsmResultNames(
1243 setNameFn(getResult(),
"all_slice");
1250LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1257 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1258 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1261void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1262 MLIRContext *context) {
1263 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1266void AllToAllOp::getAsmResultNames(
1268 setNameFn(getResult(),
"all_to_all");
1276BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1282 getRootDynamic(), getGridAxes(),
1283 grid.value().getShape()))) {
1290void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1291 MLIRContext *context) {
1292 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1295void BroadcastOp::getAsmResultNames(
1297 setNameFn(getResult(),
"broadcast");
1304LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1310 getRootDynamic(), getGridAxes(),
1311 grid.value().getShape()))) {
1315 auto gatherAxis = getGatherAxis().getSExtValue();
1318 grid.value().getShape());
1321void GatherOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1322 MLIRContext *context) {
1323 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1326void GatherOp::getAsmResultNames(
1328 setNameFn(getResult(),
"gather");
1335LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1342 getSource().value(), getSourceDynamic(),
1343 getGridAxes(), grid.value().getShape()))) {
1349void RecvOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1350 MLIRContext *context) {
1351 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1354void RecvOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1355 setNameFn(getResult(),
"recv");
1362LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1368 getRootDynamic(), getGridAxes(),
1369 grid.value().getShape()))) {
1376void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1377 MLIRContext *context) {
1378 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1381void ReduceOp::getAsmResultNames(
1383 setNameFn(getResult(),
"reduce");
1391ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1398 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1399 grid.value().getShape());
1402void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1403 MLIRContext *context) {
1404 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1407void ReduceScatterOp::getAsmResultNames(
1409 setNameFn(getResult(),
"reduce_scatter");
1416LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1422 getRootDynamic(), getGridAxes(),
1423 grid.value().getShape()))) {
1427 auto scatterAxis = getScatterAxis().getSExtValue();
1429 scatterAxis, getGridAxes(),
1430 grid.value().getShape());
1433void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1434 MLIRContext *context) {
1435 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1438void ScatterOp::getAsmResultNames(
1440 setNameFn(getResult(),
"scatter");
1447LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1453 getDestination(), getDestinationDynamic(),
1454 getGridAxes(), grid.value().getShape()))) {
1460void SendOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1461 MLIRContext *context) {
1462 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1465void SendOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1466 setNameFn(getResult(),
"send");
1473LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1479 auto gridAxes = getGridAxes();
1480 auto shiftAxis = getShiftAxis().getZExtValue();
1481 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1482 return emitError() <<
"Invalid shift axis " << shiftAxis
1483 <<
". It must be one of the grouping grid axes.";
1489void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1490 MLIRContext *context) {
1495void ShiftOp::getAsmResultNames(
1497 setNameFn(getResult(),
"shift");
1505UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1518#define GET_OP_CLASSES
1519#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1521#define GET_ATTRDEF_CLASSES
1522#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1524#define GET_TYPEDEF_CLASSES
1525#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1527#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.