MLIR  20.0.0git
Macros | Functions
ShardingInterface.cpp File Reference
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include <utility>
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "sharding-interface"
 
#define DBGS()   (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 

Functions

static LogicalResult checkOperandAffineExprRecursively (AffineExpr expr, SmallVectorImpl< bool > &seenIds)
 
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr (AffineExpr expr, unsigned numDims)
 
template<typename T >
SmallVector< MeshAxesAttrfromArrayOfVector (MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
 
MeshSharding getSharding (OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
 
static FailureOr< MeshShardinggetSharding (OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
 
static LogicalResult addShardOp (OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
 
static LogicalResult addShardOp (OpBuilder &b, OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
 
static bool isValueCompatibleWithFullReplicationSharding (Value value, MeshSharding sharding)
 
template<typename ValueRange , typename MeshShardingRage >
static bool areValuesCompatibleWithFullReplicationShardings (ValueRange &&values, MeshShardingRage &&shardings)
 
static void updateMeshAxisAssignmentForLoopIterators (ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
 

Macro Definition Documentation

◆ DBGS

#define DBGS ( )    (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

Definition at line 24 of file ShardingInterface.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "sharding-interface"

Definition at line 23 of file ShardingInterface.cpp.

Function Documentation

◆ addShardOp() [1/2]

static LogicalResult addShardOp ( OpBuilder b,
OpOperand opOperand,
const ShardingOption shardingOption,
AffineMap  map 
)
static

◆ addShardOp() [2/2]

static LogicalResult addShardOp ( OpBuilder b,
OpResult  result,
const ShardingOption shardingOption,
AffineMap  map,
ArrayRef< utils::IteratorType >  loopTypes,
ArrayRef< ReductionKind >  reductionLoopKinds 
)
static

◆ areValuesCompatibleWithFullReplicationShardings()

template<typename ValueRange , typename MeshShardingRage >
static bool areValuesCompatibleWithFullReplicationShardings ( ValueRange &&  values,
MeshShardingRage &&  shardings 
)
static

◆ checkOperandAffineExpr()

static FailureOr<llvm::SmallSet<unsigned, 2> > checkOperandAffineExpr ( AffineExpr  expr,
unsigned  numDims 
)
static

◆ checkOperandAffineExprRecursively()

static LogicalResult checkOperandAffineExprRecursively ( AffineExpr  expr,
SmallVectorImpl< bool > &  seenIds 
)
static

◆ fromArrayOfVector()

template<typename T >
SmallVector<MeshAxesAttr> fromArrayOfVector ( MLIRContext ctxt,
const SmallVector< SmallVector< T >> &  vec 
)

Definition at line 96 of file ShardingInterface.cpp.

References mlir::detail::DenseArrayAttrImpl< T >::get().

Referenced by getSharding().

◆ getSharding() [1/2]

static FailureOr<MeshSharding> getSharding ( OpOperand opOperand,
const ShardingOption shardingOption,
AffineMap  map 
)
static

◆ getSharding() [2/2]

MeshSharding getSharding ( OpResult  result,
const ShardingOption shardingOption,
AffineMap  map,
ArrayRef< utils::IteratorType >  loopTypes,
ArrayRef< ReductionKind >  reductionLoopKinds 
)

◆ isValueCompatibleWithFullReplicationSharding()

static bool isValueCompatibleWithFullReplicationSharding ( Value  value,
MeshSharding  sharding 
)
static

◆ updateMeshAxisAssignmentForLoopIterators()

static void updateMeshAxisAssignmentForLoopIterators ( ArrayRef< MeshAxis meshAxesAssignmentForTensorAxis,
AffineExpr  indexingExpr,
SmallVector< std::optional< SmallVector< MeshAxis >>> &  meshAxesAssignmentForLoopIterators 
)
static