28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
40 #define DEBUG_TYPE "shard-ops"
45 #include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
49 struct DimensionSize {
50 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(int64_t val) : val(val) {}
52 int64_t value()
const {
return val; }
53 operator int64_t()
const {
return val; }
54 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
62 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
63 if (lhs.isDynamic() || rhs.isDynamic()) {
64 return DimensionSize::dynamic();
66 return lhs.value() / rhs.value();
69 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
70 if (lhs.isDynamic() || rhs.isDynamic()) {
71 return DimensionSize::dynamic();
73 return lhs.value() * rhs.value();
81 auto dyn = dynamics.begin();
86 "expected an i64 or an intex type");
87 for (
auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
92 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
122 void ShardDialect::initialize() {
125 #include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
128 #define GET_ATTRDEF_LIST
129 #include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
132 #define GET_TYPEDEF_LIST
133 #include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
135 addInterface<ShardInlinerinterface>();
141 return arith::ConstantOp::materialize(builder, value, type, loc);
153 return op->
emitError() <<
"Undefined required grid symbol \""
160 template <
typename It>
165 It next = std::next(begin);
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
181 if (!
isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) <<
"Grid axes contains duplicate elements.";
186 for (
auto axis : axes) {
187 if (axis >= rank || axis < 0) {
189 <<
"0-based grid axis index " << axis
190 <<
" is out of bounds. The referenced grid \"" << grid.getSymName()
191 <<
"\" is of rank " << rank <<
".";
198 template <
typename Op>
199 static FailureOr<GridOp>
212 template <
typename InShape,
typename GridShape,
typename SplitAxes,
214 static void shardShape(
const InShape &inShape,
const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
219 if (inShape.empty()) {
220 assert(outShape.empty());
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
238 uint64_t numShards = 0;
239 for (
auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
242 for (
size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
249 pos += numShards + 1;
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
261 if (!haloSizes.empty()) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
273 outShape[tensorAxis] = ShapedType::kDynamic;
283 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
288 return shape.clone(resShapeArr);
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
303 ShardOp &newShardOp) {
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
311 newShardOp = shardOp;
318 ShardingOp::create(builder, operandValue.
getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.
getLoc(), operandValue,
324 newShardOp, [operandOp, operandValue](
OpOperand &use) {
325 return use.
getOwner() == operandOp && use.
get() == operandValue;
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
332 auto newShardOp2 = ShardOp::create(builder, operandValue.
getLoc(), newShardOp,
333 newShardOp.getSharding(),
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
343 for (
auto &use : result.
getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
346 for (
auto &[operandValue, operandOp] : uses) {
348 builder, newShardOp);
358 bool isBlockArg = !operandSrcOp;
360 [[maybe_unused]]
auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
364 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
380 ShardingOp::create(builder, operand.
get().
getLoc(), sharding);
382 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
386 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
387 return use.
getOwner() == operandOp && use.
get() == operandValue;
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
410 int64_t rank = getRank();
413 return emitOpError(
"rank of grid is expected to be a positive integer");
415 for (int64_t dimSize :
getShape()) {
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError(
"dimension size of a grid is expected to be "
418 "non-negative or dynamic");
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() <<
"Unexpected number of results " << getResult().size()
442 <<
". Expected " << expectedResultsCount <<
".";
455 build(odsBuilder, odsState,
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
469 void GridShapeOp::getAsmResultNames(
471 setNameFn(getResults()[0],
"grid_shape");
500 void ShardingOp::build(
531 llvm::SmallSet<GridAxis, 4> visitedAxes;
536 return emitError() <<
"grid axis is expected to be non-negative";
537 if (!visitedAxes.insert(axis).second)
538 return emitError() <<
"grid axis duplicated";
543 for (
auto subAxes : getSplitAxes().getAxes()) {
545 if (
failed(checkGridAxis(subAxesArray)))
549 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
553 if (!getStaticHaloSizes().empty()) {
554 auto numSplitAxes = getSplitAxes().getAxes().size();
555 for (
auto splitAxis : getSplitAxes().getAxes()) {
556 if (splitAxis.empty()) {
560 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561 return emitError() <<
"halo sizes must be specified for all split axes.";
568 void ShardingOp::getAsmResultNames(
570 setNameFn(getResult(),
"sharding");
578 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
579 getStaticShardedDimsOffsets().size() > 0) {
580 return emitError() <<
"sharded dims offsets are not allowed for "
581 "device grids with dynamic shape.";
584 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585 if (!shardedDimsOffsets.empty()) {
586 auto gridShape = grid.value().getShape();
587 assert(ShapedType::isStaticShape(gridShape));
589 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
590 if (!innerSplitAxes.empty()) {
591 int64_t numShards = 0, off = 0;
592 for (
auto i : innerSplitAxes.asArrayRef()) {
593 numShards += gridShape[i];
595 for (int64_t i = 0; i <= numShards; ++i) {
596 if (shardedDimsOffsets.size() <= pos + i) {
597 return emitError() <<
"sharded dims offsets has wrong size.";
599 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
600 if (shardedDimsOffsets[pos + i] < off) {
602 <<
"sharded dims offsets must be non-decreasing.";
604 off = shardedDimsOffsets[pos + i];
607 pos += numShards + 1;
624 LogicalResult matchAndRewrite(ShardingOp op,
627 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
629 op.getDynamicShardedDimsOffsets(), b);
638 if (dynamicHalos.empty() && !staticHalos.empty()) {
639 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
650 if (dynamicOffs.empty() && !staticOffs.empty()) {
651 assert(staticOffs.size() >= 2);
652 auto diff = staticOffs[1] - staticOffs[0];
653 bool all_same = staticOffs.size() > 2;
654 for (
auto i = 2u; i < staticOffs.size(); ++i) {
655 if (staticOffs[i] - staticOffs[i - 1] != diff) {
670 op.setStaticHaloSizes(staticHalos);
671 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672 op.setStaticShardedDimsOffsets(staticOffs);
673 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
682 results.
add<NormalizeSharding>(context);
695 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
702 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
703 std::mem_fn(&GridAxesAttr::empty)) &&
704 llvm::all_of(llvm::drop_begin(rhs.
getSplitAxes(), minSize),
705 std::mem_fn(&GridAxesAttr::empty));
756 assert(shardingOp &&
"expected sharding op");
757 auto splitAxes = shardingOp.getSplitAxes().getAxes();
759 if (splitAxes.empty()) {
760 *
this =
Sharding(shardingOp.getGridAttr());
764 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
765 shardingOp.getStaticShardedDimsOffsets(),
777 if (split_axes_.empty()) {
781 res.split_axes.resize(split_axes_.size());
787 auto clone = [](
const auto src,
auto &dst) {
788 dst.resize(src.size());
792 clone(static_halo_sizes_, res.static_halo_sizes);
793 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
794 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
795 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
804 void ShardShapeOp::getAsmResultNames(
806 setNameFn(getResult()[0],
"shard_shape");
815 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
823 void ShardOp::getAsmResultNames(
825 setNameFn(getResult(),
"sharding_annotated");
835 LogicalResult matchAndRewrite(ShardOp op,
PatternRewriter &b)
const override {
838 Value value = op.getSrc();
844 for (
auto &use : value.
getUses()) {
845 if (use.getOwner() != op.getOperation()) {
846 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
847 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
852 Sharding currentSharding(op.getSharding());
853 Sharding otherSharding(otherOp.getSharding());
854 if (currentSharding == otherSharding) {
859 op.getSrcMutable().assign(otherOp.getResult());
872 results.
add<FoldDuplicateShardOp>(context);
889 size_t expectedResultsCount =
890 getAxes().empty() ? grid->getRank() : getAxes().size();
891 if (getResult().size() != expectedResultsCount) {
892 return emitError() <<
"Unexpected number of results " << getResult().size()
893 <<
". Expected " << expectedResultsCount <<
".";
901 build(odsBuilder, odsState,
908 build(odsBuilder, odsState,
913 void ProcessMultiIndexOp::getAsmResultNames(
915 setNameFn(getResults()[0],
"proc_linear_idx");
931 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
933 build(odsBuilder, odsState, grid.getSymName());
936 void ProcessLinearIndexOp::getAsmResultNames(
938 setNameFn(getResult(),
"proc_linear_idx");
954 void NeighborsLinearIndicesOp::getAsmResultNames(
956 setNameFn(getNeighborDown(),
"down_linear_idx");
957 setNameFn(getNeighborUp(),
"up_linear_idx");
966 template <
typename Op>
969 LogicalResult matchAndRewrite(
Op op,
971 auto gridAxes = op.getGridAxes();
972 if (!gridAxes.empty()) {
975 if (op.getInput().getType() != op.getResult().getType()) {
992 if (device.size() != gridAxes.size()) {
993 return emitError(loc) <<
"In-group device \"" << deviceName
994 <<
"\" has unexpected multi-index size "
995 << device.size() <<
". Expected " << gridAxes.size()
999 for (
size_t i = 0; i < device.size(); ++i) {
1000 if (ShapedType::isStatic(device[i]) &&
1001 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1002 gridShape[gridAxes[i]] <= device[i]) {
1004 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1005 << deviceName <<
"\"."
1006 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1007 << (gridShape[gridAxes[i]] - 1) <<
"].";
1014 int64_t expectedDimSize,
1015 int64_t resultDimSize,
1016 int64_t resultAxis) {
1017 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1018 return emitError(loc) <<
"Dimension size mismatch for result axis "
1019 << resultAxis <<
". Expected "
1020 << (ShapedType::isDynamic(expectedDimSize)
1022 : Twine(expectedDimSize))
1023 <<
", but got " << resultDimSize <<
".";
1030 Value operand,
Value result, int64_t gatherAxis,
1032 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
1033 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1035 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1036 << resultRank <<
").";
1039 ShapedType operandType = cast<ShapedType>(operand.
getType());
1040 ShapedType resultType = cast<ShapedType>(result.
getType());
1041 auto deviceGroupSize =
1043 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1044 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1045 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1046 auto expectedResultDimSize =
1047 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1049 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1057 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
1059 ShapedType operandType = cast<ShapedType>(operand.
getType());
1060 ShapedType resultType = cast<ShapedType>(result.
getType());
1061 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1062 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1064 result.
getLoc(), operandType.getDimSize(axis),
1065 resultType.getDimSize(axis), axis))) {
1071 if (splitAxis == concatAxis) {
1075 auto deviceGroupSize =
1077 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1078 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1079 DimensionSize expectedResultConcatDimSize =
1080 operandConcatDimSize * deviceGroupSize;
1081 DimensionSize expectedResultSplitDimSize =
1082 operandSplitDimSize / deviceGroupSize;
1083 if (!expectedResultSplitDimSize.isDynamic() &&
1084 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1085 expectedResultSplitDimSize = DimensionSize::dynamic();
1088 result.
getLoc(), expectedResultConcatDimSize.value(),
1089 resultType.getDimSize(concatAxis), concatAxis))) {
1093 result.
getLoc(), expectedResultSplitDimSize.value(),
1094 resultType.getDimSize(splitAxis), splitAxis))) {
1102 Value operand,
Value result, int64_t tensorAxis,
1104 ShapedType operandType = cast<ShapedType>(operand.
getType());
1105 ShapedType resultType = cast<ShapedType>(result.
getType());
1106 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1107 if (axis != tensorAxis) {
1109 result.
getLoc(), operandType.getDimSize(axis),
1110 resultType.getDimSize(axis), axis))) {
1116 auto deviceGroupSize =
1118 auto operandScatterDimSize =
1119 DimensionSize(operandType.getDimSize(tensorAxis));
1120 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1121 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1123 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1124 <<
" is not divisible by collective device group size "
1125 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1128 DimensionSize expectedResultTensorDimSize =
1129 operandScatterDimSize / deviceGroupSize;
1131 result.
getLoc(), expectedResultTensorDimSize.value(),
1132 resultType.getDimSize(tensorAxis), tensorAxis))) {
1141 int64_t sliceAxis) {
1142 RankedTensorType operandRankedTensorType =
1143 cast<RankedTensorType>(operandType);
1144 DimensionSize operandSliceAxisSize =
1145 operandRankedTensorType.getShape()[sliceAxis];
1147 llvm::to_vector(operandRankedTensorType.getShape());
1149 resultShape[sliceAxis] =
1150 operandSliceAxisSize /
1152 return operandRankedTensorType.clone(resultShape);
1165 auto gatherAxis = getGatherAxis().getSExtValue();
1167 gatherAxis, getGridAxes(),
1168 grid.value().getShape());
1173 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1176 void AllGatherOp::getAsmResultNames(
1178 setNameFn(getResult(),
"all_gather");
1192 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1196 Value input, StringRef grid,
1198 build(odsBuilder, odsState, input.
getType(), grid, gridAxes, input,
1202 void AllReduceOp::getAsmResultNames(
1204 setNameFn(getResult(),
"all_reduce");
1217 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1218 grid.value().getShape());
1223 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1228 int64_t sliceAxis) {
1230 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1235 Type resultType,
Value input, StringRef grid,
1237 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1238 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1241 void AllSliceOp::getAsmResultNames(
1243 setNameFn(getResult(),
"all_slice");
1257 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1258 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1263 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1266 void AllToAllOp::getAsmResultNames(
1268 setNameFn(getResult(),
"all_to_all");
1282 getRootDynamic(), getGridAxes(),
1283 grid.value().getShape()))) {
1292 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1295 void BroadcastOp::getAsmResultNames(
1297 setNameFn(getResult(),
"broadcast");
1310 getRootDynamic(), getGridAxes(),
1311 grid.value().getShape()))) {
1315 auto gatherAxis = getGatherAxis().getSExtValue();
1318 grid.value().getShape());
1323 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1326 void GatherOp::getAsmResultNames(
1328 setNameFn(getResult(),
"gather");
1342 getSource().value(), getSourceDynamic(),
1343 getGridAxes(), grid.value().getShape()))) {
1351 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1354 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1355 setNameFn(getResult(),
"recv");
1368 getRootDynamic(), getGridAxes(),
1369 grid.value().getShape()))) {
1378 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1381 void ReduceOp::getAsmResultNames(
1383 setNameFn(getResult(),
"reduce");
1398 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1399 grid.value().getShape());
1404 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1407 void ReduceScatterOp::getAsmResultNames(
1409 setNameFn(getResult(),
"reduce_scatter");
1422 getRootDynamic(), getGridAxes(),
1423 grid.value().getShape()))) {
1427 auto scatterAxis = getScatterAxis().getSExtValue();
1429 scatterAxis, getGridAxes(),
1430 grid.value().getShape());
1435 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1438 void ScatterOp::getAsmResultNames(
1440 setNameFn(getResult(),
"scatter");
1453 getDestination(), getDestinationDynamic(),
1454 getGridAxes(), grid.value().getShape()))) {
1462 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1465 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1466 setNameFn(getResult(),
"send");
1479 auto gridAxes = getGridAxes();
1480 auto shiftAxis = getShiftAxis().getZExtValue();
1481 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1482 return emitError() <<
"Invalid shift axis " << shiftAxis
1483 <<
". It must be one of the grouping grid axes.";
1495 void ShiftOp::getAsmResultNames(
1497 setNameFn(getResult(),
"shift");
1518 #define GET_OP_CLASSES
1519 #include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1521 #define GET_ATTRDEF_CLASSES
1522 #include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1524 #define GET_TYPEDEF_CLASSES
1525 #include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1527 #include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
bool isUnique(It begin, It end)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
shard::ReductionKind ReductionKind
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.