27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Casting.h"
40 #define DEBUG_TYPE "mesh-ops"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
50 struct DimensionSize {
51 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
52 DimensionSize(int64_t val) : val(val) {}
53 int64_t value()
const {
return val; }
54 operator int64_t()
const {
return val; }
55 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
63 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
64 if (lhs.isDynamic() || rhs.isDynamic()) {
65 return DimensionSize::dynamic();
67 return lhs.value() / rhs.value();
70 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
71 if (lhs.isDynamic() || rhs.isDynamic()) {
72 return DimensionSize::dynamic();
74 return lhs.value() * rhs.value();
81 void MeshDialect::initialize() {
84 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
87 #define GET_ATTRDEF_LIST
88 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
91 #define GET_TYPEDEF_LIST
92 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
98 return arith::ConstantOp::materialize(builder, value, type, loc);
110 return op->
emitError() <<
"Undefined required mesh symbol \""
117 template <
typename It>
122 It next = std::next(begin);
126 for (; next != end; ++next, ++begin) {
127 if (*begin == *next) {
138 if (!
isUnique(sorted.begin(), sorted.end())) {
139 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
143 for (
auto axis : axes) {
144 if (axis >= rank || axis < 0) {
146 <<
"0-based mesh axis index " << axis
147 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
148 <<
"\" is of rank " << rank <<
".";
155 template <
typename Op>
156 static FailureOr<MeshOp>
169 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
171 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
172 const SplitAxes &splitAxes, OutShape &outShape,
175 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
176 llvm::adl_begin(outShape));
178 if (!shardedDimsSizes.empty()) {
180 if (innerSplitAxes.empty()) {
182 for (
auto dimSz : shardedDimsSizes) {
183 auto inAxis = dimSz % inShape.size();
184 assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
185 inShape[inAxis] == ShapedType::kDynamic);
192 auto sz = shardedDimsSizes[tensorAxis];
194 for (
size_t i = tensorAxis + inShape.size();
195 i < shardedDimsSizes.size(); i += inShape.size()) {
196 if (shardedDimsSizes[i] != sz) {
201 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
211 if (!haloSizes.empty()) {
215 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
216 !innerSplitAxes.empty()) {
217 if (haloSizes[haloAxis * 2] >= 0 &&
218 haloSizes[haloAxis * 2 + 1] >= 0) {
219 outShape[tensorAxis] +=
220 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
223 outShape[tensorAxis] = ShapedType::kDynamic;
233 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
238 return shape.clone(resShapeArr);
242 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
243 if (rankedTensorType) {
256 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
257 if (shardOp && sharding == shardOp.getSharding() &&
258 !shardOp.getAnnotateForUsers()) {
263 auto shardingOp = builder.
create<ShardingOp>(operandValue.
getLoc(), sharding);
265 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
268 rewriter.replaceUsesWithIf(
269 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
270 return use.
getOwner() == operandOp && use.
get() == operandValue;
273 if (!shardOp || shardOp.getAnnotateForUsers()) {
278 builder.
create<ShardOp>(operandValue.
getLoc(), newShardOp, shardingOp,
280 rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
286 for (
auto &use : llvm::make_early_inc_range(result.
getUses())) {
298 bool isBlockArg = !operandSrcOp;
299 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
301 if (shardOp && sharding == shardOp.getSharding() &&
302 shardOp.getAnnotateForUsers()) {
311 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
314 rewriter.replaceUsesWithIf(
315 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
316 return use.
getOwner() == operandOp && use.
get() == operandValue;
319 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
325 auto newPreceedingShardOp =
326 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, shardingOp,
329 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](
OpOperand &use) {
330 return use.
getOwner() == newShardOp.getOperation();
339 int64_t rank = getRank();
342 return emitOpError(
"rank of mesh is expected to be a positive integer");
344 for (int64_t dimSize :
getShape()) {
345 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
346 return emitOpError(
"dimension size of a mesh is expected to be "
347 "non-negative or dynamic");
367 size_t expectedResultsCount =
368 getAxes().empty() ? mesh->getRank() : getAxes().size();
369 if (getResult().size() != expectedResultsCount) {
370 return emitError() <<
"Unexpected number of results " << getResult().size()
371 <<
". Expected " << expectedResultsCount <<
".";
384 build(odsBuilder, odsState,
392 assert(!axes.empty());
393 build(odsBuilder, odsState,
398 void MeshShapeOp::getAsmResultNames(
400 setNameFn(getResults()[0],
"mesh_shape");
432 void ShardingOp::build(
469 llvm::SmallSet<MeshAxis, 4> visitedAxes;
474 return emitError() <<
"mesh axis is expected to be non-negative";
475 if (!visitedAxes.insert(axis).second)
476 return emitError() <<
"mesh axis duplicated";
481 for (
auto subAxes : getSplitAxes().getAxes()) {
483 if (failed(checkMeshAxis(subAxesArray)))
486 if (getPartialAxes().has_value() &&
487 failed(checkMeshAxis(getPartialAxes().value())))
490 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
491 return emitOpError(
"halo sizes and shard shapes are mutually exclusive");
494 if (!getStaticHaloSizes().empty()) {
495 auto numSplitAxes = getSplitAxes().getAxes().size();
496 for (
auto splitAxis : getSplitAxes().getAxes()) {
497 if (splitAxis.empty()) {
501 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
502 return emitError() <<
"halo sizes must be specified for all split axes.";
509 void ShardingOp::getAsmResultNames(
511 setNameFn(getResult(),
"sharding");
519 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
520 getStaticShardedDimsSizes().size() > 0) {
521 return emitError() <<
"sharded dims sizes are not allowed for "
522 "devices meshes with dynamic shape.";
541 if (!llvm::equal(llvm::make_range(
getSplitAxes().begin(),
548 return llvm::all_of(llvm::make_range(
getSplitAxes().begin() + minSize,
550 std::mem_fn(&MeshAxesAttr::empty)) &&
551 llvm::all_of(llvm::make_range(rhs.
getSplitAxes().begin() + minSize,
553 std::mem_fn(&MeshAxesAttr::empty));
600 return !(*
this == rhs);
604 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.
getDefiningOp());
605 assert(shardingOp &&
"expected sharding op");
606 *
this =
get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
608 shardingOp.getPartialType().value_or(ReductionKind::Sum),
609 shardingOp.getStaticHaloSizes(),
610 shardingOp.getStaticShardedDimsSizes(),
625 res.split_axes.resize(split_axes_.size());
631 auto clone = [](
const auto src,
auto &dst) {
632 dst.resize(src.size());
636 clone(partial_axes_, res.partial_axes);
637 res.partial_type = partial_type_;
638 clone(static_halo_sizes_, res.static_halo_sizes);
639 clone(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
640 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
641 clone(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
655 build(odsBuilder, odsState, resType, shape, sharding, device);
662 void ShardOp::getAsmResultNames(
664 setNameFn(getResult(),
"sharding_annotated");
681 size_t expectedResultsCount =
682 getAxes().empty() ? mesh->getRank() : getAxes().size();
683 if (getResult().size() != expectedResultsCount) {
684 return emitError() <<
"Unexpected number of results " << getResult().size()
685 <<
". Expected " << expectedResultsCount <<
".";
693 build(odsBuilder, odsState,
700 build(odsBuilder, odsState,
705 void ProcessMultiIndexOp::getAsmResultNames(
707 setNameFn(getResults()[0],
"proc_linear_idx");
723 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
725 build(odsBuilder, odsState, mesh.getSymName());
728 void ProcessLinearIndexOp::getAsmResultNames(
730 setNameFn(getResult(),
"proc_linear_idx");
739 template <
typename Op>
742 LogicalResult matchAndRewrite(
Op op,
744 auto meshAxes = op.getMeshAxes();
745 if (!meshAxes.empty()) {
753 rewriter.
eraseOp(op.getOperation());
765 if (device.size() != meshAxes.size()) {
766 return emitError(loc) <<
"In-group device \"" << deviceName
767 <<
"\" has unexpected multi-index size "
768 << device.size() <<
". Expected " << meshAxes.size()
772 for (
size_t i = 0; i < device.size(); ++i) {
773 if (!ShapedType::isDynamic(device[i]) &&
774 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
775 meshShape[meshAxes[i]] <= device[i]) {
777 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
778 << deviceName <<
"\"."
779 <<
" Got " << device[i] <<
", but expected value in the range [0, "
780 << (meshShape[meshAxes[i]] - 1) <<
"].";
786 template <
typename It>
788 using ElementType = std::decay_t<decltype(*begin)>;
789 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
790 std::multiplies<ElementType>());
793 template <
typename R>
795 return product(adl_begin(range), adl_end(range));
799 int64_t expectedDimSize,
800 int64_t resultDimSize,
801 int64_t resultAxis) {
802 if (!ShapedType::isDynamic(resultDimSize) &&
803 expectedDimSize != resultDimSize) {
804 return emitError(loc) <<
"Dimension size mismatch for result axis "
805 << resultAxis <<
". Expected "
806 << (ShapedType::isDynamic(expectedDimSize)
808 : Twine(expectedDimSize))
809 <<
", but got " << resultDimSize <<
".";
816 Value operand,
Value result, int64_t gatherAxis,
818 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
819 if (gatherAxis < 0 || gatherAxis >= resultRank) {
821 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
822 << resultRank <<
").";
825 ShapedType operandType = cast<ShapedType>(operand.
getType());
826 ShapedType resultType = cast<ShapedType>(result.
getType());
827 auto deviceGroupSize =
829 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
830 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
831 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
832 auto expectedResultDimSize =
833 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
835 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
843 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
845 ShapedType operandType = cast<ShapedType>(operand.
getType());
846 ShapedType resultType = cast<ShapedType>(result.
getType());
847 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
848 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
850 result.
getLoc(), operandType.getDimSize(axis),
851 resultType.getDimSize(axis), axis))) {
857 if (splitAxis == concatAxis) {
861 auto deviceGroupSize =
863 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
864 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
865 DimensionSize expectedResultConcatDimSize =
866 operandConcatDimSize * deviceGroupSize;
867 DimensionSize expectedResultSplitDimSize =
868 operandSplitDimSize / deviceGroupSize;
869 if (!expectedResultSplitDimSize.isDynamic() &&
870 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
871 expectedResultSplitDimSize = DimensionSize::dynamic();
874 result.
getLoc(), expectedResultConcatDimSize.value(),
875 resultType.getDimSize(concatAxis), concatAxis))) {
879 result.
getLoc(), expectedResultSplitDimSize.value(),
880 resultType.getDimSize(splitAxis), splitAxis))) {
888 Value operand,
Value result, int64_t tensorAxis,
890 ShapedType operandType = cast<ShapedType>(operand.
getType());
891 ShapedType resultType = cast<ShapedType>(result.
getType());
892 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
893 if (axis != tensorAxis) {
895 result.
getLoc(), operandType.getDimSize(axis),
896 resultType.getDimSize(axis), axis))) {
902 auto deviceGroupSize =
904 auto operandScatterDimSize =
905 DimensionSize(operandType.getDimSize(tensorAxis));
906 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
907 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
909 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
910 <<
" is not divisible by collective device group size "
911 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
914 DimensionSize expectedResultTensorDimSize =
915 operandScatterDimSize / deviceGroupSize;
917 result.
getLoc(), expectedResultTensorDimSize.value(),
918 resultType.getDimSize(tensorAxis), tensorAxis))) {
928 RankedTensorType operandRankedTensorType =
929 cast<RankedTensorType>(operandType);
930 DimensionSize operandSliceAxisSize =
931 operandRankedTensorType.getShape()[sliceAxis];
933 llvm::to_vector(operandRankedTensorType.getShape());
935 resultShape[sliceAxis] =
936 operandSliceAxisSize /
938 return operandRankedTensorType.clone(resultShape);
951 auto gatherAxis = getGatherAxis().getSExtValue();
953 gatherAxis, getMeshAxes(),
954 mesh.value().getShape());
959 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
962 void AllGatherOp::getAsmResultNames(
964 setNameFn(getResult(),
"all_gather");
978 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
982 Value input, StringRef mesh,
984 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
988 void AllReduceOp::getAsmResultNames(
990 setNameFn(getResult(),
"all_reduce");
1003 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1004 mesh.value().getShape());
1009 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1014 int64_t sliceAxis) {
1016 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1021 Type resultType,
Value input, StringRef mesh,
1023 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1024 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1027 void AllSliceOp::getAsmResultNames(
1029 setNameFn(getResult(),
"all_slice");
1043 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1044 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1049 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1052 void AllToAllOp::getAsmResultNames(
1054 setNameFn(getResult(),
"all_to_all");
1068 getRootDynamic(), getMeshAxes(),
1069 mesh.value().getShape()))) {
1078 patterns.
add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1081 void BroadcastOp::getAsmResultNames(
1083 setNameFn(getResult(),
"broadcast");
1096 getRootDynamic(), getMeshAxes(),
1097 mesh.value().getShape()))) {
1101 auto gatherAxis = getGatherAxis().getSExtValue();
1104 mesh.value().getShape());
1109 patterns.
add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1112 void GatherOp::getAsmResultNames(
1114 setNameFn(getResult(),
"gather");
1128 getSource().value(), getSourceDynamic(),
1129 getMeshAxes(), mesh.value().getShape()))) {
1137 patterns.
add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1140 void RecvOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1141 setNameFn(getResult(),
"recv");
1154 getRootDynamic(), getMeshAxes(),
1155 mesh.value().getShape()))) {
1164 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1167 void ReduceOp::getAsmResultNames(
1169 setNameFn(getResult(),
"reduce");
1184 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1185 mesh.value().getShape());
1190 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1193 void ReduceScatterOp::getAsmResultNames(
1195 setNameFn(getResult(),
"reduce_scatter");
1208 getRootDynamic(), getMeshAxes(),
1209 mesh.value().getShape()))) {
1213 auto scatterAxis = getScatterAxis().getSExtValue();
1215 scatterAxis, getMeshAxes(),
1216 mesh.value().getShape());
1221 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1224 void ScatterOp::getAsmResultNames(
1226 setNameFn(getResult(),
"scatter");
1239 getDestination(), getDestinationDynamic(),
1240 getMeshAxes(), mesh.value().getShape()))) {
1248 patterns.
add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1251 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1252 setNameFn(getResult(),
"send");
1265 auto meshAxes = getMeshAxes();
1266 auto shiftAxis = getShiftAxis().getZExtValue();
1267 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1268 return emitError() <<
"Invalid shift axis " << shiftAxis
1269 <<
". It must be one of the grouping mesh axes.";
1281 void ShiftOp::getAsmResultNames(
1283 setNameFn(getResult(),
"shift");
1304 #define GET_OP_CLASSES
1305 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1307 #define GET_ATTRDEF_CLASSES
1308 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1310 #define GET_TYPEDEF_CLASSES
1311 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1313 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static auto product(It begin, It end)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsSizes={}, ArrayRef< int64_t > haloSizes={})
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
ArrayRef< MeshAxesAttr > getSplitAxes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsSizes() const
ArrayRef< int64_t > getStaticShardedDimsSizes() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.