28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
51 struct DimensionSize {
52 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
53 DimensionSize(int64_t val) : val(val) {}
54 int64_t value()
const {
return val; }
55 operator int64_t()
const {
return val; }
56 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
64 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
65 if (lhs.isDynamic() || rhs.isDynamic()) {
66 return DimensionSize::dynamic();
68 return lhs.value() / rhs.value();
71 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
72 if (lhs.isDynamic() || rhs.isDynamic()) {
73 return DimensionSize::dynamic();
75 return lhs.value() * rhs.value();
102 void MeshDialect::initialize() {
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
115 addInterface<MeshInlinerInterface>();
120 return arith::ConstantOp::materialize(builder, value, type, loc);
132 return op->
emitError() <<
"Undefined required mesh symbol \""
139 template <
typename It>
144 It next = std::next(begin);
148 for (; next != end; ++next, ++begin) {
149 if (*begin == *next) {
160 if (!
isUnique(sorted.begin(), sorted.end())) {
161 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
165 for (
auto axis : axes) {
166 if (axis >= rank || axis < 0) {
168 <<
"0-based mesh axis index " << axis
169 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
170 <<
"\" is of rank " << rank <<
".";
177 template <
typename Op>
178 static FailureOr<MeshOp>
191 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
193 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
194 const SplitAxes &splitAxes, OutShape &outShape,
198 if (inShape.empty()) {
199 assert(outShape.empty());
203 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
204 llvm::adl_begin(outShape));
206 if (!shardedDimsOffsets.empty()) {
207 auto isDynShape = ShapedType::isDynamicShape(meshShape);
210 if (!innerSplitAxes.empty()) {
211 auto sz = shardedDimsOffsets[pos];
212 bool same = !isDynShape;
217 uint64_t numShards = 0;
218 for (
auto i : innerSplitAxes.asArrayRef()) {
219 numShards += meshShape[i];
221 for (
size_t i = 1; i < numShards; ++i) {
222 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
228 pos += numShards + 1;
230 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
240 if (!haloSizes.empty()) {
244 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
245 !innerSplitAxes.empty()) {
246 if (haloSizes[haloAxis * 2] >= 0 &&
247 haloSizes[haloAxis * 2 + 1] >= 0) {
248 outShape[tensorAxis] +=
249 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
252 outShape[tensorAxis] = ShapedType::kDynamic;
262 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
267 return shape.clone(resShapeArr);
271 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
272 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
281 ShardOp &newShardOp) {
286 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
287 if (shardOp && sharding == shardOp.getSharding() &&
288 !shardOp.getAnnotateForUsers()) {
291 newShardOp = shardOp;
298 builder.
create<ShardingOp>(operandValue.
getLoc(), sharding);
300 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
305 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
306 return use.
getOwner() == operandOp && use.
get() == operandValue;
309 if (!shardOp || shardOp.getAnnotateForUsers()) {
313 auto newShardOp2 = builder.
create<ShardOp>(operandValue.
getLoc(), newShardOp,
314 newShardOp.getSharding(),
324 for (
auto &use : llvm::make_early_inc_range(result.
getUses())) {
335 bool isBlockArg = !operandSrcOp;
337 [[maybe_unused]]
auto opType =
338 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
341 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
347 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
349 if (shardOp && sharding == shardOp.getSharding() &&
350 shardOp.getAnnotateForUsers()) {
359 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
362 rewriter.replaceUsesWithIf(
363 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
364 return use.
getOwner() == operandOp && use.
get() == operandValue;
367 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
373 auto newPreceedingShardOp =
374 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
377 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
378 return use.
getOwner() == newShardOp.getOperation();
387 int64_t rank = getRank();
390 return emitOpError(
"rank of mesh is expected to be a positive integer");
392 for (int64_t dimSize :
getShape()) {
393 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
394 return emitOpError(
"dimension size of a mesh is expected to be "
395 "non-negative or dynamic");
415 size_t expectedResultsCount =
416 getAxes().empty() ? mesh->getRank() : getAxes().size();
417 if (getResult().size() != expectedResultsCount) {
418 return emitError() <<
"Unexpected number of results " << getResult().size()
419 <<
". Expected " << expectedResultsCount <<
".";
432 build(odsBuilder, odsState,
440 assert(!axes.empty());
441 build(odsBuilder, odsState,
446 void MeshShapeOp::getAsmResultNames(
448 setNameFn(getResults()[0],
"mesh_shape");
491 void ShardingOp::build(
528 llvm::SmallSet<MeshAxis, 4> visitedAxes;
533 return emitError() <<
"mesh axis is expected to be non-negative";
534 if (!visitedAxes.insert(axis).second)
535 return emitError() <<
"mesh axis duplicated";
540 for (
auto subAxes : getSplitAxes().getAxes()) {
542 if (failed(checkMeshAxis(subAxesArray)))
545 if (getPartialAxes().has_value() &&
546 failed(checkMeshAxis(getPartialAxes().value())))
549 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
553 if (!getStaticHaloSizes().empty()) {
554 auto numSplitAxes = getSplitAxes().getAxes().size();
555 for (
auto splitAxis : getSplitAxes().getAxes()) {
556 if (splitAxis.empty()) {
560 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561 return emitError() <<
"halo sizes must be specified for all split axes.";
568 void ShardingOp::getAsmResultNames(
570 setNameFn(getResult(),
"sharding");
578 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
579 getStaticShardedDimsOffsets().size() > 0) {
580 return emitError() <<
"sharded dims offsets are not allowed for "
581 "devices meshes with dynamic shape.";
584 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585 if (!shardedDimsOffsets.empty()) {
586 auto meshShape = mesh.value().getShape();
587 assert(!ShapedType::isDynamicShape(meshShape));
589 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
590 if (!innerSplitAxes.empty()) {
591 int64_t numShards = 0, off = 0;
592 for (
auto i : innerSplitAxes.asArrayRef()) {
593 numShards += meshShape[i];
595 for (int64_t i = 0; i <= numShards; ++i) {
596 if (shardedDimsOffsets.size() <= pos + i) {
597 return emitError() <<
"sharded dims offsets has wrong size.";
599 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
600 if (shardedDimsOffsets[pos + i] < off) {
602 <<
"sharded dims offsets must be non-decreasing.";
604 off = shardedDimsOffsets[pos + i];
607 pos += numShards + 1;
624 LogicalResult matchAndRewrite(ShardingOp op,
627 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
629 op.getDynamicShardedDimsOffsets(), b);
638 if (dynamicHalos.empty() && !staticHalos.empty()) {
639 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
650 if (dynamicOffs.empty() && !staticOffs.empty()) {
651 assert(staticOffs.size() >= 2);
652 auto diff = staticOffs[1] - staticOffs[0];
653 bool all_same = staticOffs.size() > 2;
654 for (
auto i = 2u; i < staticOffs.size(); ++i) {
655 if (staticOffs[i] - staticOffs[i - 1] != diff) {
670 op.setStaticHaloSizes(staticHalos);
671 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672 op.setStaticShardedDimsOffsets(staticOffs);
673 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
682 results.
add<NormalizeSharding>(context);
704 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
711 return llvm::all_of(llvm::make_range(
getSplitAxes().begin() + minSize,
713 std::mem_fn(&MeshAxesAttr::empty)) &&
714 llvm::all_of(llvm::make_range(rhs.
getSplitAxes().begin() + minSize,
716 std::mem_fn(&MeshAxesAttr::empty));
773 return !(*
this == rhs);
779 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.
getDefiningOp());
780 assert(shardingOp &&
"expected sharding op");
781 auto splitAxes = shardingOp.getSplitAxes().getAxes();
784 if (splitAxes.empty() && partialAxes.empty()) {
788 *
this =
get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
789 shardingOp.getPartialType().value_or(ReductionKind::Sum),
790 shardingOp.getStaticHaloSizes(),
791 shardingOp.getStaticShardedDimsOffsets(),
805 if (split_axes_.empty() && partial_axes_.empty()) {
809 res.split_axes.resize(split_axes_.size());
815 auto clone = [](
const auto src,
auto &dst) {
816 dst.resize(src.size());
820 clone(partial_axes_, res.partial_axes);
821 res.partial_type = partial_type_;
822 clone(static_halo_sizes_, res.static_halo_sizes);
823 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
824 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
825 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
834 void ShardShapeOp::getAsmResultNames(
836 setNameFn(getResult()[0],
"shard_shape");
845 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
853 void ShardOp::getAsmResultNames(
855 setNameFn(getResult(),
"sharding_annotated");
865 LogicalResult matchAndRewrite(ShardOp op,
PatternRewriter &b)
const override {
868 Value value = op.getSrc();
874 for (
auto &use : value.
getUses()) {
875 if (use.getOwner() != op.getOperation()) {
876 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
877 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
884 if (currentSharding == otherSharding) {
889 op.getSrcMutable().assign(otherOp.getResult());
902 results.
add<FoldDuplicateShardOp>(context);
919 size_t expectedResultsCount =
920 getAxes().empty() ? mesh->getRank() : getAxes().size();
921 if (getResult().size() != expectedResultsCount) {
922 return emitError() <<
"Unexpected number of results " << getResult().size()
923 <<
". Expected " << expectedResultsCount <<
".";
931 build(odsBuilder, odsState,
938 build(odsBuilder, odsState,
943 void ProcessMultiIndexOp::getAsmResultNames(
945 setNameFn(getResults()[0],
"proc_linear_idx");
961 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
963 build(odsBuilder, odsState, mesh.getSymName());
966 void ProcessLinearIndexOp::getAsmResultNames(
968 setNameFn(getResult(),
"proc_linear_idx");
984 void NeighborsLinearIndicesOp::getAsmResultNames(
986 setNameFn(getNeighborDown(),
"down_linear_idx");
987 setNameFn(getNeighborUp(),
"up_linear_idx");
996 template <
typename Op>
999 LogicalResult matchAndRewrite(
Op op,
1001 auto meshAxes = op.getMeshAxes();
1002 if (!meshAxes.empty()) {
1005 if (op.getInput().getType() != op.getResult().getType()) {
1022 if (device.size() != meshAxes.size()) {
1023 return emitError(loc) <<
"In-group device \"" << deviceName
1024 <<
"\" has unexpected multi-index size "
1025 << device.size() <<
". Expected " << meshAxes.size()
1029 for (
size_t i = 0; i < device.size(); ++i) {
1030 if (!ShapedType::isDynamic(device[i]) &&
1031 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
1032 meshShape[meshAxes[i]] <= device[i]) {
1034 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1035 << deviceName <<
"\"."
1036 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1037 << (meshShape[meshAxes[i]] - 1) <<
"].";
1043 template <
typename It>
1045 using ElementType = std::decay_t<decltype(*begin)>;
1046 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
1047 std::multiplies<ElementType>());
1050 template <
typename R>
1052 return product(adl_begin(range), adl_end(range));
1056 int64_t expectedDimSize,
1057 int64_t resultDimSize,
1058 int64_t resultAxis) {
1059 if (!ShapedType::isDynamic(resultDimSize) &&
1060 expectedDimSize != resultDimSize) {
1061 return emitError(loc) <<
"Dimension size mismatch for result axis "
1062 << resultAxis <<
". Expected "
1063 << (ShapedType::isDynamic(expectedDimSize)
1065 : Twine(expectedDimSize))
1066 <<
", but got " << resultDimSize <<
".";
1073 Value operand,
Value result, int64_t gatherAxis,
1075 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
1076 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1078 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1079 << resultRank <<
").";
1082 ShapedType operandType = cast<ShapedType>(operand.
getType());
1083 ShapedType resultType = cast<ShapedType>(result.
getType());
1084 auto deviceGroupSize =
1086 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1087 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1088 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1089 auto expectedResultDimSize =
1090 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1092 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1100 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
1102 ShapedType operandType = cast<ShapedType>(operand.
getType());
1103 ShapedType resultType = cast<ShapedType>(result.
getType());
1104 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1105 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1107 result.
getLoc(), operandType.getDimSize(axis),
1108 resultType.getDimSize(axis), axis))) {
1114 if (splitAxis == concatAxis) {
1118 auto deviceGroupSize =
1120 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1121 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1122 DimensionSize expectedResultConcatDimSize =
1123 operandConcatDimSize * deviceGroupSize;
1124 DimensionSize expectedResultSplitDimSize =
1125 operandSplitDimSize / deviceGroupSize;
1126 if (!expectedResultSplitDimSize.isDynamic() &&
1127 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1128 expectedResultSplitDimSize = DimensionSize::dynamic();
1131 result.
getLoc(), expectedResultConcatDimSize.value(),
1132 resultType.getDimSize(concatAxis), concatAxis))) {
1136 result.
getLoc(), expectedResultSplitDimSize.value(),
1137 resultType.getDimSize(splitAxis), splitAxis))) {
1145 Value operand,
Value result, int64_t tensorAxis,
1147 ShapedType operandType = cast<ShapedType>(operand.
getType());
1148 ShapedType resultType = cast<ShapedType>(result.
getType());
1149 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1150 if (axis != tensorAxis) {
1152 result.
getLoc(), operandType.getDimSize(axis),
1153 resultType.getDimSize(axis), axis))) {
1159 auto deviceGroupSize =
1161 auto operandScatterDimSize =
1162 DimensionSize(operandType.getDimSize(tensorAxis));
1163 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1164 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1166 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1167 <<
" is not divisible by collective device group size "
1168 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1171 DimensionSize expectedResultTensorDimSize =
1172 operandScatterDimSize / deviceGroupSize;
1174 result.
getLoc(), expectedResultTensorDimSize.value(),
1175 resultType.getDimSize(tensorAxis), tensorAxis))) {
1184 int64_t sliceAxis) {
1185 RankedTensorType operandRankedTensorType =
1186 cast<RankedTensorType>(operandType);
1187 DimensionSize operandSliceAxisSize =
1188 operandRankedTensorType.getShape()[sliceAxis];
1190 llvm::to_vector(operandRankedTensorType.getShape());
1192 resultShape[sliceAxis] =
1193 operandSliceAxisSize /
1195 return operandRankedTensorType.clone(resultShape);
1208 auto gatherAxis = getGatherAxis().getSExtValue();
1210 gatherAxis, getMeshAxes(),
1211 mesh.value().getShape());
1216 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1219 void AllGatherOp::getAsmResultNames(
1221 setNameFn(getResult(),
"all_gather");
1235 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1239 Value input, StringRef mesh,
1241 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
1245 void AllReduceOp::getAsmResultNames(
1247 setNameFn(getResult(),
"all_reduce");
1260 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1261 mesh.value().getShape());
1266 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1271 int64_t sliceAxis) {
1273 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1278 Type resultType,
Value input, StringRef mesh,
1280 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1281 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1284 void AllSliceOp::getAsmResultNames(
1286 setNameFn(getResult(),
"all_slice");
1300 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1301 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1306 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1309 void AllToAllOp::getAsmResultNames(
1311 setNameFn(getResult(),
"all_to_all");
1325 getRootDynamic(), getMeshAxes(),
1326 mesh.value().getShape()))) {
1335 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1338 void BroadcastOp::getAsmResultNames(
1340 setNameFn(getResult(),
"broadcast");
1353 getRootDynamic(), getMeshAxes(),
1354 mesh.value().getShape()))) {
1358 auto gatherAxis = getGatherAxis().getSExtValue();
1361 mesh.value().getShape());
1366 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1369 void GatherOp::getAsmResultNames(
1371 setNameFn(getResult(),
"gather");
1385 getSource().value(), getSourceDynamic(),
1386 getMeshAxes(), mesh.value().getShape()))) {
1394 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1397 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1398 setNameFn(getResult(),
"recv");
1411 getRootDynamic(), getMeshAxes(),
1412 mesh.value().getShape()))) {
1421 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1424 void ReduceOp::getAsmResultNames(
1426 setNameFn(getResult(),
"reduce");
1441 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1442 mesh.value().getShape());
1447 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1450 void ReduceScatterOp::getAsmResultNames(
1452 setNameFn(getResult(),
"reduce_scatter");
1465 getRootDynamic(), getMeshAxes(),
1466 mesh.value().getShape()))) {
1470 auto scatterAxis = getScatterAxis().getSExtValue();
1472 scatterAxis, getMeshAxes(),
1473 mesh.value().getShape());
1478 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1481 void ScatterOp::getAsmResultNames(
1483 setNameFn(getResult(),
"scatter");
1496 getDestination(), getDestinationDynamic(),
1497 getMeshAxes(), mesh.value().getShape()))) {
1505 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1508 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1509 setNameFn(getResult(),
"send");
1522 auto meshAxes = getMeshAxes();
1523 auto shiftAxis = getShiftAxis().getZExtValue();
1524 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1525 return emitError() <<
"Invalid shift axis " << shiftAxis
1526 <<
". It must be one of the grouping mesh axes.";
1538 void ShiftOp::getAsmResultNames(
1540 setNameFn(getResult(),
"shift");
1561 #define GET_OP_CLASSES
1562 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1564 #define GET_ATTRDEF_CLASSES
1565 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1567 #define GET_TYPEDEF_CLASSES
1568 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1570 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static auto product(It begin, It end)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
MeshSharding(::mlir::FlatSymbolRefAttr mesh_=nullptr)
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalShardSizes(const MeshSharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isFullReplication(MeshSharding sharding)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.