MLIR 22.0.0git
ShardingInterfaceImpl.cpp File Reference
#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
#include <optional>
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"

Go to the source code of this file.

Namespaces

namespace  mlir
 Include the generated interface declarations.
namespace  mlir::linalg

Macros

#define GET_OP_LIST

Typedefs

using mlir::linalg::GridAxis = shard::GridAxis
using mlir::linalg::ReductionKind = shard::ReductionKind
using mlir::linalg::Sharding = shard::Sharding
using mlir::linalg::ShardingArray = shard::ShardingArray
using mlir::linalg::GridOp = shard::GridOp

Functions

static ReductionKind mlir::linalg::getReductionKind (Operation *op)
static std::optional< Operation * > mlir::linalg::getCombinerOp (LinalgOp op)
static ReductionKind mlir::linalg::getReductionKindOfLinalgOp (LinalgOp op)
static GridOp mlir::linalg::getGrid (Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, SymbolTableCollection &symbolTable)
static Value mlir::linalg::createDestinationPassingStyleInitOperand (LinalgOp op, int operandNumber, Value partitionedOperand, ArrayRef< GridAxis > reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder)
static SmallVector< Valuemlir::linalg::createDestinationPassingStyleInitOperands (LinalgOp op, GridOp gridOp, ArrayRef< Value > partitionedOperands, ArrayRef< GridAxis > reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
static void mlir::linalg::createAllReduceForResultsWithoutPartialShardings (LinalgOp unshardedOp, ArrayRef< GridAxis > opReductionGridAxes, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
static void mlir::linalg::partitionLinalgOpWithShardedReduction (LinalgOp op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators, IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
template<typename OpType>
static void mlir::linalg::registerOne (MLIRContext *ctx)
template<typename... OpTypes>
static void mlir::linalg::registerAll (MLIRContext *ctx)
 Variadic helper function.
void mlir::linalg::registerShardingInterfaceExternalModels (DialectRegistry &registry)

Macro Definition Documentation

◆ GET_OP_LIST

#define GET_OP_LIST