MLIR 22.0.0git
ShardOps.cpp File Reference
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "shard-ops"
#define GET_OP_LIST
#define GET_ATTRDEF_LIST
#define GET_TYPEDEF_LIST
#define GET_OP_CLASSES
#define GET_ATTRDEF_CLASSES
#define GET_TYPEDEF_CLASSES

Functions

static DimensionSize operator/ (DimensionSize lhs, DimensionSize rhs)
static DimensionSize operator* (DimensionSize lhs, DimensionSize rhs)
static FailureOr< GridOp > getGridAndVerify (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
template<typename It>
static bool isUnique (It begin, It end)
static LogicalResult verifyGridAxes (Location loc, ArrayRef< GridAxis > axes, GridOp grid)
template<typename Op>
static FailureOr< GridOp > getGridAndVerifyAxes (Op op, SymbolTableCollection &symbolTable)
template<typename InShape, typename GridShape, typename SplitAxes, typename OutShape>
static void shardShape (const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static void maybeInsertTargetShardingAnnotationImpl (Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
static LogicalResult verifyInGroupDevice (Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static LogicalResult verifyDimensionCompatibility (Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static LogicalResult verifyGatherOperandAndResultShape (Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static LogicalResult verifyAllToAllOperandAndResultShape (Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static LogicalResult verifyScatterOrSliceOperandAndResultShape (Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType (Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)

Macro Definition Documentation

◆ DEBUG_TYPE

#define DEBUG_TYPE   "shard-ops"

Definition at line 40 of file ShardOps.cpp.

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

Definition at line 1521 of file ShardOps.cpp.

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 1518 of file ShardOps.cpp.

◆ GET_OP_LIST

#define GET_OP_LIST

◆ GET_TYPEDEF_CLASSES

#define GET_TYPEDEF_CLASSES

Definition at line 1524 of file ShardOps.cpp.

◆ GET_TYPEDEF_LIST

#define GET_TYPEDEF_LIST

Function Documentation

◆ getGridAndVerify()

FailureOr< GridOp > getGridAndVerify ( Operation * op,
FlatSymbolRefAttr gridSymbol,
SymbolTableCollection & symbolTable )
static

◆ getGridAndVerifyAxes()

template<typename Op>
FailureOr< GridOp > getGridAndVerifyAxes ( Op op,
SymbolTableCollection & symbolTable )
static

◆ isUnique()

template<typename It>
bool isUnique ( It begin,
It end )
static

◆ maybeInsertTargetShardingAnnotationImpl()

void maybeInsertTargetShardingAnnotationImpl ( Sharding sharding,
Value & operandValue,
Operation * operandOp,
OpBuilder & builder,
ShardOp & newShardOp )
static

◆ operator*()

DimensionSize operator* ( DimensionSize lhs,
DimensionSize rhs )
static

Definition at line 69 of file ShardOps.cpp.

References lhs, and rhs.

◆ operator/()

DimensionSize operator/ ( DimensionSize lhs,
DimensionSize rhs )
static

Definition at line 62 of file ShardOps.cpp.

References lhs, and rhs.

◆ shardShape()

template<typename InShape, typename GridShape, typename SplitAxes, typename OutShape>
void shardShape ( const InShape & inShape,
const GridShape & gridShape,
const SplitAxes & splitAxes,
OutShape & outShape,
ArrayRef< int64_t > shardedDimsOffsets = {},
ArrayRef< int64_t > haloSizes = {} )
static

Definition at line 214 of file ShardOps.cpp.

Referenced by mlir::shard::shardShapedType().

◆ sliceResultType()

RankedTensorType sliceResultType ( Type operandType,
GridOp grid,
ArrayRef< GridAxis > gridAxes,
int64_t sliceAxis )
static

Definition at line 1139 of file ShardOps.cpp.

References mlir::shard::collectiveProcessGroupSize().

◆ verifyAllToAllOperandAndResultShape()

LogicalResult verifyAllToAllOperandAndResultShape ( Value operand,
Value result,
int64_t splitAxis,
int64_t concatAxis,
ArrayRef< GridAxis > gridAxes,
ArrayRef< int64_t > gridShape )
static

◆ verifyDimensionCompatibility()

LogicalResult verifyDimensionCompatibility ( Location loc,
int64_t expectedDimSize,
int64_t resultDimSize,
int64_t resultAxis )
static

◆ verifyGatherOperandAndResultShape()

LogicalResult verifyGatherOperandAndResultShape ( Value operand,
Value result,
int64_t gatherAxis,
ArrayRef< GridAxis > gridAxes,
ArrayRef< int64_t > gridShape )
static

◆ verifyGridAxes()

LogicalResult verifyGridAxes ( Location loc,
ArrayRef< GridAxis > axes,
GridOp grid )
static

Definition at line 177 of file ShardOps.cpp.

References mlir::emitError(), isUnique(), and success().

Referenced by getGridAndVerifyAxes().

◆ verifyInGroupDevice()

LogicalResult verifyInGroupDevice ( Location loc,
StringRef deviceName,
ArrayRef< int64_t > device,
Operation::operand_range deviceDynamic,
ArrayRef< GridAxis > gridAxes,
ArrayRef< int64_t > gridShape )
static

Definition at line 987 of file ShardOps.cpp.

References mlir::emitError(), and success().

◆ verifyScatterOrSliceOperandAndResultShape()

LogicalResult verifyScatterOrSliceOperandAndResultShape ( Value operand,
Value result,
int64_t tensorAxis,
ArrayRef< GridAxis > gridAxes,
ArrayRef< int64_t > gridShape )
static