MLIR  22.0.0git
Namespaces | Macros | Typedefs | Functions
Partition.cpp File Reference
#include "mlir/Dialect/Shard/Transforms/Partition.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <iterator>
#include <optional>
#include <tuple>
#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"

Go to the source code of this file.

Namespaces

 mlir
 Include the generated interface declarations.
 
 mlir::shard
 

Macros

#define GEN_PASS_DEF_PARTITION
 

Typedefs

using mlir::shard::UnshardedToShardedValueMap = DenseMap< Value, Value >
 

Functions

template<typename SourceAxes , typename TargetAxes >
static bool mlir::shard::arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
 
static Sharding mlir::shard::targetShardingInSplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
 
static std::tuple< TypedValue< ShapedType >, Sharding > mlir::shard::splitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
 
static std::optional< std::tuple< int64_t, GridAxis > > mlir::shard::detectSplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, GridAxis > > mlir::shard::detectUnsplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
 
static Sharding mlir::shard::targetShardingInUnsplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
 
static ShapedType mlir::shard::allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, Sharding > mlir::shard::unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > mlir::shard::detectMoveLastSplitAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
 
static Sharding mlir::shard::targetShardingInMoveLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static ShapedType mlir::shard::allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, Sharding > mlir::shard::moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > mlir::shard::reshardOn1DGrid (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > mlir::shard::reshard (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > mlir::shard::reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
 
TypedValue< ShapedType > mlir::shard::reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
 
void mlir::shard::reshardingRegisterDependentDialects (DialectRegistry &registry)
 
static SmallVector< Type > mlir::shard::shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
 
static LogicalResult mlir::shard::partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static std::vector< Sharding > mlir::shard::getOperandShardings (Operation &op)
 
static std::vector< Sharding > mlir::shard::getResultShardings (Operation &op)
 
static LogicalResult mlir::shard::partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::shard::partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::shard::partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::shard::partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
 

Macro Definition Documentation

◆ GEN_PASS_DEF_PARTITION

#define GEN_PASS_DEF_PARTITION

Definition at line 530 of file Partition.cpp.