28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
51 struct DimensionSize {
52 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
53 DimensionSize(int64_t val) : val(val) {}
54 int64_t value()
const {
return val; }
55 operator int64_t()
const {
return val; }
56 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
64 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
65 if (lhs.isDynamic() || rhs.isDynamic()) {
66 return DimensionSize::dynamic();
68 return lhs.value() / rhs.value();
71 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
72 if (lhs.isDynamic() || rhs.isDynamic()) {
73 return DimensionSize::dynamic();
75 return lhs.value() * rhs.value();
102 void MeshDialect::initialize() {
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
115 addInterface<MeshInlinerInterface>();
120 return arith::ConstantOp::materialize(builder, value, type, loc);
132 return op->
emitError() <<
"Undefined required mesh symbol \""
139 template <
typename It>
144 It next = std::next(begin);
148 for (; next != end; ++next, ++begin) {
149 if (*begin == *next) {
160 if (!
isUnique(sorted.begin(), sorted.end())) {
161 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
165 for (
auto axis : axes) {
166 if (axis >= rank || axis < 0) {
168 <<
"0-based mesh axis index " << axis
169 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
170 <<
"\" is of rank " << rank <<
".";
177 template <
typename Op>
178 static FailureOr<MeshOp>
191 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
193 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
194 const SplitAxes &splitAxes, OutShape &outShape,
198 if (inShape.empty()) {
199 assert(outShape.empty());
203 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
204 llvm::adl_begin(outShape));
206 if (!shardedDimsOffsets.empty()) {
207 auto isDynShape = ShapedType::isDynamicShape(meshShape);
210 if (!innerSplitAxes.empty()) {
211 auto sz = shardedDimsOffsets[pos];
212 bool same = !isDynShape;
217 uint64_t numShards = 0;
218 for (
auto i : innerSplitAxes.asArrayRef()) {
219 numShards += meshShape[i];
221 for (
size_t i = 1; i < numShards; ++i) {
222 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
228 pos += numShards + 1;
230 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
240 if (!haloSizes.empty()) {
244 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
245 !innerSplitAxes.empty()) {
246 if (haloSizes[haloAxis * 2] >= 0 &&
247 haloSizes[haloAxis * 2 + 1] >= 0) {
248 outShape[tensorAxis] +=
249 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
252 outShape[tensorAxis] = ShapedType::kDynamic;
262 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
267 return shape.clone(resShapeArr);
271 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
272 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
281 ShardOp &newShardOp) {
286 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
287 if (shardOp && sharding == shardOp.getSharding() &&
288 !shardOp.getAnnotateForUsers()) {
291 newShardOp = shardOp;
298 builder.
create<ShardingOp>(operandValue.
getLoc(), sharding);
300 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
305 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
306 return use.
getOwner() == operandOp && use.
get() == operandValue;
309 if (!shardOp || shardOp.getAnnotateForUsers()) {
313 auto newShardOp2 = builder.
create<ShardOp>(operandValue.
getLoc(), newShardOp,
314 newShardOp.getSharding(),
324 for (
auto &use : llvm::make_early_inc_range(result.
getUses())) {
335 bool isBlockArg = !operandSrcOp;
337 [[maybe_unused]]
auto opType =
338 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
341 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
347 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
349 if (shardOp && sharding == shardOp.getSharding() &&
350 shardOp.getAnnotateForUsers()) {
359 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
362 rewriter.replaceUsesWithIf(
363 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
364 return use.
getOwner() == operandOp && use.
get() == operandValue;
367 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
373 auto newPreceedingShardOp =
374 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
377 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
378 return use.
getOwner() == newShardOp.getOperation();
387 int64_t rank = getRank();
390 return emitOpError(
"rank of mesh is expected to be a positive integer");
392 for (int64_t dimSize :
getShape()) {
393 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
394 return emitOpError(
"dimension size of a mesh is expected to be "
395 "non-negative or dynamic");
415 size_t expectedResultsCount =
416 getAxes().empty() ? mesh->getRank() : getAxes().size();
417 if (getResult().size() != expectedResultsCount) {
418 return emitError() <<
"Unexpected number of results " << getResult().size()
419 <<
". Expected " << expectedResultsCount <<
".";
432 build(odsBuilder, odsState,
440 assert(!axes.empty());
441 build(odsBuilder, odsState,
446 void MeshShapeOp::getAsmResultNames(
448 setNameFn(getResults()[0],
"mesh_shape");
491 void ShardingOp::build(
528 llvm::SmallSet<MeshAxis, 4> visitedAxes;
533 return emitError() <<
"mesh axis is expected to be non-negative";
534 if (!visitedAxes.insert(axis).second)
535 return emitError() <<
"mesh axis duplicated";
540 for (
auto subAxes : getSplitAxes().getAxes()) {
542 if (failed(checkMeshAxis(subAxesArray)))
545 if (getPartialAxes().has_value() &&
546 failed(checkMeshAxis(getPartialAxes().value())))
549 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
553 if (!getStaticHaloSizes().empty()) {
554 auto numSplitAxes = getSplitAxes().getAxes().size();
555 for (
auto splitAxis : getSplitAxes().getAxes()) {
556 if (splitAxis.empty()) {
560 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561 return emitError() <<
"halo sizes must be specified for all split axes.";
568 void ShardingOp::getAsmResultNames(
570 setNameFn(getResult(),
"sharding");
578 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
579 getStaticShardedDimsOffsets().size() > 0) {
580 return emitError() <<
"sharded dims offsets are not allowed for "
581 "devices meshes with dynamic shape.";
584 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585 if (!shardedDimsOffsets.empty()) {
586 auto meshShape = mesh.value().getShape();
587 assert(!ShapedType::isDynamicShape(meshShape));
589 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
590 if (!innerSplitAxes.empty()) {
591 int64_t numShards = 0, off = 0;
592 for (
auto i : innerSplitAxes.asArrayRef()) {
593 numShards += meshShape[i];
595 for (int64_t i = 0; i <= numShards; ++i) {
596 if (shardedDimsOffsets.size() <= pos + i) {
597 return emitError() <<
"sharded dims offsets has wrong size.";
599 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
600 if (shardedDimsOffsets[pos + i] < off) {
602 <<
"sharded dims offsets must be non-decreasing.";
604 off = shardedDimsOffsets[pos + i];
607 pos += numShards + 1;
624 LogicalResult matchAndRewrite(ShardingOp op,
627 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
629 op.getDynamicShardedDimsOffsets(), b);
638 if (dynamicHalos.empty() && !staticHalos.empty()) {
639 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
650 if (dynamicOffs.empty() && !staticOffs.empty()) {
651 assert(staticOffs.size() >= 2);
652 auto diff = staticOffs[1] - staticOffs[0];
653 bool all_same = staticOffs.size() > 2;
654 for (
auto i = 2u; i < staticOffs.size(); ++i) {
655 if (staticOffs[i] - staticOffs[i - 1] != diff) {
670 op.setStaticHaloSizes(staticHalos);
671 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672 op.setStaticShardedDimsOffsets(staticOffs);
673 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
682 results.
add<NormalizeSharding>(context);
701 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
708 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
709 std::mem_fn(&MeshAxesAttr::empty)) &&
710 llvm::all_of(llvm::drop_begin(rhs.
getSplitAxes(), minSize),
711 std::mem_fn(&MeshAxesAttr::empty));
757 return !(*
this == rhs);
763 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.
getDefiningOp());
764 assert(shardingOp &&
"expected sharding op");
765 auto splitAxes = shardingOp.getSplitAxes().getAxes();
768 if (splitAxes.empty() && partialAxes.empty()) {
772 *
this =
get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
773 shardingOp.getPartialType().value_or(ReductionKind::Sum),
774 shardingOp.getStaticHaloSizes(),
775 shardingOp.getStaticShardedDimsOffsets(),
789 if (split_axes_.empty() && partial_axes_.empty()) {
793 res.split_axes.resize(split_axes_.size());
799 auto clone = [](
const auto src,
auto &dst) {
800 dst.resize(src.size());
804 clone(partial_axes_, res.partial_axes);
805 res.partial_type = partial_type_;
806 clone(static_halo_sizes_, res.static_halo_sizes);
807 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
808 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
809 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
818 void ShardShapeOp::getAsmResultNames(
820 setNameFn(getResult()[0],
"shard_shape");
829 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
837 void ShardOp::getAsmResultNames(
839 setNameFn(getResult(),
"sharding_annotated");
849 LogicalResult matchAndRewrite(ShardOp op,
PatternRewriter &b)
const override {
852 Value value = op.getSrc();
858 for (
auto &use : value.
getUses()) {
859 if (use.getOwner() != op.getOperation()) {
860 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
861 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
868 if (currentSharding == otherSharding) {
873 op.getSrcMutable().assign(otherOp.getResult());
886 results.
add<FoldDuplicateShardOp>(context);
903 size_t expectedResultsCount =
904 getAxes().empty() ? mesh->getRank() : getAxes().size();
905 if (getResult().size() != expectedResultsCount) {
906 return emitError() <<
"Unexpected number of results " << getResult().size()
907 <<
". Expected " << expectedResultsCount <<
".";
915 build(odsBuilder, odsState,
922 build(odsBuilder, odsState,
927 void ProcessMultiIndexOp::getAsmResultNames(
929 setNameFn(getResults()[0],
"proc_linear_idx");
945 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
947 build(odsBuilder, odsState, mesh.getSymName());
950 void ProcessLinearIndexOp::getAsmResultNames(
952 setNameFn(getResult(),
"proc_linear_idx");
968 void NeighborsLinearIndicesOp::getAsmResultNames(
970 setNameFn(getNeighborDown(),
"down_linear_idx");
971 setNameFn(getNeighborUp(),
"up_linear_idx");
980 template <
typename Op>
983 LogicalResult matchAndRewrite(
Op op,
985 auto meshAxes = op.getMeshAxes();
986 if (!meshAxes.empty()) {
989 if (op.getInput().getType() != op.getResult().getType()) {
1006 if (device.size() != meshAxes.size()) {
1007 return emitError(loc) <<
"In-group device \"" << deviceName
1008 <<
"\" has unexpected multi-index size "
1009 << device.size() <<
". Expected " << meshAxes.size()
1013 for (
size_t i = 0; i < device.size(); ++i) {
1014 if (!ShapedType::isDynamic(device[i]) &&
1015 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
1016 meshShape[meshAxes[i]] <= device[i]) {
1018 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1019 << deviceName <<
"\"."
1020 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1021 << (meshShape[meshAxes[i]] - 1) <<
"].";
1027 template <
typename It>
1029 using ElementType = std::decay_t<decltype(*begin)>;
1030 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
1031 std::multiplies<ElementType>());
1034 template <
typename R>
1036 return product(adl_begin(range), adl_end(range));
1040 int64_t expectedDimSize,
1041 int64_t resultDimSize,
1042 int64_t resultAxis) {
1043 if (!ShapedType::isDynamic(resultDimSize) &&
1044 expectedDimSize != resultDimSize) {
1045 return emitError(loc) <<
"Dimension size mismatch for result axis "
1046 << resultAxis <<
". Expected "
1047 << (ShapedType::isDynamic(expectedDimSize)
1049 : Twine(expectedDimSize))
1050 <<
", but got " << resultDimSize <<
".";
1057 Value operand,
Value result, int64_t gatherAxis,
1059 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
1060 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1062 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1063 << resultRank <<
").";
1066 ShapedType operandType = cast<ShapedType>(operand.
getType());
1067 ShapedType resultType = cast<ShapedType>(result.
getType());
1068 auto deviceGroupSize =
1070 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1071 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1072 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1073 auto expectedResultDimSize =
1074 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1076 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1084 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
1086 ShapedType operandType = cast<ShapedType>(operand.
getType());
1087 ShapedType resultType = cast<ShapedType>(result.
getType());
1088 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1089 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1091 result.
getLoc(), operandType.getDimSize(axis),
1092 resultType.getDimSize(axis), axis))) {
1098 if (splitAxis == concatAxis) {
1102 auto deviceGroupSize =
1104 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1105 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1106 DimensionSize expectedResultConcatDimSize =
1107 operandConcatDimSize * deviceGroupSize;
1108 DimensionSize expectedResultSplitDimSize =
1109 operandSplitDimSize / deviceGroupSize;
1110 if (!expectedResultSplitDimSize.isDynamic() &&
1111 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1112 expectedResultSplitDimSize = DimensionSize::dynamic();
1115 result.
getLoc(), expectedResultConcatDimSize.value(),
1116 resultType.getDimSize(concatAxis), concatAxis))) {
1120 result.
getLoc(), expectedResultSplitDimSize.value(),
1121 resultType.getDimSize(splitAxis), splitAxis))) {
1129 Value operand,
Value result, int64_t tensorAxis,
1131 ShapedType operandType = cast<ShapedType>(operand.
getType());
1132 ShapedType resultType = cast<ShapedType>(result.
getType());
1133 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1134 if (axis != tensorAxis) {
1136 result.
getLoc(), operandType.getDimSize(axis),
1137 resultType.getDimSize(axis), axis))) {
1143 auto deviceGroupSize =
1145 auto operandScatterDimSize =
1146 DimensionSize(operandType.getDimSize(tensorAxis));
1147 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1148 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1150 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1151 <<
" is not divisible by collective device group size "
1152 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1155 DimensionSize expectedResultTensorDimSize =
1156 operandScatterDimSize / deviceGroupSize;
1158 result.
getLoc(), expectedResultTensorDimSize.value(),
1159 resultType.getDimSize(tensorAxis), tensorAxis))) {
1168 int64_t sliceAxis) {
1169 RankedTensorType operandRankedTensorType =
1170 cast<RankedTensorType>(operandType);
1171 DimensionSize operandSliceAxisSize =
1172 operandRankedTensorType.getShape()[sliceAxis];
1174 llvm::to_vector(operandRankedTensorType.getShape());
1176 resultShape[sliceAxis] =
1177 operandSliceAxisSize /
1179 return operandRankedTensorType.clone(resultShape);
1192 auto gatherAxis = getGatherAxis().getSExtValue();
1194 gatherAxis, getMeshAxes(),
1195 mesh.value().getShape());
1200 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1203 void AllGatherOp::getAsmResultNames(
1205 setNameFn(getResult(),
"all_gather");
1219 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1223 Value input, StringRef mesh,
1225 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
1229 void AllReduceOp::getAsmResultNames(
1231 setNameFn(getResult(),
"all_reduce");
1244 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1245 mesh.value().getShape());
1250 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1255 int64_t sliceAxis) {
1257 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1262 Type resultType,
Value input, StringRef mesh,
1264 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1265 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1268 void AllSliceOp::getAsmResultNames(
1270 setNameFn(getResult(),
"all_slice");
1284 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1285 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1290 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1293 void AllToAllOp::getAsmResultNames(
1295 setNameFn(getResult(),
"all_to_all");
1309 getRootDynamic(), getMeshAxes(),
1310 mesh.value().getShape()))) {
1319 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1322 void BroadcastOp::getAsmResultNames(
1324 setNameFn(getResult(),
"broadcast");
1337 getRootDynamic(), getMeshAxes(),
1338 mesh.value().getShape()))) {
1342 auto gatherAxis = getGatherAxis().getSExtValue();
1345 mesh.value().getShape());
1350 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1353 void GatherOp::getAsmResultNames(
1355 setNameFn(getResult(),
"gather");
1369 getSource().value(), getSourceDynamic(),
1370 getMeshAxes(), mesh.value().getShape()))) {
1378 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1381 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1382 setNameFn(getResult(),
"recv");
1395 getRootDynamic(), getMeshAxes(),
1396 mesh.value().getShape()))) {
1405 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1408 void ReduceOp::getAsmResultNames(
1410 setNameFn(getResult(),
"reduce");
1425 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1426 mesh.value().getShape());
1431 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1434 void ReduceScatterOp::getAsmResultNames(
1436 setNameFn(getResult(),
"reduce_scatter");
1449 getRootDynamic(), getMeshAxes(),
1450 mesh.value().getShape()))) {
1454 auto scatterAxis = getScatterAxis().getSExtValue();
1456 scatterAxis, getMeshAxes(),
1457 mesh.value().getShape());
1462 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1465 void ScatterOp::getAsmResultNames(
1467 setNameFn(getResult(),
"scatter");
1480 getDestination(), getDestinationDynamic(),
1481 getMeshAxes(), mesh.value().getShape()))) {
1489 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1492 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1493 setNameFn(getResult(),
"send");
1506 auto meshAxes = getMeshAxes();
1507 auto shiftAxis = getShiftAxis().getZExtValue();
1508 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1509 return emitError() <<
"Invalid shift axis " << shiftAxis
1510 <<
". It must be one of the grouping mesh axes.";
1522 void ShiftOp::getAsmResultNames(
1524 setNameFn(getResult(),
"shift");
1545 #define GET_OP_CLASSES
1546 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1548 #define GET_ATTRDEF_CLASSES
1549 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1551 #define GET_TYPEDEF_CLASSES
1552 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1554 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static auto product(It begin, It end)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
MeshSharding(::mlir::FlatSymbolRefAttr mesh_=nullptr)
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalShardSizes(const MeshSharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isFullReplication(MeshSharding sharding)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.