28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
40#define DEBUG_TYPE "shard-ops"
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
50 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(
int64_t val) : val(val) {}
52 int64_t value()
const {
return val; }
53 operator int64_t()
const {
return val; }
54 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
63 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
64 return DimensionSize::dynamic();
66 return lhs.value() /
rhs.value();
70 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
71 return DimensionSize::dynamic();
73 return lhs.value() *
rhs.value();
81 auto dyn = dynamics.begin();
82 Type i64 =
b.getI64Type();
85 assert((i64 == type ||
b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (
auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
91 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
103struct ShardInlinerinterface :
public DialectInlinerInterface {
104 using DialectInlinerInterface::DialectInlinerInterface;
112 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
122void ShardDialect::initialize() {
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
135 addInterface<ShardInlinerinterface>();
141 return arith::ConstantOp::materialize(builder, value, type, loc);
151 shard::GridOp grid =
getGridOrNull(op, gridSymbol, symbolTable);
153 return op->
emitError() <<
"Undefined required grid symbol \""
160template <
typename It>
165 It next = std::next(begin);
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
181 if (!
isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) <<
"Grid axes contains duplicate elements.";
186 for (
auto axis : axes) {
187 if (axis >= rank || axis < 0) {
189 <<
"0-based grid axis index " << axis
190 <<
" is out of bounds. The referenced grid \"" << grid.getSymName()
191 <<
"\" is of rank " << rank <<
".";
198template <
typename Op>
199static FailureOr<GridOp>
212template <
typename InShape,
typename GridShape,
typename SplitAxes,
214static void shardShape(
const InShape &inShape,
const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
219 if (inShape.empty()) {
220 assert(outShape.empty());
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
230 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
238 uint64_t numShards = 0;
239 for (
auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
242 for (
size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
249 pos += numShards + 1;
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
255 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
261 if (!haloSizes.empty()) {
264 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
273 outShape[tensorAxis] = ShapedType::kDynamic;
283 using Dim = std::decay_t<
decltype(
shape.getDimSize(0))>;
288 return shape.clone(resShapeArr);
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
303 ShardOp &newShardOp) {
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
311 newShardOp = shardOp;
318 ShardingOp::create(builder, operandValue.
getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.
getLoc(), operandValue,
324 newShardOp, [operandOp, operandValue](
OpOperand &use) {
325 return use.
getOwner() == operandOp && use.
get() == operandValue;
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
332 auto newShardOp2 = ShardOp::create(builder, operandValue.
getLoc(), newShardOp,
333 newShardOp.getSharding(),
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
343 for (
auto &use :
result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
346 for (
auto &[operandValue, operandOp] : uses) {
348 builder, newShardOp);
358 bool isBlockArg = !operandSrcOp;
360 [[maybe_unused]]
auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
364 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
380 ShardingOp::create(builder, operand.
get().
getLoc(), sharding);
382 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
386 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
387 return use.
getOwner() == operandOp && use.
get() == operandValue;
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
409LogicalResult GridOp::verify() {
413 return emitOpError(
"rank of grid is expected to be a positive integer");
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError(
"dimension size of a grid is expected to be "
418 "non-negative or dynamic");
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() <<
"Unexpected number of results " << getResult().size()
442 <<
". Expected " << expectedResultsCount <<
".";
455 build(odsBuilder, odsState,
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
469void GridShapeOp::getAsmResultNames(
471 setNameFn(getResults()[0],
"grid_shape");
483 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), splitAxes),
493 GridAxesArrayAttr::get(
b.getContext(), splitAxes),
499void ShardingOp::build(
509 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), splitAxes),
517 build(
b, odsState, ShardingType::get(
b.getContext()), from.
getGridAttr(),
529LogicalResult ShardingOp::verify() {
530 llvm::SmallSet<GridAxis, 4> visitedAxes;
535 return emitError() <<
"grid axis is expected to be non-negative";
536 if (!visitedAxes.insert(axis).second)
537 return emitError() <<
"grid axis duplicated";
542 for (
auto subAxes : getSplitAxes().getAxes()) {
544 if (
failed(checkGridAxis(subAxesArray)))
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (
auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() <<
"halo sizes must be specified for all split axes.";
567void ShardingOp::getAsmResultNames(
569 setNameFn(getResult(),
"sharding");
577 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
578 !getStaticShardedDimsOffsets().empty()) {
579 return emitError() <<
"sharded dims offsets are not allowed for "
580 "device grids with dynamic shape.";
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto gridShape = grid.value().getShape();
586 assert(ShapedType::isStaticShape(gridShape));
588 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (
auto i : innerSplitAxes.asArrayRef()) {
592 numShards += gridShape[i];
594 for (
int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() <<
"sharded dims offsets has wrong size.";
598 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
601 <<
"sharded dims offsets must be non-decreasing.";
603 off = shardedDimsOffsets[pos + i];
606 pos += numShards + 1;
621 using OpRewritePattern<ShardingOp>::OpRewritePattern;
623 LogicalResult matchAndRewrite(ShardingOp op,
624 PatternRewriter &
b)
const override {
628 op.getDynamicShardedDimsOffsets(),
b);
637 if (dynamicHalos.empty() && !staticHalos.empty()) {
638 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
649 if (dynamicOffs.empty() && !staticOffs.empty()) {
650 assert(staticOffs.size() >= 2);
651 auto diff = staticOffs[1] - staticOffs[0];
652 bool allSame = staticOffs.size() > 2;
653 for (
auto i = 2u; i < staticOffs.size(); ++i) {
654 if (staticOffs[i] - staticOffs[i - 1] != diff) {
669 op.setStaticHaloSizes(staticHalos);
670 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
671 op.setStaticShardedDimsOffsets(staticOffs);
672 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
681 results.
add<NormalizeSharding>(context);
693 auto minSize = std::min(
getSplitAxes().size(),
rhs.getSplitAxes().size());
694 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
696 llvm::make_range(
rhs.getSplitAxes().begin(),
697 rhs.getSplitAxes().begin() + minSize))) {
701 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
702 std::mem_fn(&GridAxesAttr::empty)) &&
703 llvm::all_of(llvm::drop_begin(
rhs.getSplitAxes(), minSize),
704 std::mem_fn(&GridAxesAttr::empty));
712 if (
rhs.getStaticShardedDimsOffsets().size() !=
715 rhs.getStaticShardedDimsOffsets())) {
718 if (
rhs.getDynamicShardedDimsOffsets().size() !=
721 rhs.getDynamicShardedDimsOffsets())) {
754 auto shardingOp =
rhs.getDefiningOp<ShardingOp>();
755 assert(shardingOp &&
"expected sharding op");
756 auto splitAxes = shardingOp.getSplitAxes().getAxes();
758 if (splitAxes.empty()) {
759 *
this =
Sharding(shardingOp.getGridAttr());
763 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
764 shardingOp.getStaticShardedDimsOffsets(),
776 if (splitAxes.empty()) {
780 res.split_axes.resize(splitAxes.size());
781 for (
auto [i, axis] : llvm::enumerate(splitAxes)) {
785 auto clone = [](
const auto src,
auto &dst) {
786 dst.resize(src.size());
787 llvm::copy(src, dst.begin());
790 clone(staticHaloSizes, res.static_halo_sizes);
791 clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
792 clone(dynamicHaloSizes, res.dynamic_halo_sizes);
793 clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
802void ShardShapeOp::getAsmResultNames(
804 setNameFn(getResult()[0],
"shard_shape");
813 build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
821void ShardOp::getAsmResultNames(
823 setNameFn(getResult(),
"sharding_annotated");
829class FoldDuplicateShardOp final :
public OpRewritePattern<ShardOp> {
831 using OpRewritePattern<ShardOp>::OpRewritePattern;
833 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &
b)
const override {
836 Value value = op.getSrc();
842 for (
auto &use : value.
getUses()) {
843 if (use.getOwner() != op.getOperation()) {
844 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
845 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
850 Sharding currentSharding(op.getSharding());
851 Sharding otherSharding(otherOp.getSharding());
852 if (currentSharding == otherSharding) {
853 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
854 b.eraseOp(op.getOperation());
857 op.getSrcMutable().assign(otherOp.getResult());
868void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
869 mlir::MLIRContext *context) {
870 results.
add<FoldDuplicateShardOp>(context);
878ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
887 size_t expectedResultsCount =
888 getAxes().empty() ? grid->getRank() : getAxes().size();
889 if (getResult().size() != expectedResultsCount) {
890 return emitError() <<
"Unexpected number of results " << getResult().size()
891 <<
". Expected " << expectedResultsCount <<
".";
897void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
899 build(odsBuilder, odsState,
900 SmallVector<Type>(grid.getRank(), odsBuilder.
getIndexType()),
901 grid.getSymName(), ArrayRef<GridAxis>());
904void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
905 StringRef grid, ArrayRef<GridAxis> axes) {
906 build(odsBuilder, odsState,
907 SmallVector<Type>(axes.size(), odsBuilder.
getIndexType()), grid,
911void ProcessMultiIndexOp::getAsmResultNames(
913 setNameFn(getResults()[0],
"proc_linear_idx");
921ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
929void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
930 OperationState &odsState, GridOp grid) {
931 build(odsBuilder, odsState, grid.getSymName());
934void ProcessLinearIndexOp::getAsmResultNames(
936 setNameFn(getResult(),
"proc_linear_idx");
944NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
952void NeighborsLinearIndicesOp::getAsmResultNames(
954 setNameFn(getNeighborDown(),
"down_linear_idx");
955 setNameFn(getNeighborUp(),
"up_linear_idx");
964template <
typename Op>
965struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
966 using OpRewritePattern<
Op>::OpRewritePattern;
967 LogicalResult matchAndRewrite(Op op,
968 PatternRewriter &rewriter)
const override {
969 auto gridAxes = op.getGridAxes();
970 if (!gridAxes.empty()) {
973 if (op.getInput().getType() != op.getResult().getType()) {
990 if (device.size() != gridAxes.size()) {
991 return emitError(loc) <<
"In-group device \"" << deviceName
992 <<
"\" has unexpected multi-index size "
993 << device.size() <<
". Expected " << gridAxes.size()
997 for (
size_t i = 0; i < device.size(); ++i) {
998 if (ShapedType::isStatic(device[i]) &&
999 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1000 gridShape[gridAxes[i]] <= device[i]) {
1002 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1003 << deviceName <<
"\"."
1004 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1005 << (gridShape[gridAxes[i]] - 1) <<
"].";
1015 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1016 return emitError(loc) <<
"Dimension size mismatch for result axis "
1017 << resultAxis <<
". Expected "
1018 << (ShapedType::isDynamic(expectedDimSize)
1020 : Twine(expectedDimSize))
1021 <<
", but got " << resultDimSize <<
".";
1030 auto resultRank = cast<ShapedType>(
result.getType()).getRank();
1031 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1033 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1034 << resultRank <<
").";
1037 ShapedType operandType = cast<ShapedType>(operand.
getType());
1038 ShapedType resultType = cast<ShapedType>(
result.getType());
1039 auto deviceGroupSize =
1041 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1042 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1043 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1044 auto expectedResultDimSize =
1045 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1047 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1057 ShapedType operandType = cast<ShapedType>(operand.
getType());
1058 ShapedType resultType = cast<ShapedType>(
result.getType());
1059 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1060 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1062 result.getLoc(), operandType.getDimSize(axis),
1063 resultType.getDimSize(axis), axis))) {
1069 if (splitAxis == concatAxis) {
1073 auto deviceGroupSize =
1075 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1076 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1077 DimensionSize expectedResultConcatDimSize =
1078 operandConcatDimSize * deviceGroupSize;
1079 DimensionSize expectedResultSplitDimSize =
1080 operandSplitDimSize / deviceGroupSize;
1081 if (!expectedResultSplitDimSize.isDynamic() &&
1083 expectedResultSplitDimSize = DimensionSize::dynamic();
1086 result.getLoc(), expectedResultConcatDimSize.value(),
1087 resultType.getDimSize(concatAxis), concatAxis))) {
1091 result.getLoc(), expectedResultSplitDimSize.value(),
1092 resultType.getDimSize(splitAxis), splitAxis))) {
1102 ShapedType operandType = cast<ShapedType>(operand.
getType());
1103 ShapedType resultType = cast<ShapedType>(
result.getType());
1104 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1105 if (axis != tensorAxis) {
1107 result.getLoc(), operandType.getDimSize(axis),
1108 resultType.getDimSize(axis), axis))) {
1114 auto deviceGroupSize =
1116 auto operandScatterDimSize =
1117 DimensionSize(operandType.getDimSize(tensorAxis));
1118 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1119 int64_t(operandScatterDimSize) %
int64_t(deviceGroupSize) != 0) {
1121 <<
"Operand dimension size " <<
int64_t(operandScatterDimSize)
1122 <<
" is not divisible by collective device group size "
1123 <<
int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1126 DimensionSize expectedResultTensorDimSize =
1127 operandScatterDimSize / deviceGroupSize;
1129 result.getLoc(), expectedResultTensorDimSize.value(),
1130 resultType.getDimSize(tensorAxis), tensorAxis))) {
1140 RankedTensorType operandRankedTensorType =
1141 cast<RankedTensorType>(operandType);
1142 DimensionSize operandSliceAxisSize =
1143 operandRankedTensorType.getShape()[sliceAxis];
1145 llvm::to_vector(operandRankedTensorType.getShape());
1147 resultShape[sliceAxis] =
1148 operandSliceAxisSize /
1150 return operandRankedTensorType.clone(resultShape);
1158AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1163 auto gatherAxis = getGatherAxis().getSExtValue();
1165 gatherAxis, getGridAxes(),
1166 grid.value().getShape());
1169void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1170 MLIRContext *context) {
1171 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1174void AllGatherOp::getAsmResultNames(
1176 setNameFn(getResult(),
"all_gather");
1184AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1188void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1189 MLIRContext *context) {
1190 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1193void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1194 Value input, StringRef grid,
1195 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1196 build(odsBuilder, odsState, input.
getType(), grid, gridAxes, input,
1200void AllReduceOp::getAsmResultNames(
1202 setNameFn(getResult(),
"all_reduce");
1209LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1215 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1216 grid.value().getShape());
1219void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1220 MLIRContext *context) {
1221 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1224void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1225 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1226 int64_t sliceAxis) {
1228 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1232void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1233 Type resultType, Value input, StringRef grid,
1234 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1235 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1236 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1239void AllSliceOp::getAsmResultNames(
1241 setNameFn(getResult(),
"all_slice");
1248LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1255 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1256 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1259void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1260 MLIRContext *context) {
1261 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1264void AllToAllOp::getAsmResultNames(
1266 setNameFn(getResult(),
"all_to_all");
1274BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1280 getRootDynamic(), getGridAxes(),
1281 grid.value().getShape()))) {
1288void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1289 MLIRContext *context) {
1290 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1293void BroadcastOp::getAsmResultNames(
1295 setNameFn(getResult(),
"broadcast");
1302LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1308 getRootDynamic(), getGridAxes(),
1309 grid.value().getShape()))) {
1313 auto gatherAxis = getGatherAxis().getSExtValue();
1316 grid.value().getShape());
1319void GatherOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1320 MLIRContext *context) {
1321 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1324void GatherOp::getAsmResultNames(
1326 setNameFn(getResult(),
"gather");
1333LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1340 getSource().value(), getSourceDynamic(),
1341 getGridAxes(), grid.value().getShape()))) {
1347void RecvOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1348 MLIRContext *context) {
1349 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1352void RecvOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1353 setNameFn(getResult(),
"recv");
1360LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1366 getRootDynamic(), getGridAxes(),
1367 grid.value().getShape()))) {
1374void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1375 MLIRContext *context) {
1376 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1379void ReduceOp::getAsmResultNames(
1381 setNameFn(getResult(),
"reduce");
1389ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1396 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1397 grid.value().getShape());
1400void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1401 MLIRContext *context) {
1402 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1405void ReduceScatterOp::getAsmResultNames(
1407 setNameFn(getResult(),
"reduce_scatter");
1414LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1420 getRootDynamic(), getGridAxes(),
1421 grid.value().getShape()))) {
1425 auto scatterAxis = getScatterAxis().getSExtValue();
1427 scatterAxis, getGridAxes(),
1428 grid.value().getShape());
1431void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1432 MLIRContext *context) {
1433 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1436void ScatterOp::getAsmResultNames(
1438 setNameFn(getResult(),
"scatter");
1445LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1451 getDestination(), getDestinationDynamic(),
1452 getGridAxes(), grid.value().getShape()))) {
1458void SendOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1459 MLIRContext *context) {
1460 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1463void SendOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1464 setNameFn(getResult(),
"send");
1471LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1477 auto gridAxes = getGridAxes();
1478 auto shiftAxis = getShiftAxis().getZExtValue();
1479 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1480 return emitError() <<
"Invalid shift axis " << shiftAxis
1481 <<
". It must be one of the grouping grid axes.";
1487void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1488 MLIRContext *context) {
1493void ShiftOp::getAsmResultNames(
1495 setNameFn(getResult(),
"shift");
1503UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1516#define GET_OP_CLASSES
1517#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1519#define GET_ATTRDEF_CLASSES
1520#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1522#define GET_TYPEDEF_CLASSES
1523#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1525#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.