MLIR  20.0.0git
Macros | Functions
MeshOps.cpp File Reference
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "mesh-ops"
 
#define DBGS()   (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
#define GET_OP_LIST
 
#define GET_ATTRDEF_LIST
 
#define GET_TYPEDEF_LIST
 
#define GET_OP_CLASSES
 
#define GET_ATTRDEF_CLASSES
 
#define GET_TYPEDEF_CLASSES
 

Functions

static DimensionSize operator/ (DimensionSize lhs, DimensionSize rhs)
 
static DimensionSize operator* (DimensionSize lhs, DimensionSize rhs)
 
static FailureOr< MeshOp > getMeshAndVerify (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
 
template<typename It >
bool isUnique (It begin, It end)
 
static LogicalResult verifyMeshAxes (Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
 
template<typename Op >
static FailureOr< MeshOp > getMeshAndVerifyAxes (Op op, SymbolTableCollection &symbolTable)
 
template<typename InShape , typename MeshShape , typename SplitAxes , typename OutShape >
static void shardShape (const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsSizes={}, ArrayRef< int64_t > haloSizes={})
 
static LogicalResult verifyInGroupDevice (Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
 
template<typename It >
static auto product (It begin, It end)
 
template<typename R >
static auto product (R &&range)
 
static LogicalResult verifyDimensionCompatibility (Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
 
static LogicalResult verifyGatherOperandAndResultShape (Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
 
static LogicalResult verifyAllToAllOperandAndResultShape (Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
 
static LogicalResult verifyScatterOrSliceOperandAndResultShape (Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
 
static RankedTensorType sliceResultType (Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
 

Macro Definition Documentation

◆ DBGS

#define DBGS ( )    (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

Definition at line 42 of file MeshOps.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "mesh-ops"

Definition at line 41 of file MeshOps.cpp.

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

Definition at line 1329 of file MeshOps.cpp.

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 1326 of file MeshOps.cpp.

◆ GET_OP_LIST

#define GET_OP_LIST

◆ GET_TYPEDEF_CLASSES

#define GET_TYPEDEF_CLASSES

Definition at line 1332 of file MeshOps.cpp.

◆ GET_TYPEDEF_LIST

#define GET_TYPEDEF_LIST

Function Documentation

◆ getMeshAndVerify()

static FailureOr<MeshOp> getMeshAndVerify ( Operation op,
FlatSymbolRefAttr  meshSymbol,
SymbolTableCollection symbolTable 
)
static

Definition at line 127 of file MeshOps.cpp.

◆ getMeshAndVerifyAxes()

template<typename Op >
static FailureOr<MeshOp> getMeshAndVerifyAxes ( Op  op,
SymbolTableCollection symbolTable 
)
static

Definition at line 179 of file MeshOps.cpp.

◆ isUnique()

template<typename It >
bool isUnique ( It  begin,
It  end 
)

◆ operator*()

static DimensionSize operator* ( DimensionSize  lhs,
DimensionSize  rhs 
)
static

Definition at line 71 of file MeshOps.cpp.

◆ operator/()

static DimensionSize operator/ ( DimensionSize  lhs,
DimensionSize  rhs 
)
static

Definition at line 64 of file MeshOps.cpp.

◆ product() [1/2]

template<typename It >
static auto product ( It  begin,
It  end 
)
static

Definition at line 809 of file MeshOps.cpp.

Referenced by product().

◆ product() [2/2]

template<typename R >
static auto product ( R &&  range)
static

Definition at line 816 of file MeshOps.cpp.

References product().

◆ shardShape()

template<typename InShape , typename MeshShape , typename SplitAxes , typename OutShape >
static void shardShape ( const InShape &  inShape,
const MeshShape &  meshShape,
const SplitAxes &  splitAxes,
OutShape &  outShape,
ArrayRef< int64_t >  shardedDimsSizes = {},
ArrayRef< int64_t >  haloSizes = {} 
)
static

Definition at line 193 of file MeshOps.cpp.

Referenced by mlir::mesh::shardShapedType().

◆ sliceResultType()

static RankedTensorType sliceResultType ( Type  operandType,
MeshOp  mesh,
ArrayRef< MeshAxis meshAxes,
int64_t  sliceAxis 
)
static

Definition at line 947 of file MeshOps.cpp.

References mlir::mesh::collectiveProcessGroupSize().

◆ verifyAllToAllOperandAndResultShape()

static LogicalResult verifyAllToAllOperandAndResultShape ( Value  operand,
Value  result,
int64_t  splitAxis,
int64_t  concatAxis,
ArrayRef< MeshAxis meshAxes,
ArrayRef< int64_t >  meshShape 
)
static

◆ verifyDimensionCompatibility()

static LogicalResult verifyDimensionCompatibility ( Location  loc,
int64_t  expectedDimSize,
int64_t  resultDimSize,
int64_t  resultAxis 
)
static

◆ verifyGatherOperandAndResultShape()

static LogicalResult verifyGatherOperandAndResultShape ( Value  operand,
Value  result,
int64_t  gatherAxis,
ArrayRef< MeshAxis meshAxes,
ArrayRef< int64_t >  meshShape 
)
static

◆ verifyInGroupDevice()

static LogicalResult verifyInGroupDevice ( Location  loc,
StringRef  deviceName,
ArrayRef< int64_t >  device,
Operation::operand_range  deviceDynamic,
ArrayRef< MeshAxis meshAxes,
ArrayRef< int64_t >  meshShape 
)
static

Definition at line 782 of file MeshOps.cpp.

References mlir::emitError().

◆ verifyMeshAxes()

static LogicalResult verifyMeshAxes ( Location  loc,
ArrayRef< MeshAxis axes,
MeshOp  mesh 
)
static

Definition at line 156 of file MeshOps.cpp.

References mlir::emitError(), and isUnique().

◆ verifyScatterOrSliceOperandAndResultShape()

static LogicalResult verifyScatterOrSliceOperandAndResultShape ( Value  operand,
Value  result,
int64_t  tensorAxis,
ArrayRef< MeshAxis meshAxes,
ArrayRef< int64_t >  meshShape 
)
static