MLIR  19.0.0git
Namespaces | Macros | Typedefs | Functions
Spmdization.cpp File Reference
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <iterator>
#include <optional>
#include <tuple>
#include <type_traits>
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"

Go to the source code of this file.

Namespaces

 mlir
 Include the generated interface declarations.
 
 mlir::mesh
 

Macros

#define GEN_PASS_DEF_SPMDIZATION
 

Typedefs

using mlir::mesh::UnshardedToShardedValueMap = DenseMap< Value, Value >
 

Functions

template<typename SourceAxes , typename TargetAxes >
static bool mlir::mesh::arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > mlir::mesh::handlePartialAxesDuringResharding (OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
 
static MeshShardingAttr mlir::mesh::targetShardingInSplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > mlir::mesh::splitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< int64_t, MeshAxis > > mlir::mesh::detectSplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > mlir::mesh::trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, MeshAxis > > mlir::mesh::detectUnsplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
 
static MeshShardingAttr mlir::mesh::targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis)
 
static ShapedType mlir::mesh::allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > mlir::mesh::unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > mlir::mesh::tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > mlir::mesh::detectMoveLastSplitAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
 
static MeshShardingAttr mlir::mesh::targetShardingInMoveLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static ShapedType mlir::mesh::allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > mlir::mesh::moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > mlir::mesh::tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > mlir::mesh::reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > mlir::mesh::reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > mlir::mesh::reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
 
TypedValue< ShapedType > mlir::mesh::reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
 
void mlir::mesh::reshardingRegisterDependentDialects (DialectRegistry &registry)
 
SmallVector< Type > mlir::mesh::shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
 
static LogicalResult mlir::mesh::spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static SmallVector< MeshShardingAttr > mlir::mesh::getOperandShardings (Operation &op)
 
static SmallVector< MeshShardingAttr > mlir::mesh::getResultShardings (Operation &op)
 
static LogicalResult mlir::mesh::spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::mesh::spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::mesh::spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult mlir::mesh::spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
 

Macro Definition Documentation

◆ GEN_PASS_DEF_SPMDIZATION

#define GEN_PASS_DEF_SPMDIZATION

Definition at line 518 of file Spmdization.cpp.