MLIR
19.0.0git
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
Go to the source code of this file.
Macros | |
#define | DEBUG_TYPE "mesh-ops" |
#define | DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
#define | GET_OP_LIST |
#define | GET_ATTRDEF_LIST |
#define | GET_OP_CLASSES |
#define | GET_ATTRDEF_CLASSES |
Functions | |
static DimensionSize | operator/ (DimensionSize lhs, DimensionSize rhs) |
static DimensionSize | operator* (DimensionSize lhs, DimensionSize rhs) |
static FailureOr< MeshOp > | getMeshAndVerify (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable) |
template<typename It > | |
bool | isUnique (It begin, It end) |
static LogicalResult | verifyMeshAxes (Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh) |
template<typename InShape , typename MeshShape , typename SplitAxes , typename OutShape > | |
static void | shardShape (const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape) |
static LogicalResult | verifyInGroupDevice (Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
template<typename Op > | |
static FailureOr< MeshOp > | getMeshAndVerifyAxes (Op op, SymbolTableCollection &symbolTable) |
template<typename It > | |
static auto | product (It begin, It end) |
template<typename R > | |
static auto | product (R &&range) |
static LogicalResult | verifyDimensionCompatibility (Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis) |
static LogicalResult | verifyGatherOperandAndResultShape (Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
static LogicalResult | verifyAllToAllOperandAndResultShape (Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
static LogicalResult | verifyScatterOrSliceOperandAndResultShape (Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape) |
static RankedTensorType | sliceResultType (Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis) |
#define DBGS | ( | ) | (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
Definition at line 39 of file MeshOps.cpp.
#define DEBUG_TYPE "mesh-ops" |
Definition at line 38 of file MeshOps.cpp.
#define GET_ATTRDEF_CLASSES |
Definition at line 963 of file MeshOps.cpp.
#define GET_ATTRDEF_LIST |
#define GET_OP_CLASSES |
Definition at line 960 of file MeshOps.cpp.
#define GET_OP_LIST |
|
static |
Definition at line 99 of file MeshOps.cpp.
|
static |
Definition at line 444 of file MeshOps.cpp.
bool isUnique | ( | It | begin, |
It | end | ||
) |
Definition at line 112 of file MeshOps.cpp.
Referenced by mlir::sparse_tensor::SparseTensorType::isCOOType(), and verifyMeshAxes().
|
static |
Definition at line 68 of file MeshOps.cpp.
|
static |
Definition at line 61 of file MeshOps.cpp.
|
static |
Definition at line 457 of file MeshOps.cpp.
Referenced by product().
|
static |
Definition at line 464 of file MeshOps.cpp.
References product().
|
static |
Definition at line 151 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize(), copy(), mlir::detail::enumerate(), and mlir::mesh::shardDimension().
Referenced by mlir::mesh::shardShapedType().
|
static |
Definition at line 595 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize().
|
static |
Definition at line 512 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize(), mlir::failed(), mlir::failure(), mlir::Value::getLoc(), mlir::Value::getType(), mlir::success(), and verifyDimensionCompatibility().
|
static |
Definition at line 468 of file MeshOps.cpp.
References mlir::emitError(), and mlir::success().
Referenced by verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
|
static |
Definition at line 485 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize(), mlir::emitError(), mlir::failed(), mlir::failure(), mlir::Value::getLoc(), mlir::Value::getType(), mlir::success(), and verifyDimensionCompatibility().
|
static |
Definition at line 416 of file MeshOps.cpp.
References mlir::emitError(), and mlir::success().
|
static |
Definition at line 128 of file MeshOps.cpp.
References mlir::emitError(), isUnique(), and mlir::success().
|
static |
Definition at line 557 of file MeshOps.cpp.
References mlir::mesh::collectiveProcessGroupSize(), mlir::emitError(), mlir::failed(), mlir::failure(), mlir::Value::getLoc(), mlir::Value::getType(), mlir::success(), and verifyDimensionCompatibility().