MLIR 22.0.0git
ShardOps.h File Reference
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
#include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
#include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
#include "mlir/Dialect/Shard/IR/ShardOps.h.inc"

Go to the source code of this file.

Classes

class  mlir::shard::Sharding

Namespaces

namespace  mlir
 Include the generated interface declarations.
namespace  mlir::shard

Macros

#define GET_ATTRDEF_CLASSES
#define GET_TYPEDEF_CLASSES
#define GET_OP_CLASSES

Typedefs

using mlir::shard::GridAxis = int16_t
using mlir::shard::GridAxesAttr = DenseI16ArrayAttr
using mlir::shard::ShardShapeAttr = DenseI64ArrayAttr
using mlir::shard::HaloSizePairAttr = DenseI64ArrayAttr

Functions

bool mlir::shard::isReductionLoop (utils::IteratorType iType)
template<typename T>
void mlir::shard::removeTrailingEmptySubArray (SmallVector< SmallVector< T > > &array)
bool mlir::shard::isFullReplication (Sharding sharding)
shard::GridOp mlir::shard::getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
shard::GridOp mlir::shard::getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
template<typename Op>
shard::GridOp mlir::shard::getGrid (Op op, SymbolTableCollection &symbolTableCollection)
template<>
shard::GridOp mlir::shard::getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection)
template<typename GridAxesRange, typename GridShapeRange>
int64_t mlir::shard::collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
template<typename GridAxesRange>
int64_t mlir::shard::collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid)
int64_t mlir::shard::shardDimension (int64_t dimSize, int64_t shardCount)
int64_t mlir::shard::gatherDimension (int64_t dimSize, int64_t shardCount)
ShapedType mlir::shard::shardShapedType (ShapedType shape, GridOp grid, Sharding sharding)
Type mlir::shard::shardType (Type type, GridOp grid, Sharding sharding)
void mlir::shard::maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder)
void mlir::shard::maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder)
SmallVector< Valuemlir::shard::getMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
 Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.

Macro Definition Documentation

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

Definition at line 36 of file ShardOps.h.

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 88 of file ShardOps.h.

◆ GET_TYPEDEF_CLASSES

#define GET_TYPEDEF_CLASSES

Definition at line 85 of file ShardOps.h.