28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
40#define DEBUG_TYPE "shard-ops"
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
50 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(
int64_t val) : val(val) {}
52 int64_t value()
const {
return val; }
53 operator int64_t()
const {
return val; }
54 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
63 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
64 return DimensionSize::dynamic();
66 return lhs.value() /
rhs.value();
70 if (
lhs.isDynamic() ||
rhs.isDynamic()) {
71 return DimensionSize::dynamic();
73 return lhs.value() *
rhs.value();
81 auto dyn = dynamics.begin();
82 Type i64 =
b.getI64Type();
85 assert((i64 == type ||
b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (
auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
91 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
103struct ShardInlinerinterface :
public DialectInlinerInterface {
104 using DialectInlinerInterface::DialectInlinerInterface;
112 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
122void ShardDialect::initialize() {
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
135 addInterface<ShardInlinerinterface>();
141 return arith::ConstantOp::materialize(builder, value, type, loc);
151 shard::GridOp grid =
getGridOrNull(op, gridSymbol, symbolTable);
153 return op->
emitError() <<
"Undefined required grid symbol \""
160template <
typename It>
165 It next = std::next(begin);
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
181 if (!
isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) <<
"Grid axes contains duplicate elements.";
186 for (
auto axis : axes) {
187 if (axis >= rank || axis < 0) {
189 <<
"0-based grid axis index " << axis
190 <<
" is out of bounds. The referenced grid \"" << grid.getSymName()
191 <<
"\" is of rank " << rank <<
".";
198template <
typename Op>
199static FailureOr<GridOp>
212template <
typename InShape,
typename GridShape,
typename SplitAxes,
214static void shardShape(
const InShape &inShape,
const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
219 if (inShape.empty()) {
220 assert(outShape.empty());
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
230 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
238 uint64_t numShards = 0;
239 for (
auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
242 for (
size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
249 pos += numShards + 1;
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
255 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
261 if (!haloSizes.empty()) {
264 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
273 outShape[tensorAxis] = ShapedType::kDynamic;
283 using Dim = std::decay_t<
decltype(
shape.getDimSize(0))>;
288 return shape.clone(resShapeArr);
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
303 ShardOp &newShardOp) {
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
311 newShardOp = shardOp;
318 ShardingOp::create(builder, operandValue.
getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.
getLoc(), operandValue,
324 newShardOp, [operandOp, operandValue](
OpOperand &use) {
325 return use.
getOwner() == operandOp && use.
get() == operandValue;
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
332 auto newShardOp2 = ShardOp::create(builder, operandValue.
getLoc(), newShardOp,
333 newShardOp.getSharding(),
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
343 for (
auto &use :
result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
346 for (
auto &[operandValue, operandOp] : uses) {
348 builder, newShardOp);
358 bool isBlockArg = !operandSrcOp;
360 [[maybe_unused]]
auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
364 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
380 ShardingOp::create(builder, operand.
get().
getLoc(), sharding);
382 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
386 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
387 return use.
getOwner() == operandOp && use.
get() == operandValue;
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.
getLoc(), operandValue, shardingOp,
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
409LogicalResult GridOp::verify() {
413 return emitOpError(
"rank of grid is expected to be a positive integer");
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError(
"dimension size of a grid is expected to be "
418 "non-negative or dynamic");
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() <<
"Unexpected number of results " << getResult().size()
442 <<
". Expected " << expectedResultsCount <<
".";
455 build(odsBuilder, odsState,
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
469void GridShapeOp::getAsmResultNames(
471 setNameFn(getResults()[0],
"grid_shape");
483 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), splitAxes),
493 GridAxesArrayAttr::get(
b.getContext(), splitAxes),
499void ShardingOp::build(
509 b, odsState, grid, GridAxesArrayAttr::get(
b.getContext(), splitAxes),
517 build(
b, odsState, ShardingType::get(
b.getContext()), from.
getGridAttr(),
529LogicalResult ShardingOp::verify() {
530 llvm::SmallSet<GridAxis, 4> visitedAxes;
535 return emitError() <<
"grid axis is expected to be non-negative";
536 if (!visitedAxes.insert(axis).second)
537 return emitError() <<
"grid axis duplicated";
542 for (
auto subAxes : getSplitAxes().getAxes()) {
544 if (
failed(checkGridAxis(subAxesArray)))
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (
auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() <<
"halo sizes must be specified for all split axes.";
567void ShardingOp::getAsmResultNames(
569 setNameFn(getResult(),
"sharding");
577 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
578 !getStaticShardedDimsOffsets().empty()) {
579 return emitError() <<
"sharded dims offsets are not allowed for "
580 "device grids with dynamic shape.";
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto gridShape = grid.value().getShape();
586 assert(ShapedType::isStaticShape(gridShape));
588 for (
auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (
auto i : innerSplitAxes.asArrayRef()) {
592 numShards += gridShape[i];
594 for (
int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() <<
"sharded dims offsets has wrong size.";
598 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
601 <<
"sharded dims offsets must be non-decreasing.";
603 off = shardedDimsOffsets[pos + i];
606 pos += numShards + 1;
621 using OpRewritePattern<ShardingOp>::OpRewritePattern;
623 LogicalResult matchAndRewrite(ShardingOp op,
624 PatternRewriter &
b)
const override {
628 op.getDynamicShardedDimsOffsets(),
b);
637 if (dynamicHalos.empty() && !staticHalos.empty()) {
638 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
649 if (dynamicOffs.empty() && !staticOffs.empty()) {
650 assert(staticOffs.size() >= 2);
651 auto diff = staticOffs[1] - staticOffs[0];
652 bool allSame = staticOffs.size() > 2;
653 for (
auto i = 2u; i < staticOffs.size(); ++i) {
654 if (staticOffs[i] - staticOffs[i - 1] != diff) {
669 op.setStaticHaloSizes(staticHalos);
670 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
671 op.setStaticShardedDimsOffsets(staticOffs);
672 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
681 results.
add<NormalizeSharding>(context);
693 auto minSize = std::min(
getSplitAxes().size(),
rhs.getSplitAxes().size());
694 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
696 llvm::make_range(
rhs.getSplitAxes().begin(),
697 rhs.getSplitAxes().begin() + minSize))) {
701 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
702 std::mem_fn(&GridAxesAttr::empty)) &&
703 llvm::all_of(llvm::drop_begin(
rhs.getSplitAxes(), minSize),
704 std::mem_fn(&GridAxesAttr::empty));
712 if (
rhs.getStaticShardedDimsOffsets().size() !=
715 rhs.getStaticShardedDimsOffsets())) {
718 if (
rhs.getDynamicShardedDimsOffsets().size() !=
721 rhs.getDynamicShardedDimsOffsets())) {
753 os <<
"Sharding<grid=" << sharding.
getGrid() <<
", split_axes=[";
756 llvm::interleaveComma(axes.asArrayRef(), os);
761 os <<
", halo_sizes=[";
766 os <<
", sharded_dims_offsets=[";
777 auto shardingOp =
rhs.getDefiningOp<ShardingOp>();
778 assert(shardingOp &&
"expected sharding op");
779 auto splitAxes = shardingOp.getSplitAxes().getAxes();
781 if (splitAxes.empty()) {
782 *
this =
Sharding(shardingOp.getGridAttr());
786 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
787 shardingOp.getStaticShardedDimsOffsets(),
799 if (splitAxes.empty()) {
803 res.split_axes.resize(splitAxes.size());
804 for (
auto [i, axis] : llvm::enumerate(splitAxes)) {
808 auto clone = [](
const auto src,
auto &dst) {
809 dst.resize(src.size());
810 llvm::copy(src, dst.begin());
813 clone(staticHaloSizes, res.static_halo_sizes);
814 clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
815 clone(dynamicHaloSizes, res.dynamic_halo_sizes);
816 clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
825void ShardShapeOp::getAsmResultNames(
827 setNameFn(getResult()[0],
"shard_shape");
836 build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
844void ShardOp::getAsmResultNames(
846 setNameFn(getResult(),
"sharding_annotated");
852class FoldDuplicateShardOp final :
public OpRewritePattern<ShardOp> {
854 using OpRewritePattern<ShardOp>::OpRewritePattern;
856 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &
b)
const override {
859 Value value = op.getSrc();
865 for (
auto &use : value.
getUses()) {
866 if (use.getOwner() != op.getOperation()) {
867 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
868 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
873 Sharding currentSharding(op.getSharding());
874 Sharding otherSharding(otherOp.getSharding());
875 if (currentSharding == otherSharding) {
876 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
877 b.eraseOp(op.getOperation());
880 op.getSrcMutable().assign(otherOp.getResult());
891void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
892 mlir::MLIRContext *context) {
893 results.
add<FoldDuplicateShardOp>(context);
901ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
910 size_t expectedResultsCount =
911 getAxes().empty() ? grid->getRank() : getAxes().size();
912 if (getResult().size() != expectedResultsCount) {
913 return emitError() <<
"Unexpected number of results " << getResult().size()
914 <<
". Expected " << expectedResultsCount <<
".";
920void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
922 build(odsBuilder, odsState,
923 SmallVector<Type>(grid.getRank(), odsBuilder.
getIndexType()),
924 grid.getSymName(), ArrayRef<GridAxis>());
927void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
928 StringRef grid, ArrayRef<GridAxis> axes) {
929 build(odsBuilder, odsState,
930 SmallVector<Type>(axes.size(), odsBuilder.
getIndexType()), grid,
934void ProcessMultiIndexOp::getAsmResultNames(
936 setNameFn(getResults()[0],
"proc_linear_idx");
944ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
952void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
953 OperationState &odsState, GridOp grid) {
954 build(odsBuilder, odsState, grid.getSymName());
957void ProcessLinearIndexOp::getAsmResultNames(
959 setNameFn(getResult(),
"proc_linear_idx");
967NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
975void NeighborsLinearIndicesOp::getAsmResultNames(
977 setNameFn(getNeighborDown(),
"down_linear_idx");
978 setNameFn(getNeighborUp(),
"up_linear_idx");
987template <
typename Op>
988struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
989 using OpRewritePattern<
Op>::OpRewritePattern;
990 LogicalResult matchAndRewrite(Op op,
991 PatternRewriter &rewriter)
const override {
992 auto gridAxes = op.getGridAxes();
993 if (!gridAxes.empty()) {
996 if (op.getInput().getType() != op.getResult().getType()) {
1013 if (device.size() != gridAxes.size()) {
1014 return emitError(loc) <<
"In-group device \"" << deviceName
1015 <<
"\" has unexpected multi-index size "
1016 << device.size() <<
". Expected " << gridAxes.size()
1020 for (
size_t i = 0; i < device.size(); ++i) {
1021 if (ShapedType::isStatic(device[i]) &&
1022 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1023 gridShape[gridAxes[i]] <= device[i]) {
1025 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1026 << deviceName <<
"\"."
1027 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1028 << (gridShape[gridAxes[i]] - 1) <<
"].";
1038 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1039 return emitError(loc) <<
"Dimension size mismatch for result axis "
1040 << resultAxis <<
". Expected "
1041 << (ShapedType::isDynamic(expectedDimSize)
1043 : Twine(expectedDimSize))
1044 <<
", but got " << resultDimSize <<
".";
1053 auto resultRank = cast<ShapedType>(
result.getType()).getRank();
1054 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1056 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1057 << resultRank <<
").";
1060 ShapedType operandType = cast<ShapedType>(operand.
getType());
1061 ShapedType resultType = cast<ShapedType>(
result.getType());
1062 auto deviceGroupSize =
1064 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1065 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1066 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1067 auto expectedResultDimSize =
1068 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1070 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1080 ShapedType operandType = cast<ShapedType>(operand.
getType());
1081 ShapedType resultType = cast<ShapedType>(
result.getType());
1082 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1083 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1085 result.getLoc(), operandType.getDimSize(axis),
1086 resultType.getDimSize(axis), axis))) {
1092 if (splitAxis == concatAxis) {
1096 auto deviceGroupSize =
1098 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1099 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1100 DimensionSize expectedResultConcatDimSize =
1101 operandConcatDimSize * deviceGroupSize;
1102 DimensionSize expectedResultSplitDimSize =
1103 operandSplitDimSize / deviceGroupSize;
1104 if (!expectedResultSplitDimSize.isDynamic() &&
1106 expectedResultSplitDimSize = DimensionSize::dynamic();
1109 result.getLoc(), expectedResultConcatDimSize.value(),
1110 resultType.getDimSize(concatAxis), concatAxis))) {
1114 result.getLoc(), expectedResultSplitDimSize.value(),
1115 resultType.getDimSize(splitAxis), splitAxis))) {
1125 ShapedType operandType = cast<ShapedType>(operand.
getType());
1126 ShapedType resultType = cast<ShapedType>(
result.getType());
1127 for (
int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1128 if (axis != tensorAxis) {
1130 result.getLoc(), operandType.getDimSize(axis),
1131 resultType.getDimSize(axis), axis))) {
1137 auto deviceGroupSize =
1139 auto operandScatterDimSize =
1140 DimensionSize(operandType.getDimSize(tensorAxis));
1141 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1142 int64_t(operandScatterDimSize) %
int64_t(deviceGroupSize) != 0) {
1144 <<
"Operand dimension size " <<
int64_t(operandScatterDimSize)
1145 <<
" is not divisible by collective device group size "
1146 <<
int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1149 DimensionSize expectedResultTensorDimSize =
1150 operandScatterDimSize / deviceGroupSize;
1152 result.getLoc(), expectedResultTensorDimSize.value(),
1153 resultType.getDimSize(tensorAxis), tensorAxis))) {
1163 RankedTensorType operandRankedTensorType =
1164 cast<RankedTensorType>(operandType);
1165 DimensionSize operandSliceAxisSize =
1166 operandRankedTensorType.getShape()[sliceAxis];
1168 llvm::to_vector(operandRankedTensorType.getShape());
1170 resultShape[sliceAxis] =
1171 operandSliceAxisSize /
1173 return operandRankedTensorType.clone(resultShape);
1181AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1186 auto gatherAxis = getGatherAxis().getSExtValue();
1188 gatherAxis, getGridAxes(),
1189 grid.value().getShape());
1192void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1193 MLIRContext *context) {
1194 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1197void AllGatherOp::getAsmResultNames(
1199 setNameFn(getResult(),
"all_gather");
1207AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1211void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1212 MLIRContext *context) {
1213 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1216void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1217 Value input, StringRef grid,
1218 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1219 build(odsBuilder, odsState, input.
getType(), grid, gridAxes, input,
1223void AllReduceOp::getAsmResultNames(
1225 setNameFn(getResult(),
"all_reduce");
1232LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1238 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1239 grid.value().getShape());
1242void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1243 MLIRContext *context) {
1244 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1247void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1248 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1249 int64_t sliceAxis) {
1251 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1255void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1256 Type resultType, Value input, StringRef grid,
1257 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1258 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1259 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1262void AllSliceOp::getAsmResultNames(
1264 setNameFn(getResult(),
"all_slice");
1271LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1278 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1279 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1282void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1283 MLIRContext *context) {
1284 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1287void AllToAllOp::getAsmResultNames(
1289 setNameFn(getResult(),
"all_to_all");
1297BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1303 getRootDynamic(), getGridAxes(),
1304 grid.value().getShape()))) {
1311void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1312 MLIRContext *context) {
1313 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1316void BroadcastOp::getAsmResultNames(
1318 setNameFn(getResult(),
"broadcast");
1325LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1331 getRootDynamic(), getGridAxes(),
1332 grid.value().getShape()))) {
1336 auto gatherAxis = getGatherAxis().getSExtValue();
1339 grid.value().getShape());
1342void GatherOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1343 MLIRContext *context) {
1344 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1347void GatherOp::getAsmResultNames(
1349 setNameFn(getResult(),
"gather");
1356LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1363 getSource().value(), getSourceDynamic(),
1364 getGridAxes(), grid.value().getShape()))) {
1370void RecvOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1371 MLIRContext *context) {
1372 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1375void RecvOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1376 setNameFn(getResult(),
"recv");
1383LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1389 getRootDynamic(), getGridAxes(),
1390 grid.value().getShape()))) {
1397void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1398 MLIRContext *context) {
1399 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1402void ReduceOp::getAsmResultNames(
1404 setNameFn(getResult(),
"reduce");
1412ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1419 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1420 grid.value().getShape());
1423void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1424 MLIRContext *context) {
1425 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1428void ReduceScatterOp::getAsmResultNames(
1430 setNameFn(getResult(),
"reduce_scatter");
1437LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1443 getRootDynamic(), getGridAxes(),
1444 grid.value().getShape()))) {
1448 auto scatterAxis = getScatterAxis().getSExtValue();
1450 scatterAxis, getGridAxes(),
1451 grid.value().getShape());
1454void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1455 MLIRContext *context) {
1456 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1459void ScatterOp::getAsmResultNames(
1461 setNameFn(getResult(),
"scatter");
1468LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1474 getDestination(), getDestinationDynamic(),
1475 getGridAxes(), grid.value().getShape()))) {
1481void SendOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1482 MLIRContext *context) {
1483 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1486void SendOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1487 setNameFn(getResult(),
"send");
1494LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1500 auto gridAxes = getGridAxes();
1501 auto shiftAxis = getShiftAxis().getZExtValue();
1502 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1503 return emitError() <<
"Invalid shift axis " << shiftAxis
1504 <<
". It must be one of the grouping grid axes.";
1510void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1511 MLIRContext *context) {
1516void ShiftOp::getAsmResultNames(
1518 setNameFn(getResult(),
"shift");
1526UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1539#define GET_OP_CLASSES
1540#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1542#define GET_ATTRDEF_CLASSES
1543#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1545#define GET_TYPEDEF_CLASSES
1546#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1548#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const Sharding &sharding)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
DenseI16ArrayAttr GridAxesAttr
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.