28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
51 struct DimensionSize {
52 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
53 DimensionSize(int64_t val) : val(val) {}
54 int64_t value()
const {
return val; }
55 operator int64_t()
const {
return val; }
56 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
64 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
65 if (lhs.isDynamic() || rhs.isDynamic()) {
66 return DimensionSize::dynamic();
68 return lhs.value() / rhs.value();
71 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
72 if (lhs.isDynamic() || rhs.isDynamic()) {
73 return DimensionSize::dynamic();
75 return lhs.value() * rhs.value();
102 void MeshDialect::initialize() {
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
115 addInterface<MeshInlinerInterface>();
120 return arith::ConstantOp::materialize(builder, value, type, loc);
132 return op->
emitError() <<
"Undefined required mesh symbol \""
139 template <
typename It>
144 It next = std::next(begin);
148 for (; next != end; ++next, ++begin) {
149 if (*begin == *next) {
160 if (!
isUnique(sorted.begin(), sorted.end())) {
161 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
165 for (
auto axis : axes) {
166 if (axis >= rank || axis < 0) {
168 <<
"0-based mesh axis index " << axis
169 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
170 <<
"\" is of rank " << rank <<
".";
177 template <
typename Op>
178 static FailureOr<MeshOp>
191 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
193 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
194 const SplitAxes &splitAxes, OutShape &outShape,
197 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
198 llvm::adl_begin(outShape));
200 if (!shardedDimsOffsets.empty()) {
201 auto isDynShape = ShapedType::isDynamicShape(meshShape);
204 if (!innerSplitAxes.empty()) {
205 auto sz = shardedDimsOffsets[pos];
206 bool same = !isDynShape;
211 uint64_t numShards = 0;
212 for (
auto i : innerSplitAxes.asArrayRef()) {
213 numShards += meshShape[i];
215 for (
size_t i = 1; i < numShards; ++i) {
216 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
222 pos += numShards + 1;
224 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
234 if (!haloSizes.empty()) {
238 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
239 !innerSplitAxes.empty()) {
240 if (haloSizes[haloAxis * 2] >= 0 &&
241 haloSizes[haloAxis * 2 + 1] >= 0) {
242 outShape[tensorAxis] +=
243 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
246 outShape[tensorAxis] = ShapedType::kDynamic;
256 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
261 return shape.clone(resShapeArr);
265 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
266 if (rankedTensorType) {
279 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
280 if (shardOp && sharding == shardOp.getSharding() &&
281 !shardOp.getAnnotateForUsers()) {
286 auto shardingOp = builder.
create<ShardingOp>(operandValue.
getLoc(), sharding);
288 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
291 rewriter.replaceUsesWithIf(
292 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
293 return use.
getOwner() == operandOp && use.
get() == operandValue;
296 if (!shardOp || shardOp.getAnnotateForUsers()) {
301 builder.
create<ShardOp>(operandValue.
getLoc(), newShardOp, shardingOp,
303 rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
309 for (
auto &use : llvm::make_early_inc_range(result.
getUses())) {
321 bool isBlockArg = !operandSrcOp;
322 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
324 if (shardOp && sharding == shardOp.getSharding() &&
325 shardOp.getAnnotateForUsers()) {
334 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
337 rewriter.replaceUsesWithIf(
338 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
339 return use.
getOwner() == operandOp && use.
get() == operandValue;
342 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
348 auto newPreceedingShardOp =
349 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
352 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
353 return use.
getOwner() == newShardOp.getOperation();
362 int64_t rank = getRank();
365 return emitOpError(
"rank of mesh is expected to be a positive integer");
367 for (int64_t dimSize :
getShape()) {
368 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
369 return emitOpError(
"dimension size of a mesh is expected to be "
370 "non-negative or dynamic");
390 size_t expectedResultsCount =
391 getAxes().empty() ? mesh->getRank() : getAxes().size();
392 if (getResult().size() != expectedResultsCount) {
393 return emitError() <<
"Unexpected number of results " << getResult().size()
394 <<
". Expected " << expectedResultsCount <<
".";
407 build(odsBuilder, odsState,
415 assert(!axes.empty());
416 build(odsBuilder, odsState,
421 void MeshShapeOp::getAsmResultNames(
423 setNameFn(getResults()[0],
"mesh_shape");
443 static_sharded_dims_offsets),
456 void ShardingOp::build(
493 llvm::SmallSet<MeshAxis, 4> visitedAxes;
498 return emitError() <<
"mesh axis is expected to be non-negative";
499 if (!visitedAxes.insert(axis).second)
500 return emitError() <<
"mesh axis duplicated";
505 for (
auto subAxes : getSplitAxes().getAxes()) {
507 if (failed(checkMeshAxis(subAxesArray)))
510 if (getPartialAxes().has_value() &&
511 failed(checkMeshAxis(getPartialAxes().value())))
514 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
515 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
518 if (!getStaticHaloSizes().empty()) {
519 auto numSplitAxes = getSplitAxes().getAxes().size();
520 for (
auto splitAxis : getSplitAxes().getAxes()) {
521 if (splitAxis.empty()) {
525 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
526 return emitError() <<
"halo sizes must be specified for all split axes.";
533 void ShardingOp::getAsmResultNames(
535 setNameFn(getResult(),
"sharding");
543 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
544 getStaticShardedDimsOffsets().size() > 0) {
545 return emitError() <<
"sharded dims offsets are not allowed for "
546 "devices meshes with dynamic shape.";
549 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
550 if (!shardedDimsOffsets.empty()) {
551 auto meshShape = mesh.value().getShape();
552 assert(!ShapedType::isDynamicShape(meshShape));
554 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
555 if (!innerSplitAxes.empty()) {
556 int64_t numShards = 0, off = 0;
557 for (
auto i : innerSplitAxes.asArrayRef()) {
558 numShards += meshShape[i];
560 for (int64_t i = 0; i <= numShards; ++i) {
561 if (shardedDimsOffsets.size() <= pos + i) {
562 return emitError() <<
"sharded dims offsets has wrong size.";
564 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
565 if (shardedDimsOffsets[pos + i] < off) {
567 <<
"sharded dims offsets must be non-decreasing.";
569 off = shardedDimsOffsets[pos + i];
572 pos += numShards + 1;
588 LogicalResult matchAndRewrite(ShardingOp op,
591 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
593 op.getDynamicShardedDimsOffsets(), b);
604 op.setStaticHaloSizes(halos.first);
605 op.getDynamicHaloSizesMutable().assign(halos.second);
606 op.setStaticShardedDimsOffsets(offs.first);
607 op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
616 results.
add<FoldDynamicLists>(context);
638 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
645 return llvm::all_of(llvm::make_range(
getSplitAxes().begin() + minSize,
647 std::mem_fn(&MeshAxesAttr::empty)) &&
648 llvm::all_of(llvm::make_range(rhs.
getSplitAxes().begin() + minSize,
650 std::mem_fn(&MeshAxesAttr::empty));
707 return !(*
this == rhs);
711 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.
getDefiningOp());
712 assert(shardingOp &&
"expected sharding op");
713 *
this =
get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
715 shardingOp.getPartialType().value_or(ReductionKind::Sum),
716 shardingOp.getStaticHaloSizes(),
717 shardingOp.getStaticShardedDimsOffsets(),
732 res.split_axes.resize(split_axes_.size());
738 auto clone = [](
const auto src,
auto &dst) {
739 dst.resize(src.size());
743 clone(partial_axes_, res.partial_axes);
744 res.partial_type = partial_type_;
745 clone(static_halo_sizes_, res.static_halo_sizes);
746 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
747 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
748 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
762 build(odsBuilder, odsState, resType, shape, sharding, device);
769 void ShardOp::getAsmResultNames(
771 setNameFn(getResult(),
"sharding_annotated");
788 size_t expectedResultsCount =
789 getAxes().empty() ? mesh->getRank() : getAxes().size();
790 if (getResult().size() != expectedResultsCount) {
791 return emitError() <<
"Unexpected number of results " << getResult().size()
792 <<
". Expected " << expectedResultsCount <<
".";
800 build(odsBuilder, odsState,
807 build(odsBuilder, odsState,
812 void ProcessMultiIndexOp::getAsmResultNames(
814 setNameFn(getResults()[0],
"proc_linear_idx");
830 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
832 build(odsBuilder, odsState, mesh.getSymName());
835 void ProcessLinearIndexOp::getAsmResultNames(
837 setNameFn(getResult(),
"proc_linear_idx");
846 template <
typename Op>
849 LogicalResult matchAndRewrite(
Op op,
851 auto meshAxes = op.getMeshAxes();
852 if (!meshAxes.empty()) {
855 if (op.getInput().getType() != op.getResult().getType()) {
872 if (device.size() != meshAxes.size()) {
873 return emitError(loc) <<
"In-group device \"" << deviceName
874 <<
"\" has unexpected multi-index size "
875 << device.size() <<
". Expected " << meshAxes.size()
879 for (
size_t i = 0; i < device.size(); ++i) {
880 if (!ShapedType::isDynamic(device[i]) &&
881 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
882 meshShape[meshAxes[i]] <= device[i]) {
884 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
885 << deviceName <<
"\"."
886 <<
" Got " << device[i] <<
", but expected value in the range [0, "
887 << (meshShape[meshAxes[i]] - 1) <<
"].";
893 template <
typename It>
895 using ElementType = std::decay_t<decltype(*begin)>;
896 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
897 std::multiplies<ElementType>());
900 template <
typename R>
902 return product(adl_begin(range), adl_end(range));
906 int64_t expectedDimSize,
907 int64_t resultDimSize,
908 int64_t resultAxis) {
909 if (!ShapedType::isDynamic(resultDimSize) &&
910 expectedDimSize != resultDimSize) {
911 return emitError(loc) <<
"Dimension size mismatch for result axis "
912 << resultAxis <<
". Expected "
913 << (ShapedType::isDynamic(expectedDimSize)
915 : Twine(expectedDimSize))
916 <<
", but got " << resultDimSize <<
".";
923 Value operand,
Value result, int64_t gatherAxis,
925 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
926 if (gatherAxis < 0 || gatherAxis >= resultRank) {
928 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
929 << resultRank <<
").";
932 ShapedType operandType = cast<ShapedType>(operand.
getType());
933 ShapedType resultType = cast<ShapedType>(result.
getType());
934 auto deviceGroupSize =
936 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
937 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
938 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
939 auto expectedResultDimSize =
940 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
942 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
950 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
952 ShapedType operandType = cast<ShapedType>(operand.
getType());
953 ShapedType resultType = cast<ShapedType>(result.
getType());
954 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
955 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
957 result.
getLoc(), operandType.getDimSize(axis),
958 resultType.getDimSize(axis), axis))) {
964 if (splitAxis == concatAxis) {
968 auto deviceGroupSize =
970 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
971 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
972 DimensionSize expectedResultConcatDimSize =
973 operandConcatDimSize * deviceGroupSize;
974 DimensionSize expectedResultSplitDimSize =
975 operandSplitDimSize / deviceGroupSize;
976 if (!expectedResultSplitDimSize.isDynamic() &&
977 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
978 expectedResultSplitDimSize = DimensionSize::dynamic();
981 result.
getLoc(), expectedResultConcatDimSize.value(),
982 resultType.getDimSize(concatAxis), concatAxis))) {
986 result.
getLoc(), expectedResultSplitDimSize.value(),
987 resultType.getDimSize(splitAxis), splitAxis))) {
995 Value operand,
Value result, int64_t tensorAxis,
997 ShapedType operandType = cast<ShapedType>(operand.
getType());
998 ShapedType resultType = cast<ShapedType>(result.
getType());
999 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1000 if (axis != tensorAxis) {
1002 result.
getLoc(), operandType.getDimSize(axis),
1003 resultType.getDimSize(axis), axis))) {
1009 auto deviceGroupSize =
1011 auto operandScatterDimSize =
1012 DimensionSize(operandType.getDimSize(tensorAxis));
1013 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1014 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1016 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1017 <<
" is not divisible by collective device group size "
1018 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1021 DimensionSize expectedResultTensorDimSize =
1022 operandScatterDimSize / deviceGroupSize;
1024 result.
getLoc(), expectedResultTensorDimSize.value(),
1025 resultType.getDimSize(tensorAxis), tensorAxis))) {
1034 int64_t sliceAxis) {
1035 RankedTensorType operandRankedTensorType =
1036 cast<RankedTensorType>(operandType);
1037 DimensionSize operandSliceAxisSize =
1038 operandRankedTensorType.getShape()[sliceAxis];
1040 llvm::to_vector(operandRankedTensorType.getShape());
1042 resultShape[sliceAxis] =
1043 operandSliceAxisSize /
1045 return operandRankedTensorType.clone(resultShape);
1058 auto gatherAxis = getGatherAxis().getSExtValue();
1060 gatherAxis, getMeshAxes(),
1061 mesh.value().getShape());
1066 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1069 void AllGatherOp::getAsmResultNames(
1071 setNameFn(getResult(),
"all_gather");
1085 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1089 Value input, StringRef mesh,
1091 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
1095 void AllReduceOp::getAsmResultNames(
1097 setNameFn(getResult(),
"all_reduce");
1110 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1111 mesh.value().getShape());
1116 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1121 int64_t sliceAxis) {
1123 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1128 Type resultType,
Value input, StringRef mesh,
1130 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1131 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1134 void AllSliceOp::getAsmResultNames(
1136 setNameFn(getResult(),
"all_slice");
1150 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1151 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1156 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1159 void AllToAllOp::getAsmResultNames(
1161 setNameFn(getResult(),
"all_to_all");
1175 getRootDynamic(), getMeshAxes(),
1176 mesh.value().getShape()))) {
1185 patterns.
add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1188 void BroadcastOp::getAsmResultNames(
1190 setNameFn(getResult(),
"broadcast");
1203 getRootDynamic(), getMeshAxes(),
1204 mesh.value().getShape()))) {
1208 auto gatherAxis = getGatherAxis().getSExtValue();
1211 mesh.value().getShape());
1216 patterns.
add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1219 void GatherOp::getAsmResultNames(
1221 setNameFn(getResult(),
"gather");
1235 getSource().value(), getSourceDynamic(),
1236 getMeshAxes(), mesh.value().getShape()))) {
1244 patterns.
add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1247 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1248 setNameFn(getResult(),
"recv");
1261 getRootDynamic(), getMeshAxes(),
1262 mesh.value().getShape()))) {
1271 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1274 void ReduceOp::getAsmResultNames(
1276 setNameFn(getResult(),
"reduce");
1291 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1292 mesh.value().getShape());
1297 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1300 void ReduceScatterOp::getAsmResultNames(
1302 setNameFn(getResult(),
"reduce_scatter");
1315 getRootDynamic(), getMeshAxes(),
1316 mesh.value().getShape()))) {
1320 auto scatterAxis = getScatterAxis().getSExtValue();
1322 scatterAxis, getMeshAxes(),
1323 mesh.value().getShape());
1328 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1331 void ScatterOp::getAsmResultNames(
1333 setNameFn(getResult(),
"scatter");
1346 getDestination(), getDestinationDynamic(),
1347 getMeshAxes(), mesh.value().getShape()))) {
1355 patterns.
add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1358 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1359 setNameFn(getResult(),
"send");
1372 auto meshAxes = getMeshAxes();
1373 auto shiftAxis = getShiftAxis().getZExtValue();
1374 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1375 return emitError() <<
"Invalid shift axis " << shiftAxis
1376 <<
". It must be one of the grouping mesh axes.";
1388 void ShiftOp::getAsmResultNames(
1390 setNameFn(getResult(),
"shift");
1411 #define GET_OP_CLASSES
1412 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1414 #define GET_ATTRDEF_CLASSES
1415 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1417 #define GET_TYPEDEF_CLASSES
1418 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1420 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static auto product(It begin, It end)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalShardSizes(const MeshSharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.