MLIR
20.0.0git
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
Go to the source code of this file.
Macros | |
#define | DEBUG_TYPE "mesh-ops" |
#define | DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
#define | GET_OP_LIST |
#define | GET_ATTRDEF_LIST |
#define | GET_TYPEDEF_LIST |
#define | GET_OP_CLASSES |
#define | GET_ATTRDEF_CLASSES |
#define | GET_TYPEDEF_CLASSES |
Functions | |
static DimensionSize | operator/ (DimensionSize lhs, DimensionSize rhs) |
static DimensionSize | operator* (DimensionSize lhs, DimensionSize rhs) |
static FailureOr< MeshOp > | getMeshAndVerify (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable) |
template<typename It > | |
bool | isUnique (It begin, It end) |
static LogicalResult | verifyMeshAxes (Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh) |
template<typename Op > | |
static FailureOr< MeshOp > | getMeshAndVerifyAxes (Op op, SymbolTableCollection &symbolTable) |
template<typename InShape , typename MeshShape , typename SplitAxes , typename OutShape > | |
static void | shardShape (const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={}) |
static LogicalResult | verifyInGroupDevice (Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
template<typename It > | |
static auto | product (It begin, It end) |
template<typename R > | |
static auto | product (R &&range) |
static LogicalResult | verifyDimensionCompatibility (Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis) |
static LogicalResult | verifyGatherOperandAndResultShape (Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
static LogicalResult | verifyAllToAllOperandAndResultShape (Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
static LogicalResult | verifyScatterOrSliceOperandAndResultShape (Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
static RankedTensorType | sliceResultType (Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis) |
#define DBGS | ( | ) | (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
Definition at line 42 of file MeshOps.cpp.
#define DEBUG_TYPE "mesh-ops" |
Definition at line 41 of file MeshOps.cpp.
#define GET_ATTRDEF_CLASSES |
Definition at line 1414 of file MeshOps.cpp.
#define GET_ATTRDEF_LIST |
#define GET_OP_CLASSES |
Definition at line 1411 of file MeshOps.cpp.
#define GET_OP_LIST |
#define GET_TYPEDEF_CLASSES |
Definition at line 1417 of file MeshOps.cpp.
#define GET_TYPEDEF_LIST |
|
static |
Definition at line 127 of file MeshOps.cpp.
References mlir::Operation::emitError(), mlir::mesh::getMeshOrNull(), and mlir::FlatSymbolRefAttr::getValue().
Referenced by getMeshAndVerifyAxes().
|
static |
Definition at line 179 of file MeshOps.cpp.
References mlir::OpState::getLoc(), getMeshAndVerify(), mlir::Op< ConcreteType, Traits >::getOperation(), and verifyMeshAxes().
bool isUnique | ( | It | begin, |
It | end | ||
) |
Definition at line 140 of file MeshOps.cpp.
Referenced by mlir::sparse_tensor::SparseTensorType::isCOOType(), and verifyMeshAxes().
|
static |
Definition at line 71 of file MeshOps.cpp.
|
static |
Definition at line 64 of file MeshOps.cpp.
|
static |
Definition at line 894 of file MeshOps.cpp.
Referenced by product().
|
static |
Definition at line 901 of file MeshOps.cpp.
References product().
|
static |
Definition at line 193 of file MeshOps.cpp.
Referenced by mlir::mesh::shardShapedType().
|
static |
Definition at line 1032 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize().
|
static |
Definition at line 949 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize(), mlir::Value::getLoc(), mlir::Value::getType(), and verifyDimensionCompatibility().
|
static |
Definition at line 905 of file MeshOps.cpp.
References mlir::emitError().
Referenced by verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
|
static |
Definition at line 922 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize(), mlir::emitError(), mlir::Value::getLoc(), mlir::Value::getType(), and verifyDimensionCompatibility().
|
static |
Definition at line 867 of file MeshOps.cpp.
References mlir::emitError().
Definition at line 156 of file MeshOps.cpp.
References mlir::emitError(), and isUnique().
Referenced by getMeshAndVerifyAxes().
|
static |
Definition at line 994 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize(), mlir::emitError(), mlir::Value::getLoc(), mlir::Value::getType(), and verifyDimensionCompatibility().