9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
31 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
33 #define GET_ATTRDEF_CLASSES
34 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
43 return iType == utils::IteratorType::reduction;
48 while (!array.empty() && array.back().empty())
54 return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
64 template <
typename Op>
66 return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
72 return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
73 symbolTableCollection);
78 template <
typename MeshAxesRange,
typename MeshShapeRange>
80 MeshShapeRange &&meshShape) {
84 auto axisSize = *(std::begin(meshShape) + axis);
85 if (ShapedType::isDynamic(axisSize)) {
86 return ShapedType::kDynamic;
94 template <
typename MeshAxesRange>
102 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
103 return ShapedType::kDynamic;
105 assert(dimSize % shardCount == 0);
106 return ceilDiv(dimSize, shardCount);
111 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
112 return ShapedType::kDynamic;
114 return dimSize * shardCount;
A symbol reference with a reference path containing a single element.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isFullReplication(MeshShardingAttr attr)
bool isReductionLoop(utils::IteratorType iType)
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr