9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
20 #include "llvm/Support/MathExtras.h"
31 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
33 #define GET_ATTRDEF_CLASSES
34 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
43 return iType == utils::IteratorType::reduction;
48 while (!array.empty() && array.back().empty())
54 return attr.getPartialAxes().empty() &&
55 llvm::all_of(attr.getSplitAxes(), [](
MeshAxesAttr axes) {
56 return axes.asArrayRef().empty();
75 template <
typename Op>
77 return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
83 return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
84 symbolTableCollection);
89 template <
typename MeshAxesRange,
typename MeshShapeRange>
91 MeshShapeRange &&meshShape) {
95 auto axisSize = *(std::begin(meshShape) + axis);
96 if (ShapedType::isDynamic(axisSize)) {
97 return ShapedType::kDynamic;
105 template <
typename MeshAxesRange>
113 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
114 return ShapedType::kDynamic;
116 assert(dimSize % shardCount == 0);
117 return dimSize / shardCount;
122 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
123 return ShapedType::kDynamic;
125 return dimSize * shardCount;
A symbol reference with a reference path containing a single element.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
bool isFullReplication(MeshShardingAttr attr)
bool isReductionLoop(utils::IteratorType iType)
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr