9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
21 #include "llvm/Support/MathExtras.h"
34 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
62 ArrayRef<Value> dynamic_halo_sizes_ = {},
63 ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
71 return static_sharded_dims_offsets;
75 return dynamic_sharded_dims_offsets;
77 operator bool()
const {
return (!mesh) ==
false; }
91 #define GET_TYPEDEF_CLASSES
92 #include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
94 #define GET_OP_CLASSES
95 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
101 return iType == utils::IteratorType::reduction;
105 template <
typename T>
107 while (array.size() > 1 && array.back().empty())
115 return axes.asArrayRef().empty();
134 template <
typename Op>
144 cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
145 symbolTableCollection);
150 template <
typename MeshAxesRange,
typename MeshShapeRange>
152 MeshShapeRange &&meshShape) {
156 auto axisSize = *(std::begin(meshShape) + axis);
157 if (ShapedType::isDynamic(axisSize)) {
158 return ShapedType::kDynamic;
166 template <
typename MeshAxesRange>
174 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
175 return ShapedType::kDynamic;
177 assert(dimSize % shardCount == 0);
178 return dimSize / shardCount;
183 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
184 return ShapedType::kDynamic;
186 return dimSize * shardCount;
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalShardSizes(const MeshSharding &rhs) const
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isReductionLoop(utils::IteratorType iType)
bool isFullReplication(MeshSharding sharding)
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr