9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
20 #include "llvm/Support/MathExtras.h"
33 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
35 #define GET_ATTRDEF_CLASSES
36 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
61 ArrayRef<Value> dynamic_halo_sizes_ = {},
62 ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
70 return static_sharded_dims_sizes;
74 return dynamic_sharded_dims_sizes;
76 operator bool()
const {
return (!mesh) ==
false; }
88 #define GET_TYPEDEF_CLASSES
89 #include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
91 #define GET_OP_CLASSES
92 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
98 return iType == utils::IteratorType::reduction;
101 template <
typename T>
103 while (!array.empty() && array.back().empty())
111 return axes.asArrayRef().empty();
130 template <
typename Op>
132 return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
140 cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
141 symbolTableCollection);
146 template <
typename MeshAxesRange,
typename MeshShapeRange>
148 MeshShapeRange &&meshShape) {
152 auto axisSize = *(std::begin(meshShape) + axis);
153 if (ShapedType::isDynamic(axisSize)) {
154 return ShapedType::kDynamic;
162 template <
typename MeshAxesRange>
170 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
171 return ShapedType::kDynamic;
173 assert(dimSize % shardCount == 0);
174 return dimSize / shardCount;
179 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
180 return ShapedType::kDynamic;
182 return dimSize * shardCount;
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
::mlir::FlatSymbolRefAttr getMeshAttr() const
ArrayRef< MeshAxesAttr > getSplitAxes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
bool operator!=(Value rhs) const
ReductionKind getPartialType() const
ArrayRef< Value > getDynamicShardedDimsSizes() const
ArrayRef< int64_t > getStaticShardedDimsSizes() const
bool operator==(Value rhs) const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
::llvm::StringRef getMesh() const
ArrayRef< int64_t > getStaticHaloSizes() const
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isReductionLoop(utils::IteratorType iType)
bool isFullReplication(MeshSharding sharding)
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr