MLIR 22.0.0git
ShardingInterface.cpp File Reference
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include <utility>
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "sharding-interface"
#define DBGS()

Functions

static LogicalResult checkOperandAffineExprRecursively (AffineExpr expr, SmallVectorImpl< bool > &seenIds)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr (AffineExpr expr, unsigned numDims)
template<typename T>
SmallVector< GridAxesAttrfromArrayOfVector (MLIRContext *ctxt, const SmallVector< SmallVector< T > > &vec)
static Sharding getSharding (OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes)
static FailureOr< ShardinggetSharding (OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
static LogicalResult addShardOp (OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes)
static LogicalResult addShardOp (OpBuilder &b, OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
static bool isValueCompatibleWithFullReplicationSharding (Value value, const Sharding &sharding)
template<typename ValueRange, typename ShardingRage>
static bool areValuesCompatibleWithFullReplicationShardings (ValueRange &&values, ShardingRage &&shardings)
static void updateGridAxisAssignmentForLoopIterators (ArrayRef< GridAxis > gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< GridAxis > > > &gridAxesAssignmentForLoopIterators)

Macro Definition Documentation

◆ DBGS

#define DBGS ( )
Value:
(llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#define DEBUG_TYPE

Definition at line 24 of file ShardingInterface.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "sharding-interface"

Definition at line 23 of file ShardingInterface.cpp.

Function Documentation

◆ addShardOp() [1/2]

LogicalResult addShardOp ( OpBuilder & b,
OpOperand & opOperand,
const ShardingOption & shardingOption,
AffineMap map )
static

◆ addShardOp() [2/2]

LogicalResult addShardOp ( OpBuilder & b,
OpResult result,
const ShardingOption & shardingOption,
AffineMap map,
ArrayRef< utils::IteratorType > loopTypes )
static

◆ areValuesCompatibleWithFullReplicationShardings()

template<typename ValueRange, typename ShardingRage>
bool areValuesCompatibleWithFullReplicationShardings ( ValueRange && values,
ShardingRage && shardings )
static

◆ checkOperandAffineExpr()

FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr ( AffineExpr expr,
unsigned numDims )
static

◆ checkOperandAffineExprRecursively()

LogicalResult checkOperandAffineExprRecursively ( AffineExpr expr,
SmallVectorImpl< bool > & seenIds )
static

◆ fromArrayOfVector()

template<typename T>
SmallVector< GridAxesAttr > fromArrayOfVector ( MLIRContext * ctxt,
const SmallVector< SmallVector< T > > & vec )

Definition at line 97 of file ShardingInterface.cpp.

References mlir::detail::DenseArrayAttrImpl< int16_t >::get().

Referenced by getSharding(), and getSharding().

◆ getSharding() [1/2]

◆ getSharding() [2/2]

◆ isValueCompatibleWithFullReplicationSharding()

bool isValueCompatibleWithFullReplicationSharding ( Value value,
const Sharding & sharding )
static

◆ updateGridAxisAssignmentForLoopIterators()

void updateGridAxisAssignmentForLoopIterators ( ArrayRef< GridAxis > gridAxesAssignmentForTensorAxis,
AffineExpr indexingExpr,
SmallVector< std::optional< SmallVector< GridAxis > > > & gridAxesAssignmentForLoopIterators )
static