9#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H
10#define MLIR_DIALECT_SHARD_IR_SHARDOPS_H
22#include "llvm/Support/MathExtras.h"
35#include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
37#define GET_ATTRDEF_CLASSES
38#include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
62 ::llvm::StringRef
getGrid()
const {
return grid ? grid.getValue() :
""; }
66 return static_sharded_dims_offsets;
70 return dynamic_sharded_dims_offsets;
72 operator bool()
const {
return (!grid) ==
false; }
83llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const Sharding &sharding);
87 llvm::raw_string_ostream os(str);
89 return diag << os.str();
95#define GET_TYPEDEF_CLASSES
96#include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
99#include "mlir/Dialect/Shard/IR/ShardOps.h.inc"
105 return iType == utils::IteratorType::reduction;
111 while (array.size() > 1 && array.back().empty())
118 return axes.asArrayRef().empty();
133 shard::GridOp gridOp =
getGridOrNull(op, gridSymbol, symbolTableCollection);
139template <
typename Op>
149 cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr(),
150 symbolTableCollection);
155template <
typename Gr
idAxesRange,
typename Gr
idShapeRange>
157 GridShapeRange &&gridShape) {
161 auto axisSize = *(std::begin(gridShape) + axis);
162 if (ShapedType::isDynamic(axisSize)) {
163 return ShapedType::kDynamic;
171template <
typename Gr
idAxesRange>
179 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
180 return ShapedType::kDynamic;
182 assert(dimSize % shardCount == 0);
183 return dimSize / shardCount;
188 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
189 return ShapedType::kDynamic;
191 return dimSize * shardCount;
static std::string diag(const llvm::Value &value)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
ArrayRef< Value > getDynamicShardedDimsOffsets() const
bool operator!=(Value rhs) const
bool equalShardSizes(const Sharding &rhs) const
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
::llvm::StringRef getGrid() const
bool equalHaloAndShardSizes(const Sharding &rhs) const
bool operator==(Value rhs) const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const Sharding &sharding)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
shard::GridOp getGrid< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
DenseI16ArrayAttr GridAxesAttr
void removeTrailingEmptySubArray(SmallVector< SmallVector< T > > &array)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
DenseI64ArrayAttr HaloSizePairAttr
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
bool isReductionLoop(utils::IteratorType iType)
DenseI64ArrayAttr ShardShapeAttr
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Type shardType(Type type, GridOp grid, Sharding sharding)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr