28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
51 struct DimensionSize {
52 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
53 DimensionSize(int64_t val) : val(val) {}
54 int64_t value()
const {
return val; }
55 operator int64_t()
const {
return val; }
56 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
64 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
65 if (lhs.isDynamic() || rhs.isDynamic()) {
66 return DimensionSize::dynamic();
68 return lhs.value() / rhs.value();
71 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
72 if (lhs.isDynamic() || rhs.isDynamic()) {
73 return DimensionSize::dynamic();
75 return lhs.value() * rhs.value();
102 void MeshDialect::initialize() {
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
115 addInterface<MeshInlinerInterface>();
120 return arith::ConstantOp::materialize(builder, value, type, loc);
132 return op->
emitError() <<
"Undefined required mesh symbol \""
139 template <
typename It>
144 It next = std::next(begin);
148 for (; next != end; ++next, ++begin) {
149 if (*begin == *next) {
160 if (!
isUnique(sorted.begin(), sorted.end())) {
161 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
165 for (
auto axis : axes) {
166 if (axis >= rank || axis < 0) {
168 <<
"0-based mesh axis index " << axis
169 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
170 <<
"\" is of rank " << rank <<
".";
177 template <
typename Op>
178 static FailureOr<MeshOp>
191 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
193 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
194 const SplitAxes &splitAxes, OutShape &outShape,
197 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
198 llvm::adl_begin(outShape));
200 if (!shardedDimsOffsets.empty()) {
201 auto isDynShape = ShapedType::isDynamicShape(meshShape);
204 if (!innerSplitAxes.empty()) {
205 auto sz = shardedDimsOffsets[pos];
206 bool same = !isDynShape;
211 uint64_t numShards = 0;
212 for (
auto i : innerSplitAxes.asArrayRef()) {
213 numShards += meshShape[i];
215 for (
size_t i = 1; i < numShards; ++i) {
216 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
222 pos += numShards + 1;
224 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
234 if (!haloSizes.empty()) {
238 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
239 !innerSplitAxes.empty()) {
240 if (haloSizes[haloAxis * 2] >= 0 &&
241 haloSizes[haloAxis * 2 + 1] >= 0) {
242 outShape[tensorAxis] +=
243 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
246 outShape[tensorAxis] = ShapedType::kDynamic;
256 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
261 return shape.clone(resShapeArr);
265 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
266 if (rankedTensorType) {
279 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
280 if (shardOp && sharding == shardOp.getSharding() &&
281 !shardOp.getAnnotateForUsers()) {
286 auto shardingOp = builder.
create<ShardingOp>(operandValue.
getLoc(), sharding);
288 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
291 rewriter.replaceUsesWithIf(
292 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
293 return use.
getOwner() == operandOp && use.
get() == operandValue;
296 if (!shardOp || shardOp.getAnnotateForUsers()) {
301 builder.
create<ShardOp>(operandValue.
getLoc(), newShardOp, shardingOp,
303 rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
309 for (
auto &use : llvm::make_early_inc_range(result.
getUses())) {
321 bool isBlockArg = !operandSrcOp;
322 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
324 if (shardOp && sharding == shardOp.getSharding() &&
325 shardOp.getAnnotateForUsers()) {
334 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
337 rewriter.replaceUsesWithIf(
338 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
339 return use.
getOwner() == operandOp && use.
get() == operandValue;
342 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
348 auto newPreceedingShardOp =
349 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
352 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
353 return use.
getOwner() == newShardOp.getOperation();
362 int64_t rank = getRank();
365 return emitOpError(
"rank of mesh is expected to be a positive integer");
367 for (int64_t dimSize :
getShape()) {
368 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
369 return emitOpError(
"dimension size of a mesh is expected to be "
370 "non-negative or dynamic");
390 size_t expectedResultsCount =
391 getAxes().empty() ? mesh->getRank() : getAxes().size();
392 if (getResult().size() != expectedResultsCount) {
393 return emitError() <<
"Unexpected number of results " << getResult().size()
394 <<
". Expected " << expectedResultsCount <<
".";
407 build(odsBuilder, odsState,
415 assert(!axes.empty());
416 build(odsBuilder, odsState,
421 void MeshShapeOp::getAsmResultNames(
423 setNameFn(getResults()[0],
"mesh_shape");
443 static_sharded_dims_offsets),
456 void ShardingOp::build(
493 llvm::SmallSet<MeshAxis, 4> visitedAxes;
498 return emitError() <<
"mesh axis is expected to be non-negative";
499 if (!visitedAxes.insert(axis).second)
500 return emitError() <<
"mesh axis duplicated";
505 for (
auto subAxes : getSplitAxes().getAxes()) {
507 if (failed(checkMeshAxis(subAxesArray)))
510 if (getPartialAxes().has_value() &&
511 failed(checkMeshAxis(getPartialAxes().value())))
514 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
515 return emitOpError(
"halo sizes and shard offsets are mutually exclusive");
518 if (!getStaticHaloSizes().empty()) {
519 auto numSplitAxes = getSplitAxes().getAxes().size();
520 for (
auto splitAxis : getSplitAxes().getAxes()) {
521 if (splitAxis.empty()) {
525 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
526 return emitError() <<
"halo sizes must be specified for all split axes.";
533 void ShardingOp::getAsmResultNames(
535 setNameFn(getResult(),
"sharding");
543 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
544 getStaticShardedDimsOffsets().size() > 0) {
545 return emitError() <<
"sharded dims offsets are not allowed for "
546 "devices meshes with dynamic shape.";
549 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
550 if (!shardedDimsOffsets.empty()) {
551 auto meshShape = mesh.value().getShape();
552 assert(!ShapedType::isDynamicShape(meshShape));
554 for (
auto [tensorAxis, innerSplitAxes] :
llvm::enumerate(getSplitAxes())) {
555 if (!innerSplitAxes.empty()) {
556 int64_t numShards = 0, off = 0;
557 for (
auto i : innerSplitAxes.asArrayRef()) {
558 numShards += meshShape[i];
560 for (int64_t i = 0; i <= numShards; ++i) {
561 if (shardedDimsOffsets.size() <= pos + i) {
562 return emitError() <<
"sharded dims offsets has wrong size.";
564 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
565 if (shardedDimsOffsets[pos + i] < off) {
567 <<
"sharded dims offsets must be non-decreasing.";
569 off = shardedDimsOffsets[pos + i];
572 pos += numShards + 1;
588 LogicalResult matchAndRewrite(ShardingOp op,
591 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
593 op.getDynamicShardedDimsOffsets(), b);
604 op.setStaticHaloSizes(halos.first);
605 op.getDynamicHaloSizesMutable().assign(halos.second);
606 op.setStaticShardedDimsOffsets(offs.first);
607 op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
616 results.
add<FoldDynamicLists>(context);
638 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
645 return llvm::all_of(llvm::make_range(
getSplitAxes().begin() + minSize,
647 std::mem_fn(&MeshAxesAttr::empty)) &&
648 llvm::all_of(llvm::make_range(rhs.
getSplitAxes().begin() + minSize,
650 std::mem_fn(&MeshAxesAttr::empty));
707 return !(*
this == rhs);
711 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.
getDefiningOp());
712 assert(shardingOp &&
"expected sharding op");
713 *
this =
get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
715 shardingOp.getPartialType().value_or(ReductionKind::Sum),
716 shardingOp.getStaticHaloSizes(),
717 shardingOp.getStaticShardedDimsOffsets(),
732 res.split_axes.resize(split_axes_.size());
738 auto clone = [](
const auto src,
auto &dst) {
739 dst.resize(src.size());
743 clone(partial_axes_, res.partial_axes);
744 res.partial_type = partial_type_;
745 clone(static_halo_sizes_, res.static_halo_sizes);
746 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
747 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
748 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
762 build(odsBuilder, odsState, resType, shape, sharding, device);
769 void ShardOp::getAsmResultNames(
771 setNameFn(getResult(),
"sharding_annotated");
788 size_t expectedResultsCount =
789 getAxes().empty() ? mesh->getRank() : getAxes().size();
790 if (getResult().size() != expectedResultsCount) {
791 return emitError() <<
"Unexpected number of results " << getResult().size()
792 <<
". Expected " << expectedResultsCount <<
".";
800 build(odsBuilder, odsState,
807 build(odsBuilder, odsState,
812 void ProcessMultiIndexOp::getAsmResultNames(
814 setNameFn(getResults()[0],
"proc_linear_idx");
830 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
832 build(odsBuilder, odsState, mesh.getSymName());
835 void ProcessLinearIndexOp::getAsmResultNames(
837 setNameFn(getResult(),
"proc_linear_idx");
853 void NeighborsLinearIndicesOp::getAsmResultNames(
855 setNameFn(getNeighborDown(),
"down_linear_idx");
856 setNameFn(getNeighborUp(),
"up_linear_idx");
865 template <
typename Op>
868 LogicalResult matchAndRewrite(
Op op,
870 auto meshAxes = op.getMeshAxes();
871 if (!meshAxes.empty()) {
874 if (op.getInput().getType() != op.getResult().getType()) {
891 if (device.size() != meshAxes.size()) {
892 return emitError(loc) <<
"In-group device \"" << deviceName
893 <<
"\" has unexpected multi-index size "
894 << device.size() <<
". Expected " << meshAxes.size()
898 for (
size_t i = 0; i < device.size(); ++i) {
899 if (!ShapedType::isDynamic(device[i]) &&
900 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
901 meshShape[meshAxes[i]] <= device[i]) {
903 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
904 << deviceName <<
"\"."
905 <<
" Got " << device[i] <<
", but expected value in the range [0, "
906 << (meshShape[meshAxes[i]] - 1) <<
"].";
912 template <
typename It>
914 using ElementType = std::decay_t<decltype(*begin)>;
915 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
916 std::multiplies<ElementType>());
919 template <
typename R>
921 return product(adl_begin(range), adl_end(range));
925 int64_t expectedDimSize,
926 int64_t resultDimSize,
927 int64_t resultAxis) {
928 if (!ShapedType::isDynamic(resultDimSize) &&
929 expectedDimSize != resultDimSize) {
930 return emitError(loc) <<
"Dimension size mismatch for result axis "
931 << resultAxis <<
". Expected "
932 << (ShapedType::isDynamic(expectedDimSize)
934 : Twine(expectedDimSize))
935 <<
", but got " << resultDimSize <<
".";
942 Value operand,
Value result, int64_t gatherAxis,
944 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
945 if (gatherAxis < 0 || gatherAxis >= resultRank) {
947 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
948 << resultRank <<
").";
951 ShapedType operandType = cast<ShapedType>(operand.
getType());
952 ShapedType resultType = cast<ShapedType>(result.
getType());
953 auto deviceGroupSize =
955 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
956 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
957 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
958 auto expectedResultDimSize =
959 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
961 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
969 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
971 ShapedType operandType = cast<ShapedType>(operand.
getType());
972 ShapedType resultType = cast<ShapedType>(result.
getType());
973 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
974 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
976 result.
getLoc(), operandType.getDimSize(axis),
977 resultType.getDimSize(axis), axis))) {
983 if (splitAxis == concatAxis) {
987 auto deviceGroupSize =
989 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
990 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
991 DimensionSize expectedResultConcatDimSize =
992 operandConcatDimSize * deviceGroupSize;
993 DimensionSize expectedResultSplitDimSize =
994 operandSplitDimSize / deviceGroupSize;
995 if (!expectedResultSplitDimSize.isDynamic() &&
996 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
997 expectedResultSplitDimSize = DimensionSize::dynamic();
1000 result.
getLoc(), expectedResultConcatDimSize.value(),
1001 resultType.getDimSize(concatAxis), concatAxis))) {
1005 result.
getLoc(), expectedResultSplitDimSize.value(),
1006 resultType.getDimSize(splitAxis), splitAxis))) {
1014 Value operand,
Value result, int64_t tensorAxis,
1016 ShapedType operandType = cast<ShapedType>(operand.
getType());
1017 ShapedType resultType = cast<ShapedType>(result.
getType());
1018 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1019 if (axis != tensorAxis) {
1021 result.
getLoc(), operandType.getDimSize(axis),
1022 resultType.getDimSize(axis), axis))) {
1028 auto deviceGroupSize =
1030 auto operandScatterDimSize =
1031 DimensionSize(operandType.getDimSize(tensorAxis));
1032 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1033 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1035 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
1036 <<
" is not divisible by collective device group size "
1037 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
1040 DimensionSize expectedResultTensorDimSize =
1041 operandScatterDimSize / deviceGroupSize;
1043 result.
getLoc(), expectedResultTensorDimSize.value(),
1044 resultType.getDimSize(tensorAxis), tensorAxis))) {
1053 int64_t sliceAxis) {
1054 RankedTensorType operandRankedTensorType =
1055 cast<RankedTensorType>(operandType);
1056 DimensionSize operandSliceAxisSize =
1057 operandRankedTensorType.getShape()[sliceAxis];
1059 llvm::to_vector(operandRankedTensorType.getShape());
1061 resultShape[sliceAxis] =
1062 operandSliceAxisSize /
1064 return operandRankedTensorType.clone(resultShape);
1077 auto gatherAxis = getGatherAxis().getSExtValue();
1079 gatherAxis, getMeshAxes(),
1080 mesh.value().getShape());
1085 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1088 void AllGatherOp::getAsmResultNames(
1090 setNameFn(getResult(),
"all_gather");
1104 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1108 Value input, StringRef mesh,
1110 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
1114 void AllReduceOp::getAsmResultNames(
1116 setNameFn(getResult(),
"all_reduce");
1129 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1130 mesh.value().getShape());
1135 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1140 int64_t sliceAxis) {
1142 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1147 Type resultType,
Value input, StringRef mesh,
1149 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1150 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1153 void AllSliceOp::getAsmResultNames(
1155 setNameFn(getResult(),
"all_slice");
1169 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1170 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1175 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1178 void AllToAllOp::getAsmResultNames(
1180 setNameFn(getResult(),
"all_to_all");
1194 getRootDynamic(), getMeshAxes(),
1195 mesh.value().getShape()))) {
1204 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1207 void BroadcastOp::getAsmResultNames(
1209 setNameFn(getResult(),
"broadcast");
1222 getRootDynamic(), getMeshAxes(),
1223 mesh.value().getShape()))) {
1227 auto gatherAxis = getGatherAxis().getSExtValue();
1230 mesh.value().getShape());
1235 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1238 void GatherOp::getAsmResultNames(
1240 setNameFn(getResult(),
"gather");
1254 getSource().value(), getSourceDynamic(),
1255 getMeshAxes(), mesh.value().getShape()))) {
1263 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1266 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1267 setNameFn(getResult(),
"recv");
1280 getRootDynamic(), getMeshAxes(),
1281 mesh.value().getShape()))) {
1290 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1293 void ReduceOp::getAsmResultNames(
1295 setNameFn(getResult(),
"reduce");
1310 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1311 mesh.value().getShape());
1316 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1319 void ReduceScatterOp::getAsmResultNames(
1321 setNameFn(getResult(),
"reduce_scatter");
1334 getRootDynamic(), getMeshAxes(),
1335 mesh.value().getShape()))) {
1339 auto scatterAxis = getScatterAxis().getSExtValue();
1341 scatterAxis, getMeshAxes(),
1342 mesh.value().getShape());
1347 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1350 void ScatterOp::getAsmResultNames(
1352 setNameFn(getResult(),
"scatter");
1365 getDestination(), getDestinationDynamic(),
1366 getMeshAxes(), mesh.value().getShape()))) {
1374 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1377 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1378 setNameFn(getResult(),
"send");
1391 auto meshAxes = getMeshAxes();
1392 auto shiftAxis = getShiftAxis().getZExtValue();
1393 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1394 return emitError() <<
"Invalid shift axis " << shiftAxis
1395 <<
". It must be one of the grouping mesh axes.";
1407 void ShiftOp::getAsmResultNames(
1409 setNameFn(getResult(),
"shift");
1430 #define GET_OP_CLASSES
1431 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1433 #define GET_ATTRDEF_CLASSES
1434 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1436 #define GET_TYPEDEF_CLASSES
1437 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1439 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static auto product(It begin, It end)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalShardSizes(const MeshSharding &rhs) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.