28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
51 struct DimensionSize {
52 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
53 DimensionSize(int64_t val) : val(val) {}
54 int64_t value()
const {
return val; }
55 operator int64_t()
const {
return val; }
56 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
64 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
65 if (lhs.isDynamic() || rhs.isDynamic()) {
66 return DimensionSize::dynamic();
68 return lhs.value() / rhs.value();
71 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
72 if (lhs.isDynamic() || rhs.isDynamic()) {
73 return DimensionSize::dynamic();
75 return lhs.value() * rhs.value();
102 void MeshDialect::initialize() {
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
115 addInterface<MeshInlinerInterface>();
120 return arith::ConstantOp::materialize(builder, value, type, loc);
132 return op->
emitError() <<
"Undefined required mesh symbol \""
139 template <
typename It>
144 It next = std::next(begin);
148 for (; next != end; ++next, ++begin) {
149 if (*begin == *next) {
160 if (!
isUnique(sorted.begin(), sorted.end())) {
161 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
165 for (
auto axis : axes) {
166 if (axis >= rank || axis < 0) {
168 <<
"0-based mesh axis index " << axis
169 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
170 <<
"\" is of rank " << rank <<
".";
177 template <
typename Op>
178 static FailureOr<MeshOp>
191 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
193 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
194 const SplitAxes &splitAxes, OutShape &outShape,
198 if (inShape.empty()) {
199 assert(outShape.empty());
203 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
204 llvm::adl_begin(outShape));
206 if (!shardedDimsOffsets.empty()) {
207 auto isDynShape = ShapedType::isDynamicShape(meshShape);
210 if (!innerSplitAxes.empty()) {
211 auto sz = shardedDimsOffsets[pos];
212 bool same = !isDynShape;
217 uint64_t numShards = 0;
218 for (
auto i : innerSplitAxes.asArrayRef()) {
219 numShards += meshShape[i];
221 for (
size_t i = 1; i < numShards; ++i) {
222 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
228 pos += numShards + 1;
230 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
240 if (!haloSizes.empty()) {
244 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
245 !innerSplitAxes.empty()) {
246 if (haloSizes[haloAxis * 2] >= 0 &&
247 haloSizes[haloAxis * 2 + 1] >= 0) {
248 outShape[tensorAxis] +=
249 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
252 outShape[tensorAxis] = ShapedType::kDynamic;
262 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
267 return shape.clone(resShapeArr);
271 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
272 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
281 ShardOp &newShardOp) {
286 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
287 if (shardOp && sharding == shardOp.getSharding() &&
288 !shardOp.getAnnotateForUsers()) {
291 newShardOp = shardOp;
298 builder.
create<ShardingOp>(operandValue.
getLoc(), sharding);
300 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
305 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
306 return use.
getOwner() == operandOp && use.
get() == operandValue;
309 if (!shardOp || shardOp.getAnnotateForUsers()) {
313 auto newShardOp2 = builder.
create<ShardOp>(operandValue.
getLoc(), newShardOp,
314 newShardOp.getSharding(),
323 for (
auto &use : llvm::make_early_inc_range(result.
getUses())) {
334 bool isBlockArg = !operandSrcOp;
336 [[maybe_unused]]
auto opType =
337 dyn_cast<mlir::RankedTensorType>(operandValue.
getType());
340 if (!isa<RankedTensorType>(operandValue.
getType()) && operandSrcOp &&
346 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
348 if (shardOp && sharding == shardOp.getSharding() &&
349 shardOp.getAnnotateForUsers()) {
358 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
361 rewriter.replaceUsesWithIf(
362 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
363 return use.
getOwner() == operandOp && use.
get() == operandValue;
366 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
372 auto newPreceedingShardOp =
373 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
376 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
377 return use.
getOwner() == newShardOp.getOperation();
386 int64_t rank = getRank();
389 return emitOpError(
"rank of mesh is expected to be a positive integer");
391 for (int64_t dimSize :
getShape()) {
392 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
393 return emitOpError(
"dimension size of a mesh is expected to be "
394 "non-negative or dynamic");
414 size_t expectedResultsCount =
415 getAxes().empty() ? mesh->getRank() : getAxes().size();
416 if (getResult().size() != expectedResultsCount) {
417 return emitError() <<
"Unexpected number of results " << getResult().size()
418 <<
". Expected " << expectedResultsCount <<
".";
431 build(odsBuilder, odsState,
439 assert(!axes.empty());
440 build(odsBuilder, odsState,
445 void MeshShapeOp::getAsmResultNames(
447 setNameFn(getResults()[0],
"mesh_shape");
490 void ShardingOp::build(
527 llvm::SmallSet<MeshAxis, 4> visitedAxes;
532 return emitError() <<
"mesh axis is expected to be non-negative";
533 if (!visitedAxes.insert(axis).second)
534 return emitError() <<
"mesh axis duplicated";
539 for (
auto subAxes : getSplitAxes().getAxes()) {
541 if (failed(checkMeshAxis(subAxesArray)))
544 if (getPartialAxes().has_value() &&
545 failed(checkMeshAxis(getPartialAxes().value())))
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (
auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() <<
"halo sizes must be specified for all split axes.";
567 void ShardingOp::getAsmResultNames(
569 setNameFn(getResult(),
"sharding");
577 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
578 getStaticShardedDimsOffsets().size() > 0) {
579 return emitError() <<
"sharded dims offsets are not allowed for "
580 "devices meshes with dynamic shape.";
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto meshShape = mesh.value().getShape();
586 assert(!ShapedType::isDynamicShape(meshShape));
588 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (
auto i : innerSplitAxes.asArrayRef()) {
592 numShards += meshShape[i];
594 for (int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() <<
"sharded dims offsets has wrong size.";
598 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
601 <<
"sharded dims offsets must be non-decreasing.";
603 off = shardedDimsOffsets[pos + i];
606 pos += numShards + 1;
623 LogicalResult matchAndRewrite(ShardingOp op,
626 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
628 op.getDynamicShardedDimsOffsets(), b);
637 if (dynamicHalos.empty() && !staticHalos.empty()) {
638 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
649 if (dynamicOffs.empty() && !staticOffs.empty()) {
650 assert(staticOffs.size() >= 2);
651 auto diff = staticOffs[1] - staticOffs[0];
652 bool all_same = staticOffs.size() > 2;
653 for (
auto i = 2u; i < staticOffs.size(); ++i) {
654 if (staticOffs[i] - staticOffs[i - 1] != diff) {
669 op.setStaticHaloSizes(staticHalos);
670 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
671 op.setStaticShardedDimsOffsets(staticOffs);
672 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
681 results.
add<NormalizeSharding>(context);
700 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
707 return llvm::all_of(llvm::drop_begin(
getSplitAxes(), minSize),
708 std::mem_fn(&MeshAxesAttr::empty)) &&
709 llvm::all_of(llvm::drop_begin(rhs.
getSplitAxes(), minSize),
710 std::mem_fn(&MeshAxesAttr::empty));
756 return !(*
this == rhs);
762 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.
getDefiningOp());
763 assert(shardingOp &&
"expected sharding op");
764 auto splitAxes = shardingOp.getSplitAxes().getAxes();
767 if (splitAxes.empty() && partialAxes.empty()) {
771 *
this =
get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
772 shardingOp.getPartialType().value_or(ReductionKind::Sum),
773 shardingOp.getStaticHaloSizes(),
774 shardingOp.getStaticShardedDimsOffsets(),
788 if (split_axes_.empty() && partial_axes_.empty()) {
792 res.split_axes.resize(split_axes_.size());
798 auto clone = [](
const auto src,
auto &dst) {
799 dst.resize(src.size());
803 clone(partial_axes_, res.partial_axes);
804 res.partial_type = partial_type_;
805 clone(static_halo_sizes_, res.static_halo_sizes);
806 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
807 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
808 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
817 void ShardShapeOp::getAsmResultNames(
819 setNameFn(getResult()[0],
"shard_shape");
828 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
836 void ShardOp::getAsmResultNames(
838 setNameFn(getResult(),
"sharding_annotated");
848 LogicalResult matchAndRewrite(ShardOp op,
PatternRewriter &b)
const override {
851 Value value = op.getSrc();
857 for (
auto &use : value.
getUses()) {
858 if (use.getOwner() != op.getOperation()) {
859 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
860 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
867 if (currentSharding == otherSharding) {
872 op.getSrcMutable().assign(otherOp.getResult());
885 results.
add<FoldDuplicateShardOp>(context);
902 size_t expectedResultsCount =
903 getAxes().empty() ? mesh->getRank() : getAxes().size();
904 if (getResult().size() != expectedResultsCount) {
905 return emitError() <<
"Unexpected number of results " << getResult().size()
906 <<
". Expected " << expectedResultsCount <<
".";
914 build(odsBuilder, odsState,
921 build(odsBuilder, odsState,
926 void ProcessMultiIndexOp::getAsmResultNames(
928 setNameFn(getResults()[0],
"proc_linear_idx");
944 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
946 build(odsBuilder, odsState, mesh.getSymName());
949 void ProcessLinearIndexOp::getAsmResultNames(
951 setNameFn(getResult(),
"proc_linear_idx");
967 void NeighborsLinearIndicesOp::getAsmResultNames(
969 setNameFn(getNeighborDown(),
"down_linear_idx");
970 setNameFn(getNeighborUp(),
"up_linear_idx");
979 template <
typename Op>
982 LogicalResult matchAndRewrite(
Op op,
984 auto meshAxes = op.getMeshAxes();
985 if (!meshAxes.empty()) {
988 if (op.getInput().getType() != op.getResult().getType()) {
1005 if (device.size() != meshAxes.size()) {
1006 return emitError(loc) <<
"In-group device \"" << deviceName
1007 <<
"\" has unexpected multi-index size "
1008 << device.size() <<
". Expected " << meshAxes.size()
1012 for (
size_t i = 0; i < device.size(); ++i) {
1013 if (!ShapedType::isDynamic(device[i]) &&
1014 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
1015 meshShape[meshAxes[i]] <= device[i]) {
1017 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
1018 << deviceName <<
"\"."
1019 <<
" Got " << device[i] <<
", but expected value in the range [0, "
1020 << (meshShape[meshAxes[i]] - 1) <<
"].";
1026 template <
typename It>
1028 using ElementType = std::decay_t<decltype(*begin)>;
1029 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
1030 std::multiplies<ElementType>());
1033 template <
typename R>
1035 return product(adl_begin(range), adl_end(range));
1039 int64_t expectedDimSize,
1040 int64_t resultDimSize,
1041 int64_t resultAxis) {
1042 if (!ShapedType::isDynamic(resultDimSize) &&
1043 expectedDimSize != resultDimSize) {
1044 return emitError(loc) <<
"Dimension size mismatch for result axis "
1045 << resultAxis <<
". Expected "
1046 << (ShapedType::isDynamic(expectedDimSize)
1048 : Twine(expectedDimSize))
1049 <<
", but got " << resultDimSize <<
".";
1056 Value operand,
Value result, int64_t gatherAxis,
1058 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
1059 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1061 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
1062 << resultRank <<
").";
1065 ShapedType operandType = cast<ShapedType>(operand.
getType());
1066 ShapedType resultType = cast<ShapedType>(result.
getType());
1067 auto deviceGroupSize =
1069 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1070 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1071 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1072 auto expectedResultDimSize =
1073 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1075 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1083 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
1085 ShapedType operandType = cast<ShapedType>(operand.
getType());
1086 ShapedType resultType = cast<ShapedType>(result.
getType());
1087 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1088 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1090 result.
getLoc(), operandType.getDimSize(axis),
1091 resultType.getDimSize(axis), axis))) {
1097 if (splitAxis == concatAxis) {
1101 auto deviceGroupSize =
1103 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1104 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1105 DimensionSize expectedResultConcatDimSize =
1106 operandConcatDimSize * deviceGroupSize;
1107 DimensionSize expectedResultSplitDimSize =
1108 operandSplitDimSize / deviceGroupSize;
1109 if (!expectedResultSplitDimSize.isDynamic() &&
1110 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1111 expectedResultSplitDimSize = DimensionSize::dynamic();
1114 result.
getLoc(), expectedResultConcatDimSize.value(),
1115 resultType.getDimSize(concatAxis), concatAxis))) {
1119 result.
getLoc(), expectedResultSplitDimSize.value(),
1120 resultType.getDimSize(splitAxis), splitAxis))) {
1128 Value operand,
Value result, int64_t tensorAxis,
1130 ShapedType operandType = cast<ShapedType>(operand.
getType());
1131 ShapedType resultType = cast<ShapedType>(result.
getType());
1132 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1133 if (axis != tensorAxis) {
1135 result.
getLoc(), operandType.getDimSize(axis),
1136 resultType.getDimSize(axis), axis))) {
1142 auto deviceGroupSize =
1144 auto operandScatterDimSize =
1145 DimensionSize(operandType.getDimSize(tensorAxis));
1146 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1147 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1149 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1150 <<
" is not divisible by collective device group size "
1151 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1154 DimensionSize expectedResultTensorDimSize =
1155 operandScatterDimSize / deviceGroupSize;
1157 result.
getLoc(), expectedResultTensorDimSize.value(),
1158 resultType.getDimSize(tensorAxis), tensorAxis))) {
1167 int64_t sliceAxis) {
1168 RankedTensorType operandRankedTensorType =
1169 cast<RankedTensorType>(operandType);
1170 DimensionSize operandSliceAxisSize =
1171 operandRankedTensorType.getShape()[sliceAxis];
1173 llvm::to_vector(operandRankedTensorType.getShape());
1175 resultShape[sliceAxis] =
1176 operandSliceAxisSize /
1178 return operandRankedTensorType.clone(resultShape);
1191 auto gatherAxis = getGatherAxis().getSExtValue();
1193 gatherAxis, getMeshAxes(),
1194 mesh.value().getShape());
1199 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1202 void AllGatherOp::getAsmResultNames(
1204 setNameFn(getResult(),
"all_gather");
1218 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1222 Value input, StringRef mesh,
1224 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
1228 void AllReduceOp::getAsmResultNames(
1230 setNameFn(getResult(),
"all_reduce");
1243 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1244 mesh.value().getShape());
1249 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1254 int64_t sliceAxis) {
1256 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1261 Type resultType,
Value input, StringRef mesh,
1263 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1264 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1267 void AllSliceOp::getAsmResultNames(
1269 setNameFn(getResult(),
"all_slice");
1283 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1284 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1289 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1292 void AllToAllOp::getAsmResultNames(
1294 setNameFn(getResult(),
"all_to_all");
1308 getRootDynamic(), getMeshAxes(),
1309 mesh.value().getShape()))) {
1318 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1321 void BroadcastOp::getAsmResultNames(
1323 setNameFn(getResult(),
"broadcast");
1336 getRootDynamic(), getMeshAxes(),
1337 mesh.value().getShape()))) {
1341 auto gatherAxis = getGatherAxis().getSExtValue();
1344 mesh.value().getShape());
1349 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1352 void GatherOp::getAsmResultNames(
1354 setNameFn(getResult(),
"gather");
1368 getSource().value(), getSourceDynamic(),
1369 getMeshAxes(), mesh.value().getShape()))) {
1377 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1380 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1381 setNameFn(getResult(),
"recv");
1394 getRootDynamic(), getMeshAxes(),
1395 mesh.value().getShape()))) {
1404 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1407 void ReduceOp::getAsmResultNames(
1409 setNameFn(getResult(),
"reduce");
1424 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1425 mesh.value().getShape());
1430 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1433 void ReduceScatterOp::getAsmResultNames(
1435 setNameFn(getResult(),
"reduce_scatter");
1448 getRootDynamic(), getMeshAxes(),
1449 mesh.value().getShape()))) {
1453 auto scatterAxis = getScatterAxis().getSExtValue();
1455 scatterAxis, getMeshAxes(),
1456 mesh.value().getShape());
1461 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1464 void ScatterOp::getAsmResultNames(
1466 setNameFn(getResult(),
"scatter");
1479 getDestination(), getDestinationDynamic(),
1480 getMeshAxes(), mesh.value().getShape()))) {
1488 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1491 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1492 setNameFn(getResult(),
"send");
1505 auto meshAxes = getMeshAxes();
1506 auto shiftAxis = getShiftAxis().getZExtValue();
1507 if (!llvm::is_contained(meshAxes, shiftAxis)) {
1508 return emitError() <<
"Invalid shift axis " << shiftAxis
1509 <<
". It must be one of the grouping mesh axes.";
1521 void ShiftOp::getAsmResultNames(
1523 setNameFn(getResult(),
"shift");
1544 #define GET_OP_CLASSES
1545 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1547 #define GET_ATTRDEF_CLASSES
1548 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1550 #define GET_TYPEDEF_CLASSES
1551 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1553 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static auto product(It begin, It end)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
MeshSharding(::mlir::FlatSymbolRefAttr mesh_=nullptr)
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalShardSizes(const MeshSharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isFullReplication(MeshSharding sharding)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.