28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
51 struct DimensionSize {
52 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
53 DimensionSize(int64_t val) : val(val) {}
54 int64_t value()
const {
return val; }
55 operator int64_t()
const {
return val; }
56 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
64 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
65 if (lhs.isDynamic() || rhs.isDynamic()) {
66 return DimensionSize::dynamic();
68 return lhs.value() / rhs.value();
71 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
72 if (lhs.isDynamic() || rhs.isDynamic()) {
73 return DimensionSize::dynamic();
75 return lhs.value() * rhs.value();
84 auto dyn = dynamics.begin();
89 "expected an i64 or an intex type");
90 for (
auto s : statics) {
91 if (s == ShapedType::kDynamic) {
92 values.emplace_back(*(dyn++));
95 values.emplace_back(b.
create<arith::ConstantOp>(loc, type, val));
125 void MeshDialect::initialize() {
128 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
131 #define GET_ATTRDEF_LIST
132 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
135 #define GET_TYPEDEF_LIST
136 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
138 addInterface<MeshInlinerInterface>();
143 return arith::ConstantOp::materialize(builder, value, type, loc);
155 return op->
emitError() <<
"Undefined required mesh symbol \""
162 template <
typename It>
167 It next = std::next(begin);
171 for (; next != end; ++next, ++begin) {
172 if (*begin == *next) {
183 if (!
isUnique(sorted.begin(), sorted.end())) {
184 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
188 for (
auto axis : axes) {
189 if (axis >= rank || axis < 0) {
191 <<
"0-based mesh axis index " << axis
192 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
193 <<
"\" is of rank " << rank <<
".";
200 template <
typename Op>
201 static FailureOr<MeshOp>
214 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
216 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
217 const SplitAxes &splitAxes, OutShape &outShape,
221 if (inShape.empty()) {
222 assert(outShape.empty());
226 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
227 llvm::adl_begin(outShape));
229 if (!shardedDimsOffsets.empty()) {
230 auto isDynShape = ShapedType::isDynamicShape(meshShape);
233 if (!innerSplitAxes.empty()) {
234 auto sz = shardedDimsOffsets[pos];
235 bool same = !isDynShape;
240 uint64_t numShards = 0;
241 for (
auto i : innerSplitAxes.asArrayRef()) {
242 numShards += meshShape[i];
244 for (
size_t i = 1; i < numShards; ++i) {
245 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
251 pos += numShards + 1;
253 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
263 if (!haloSizes.empty()) {
267 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
268 !innerSplitAxes.empty()) {
269 if (haloSizes[haloAxis * 2] >= 0 &&
270 haloSizes[haloAxis * 2 + 1] >= 0) {
271 outShape[tensorAxis] +=
272 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
275 outShape[tensorAxis] = ShapedType::kDynamic;
285 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
290 return shape.clone(resShapeArr);
294 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
295 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
305 ShardOp &newShardOp) {
308 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
309 if (shardOp && sharding == shardOp.getSharding() &&
310 !shardOp.getAnnotateForUsers()) {
313 newShardOp = shardOp;
320 builder.
create<ShardingOp>(operandValue.
getLoc(), sharding);
322 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
326 newShardOp, [operandOp, operandValue](
OpOperand &use) {
327 return use.
getOwner() == operandOp && use.
get() == operandValue;
330 if (!shardOp || shardOp.getAnnotateForUsers()) {
334 auto newShardOp2 = builder.
create<ShardOp>(operandValue.
getLoc(), newShardOp,
335 newShardOp.getSharding(),
345 for (
auto &use : result.
getUses()) {
346 uses.emplace_back(use.get(), use.getOwner());
348 for (
auto &[operandValue, operandOp] : uses) {
350 builder, newShardOp);
360 bool isBlockArg = !operandSrcOp;
362 [[maybe_unused]]
auto opType =
363 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
366 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
372 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
374 if (shardOp && sharding == shardOp.getSharding() &&
375 shardOp.getAnnotateForUsers()) {
384 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
387 rewriter.replaceUsesWithIf(
388 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
389 return use.
getOwner() == operandOp && use.
get() == operandValue;
392 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
398 auto newPreceedingShardOp =
399 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
402 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
403 return use.
getOwner() == newShardOp.getOperation();
412 int64_t rank = getRank();
415 return emitOpError(
"rank of mesh is expected to be a positive integer");
417 for (int64_t dimSize :
getShape()) {
418 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
419 return emitOpError(
"dimension size of a mesh is expected to be "
420 "non-negative or dynamic");
440 size_t expectedResultsCount =
441 getAxes().empty() ? mesh->getRank() : getAxes().size();
442 if (getResult().size() != expectedResultsCount) {
443 return emitError() <<
"Unexpected number of results " << getResult().size()
444 <<
". Expected " << expectedResultsCount <<
".";
457 build(odsBuilder, odsState,
465 assert(!axes.empty());
466 build(odsBuilder, odsState,
471 void MeshShapeOp::getAsmResultNames(
473 setNameFn(getResults()[0],
"mesh_shape");
516 void ShardingOp::build(
553 llvm::SmallSet<MeshAxis, 4> visitedAxes;
558 return emitError() <<
"mesh axis is expected to be non-negative";
559 if (!visitedAxes.insert(axis).second)
560 return emitError() <<
"mesh axis duplicated";
565 for (
auto subAxes : getSplitAxes().getAxes()) {
567 if (failed(checkMeshAxis(subAxesArray)))
570 if (getPartialAxes().has_value() &&
571 failed(checkMeshAxis(getPartialAxes().value())))
574 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
575 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
578 if (!getStaticHaloSizes().empty()) {
579 auto numSplitAxes = getSplitAxes().getAxes().size();
580 for (
auto splitAxis : getSplitAxes().getAxes()) {
581 if (splitAxis.empty()) {
585 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
586 return emitError() <<
"halo sizes must be specified for all split axes.";
593 void ShardingOp::getAsmResultNames(
595 setNameFn(getResult(),
"sharding");
603 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
604 getStaticShardedDimsOffsets().size() > 0) {
605 return emitError() <<
"sharded dims offsets are not allowed for "
606 "devices meshes with dynamic shape.";
609 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
610 if (!shardedDimsOffsets.empty()) {
611 auto meshShape = mesh.value().getShape();
612 assert(!ShapedType::isDynamicShape(meshShape));
614 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
615 if (!innerSplitAxes.empty()) {
616 int64_t numShards = 0, off = 0;
617 for (
auto i : innerSplitAxes.asArrayRef()) {
618 numShards += meshShape[i];
620 for (int64_t i = 0; i <= numShards; ++i) {
621 if (shardedDimsOffsets.size() <= pos + i) {
622 return emitError() <<
"sharded dims offsets has wrong size.";
624 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
625 if (shardedDimsOffsets[pos + i] < off) {
627 <<
"sharded dims offsets must be non-decreasing.";
629 off = shardedDimsOffsets[pos + i];
632 pos += numShards + 1;
649 LogicalResult matchAndRewrite(ShardingOp op,
652 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
654 op.getDynamicShardedDimsOffsets(), b);
663 if (dynamicHalos.empty() && !staticHalos.empty()) {
664 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
675 if (dynamicOffs.empty() && !staticOffs.empty()) {
676 assert(staticOffs.size() >= 2);
677 auto diff = staticOffs[1] - staticOffs[0];
678 bool all_same = staticOffs.size() > 2;
679 for (
auto i = 2u; i < staticOffs.size(); ++i) {
680 if (staticOffs[i] - staticOffs[i - 1] != diff) {
695 op.setStaticHaloSizes(staticHalos);
696 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
697 op.setStaticShardedDimsOffsets(staticOffs);
698 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
707 results.
add<NormalizeSharding>(context);
726 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
733 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
734 std::mem_fn(&MeshAxesAttr::empty)) &&
735 llvm::all_of(llvm::drop_begin(rhs.
getSplitAxes(), minSize),
736 std::mem_fn(&MeshAxesAttr::empty));
782 return !(*
this == rhs);
788 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.
getDefiningOp());
789 assert(shardingOp &&
"expected sharding op");
790 auto splitAxes = shardingOp.getSplitAxes().getAxes();
793 if (splitAxes.empty() && partialAxes.empty()) {
797 *
this =
get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
798 shardingOp.getPartialType().value_or(ReductionKind::Sum),
799 shardingOp.getStaticHaloSizes(),
800 shardingOp.getStaticShardedDimsOffsets(),
814 if (split_axes_.empty() && partial_axes_.empty()) {
818 res.split_axes.resize(split_axes_.size());
824 auto clone = [](
const auto src,
auto &dst) {
825 dst.resize(src.size());
829 clone(partial_axes_, res.partial_axes);
830 res.partial_type = partial_type_;
831 clone(static_halo_sizes_, res.static_halo_sizes);
832 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
833 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
834 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
843 void ShardShapeOp::getAsmResultNames(
845 setNameFn(getResult()[0],
"shard_shape");
854 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
862 void ShardOp::getAsmResultNames(
864 setNameFn(getResult(),
"sharding_annotated");
874 LogicalResult matchAndRewrite(ShardOp op,
PatternRewriter &b)
const override {
877 Value value = op.getSrc();
883 for (
auto &use : value.
getUses()) {
884 if (use.getOwner() != op.getOperation()) {
885 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
886 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
893 if (currentSharding == otherSharding) {
898 op.getSrcMutable().assign(otherOp.getResult());
911 results.
add<FoldDuplicateShardOp>(context);
928 size_t expectedResultsCount =
929 getAxes().empty() ? mesh->getRank() : getAxes().size();
930 if (getResult().size() != expectedResultsCount) {
931 return emitError() <<
"Unexpected number of results " << getResult().size()
932 <<
". Expected " << expectedResultsCount <<
".";
940 build(odsBuilder, odsState,
947 build(odsBuilder, odsState,
952 void ProcessMultiIndexOp::getAsmResultNames(
954 setNameFn(getResults()[0],
"proc_linear_idx");
970 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
972 build(odsBuilder, odsState, mesh.getSymName());
975 void ProcessLinearIndexOp::getAsmResultNames(
977 setNameFn(getResult(),
"proc_linear_idx");
993 void NeighborsLinearIndicesOp::getAsmResultNames(
995 setNameFn(getNeighborDown(),
"down_linear_idx");
996 setNameFn(getNeighborUp(),
"up_linear_idx");
1005 template <
typename Op>
1008 LogicalResult matchAndRewrite(
Op op,
1010 auto meshAxes = op.getMeshAxes();
1011 if (!meshAxes.empty()) {
1014 if (op.getInput().getType() != op.getResult().getType()) {
1031 if (device.size() != meshAxes.size()) {
1032 return emitError(loc) <<
"In-group device \"" << deviceName
1033 <<
"\" has unexpected multi-index size "
1034 << device.size() <<
". Expected " << meshAxes.size()
1038 for (
size_t i = 0; i < device.size(); ++i) {
1039 if (!ShapedType::isDynamic(device[i]) &&
1040 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
1041 meshShape[meshAxes[i]] <= device[i]) {
1043 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1044 << deviceName <<
"\"."
1045 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1046 << (meshShape[meshAxes[i]] - 1) <<
"].";
1052 template <
typename It>
1054 using ElementType = std::decay_t<decltype(*begin)>;
1055 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
1056 std::multiplies<ElementType>());
1059 template <
typename R>
1061 return product(adl_begin(range), adl_end(range));
1065 int64_t expectedDimSize,
1066 int64_t resultDimSize,
1067 int64_t resultAxis) {
1068 if (!ShapedType::isDynamic(resultDimSize) &&
1069 expectedDimSize != resultDimSize) {
1070 return emitError(loc) <<
"Dimension size mismatch for result axis "
1071 << resultAxis <<
". Expected "
1072 << (ShapedType::isDynamic(expectedDimSize)
1074 : Twine(expectedDimSize))
1075 <<
", but got " << resultDimSize <<
".";
1082 Value operand,
Value result, int64_t gatherAxis,
1084 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
1085 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1087 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1088 << resultRank <<
").";
1091 ShapedType operandType = cast<ShapedType>(operand.
getType());
1092 ShapedType resultType = cast<ShapedType>(result.
getType());
1093 auto deviceGroupSize =
1095 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1096 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1097 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1098 auto expectedResultDimSize =
1099 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1101 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1109 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
1111 ShapedType operandType = cast<ShapedType>(operand.
getType());
1112 ShapedType resultType = cast<ShapedType>(result.
getType());
1113 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1114 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1116 result.
getLoc(), operandType.getDimSize(axis),
1117 resultType.getDimSize(axis), axis))) {
1123 if (splitAxis == concatAxis) {
1127 auto deviceGroupSize =
1129 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1130 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1131 DimensionSize expectedResultConcatDimSize =
1132 operandConcatDimSize * deviceGroupSize;
1133 DimensionSize expectedResultSplitDimSize =
1134 operandSplitDimSize / deviceGroupSize;
1135 if (!expectedResultSplitDimSize.isDynamic() &&
1136 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1137 expectedResultSplitDimSize = DimensionSize::dynamic();
1140 result.
getLoc(), expectedResultConcatDimSize.value(),
1141 resultType.getDimSize(concatAxis), concatAxis))) {
1145 result.
getLoc(), expectedResultSplitDimSize.value(),
1146 resultType.getDimSize(splitAxis), splitAxis))) {
1154 Value operand,
Value result, int64_t tensorAxis,
1156 ShapedType operandType = cast<ShapedType>(operand.
getType());
1157 ShapedType resultType = cast<ShapedType>(result.
getType());
1158 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1159 if (axis != tensorAxis) {
1161 result.
getLoc(), operandType.getDimSize(axis),
1162 resultType.getDimSize(axis), axis))) {
1168 auto deviceGroupSize =
1170 auto operandScatterDimSize =
1171 DimensionSize(operandType.getDimSize(tensorAxis));
1172 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1173 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1175 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1176 <<
" is not divisible by collective device group size "
1177 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1180 DimensionSize expectedResultTensorDimSize =
1181 operandScatterDimSize / deviceGroupSize;
1183 result.
getLoc(), expectedResultTensorDimSize.value(),
1184 resultType.getDimSize(tensorAxis), tensorAxis))) {
1193 int64_t sliceAxis) {
1194 RankedTensorType operandRankedTensorType =
1195 cast<RankedTensorType>(operandType);
1196 DimensionSize operandSliceAxisSize =
1197 operandRankedTensorType.getShape()[sliceAxis];
1199 llvm::to_vector(operandRankedTensorType.getShape());
1201 resultShape[sliceAxis] =
1202 operandSliceAxisSize /
1204 return operandRankedTensorType.clone(resultShape);
1217 auto gatherAxis = getGatherAxis().getSExtValue();
1219 gatherAxis, getMeshAxes(),
1220 mesh.value().getShape());
1225 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1228 void AllGatherOp::getAsmResultNames(
1230 setNameFn(getResult(),
"all_gather");
1244 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1248 Value input, StringRef mesh,
1250 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
1254 void AllReduceOp::getAsmResultNames(
1256 setNameFn(getResult(),
"all_reduce");
1269 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1270 mesh.value().getShape());
1275 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1280 int64_t sliceAxis) {
1282 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1287 Type resultType,
Value input, StringRef mesh,
1289 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1290 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1293 void AllSliceOp::getAsmResultNames(
1295 setNameFn(getResult(),
"all_slice");
1309 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1310 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1315 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1318 void AllToAllOp::getAsmResultNames(
1320 setNameFn(getResult(),
"all_to_all");
1334 getRootDynamic(), getMeshAxes(),
1335 mesh.value().getShape()))) {
1344 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1347 void BroadcastOp::getAsmResultNames(
1349 setNameFn(getResult(),
"broadcast");
1362 getRootDynamic(), getMeshAxes(),
1363 mesh.value().getShape()))) {
1367 auto gatherAxis = getGatherAxis().getSExtValue();
1370 mesh.value().getShape());
1375 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1378 void GatherOp::getAsmResultNames(
1380 setNameFn(getResult(),
"gather");
1394 getSource().value(), getSourceDynamic(),
1395 getMeshAxes(), mesh.value().getShape()))) {
1403 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1406 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1407 setNameFn(getResult(),
"recv");
1420 getRootDynamic(), getMeshAxes(),
1421 mesh.value().getShape()))) {
1430 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1433 void ReduceOp::getAsmResultNames(
1435 setNameFn(getResult(),
"reduce");
1450 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1451 mesh.value().getShape());
1456 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1459 void ReduceScatterOp::getAsmResultNames(
1461 setNameFn(getResult(),
"reduce_scatter");
1474 getRootDynamic(), getMeshAxes(),
1475 mesh.value().getShape()))) {
1479 auto scatterAxis = getScatterAxis().getSExtValue();
1481 scatterAxis, getMeshAxes(),
1482 mesh.value().getShape());
1487 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1490 void ScatterOp::getAsmResultNames(
1492 setNameFn(getResult(),
"scatter");
1505 getDestination(), getDestinationDynamic(),
1506 getMeshAxes(), mesh.value().getShape()))) {
1514 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1517 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1518 setNameFn(getResult(),
"send");
1531 auto meshAxes = getMeshAxes();
1532 auto shiftAxis = getShiftAxis().getZExtValue();
1533 if (!llvm::is_contained(meshAxes, shiftAxis)) {
1534 return emitError() <<
"Invalid shift axis " << shiftAxis
1535 <<
". It must be one of the grouping mesh axes.";
1547 void ShiftOp::getAsmResultNames(
1549 setNameFn(getResult(),
"shift");
1570 #define GET_OP_CLASSES
1571 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1573 #define GET_ATTRDEF_CLASSES
1574 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1576 #define GET_TYPEDEF_CLASSES
1577 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1579 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static auto product(It begin, It end)
static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
void replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl< Operation * > &exceptions)
Replace all uses of 'this' value with 'newValue', updating anything in the IR that uses 'this' to use...
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
MeshSharding(::mlir::FlatSymbolRefAttr mesh_=nullptr)
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalShardSizes(const MeshSharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isFullReplication(MeshSharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.