MLIR 23.0.0git
Partition.cpp File Reference
#include "mlir/Dialect/Shard/Transforms/Partition.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <array>
#include <iterator>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"

Go to the source code of this file.

Classes

class  mlir::shard::ReshardingPattern
 Base class for resharding patterns. More...
class  mlir::shard::SplitLastAxisPattern
 Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]]. More...
class  mlir::shard::UnsplitLastAxesPattern
 Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> []. More...
class  mlir::shard::MoveSplitAxisPattern
 Move a split axis between tensor dimensions: e.g. More...
class  mlir::shard::UpdateHaloPattern
 Update halo sizes: handles cases where only the halo sizes differ between source and target sharding. More...

Namespaces

namespace  mlir
 Include the generated interface declarations.
namespace  mlir::shard

Macros

#define GEN_PASS_DEF_PARTITION

Typedefs

using mlir::shard::UnshardedToShardedValueMap = DenseMap<Value, Value>

Functions

template<typename SourceAxes, typename TargetAxes>
static bool mlir::shard::arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
static TypedValue< ShapedType > mlir::shard::reshard (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding, const Sharding &tgtSharding, TypedValue< ShapedType > unshardedSrc, TypedValue< ShapedType > shardedSrc)
TypedValue< ShapedType > mlir::shard::reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
TypedValue< ShapedType > mlir::shard::reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
void mlir::shard::reshardingRegisterDependentDialects (DialectRegistry &registry)
static SmallVector< Typemlir::shard::shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult mlir::shard::partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Shardingmlir::shard::getOperandShardings (Operation &op)
static std::vector< Shardingmlir::shard::getResultShardings (Operation &op)
static LogicalResult mlir::shard::partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult mlir::shard::checkFullyAnnotated (Block &block)
static LogicalResult mlir::shard::checkFullyAnnotated (Operation *op)
static LogicalResult mlir::shard::partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult mlir::shard::partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult mlir::shard::partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)

Macro Definition Documentation

◆ GEN_PASS_DEF_PARTITION

#define GEN_PASS_DEF_PARTITION

Definition at line 519 of file Partition.cpp.