28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
40 #define DEBUG_TYPE "shard-ops"
45 #include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
49 struct DimensionSize {
50 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(int64_t val) : val(val) {}
52 int64_t value()
const {
return val; }
53 operator int64_t()
const {
return val; }
54 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
62 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
63 if (lhs.isDynamic() || rhs.isDynamic()) {
64 return DimensionSize::dynamic();
66 return lhs.value() / rhs.value();
69 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
70 if (lhs.isDynamic() || rhs.isDynamic()) {
71 return DimensionSize::dynamic();
73 return lhs.value() * rhs.value();
81 auto dyn = dynamics.begin();
86 "expected an i64 or an intex type");
87 for (
auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
92 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
122 void ShardDialect::initialize() {
125 #include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
128 #define GET_ATTRDEF_LIST
129 #include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
132 #define GET_TYPEDEF_LIST
133 #include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
135 addInterface<ShardInlinerinterface>();
141 return arith::ConstantOp::materialize(builder, value, type, loc);
153 return op->
emitError() <<
"Undefined required grid symbol \""
160 template <
typename It>
165 It next = std::next(begin);
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
181 if (!
isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) <<
"Grid axes contains duplicate elements.";
186 for (
auto axis : axes) {
187 if (axis >= rank || axis < 0) {
189 <<
"0-based grid axis index " << axis
190 <<
" is out of bounds. The referenced grid \"" << grid.getSymName()
191 <<
"\" is of rank " << rank <<
".";
198 template <
typename Op>
199 static FailureOr<GridOp>
212 template <
typename InShape,
typename GridShape,
typename SplitAxes,
214 static void shardShape(
const InShape &inShape,
const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
219 if (inShape.empty()) {
220 assert(outShape.empty());
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
238 uint64_t numShards = 0;
239 for (
auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
242 for (
size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
249 pos += numShards + 1;
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
261 if (!haloSizes.empty()) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
273 outShape[tensorAxis] = ShapedType::kDynamic;
283 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
288 return shape.clone(resShapeArr);
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
303 ShardOp &newShardOp) {
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
311 newShardOp = shardOp;
318 ShardingOp::create(builder, operandValue.
getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.
getLoc(), operandValue,
324 newShardOp, [operandOp, operandValue](
OpOperand &use) {
325 return use.
getOwner() == operandOp && use.
get() == operandValue;
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
332 auto newShardOp2 = ShardOp::create(builder, operandValue.
getLoc(), newShardOp,
333 newShardOp.getSharding(),
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
343 for (
auto &use : result.
getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
346 for (
auto &[operandValue, operandOp] : uses) {
348 builder, newShardOp);
358 bool isBlockArg = !operandSrcOp;
360 [[maybe_unused]]
auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
364 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
380 ShardingOp::create(builder, operand.
get().
getLoc(), sharding);
382 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
386 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
387 return use.
getOwner() == operandOp && use.
get() == operandValue;
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
410 int64_t rank = getRank();
413 return emitOpError(
"rank of grid is expected to be a positive integer");
415 for (int64_t dimSize :
getShape()) {
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError(
"dimension size of a grid is expected to be "
418 "non-negative or dynamic");
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() <<
"Unexpected number of results " << getResult().size()
442 <<
". Expected " << expectedResultsCount <<
".";
455 build(odsBuilder, odsState,
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
469 void GridShapeOp::getAsmResultNames(
471 setNameFn(getResults()[0],
"grid_shape");
500 void ShardingOp::build(
531 llvm::SmallSet<GridAxis, 4> visitedAxes;
536 return emitError() <<
"grid axis is expected to be non-negative";
537 if (!visitedAxes.insert(axis).second)
538 return emitError() <<
"grid axis duplicated";
543 for (
auto subAxes : getSplitAxes().getAxes()) {
545 if (
failed(checkGridAxis(subAxesArray)))
549 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
553 if (!getStaticHaloSizes().empty()) {
554 auto numSplitAxes = getSplitAxes().getAxes().size();
555 for (
auto splitAxis : getSplitAxes().getAxes()) {
556 if (splitAxis.empty()) {
560 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561 return emitError() <<
"halo sizes must be specified for all split axes.";
568 void ShardingOp::getAsmResultNames(
570 setNameFn(getResult(),
"sharding");
578 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
579 getStaticShardedDimsOffsets().size() > 0) {
580 return emitError() <<
"sharded dims offsets are not allowed for "
581 "device grids with dynamic shape.";
584 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585 if (!shardedDimsOffsets.empty()) {
586 auto gridShape = grid.value().getShape();
587 assert(ShapedType::isStaticShape(gridShape));
589 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
590 if (!innerSplitAxes.empty()) {
591 int64_t numShards = 0, off = 0;
592 for (
auto i : innerSplitAxes.asArrayRef()) {
593 numShards += gridShape[i];
595 for (int64_t i = 0; i <= numShards; ++i) {
596 if (shardedDimsOffsets.size() <= pos + i) {
597 return emitError() <<
"sharded dims offsets has wrong size.";
599 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
600 if (shardedDimsOffsets[pos + i] < off) {
602 <<
"sharded dims offsets must be non-decreasing.";
604 off = shardedDimsOffsets[pos + i];
607 pos += numShards + 1;
624 LogicalResult matchAndRewrite(ShardingOp op,
627 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
629 op.getDynamicShardedDimsOffsets(), b);
638 if (dynamicHalos.empty() && !staticHalos.empty()) {
639 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
650 if (dynamicOffs.empty() && !staticOffs.empty()) {
651 assert(staticOffs.size() >= 2);
652 auto diff = staticOffs[1] - staticOffs[0];
653 bool all_same = staticOffs.size() > 2;
654 for (
auto i = 2u; i < staticOffs.size(); ++i) {
655 if (staticOffs[i] - staticOffs[i - 1] != diff) {
670 op.setStaticHaloSizes(staticHalos);
671 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672 op.setStaticShardedDimsOffsets(staticOffs);
673 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
682 results.
add<NormalizeSharding>(context);
695 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
702 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
703 std::mem_fn(&GridAxesAttr::empty)) &&
704 llvm::all_of(llvm::drop_begin(rhs.
getSplitAxes(), minSize),
705 std::mem_fn(&GridAxesAttr::empty));
756 assert(shardingOp &&
"expected sharding op");
757 auto splitAxes = shardingOp.getSplitAxes().getAxes();
759 if (splitAxes.empty()) {
760 *
this =
Sharding(shardingOp.getGridAttr());
764 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
765 shardingOp.getStaticShardedDimsOffsets(),
777 if (split_axes_.empty()) {
781 res.split_axes.resize(split_axes_.size());
787 auto clone = [](
const auto src,
auto &dst) {
788 dst.resize(src.size());
792 clone(static_halo_sizes_, res.static_halo_sizes);
793 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
794 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
795 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
804 void ShardShapeOp::getAsmResultNames(
806 setNameFn(getResult()[0],
"shard_shape");
815 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
823 void ShardOp::getAsmResultNames(
825 setNameFn(getResult(),
"sharding_annotated");
835 LogicalResult matchAndRewrite(ShardOp op,
PatternRewriter &b)
const override {
838 Value value = op.getSrc();
844 for (
auto &use : value.
getUses()) {
845 if (use.getOwner() != op.getOperation()) {
846 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
847 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
852 Sharding currentSharding(op.getSharding());
853 Sharding otherSharding(otherOp.getSharding());
854 if (currentSharding == otherSharding) {
859 op.getSrcMutable().assign(otherOp.getResult());
872 results.
add<FoldDuplicateShardOp>(context);
889 size_t expectedResultsCount =
890 getAxes().empty() ? grid->getRank() : getAxes().size();
891 if (getResult().size() != expectedResultsCount) {
892 return emitError() <<
"Unexpected number of results " << getResult().size()
893 <<
". Expected " << expectedResultsCount <<
".";
901 build(odsBuilder, odsState,
908 build(odsBuilder, odsState,
913 void ProcessMultiIndexOp::getAsmResultNames(
915 setNameFn(getResults()[0],
"proc_linear_idx");
931 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
933 build(odsBuilder, odsState, grid.getSymName());
936 void ProcessLinearIndexOp::getAsmResultNames(
938 setNameFn(getResult(),
"proc_linear_idx");
954 void NeighborsLinearIndicesOp::getAsmResultNames(
956 setNameFn(getNeighborDown(),
"down_linear_idx");
957 setNameFn(getNeighborUp(),
"up_linear_idx");
966 template <
typename Op>
969 LogicalResult matchAndRewrite(
Op op,
971 auto gridAxes = op.getGridAxes();
972 if (!gridAxes.empty()) {
975 if (op.getInput().getType() != op.getResult().getType()) {
992 if (device.size() != gridAxes.size()) {
993 return emitError(loc) <<
"In-group device \"" << deviceName
994 <<
"\" has unexpected multi-index size "
995 << device.size() <<
". Expected " << gridAxes.size()
999 for (
size_t i = 0; i < device.size(); ++i) {
1000 if (ShapedType::isStatic(device[i]) &&
1001 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1002 gridShape[gridAxes[i]] <= device[i]) {
1004 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1005 << deviceName <<
"\"."
1006 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1007 << (gridShape[gridAxes[i]] - 1) <<
"].";
1013 template <
typename It>
1015 using ElementType = std::decay_t<decltype(*begin)>;
1016 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
1017 std::multiplies<ElementType>());
1020 template <
typename R>
1022 return product(adl_begin(range), adl_end(range));
1026 int64_t expectedDimSize,
1027 int64_t resultDimSize,
1028 int64_t resultAxis) {
1029 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1030 return emitError(loc) <<
"Dimension size mismatch for result axis "
1031 << resultAxis <<
". Expected "
1032 << (ShapedType::isDynamic(expectedDimSize)
1034 : Twine(expectedDimSize))
1035 <<
", but got " << resultDimSize <<
".";
1042 Value operand,
Value result, int64_t gatherAxis,
1044 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
1045 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1047 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1048 << resultRank <<
").";
1051 ShapedType operandType = cast<ShapedType>(operand.
getType());
1052 ShapedType resultType = cast<ShapedType>(result.
getType());
1053 auto deviceGroupSize =
1055 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1056 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1057 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1058 auto expectedResultDimSize =
1059 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1061 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1069 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
1071 ShapedType operandType = cast<ShapedType>(operand.
getType());
1072 ShapedType resultType = cast<ShapedType>(result.
getType());
1073 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1074 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1076 result.
getLoc(), operandType.getDimSize(axis),
1077 resultType.getDimSize(axis), axis))) {
1083 if (splitAxis == concatAxis) {
1087 auto deviceGroupSize =
1089 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1090 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1091 DimensionSize expectedResultConcatDimSize =
1092 operandConcatDimSize * deviceGroupSize;
1093 DimensionSize expectedResultSplitDimSize =
1094 operandSplitDimSize / deviceGroupSize;
1095 if (!expectedResultSplitDimSize.isDynamic() &&
1096 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1097 expectedResultSplitDimSize = DimensionSize::dynamic();
1100 result.
getLoc(), expectedResultConcatDimSize.value(),
1101 resultType.getDimSize(concatAxis), concatAxis))) {
1105 result.
getLoc(), expectedResultSplitDimSize.value(),
1106 resultType.getDimSize(splitAxis), splitAxis))) {
1114 Value operand,
Value result, int64_t tensorAxis,
1116 ShapedType operandType = cast<ShapedType>(operand.
getType());
1117 ShapedType resultType = cast<ShapedType>(result.
getType());
1118 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1119 if (axis != tensorAxis) {
1121 result.
getLoc(), operandType.getDimSize(axis),
1122 resultType.getDimSize(axis), axis))) {
1128 auto deviceGroupSize =
1130 auto operandScatterDimSize =
1131 DimensionSize(operandType.getDimSize(tensorAxis));
1132 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1133 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1135 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1136 <<
" is not divisible by collective device group size "
1137 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1140 DimensionSize expectedResultTensorDimSize =
1141 operandScatterDimSize / deviceGroupSize;
1143 result.
getLoc(), expectedResultTensorDimSize.value(),
1144 resultType.getDimSize(tensorAxis), tensorAxis))) {
1153 int64_t sliceAxis) {
1154 RankedTensorType operandRankedTensorType =
1155 cast<RankedTensorType>(operandType);
1156 DimensionSize operandSliceAxisSize =
1157 operandRankedTensorType.getShape()[sliceAxis];
1159 llvm::to_vector(operandRankedTensorType.getShape());
1161 resultShape[sliceAxis] =
1162 operandSliceAxisSize /
1164 return operandRankedTensorType.clone(resultShape);
1177 auto gatherAxis = getGatherAxis().getSExtValue();
1179 gatherAxis, getGridAxes(),
1180 grid.value().getShape());
1185 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1188 void AllGatherOp::getAsmResultNames(
1190 setNameFn(getResult(),
"all_gather");
1204 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1208 Value input, StringRef grid,
1210 build(odsBuilder, odsState, input.
getType(), grid, gridAxes, input,
1214 void AllReduceOp::getAsmResultNames(
1216 setNameFn(getResult(),
"all_reduce");
1229 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1230 grid.value().getShape());
1235 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1240 int64_t sliceAxis) {
1242 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1247 Type resultType,
Value input, StringRef grid,
1249 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1250 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1253 void AllSliceOp::getAsmResultNames(
1255 setNameFn(getResult(),
"all_slice");
1269 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1270 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1275 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1278 void AllToAllOp::getAsmResultNames(
1280 setNameFn(getResult(),
"all_to_all");
1294 getRootDynamic(), getGridAxes(),
1295 grid.value().getShape()))) {
1304 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1307 void BroadcastOp::getAsmResultNames(
1309 setNameFn(getResult(),
"broadcast");
1322 getRootDynamic(), getGridAxes(),
1323 grid.value().getShape()))) {
1327 auto gatherAxis = getGatherAxis().getSExtValue();
1330 grid.value().getShape());
1335 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1338 void GatherOp::getAsmResultNames(
1340 setNameFn(getResult(),
"gather");
1354 getSource().value(), getSourceDynamic(),
1355 getGridAxes(), grid.value().getShape()))) {
1363 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1366 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1367 setNameFn(getResult(),
"recv");
1380 getRootDynamic(), getGridAxes(),
1381 grid.value().getShape()))) {
1390 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1393 void ReduceOp::getAsmResultNames(
1395 setNameFn(getResult(),
"reduce");
1410 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1411 grid.value().getShape());
1416 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1419 void ReduceScatterOp::getAsmResultNames(
1421 setNameFn(getResult(),
"reduce_scatter");
1434 getRootDynamic(), getGridAxes(),
1435 grid.value().getShape()))) {
1439 auto scatterAxis = getScatterAxis().getSExtValue();
1441 scatterAxis, getGridAxes(),
1442 grid.value().getShape());
1447 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1450 void ScatterOp::getAsmResultNames(
1452 setNameFn(getResult(),
"scatter");
1465 getDestination(), getDestinationDynamic(),
1466 getGridAxes(), grid.value().getShape()))) {
1474 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1477 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1478 setNameFn(getResult(),
"send");
1491 auto gridAxes = getGridAxes();
1492 auto shiftAxis = getShiftAxis().getZExtValue();
1493 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1494 return emitError() <<
"Invalid shift axis " << shiftAxis
1495 <<
". It must be one of the grouping grid axes.";
1507 void ShiftOp::getAsmResultNames(
1509 setNameFn(getResult(),
"shift");
1530 #define GET_OP_CLASSES
1531 #include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1533 #define GET_ATTRDEF_CLASSES
1534 #include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1536 #define GET_TYPEDEF_CLASSES
1537 #include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1539 #include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static auto product(It begin, It end)
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
bool isUnique(It begin, It end)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
shard::ReductionKind ReductionKind
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.