26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
38 #define DEBUG_TYPE "mesh-ops"
39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
44 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
48 struct DimensionSize {
49 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
50 DimensionSize(int64_t val) : val(val) {}
51 int64_t value()
const {
return val; }
52 operator int64_t()
const {
return val; }
53 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
61 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
62 if (lhs.isDynamic() || rhs.isDynamic()) {
63 return DimensionSize::dynamic();
65 return lhs.value() / rhs.value();
68 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
69 if (lhs.isDynamic() || rhs.isDynamic()) {
70 return DimensionSize::dynamic();
72 return lhs.value() * rhs.value();
79 void MeshDialect::initialize() {
82 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
85 #define GET_ATTRDEF_LIST
86 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
92 return arith::ConstantOp::materialize(builder, value, type, loc);
104 return op->
emitError() <<
"Undefined required mesh symbol \""
111 template <
typename It>
116 It next = std::next(begin);
120 for (; next != end; ++next, ++begin) {
121 if (*begin == *next) {
132 if (!
isUnique(sorted.begin(), sorted.end())) {
133 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
137 for (
auto axis : axes) {
138 if (axis >= rank || axis < 0) {
140 <<
"0-based mesh axis index " << axis
141 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
142 <<
"\" is of rank " << rank <<
".";
149 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
151 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
152 const SplitAxes &splitAxes, OutShape &outShape) {
153 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
154 llvm::adl_begin(outShape));
164 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
166 shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
168 return shape.clone(resShapeArr);
172 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
173 if (rankedTensorType) {
186 int64_t rank = getRank();
189 return emitOpError(
"rank of mesh is expected to be a positive integer");
191 for (int64_t dimSize :
getShape()) {
192 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
193 return emitOpError(
"dimension size of a mesh is expected to be "
194 "non-negative or dynamic");
214 size_t expectedResultsCount =
215 getAxes().empty() ? mesh->getRank() : getAxes().size();
216 if (getResult().size() != expectedResultsCount) {
217 return emitError() <<
"Unexpected number of results " << getResult().size()
218 <<
". Expected " << expectedResultsCount <<
".";
231 build(odsBuilder, odsState,
239 assert(!axes.empty());
240 build(odsBuilder, odsState,
245 void MeshShapeOp::getAsmResultNames(
247 setNameFn(getResults()[0],
"mesh_shape");
261 llvm::SmallSet<MeshAxis, 4> visitedAxes;
266 return emitError() <<
"mesh axis is expected to be non-negative";
267 if (!visitedAxes.insert(axis).second)
268 return emitError() <<
"mesh axis duplicated";
275 if (
failed(checkMeshAxis(subAxesArray)))
278 if (
failed(checkMeshAxis(partialAxes)))
285 mlir::dyn_cast<MeshShardingAttr>(rhs);
286 return rhsAsMeshShardingAttr && *
this == rhsAsMeshShardingAttr;
290 if (
getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
294 if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
298 auto minSize =
std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
299 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
300 getSplitAxes().begin() + minSize),
301 llvm::make_range(rhs.getSplitAxes().begin(),
302 rhs.getSplitAxes().begin() + minSize))) {
306 return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
307 getSplitAxes().end()),
308 std::mem_fn(&MeshAxesAttr::empty)) &&
309 llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
310 rhs.getSplitAxes().end()),
311 std::mem_fn(&MeshAxesAttr::empty));
318 void ShardOp::getAsmResultNames(
320 setNameFn(getResult(),
"sharding_annotated");
337 size_t expectedResultsCount =
338 getAxes().empty() ? mesh->getRank() : getAxes().size();
339 if (getResult().size() != expectedResultsCount) {
340 return emitError() <<
"Unexpected number of results " << getResult().size()
341 <<
". Expected " << expectedResultsCount <<
".";
349 build(odsBuilder, odsState,
356 build(odsBuilder, odsState,
361 void ProcessMultiIndexOp::getAsmResultNames(
363 setNameFn(getResults()[0],
"proc_linear_idx");
379 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
381 build(odsBuilder, odsState, mesh.getSymName());
384 void ProcessLinearIndexOp::getAsmResultNames(
386 setNameFn(getResult(),
"proc_linear_idx");
395 template <
typename Op>
400 auto meshAxes = op.getMeshAxes();
401 if (!meshAxes.empty()) {
409 rewriter.
eraseOp(op.getOperation());
421 if (device.size() != meshAxes.size()) {
422 return emitError(loc) <<
"In-group device \"" << deviceName
423 <<
"\" has unexpected multi-index size "
424 << device.size() <<
". Expected " << meshAxes.size()
428 for (
size_t i = 0; i < device.size(); ++i) {
429 if (!ShapedType::isDynamic(device[i]) &&
430 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
431 meshShape[meshAxes[i]] <= device[i]) {
433 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
434 << deviceName <<
"\"."
435 <<
" Got " << device[i] <<
", but expected value in the range [0, "
436 << (meshShape[meshAxes[i]] - 1) <<
"].";
442 template <
typename Op>
456 template <
typename It>
458 using ElementType = std::decay_t<decltype(*begin)>;
459 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
460 std::multiplies<ElementType>());
463 template <
typename R>
465 return product(adl_begin(range), adl_end(range));
469 int64_t expectedDimSize,
470 int64_t resultDimSize,
471 int64_t resultAxis) {
472 if (!ShapedType::isDynamic(resultDimSize) &&
473 expectedDimSize != resultDimSize) {
474 return emitError(loc) <<
"Dimension size mismatch for result axis "
475 << resultAxis <<
". Expected "
476 << (ShapedType::isDynamic(expectedDimSize)
478 : Twine(expectedDimSize))
479 <<
", but got " << resultDimSize <<
".";
486 Value operand,
Value result, int64_t gatherAxis,
488 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
489 if (gatherAxis < 0 || gatherAxis >= resultRank) {
491 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
492 << resultRank <<
").";
495 ShapedType operandType = cast<ShapedType>(operand.
getType());
496 ShapedType resultType = cast<ShapedType>(result.
getType());
497 auto deviceGroupSize =
499 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
500 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
501 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
502 auto expectedResultDimSize =
503 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
505 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
513 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
515 ShapedType operandType = cast<ShapedType>(operand.
getType());
516 ShapedType resultType = cast<ShapedType>(result.
getType());
517 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
518 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
520 result.
getLoc(), operandType.getDimSize(axis),
521 resultType.getDimSize(axis), axis))) {
527 if (splitAxis == concatAxis) {
531 auto deviceGroupSize =
533 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
534 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
535 DimensionSize expectedResultConcatDimSize =
536 operandConcatDimSize * deviceGroupSize;
537 DimensionSize expectedResultSplitDimSize =
538 operandSplitDimSize / deviceGroupSize;
539 if (!expectedResultSplitDimSize.isDynamic() &&
540 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
541 expectedResultSplitDimSize = DimensionSize::dynamic();
544 result.
getLoc(), expectedResultConcatDimSize.value(),
545 resultType.getDimSize(concatAxis), concatAxis))) {
549 result.
getLoc(), expectedResultSplitDimSize.value(),
550 resultType.getDimSize(splitAxis), splitAxis))) {
558 Value operand,
Value result, int64_t tensorAxis,
560 ShapedType operandType = cast<ShapedType>(operand.
getType());
561 ShapedType resultType = cast<ShapedType>(result.
getType());
562 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
563 if (axis != tensorAxis) {
565 result.
getLoc(), operandType.getDimSize(axis),
566 resultType.getDimSize(axis), axis))) {
572 auto deviceGroupSize =
574 auto operandScatterDimSize =
575 DimensionSize(operandType.getDimSize(tensorAxis));
576 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
577 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
579 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
580 <<
" is not divisible by collective device group size "
581 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
584 DimensionSize expectedResultTensorDimSize =
585 operandScatterDimSize / deviceGroupSize;
587 result.
getLoc(), expectedResultTensorDimSize.value(),
588 resultType.getDimSize(tensorAxis), tensorAxis))) {
598 RankedTensorType operandRankedTensorType =
599 cast<RankedTensorType>(operandType);
600 DimensionSize operandSliceAxisSize =
601 operandRankedTensorType.getShape()[sliceAxis];
603 llvm::to_vector(operandRankedTensorType.getShape());
605 resultShape[sliceAxis] =
606 operandSliceAxisSize /
608 return operandRankedTensorType.clone(resultShape);
621 auto gatherAxis = getGatherAxis().getSExtValue();
623 gatherAxis, getMeshAxes(),
624 mesh.value().getShape());
629 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
632 void AllGatherOp::getAsmResultNames(
634 setNameFn(getResult(),
"all_gather");
648 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
652 Value input, StringRef mesh,
654 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
658 void AllReduceOp::getAsmResultNames(
660 setNameFn(getResult(),
"all_reduce");
673 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
674 mesh.value().getShape());
679 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
686 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
691 Type resultType,
Value input, StringRef mesh,
693 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
694 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
697 void AllSliceOp::getAsmResultNames(
699 setNameFn(getResult(),
"all_slice");
713 getOperand(), getResult(), getSplitAxis().getSExtValue(),
714 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
719 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
722 void AllToAllOp::getAsmResultNames(
724 setNameFn(getResult(),
"all_to_all");
738 getRootDynamic(), getMeshAxes(),
739 mesh.value().getShape()))) {
748 patterns.
add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
751 void BroadcastOp::getAsmResultNames(
753 setNameFn(getResult(),
"broadcast");
766 getRootDynamic(), getMeshAxes(),
767 mesh.value().getShape()))) {
771 auto gatherAxis = getGatherAxis().getSExtValue();
774 mesh.value().getShape());
779 patterns.
add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
782 void GatherOp::getAsmResultNames(
784 setNameFn(getResult(),
"gather");
798 getSource().value(), getSourceDynamic(),
799 getMeshAxes(), mesh.value().getShape()))) {
807 patterns.
add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
811 setNameFn(getResult(),
"recv");
824 getRootDynamic(), getMeshAxes(),
825 mesh.value().getShape()))) {
834 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
837 void ReduceOp::getAsmResultNames(
839 setNameFn(getResult(),
"reduce");
854 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
855 mesh.value().getShape());
860 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
863 void ReduceScatterOp::getAsmResultNames(
865 setNameFn(getResult(),
"reduce_scatter");
878 getRootDynamic(), getMeshAxes(),
879 mesh.value().getShape()))) {
883 auto scatterAxis = getScatterAxis().getSExtValue();
885 scatterAxis, getMeshAxes(),
886 mesh.value().getShape());
891 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
894 void ScatterOp::getAsmResultNames(
896 setNameFn(getResult(),
"scatter");
909 getDestination(), getDestinationDynamic(),
910 getMeshAxes(), mesh.value().getShape()))) {
918 patterns.
add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
922 setNameFn(getResult(),
"send");
935 auto meshAxes = getMeshAxes();
936 auto shiftAxis = getShiftAxis().getZExtValue();
937 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
938 return emitError() <<
"Invalid shift axis " << shiftAxis
939 <<
". It must be one of the grouping mesh axes.";
951 void ShiftOp::getAsmResultNames(
953 setNameFn(getResult(),
"shift");
960 #define GET_OP_CLASSES
961 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
963 #define GET_ATTRDEF_CLASSES
964 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
966 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static auto product(It begin, It end)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
bool operator==(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineExpr operator*(int64_t val, AffineExpr expr)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.