9 #ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H
10 #define MLIR_DIALECT_SHARD_IR_SHARDOPS_H
21 #include "llvm/Support/MathExtras.h"
34 #include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
65 return static_sharded_dims_offsets;
69 return dynamic_sharded_dims_offsets;
71 operator bool()
const {
return (!grid) ==
false; }
85 #define GET_TYPEDEF_CLASSES
86 #include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
88 #define GET_OP_CLASSES
89 #include "mlir/Dialect/Shard/IR/ShardOps.h.inc"
95 return iType == utils::IteratorType::reduction;
101 while (array.size() > 1 && array.back().empty())
108 return axes.asArrayRef().empty();
129 template <
typename Op>
139 cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr(),
140 symbolTableCollection);
145 template <
typename Gr
idAxesRange,
typename Gr
idShapeRange>
147 GridShapeRange &&gridShape) {
151 auto axisSize = *(std::begin(gridShape) + axis);
152 if (ShapedType::isDynamic(axisSize)) {
153 return ShapedType::kDynamic;
161 template <
typename Gr
idAxesRange>
169 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
170 return ShapedType::kDynamic;
172 assert(dimSize % shardCount == 0);
173 return dimSize / shardCount;
178 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
179 return ShapedType::kDynamic;
181 return dimSize * shardCount;
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
ArrayRef< Value > getDynamicShardedDimsOffsets() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
shard::GridOp getGrid< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
bool isReductionLoop(utils::IteratorType iType)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr