MLIR  22.0.0git
Macros | Functions
ShardOps.cpp File Reference
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "shard-ops"
 
#define GET_OP_LIST
 
#define GET_ATTRDEF_LIST
 
#define GET_TYPEDEF_LIST
 
#define GET_OP_CLASSES
 
#define GET_ATTRDEF_CLASSES
 
#define GET_TYPEDEF_CLASSES
 

Functions

static DimensionSize operator/ (DimensionSize lhs, DimensionSize rhs)
 
static DimensionSize operator* (DimensionSize lhs, DimensionSize rhs)
 
static FailureOr< GridOp > getGridAndVerify (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
 
template<typename It >
bool isUnique (It begin, It end)
 
static LogicalResult verifyGridAxes (Location loc, ArrayRef< GridAxis > axes, GridOp grid)
 
template<typename Op >
static FailureOr< GridOp > getGridAndVerifyAxes (Op op, SymbolTableCollection &symbolTable)
 
template<typename InShape , typename GridShape , typename SplitAxes , typename OutShape >
static void shardShape (const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
 
static void maybeInsertTargetShardingAnnotationImpl (Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
 
static LogicalResult verifyInGroupDevice (Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
 
template<typename It >
static auto product (It begin, It end)
 
template<typename R >
static auto product (R &&range)
 
static LogicalResult verifyDimensionCompatibility (Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
 
static LogicalResult verifyGatherOperandAndResultShape (Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
 
static LogicalResult verifyAllToAllOperandAndResultShape (Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
 
static LogicalResult verifyScatterOrSliceOperandAndResultShape (Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
 
static RankedTensorType sliceResultType (Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
 

Macro Definition Documentation

◆ DEBUG_TYPE

#define DEBUG_TYPE   "shard-ops"

Definition at line 40 of file ShardOps.cpp.

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

Definition at line 1533 of file ShardOps.cpp.

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 1530 of file ShardOps.cpp.

◆ GET_OP_LIST

#define GET_OP_LIST

◆ GET_TYPEDEF_CLASSES

#define GET_TYPEDEF_CLASSES

Definition at line 1536 of file ShardOps.cpp.

◆ GET_TYPEDEF_LIST

#define GET_TYPEDEF_LIST

Function Documentation

◆ getGridAndVerify()

static FailureOr<GridOp> getGridAndVerify ( Operation op,
FlatSymbolRefAttr  gridSymbol,
SymbolTableCollection symbolTable 
)
static

◆ getGridAndVerifyAxes()

template<typename Op >
static FailureOr<GridOp> getGridAndVerifyAxes ( Op  op,
SymbolTableCollection symbolTable 
)
static

◆ isUnique()

template<typename It >
bool isUnique ( It  begin,
It  end 
)

◆ maybeInsertTargetShardingAnnotationImpl()

static void maybeInsertTargetShardingAnnotationImpl ( Sharding  sharding,
Value operandValue,
Operation operandOp,
OpBuilder builder,
ShardOp &  newShardOp 
)
static

◆ operator*()

static DimensionSize operator* ( DimensionSize  lhs,
DimensionSize  rhs 
)
static

Definition at line 69 of file ShardOps.cpp.

◆ operator/()

static DimensionSize operator/ ( DimensionSize  lhs,
DimensionSize  rhs 
)
static

Definition at line 62 of file ShardOps.cpp.

◆ product() [1/2]

template<typename It >
static auto product ( It  begin,
It  end 
)
static

Definition at line 1014 of file ShardOps.cpp.

Referenced by product().

◆ product() [2/2]

template<typename R >
static auto product ( R &&  range)
static

Definition at line 1021 of file ShardOps.cpp.

References product().

◆ shardShape()

template<typename InShape , typename GridShape , typename SplitAxes , typename OutShape >
static void shardShape ( const InShape &  inShape,
const GridShape &  gridShape,
const SplitAxes &  splitAxes,
OutShape &  outShape,
ArrayRef< int64_t >  shardedDimsOffsets = {},
ArrayRef< int64_t >  haloSizes = {} 
)
static

Definition at line 214 of file ShardOps.cpp.

Referenced by mlir::shard::shardShapedType().

◆ sliceResultType()

static RankedTensorType sliceResultType ( Type  operandType,
GridOp  grid,
ArrayRef< GridAxis gridAxes,
int64_t  sliceAxis 
)
static

Definition at line 1151 of file ShardOps.cpp.

References mlir::shard::collectiveProcessGroupSize().

◆ verifyAllToAllOperandAndResultShape()

static LogicalResult verifyAllToAllOperandAndResultShape ( Value  operand,
Value  result,
int64_t  splitAxis,
int64_t  concatAxis,
ArrayRef< GridAxis gridAxes,
ArrayRef< int64_t >  gridShape 
)
static

◆ verifyDimensionCompatibility()

static LogicalResult verifyDimensionCompatibility ( Location  loc,
int64_t  expectedDimSize,
int64_t  resultDimSize,
int64_t  resultAxis 
)
static

◆ verifyGatherOperandAndResultShape()

static LogicalResult verifyGatherOperandAndResultShape ( Value  operand,
Value  result,
int64_t  gatherAxis,
ArrayRef< GridAxis gridAxes,
ArrayRef< int64_t >  gridShape 
)
static

◆ verifyGridAxes()

static LogicalResult verifyGridAxes ( Location  loc,
ArrayRef< GridAxis axes,
GridOp  grid 
)
static

Definition at line 177 of file ShardOps.cpp.

References mlir::emitError(), and isUnique().

Referenced by getGridAndVerifyAxes().

◆ verifyInGroupDevice()

static LogicalResult verifyInGroupDevice ( Location  loc,
StringRef  deviceName,
ArrayRef< int64_t >  device,
Operation::operand_range  deviceDynamic,
ArrayRef< GridAxis gridAxes,
ArrayRef< int64_t >  gridShape 
)
static

Definition at line 987 of file ShardOps.cpp.

References mlir::emitError().

◆ verifyScatterOrSliceOperandAndResultShape()

static LogicalResult verifyScatterOrSliceOperandAndResultShape ( Value  operand,
Value  result,
int64_t  tensorAxis,
ArrayRef< GridAxis gridAxes,
ArrayRef< int64_t >  gridShape 
)
static