MLIR  20.0.0git
Namespaces | Macros | Typedefs | Functions
Spmdization.cpp File Reference
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <iterator>
#include <optional>
#include <tuple>
#include <type_traits>
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"

Go to the source code of this file.

Namespaces

 mlir
 Include the generated interface declarations.
 
 mlir::mesh
 

Macros

#define GEN_PASS_DEF_SPMDIZATION
 

Typedefs

using mlir::mesh::UnshardedToShardedValueMap = DenseMap< Value, Value >
 

Functions

template<typename SourceAxes , typename TargetAxes >
static bool mlir::mesh::arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
 
static std::tuple< TypedValue< ShapedType >, MeshSharding > mlir::mesh::handlePartialAxesDuringResharding (OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static MeshSharding mlir::mesh::targetShardingInSplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshSharding > mlir::mesh::splitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< int64_t, MeshAxis > > mlir::mesh::detectSplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > mlir::mesh::trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, MeshAxis > > mlir::mesh::detectUnsplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static MeshSharding mlir::mesh::targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
 
static ShapedType mlir::mesh::allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshSharding > mlir::mesh::unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > mlir::mesh::tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > mlir::mesh::detectMoveLastSplitAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static MeshSharding mlir::mesh::targetShardingInMoveLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static ShapedType mlir::mesh::allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshSharding > mlir::mesh::moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > mlir::mesh::tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > mlir::mesh::tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > mlir::mesh::reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > mlir::mesh::reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > mlir::mesh::reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
 
TypedValue< ShapedType > mlir::mesh::reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
 
void mlir::mesh::reshardingRegisterDependentDialects (DialectRegistry &registry)
 
SmallVector< Type > mlir::mesh::shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
 
void mlir::mesh::spmdizeTriviallyShardableOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
static LogicalResult mlir::mesh::spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static std::vector< MeshSharding > mlir::mesh::getOperandShardings (Operation &op)
 
static std::vector< MeshSharding > mlir::mesh::getResultShardings (Operation &op)
 
static LogicalResult mlir::mesh::spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::mesh::spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::mesh::spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::mesh::spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
 

Macro Definition Documentation

◆ GEN_PASS_DEF_SPMDIZATION

#define GEN_PASS_DEF_SPMDIZATION

Definition at line 608 of file Spmdization.cpp.