MLIR  22.0.0git
Macros | Functions
ShardingInterface.cpp File Reference
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include <utility>
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "sharding-interface"
 
#define DBGS()   (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 

Functions

static LogicalResult checkOperandAffineExprRecursively (AffineExpr expr, SmallVectorImpl< bool > &seenIds)
 
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr (AffineExpr expr, unsigned numDims)
 
template<typename T >
SmallVector< GridAxesAttrfromArrayOfVector (MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
 
static Sharding getSharding (OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes)
 
static FailureOr< ShardinggetSharding (OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
 
static LogicalResult addShardOp (OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes)
 
static LogicalResult addShardOp (OpBuilder &b, OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
 
static bool isValueCompatibleWithFullReplicationSharding (Value value, const Sharding &sharding)
 
template<typename ValueRange , typename ShardingRage >
static bool areValuesCompatibleWithFullReplicationShardings (ValueRange &&values, ShardingRage &&shardings)
 
static void updateGridAxisAssignmentForLoopIterators (ArrayRef< GridAxis > gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< GridAxis >>> &gridAxesAssignmentForLoopIterators)
 

Macro Definition Documentation

◆ DBGS

#define DBGS ( )    (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

Definition at line 24 of file ShardingInterface.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "sharding-interface"

Definition at line 23 of file ShardingInterface.cpp.

Function Documentation

◆ addShardOp() [1/2]

static LogicalResult addShardOp ( OpBuilder b,
OpOperand opOperand,
const ShardingOption shardingOption,
AffineMap  map 
)
static

◆ addShardOp() [2/2]

static LogicalResult addShardOp ( OpBuilder b,
OpResult  result,
const ShardingOption shardingOption,
AffineMap  map,
ArrayRef< utils::IteratorType >  loopTypes 
)
static

◆ areValuesCompatibleWithFullReplicationShardings()

template<typename ValueRange , typename ShardingRage >
static bool areValuesCompatibleWithFullReplicationShardings ( ValueRange &&  values,
ShardingRage &&  shardings 
)
static

◆ checkOperandAffineExpr()

static FailureOr<llvm::SmallSet<unsigned, 2> > checkOperandAffineExpr ( AffineExpr  expr,
unsigned  numDims 
)
static

◆ checkOperandAffineExprRecursively()

static LogicalResult checkOperandAffineExprRecursively ( AffineExpr  expr,
SmallVectorImpl< bool > &  seenIds 
)
static

◆ fromArrayOfVector()

template<typename T >
SmallVector<GridAxesAttr> fromArrayOfVector ( MLIRContext ctxt,
const SmallVector< SmallVector< T >> &  vec 
)

Definition at line 97 of file ShardingInterface.cpp.

References mlir::detail::DenseArrayAttrImpl< T >::get().

Referenced by getSharding().

◆ getSharding() [1/2]

static FailureOr<Sharding> getSharding ( OpOperand opOperand,
const ShardingOption shardingOption,
AffineMap  map 
)
static

◆ getSharding() [2/2]

static Sharding getSharding ( OpResult  result,
const ShardingOption shardingOption,
AffineMap  map,
ArrayRef< utils::IteratorType >  loopTypes 
)
static

◆ isValueCompatibleWithFullReplicationSharding()

static bool isValueCompatibleWithFullReplicationSharding ( Value  value,
const Sharding sharding 
)
static

◆ updateGridAxisAssignmentForLoopIterators()

static void updateGridAxisAssignmentForLoopIterators ( ArrayRef< GridAxis gridAxesAssignmentForTensorAxis,
AffineExpr  indexingExpr,
SmallVector< std::optional< SmallVector< GridAxis >>> &  gridAxesAssignmentForLoopIterators 
)
static