MLIR
22.0.0git
|
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
Go to the source code of this file.
Macros | |
#define | DEBUG_TYPE "shard-ops" |
#define | GET_OP_LIST |
#define | GET_ATTRDEF_LIST |
#define | GET_TYPEDEF_LIST |
#define | GET_OP_CLASSES |
#define | GET_ATTRDEF_CLASSES |
#define | GET_TYPEDEF_CLASSES |
Functions | |
static DimensionSize | operator/ (DimensionSize lhs, DimensionSize rhs) |
static DimensionSize | operator* (DimensionSize lhs, DimensionSize rhs) |
static FailureOr< GridOp > | getGridAndVerify (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable) |
template<typename It > | |
bool | isUnique (It begin, It end) |
static LogicalResult | verifyGridAxes (Location loc, ArrayRef< GridAxis > axes, GridOp grid) |
template<typename Op > | |
static FailureOr< GridOp > | getGridAndVerifyAxes (Op op, SymbolTableCollection &symbolTable) |
template<typename InShape , typename GridShape , typename SplitAxes , typename OutShape > | |
static void | shardShape (const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={}) |
static void | maybeInsertTargetShardingAnnotationImpl (Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp) |
static LogicalResult | verifyInGroupDevice (Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape) |
template<typename It > | |
static auto | product (It begin, It end) |
template<typename R > | |
static auto | product (R &&range) |
static LogicalResult | verifyDimensionCompatibility (Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis) |
static LogicalResult | verifyGatherOperandAndResultShape (Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape) |
static LogicalResult | verifyAllToAllOperandAndResultShape (Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape) |
static LogicalResult | verifyScatterOrSliceOperandAndResultShape (Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape) |
static RankedTensorType | sliceResultType (Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis) |
#define DEBUG_TYPE "shard-ops" |
Definition at line 40 of file ShardOps.cpp.
#define GET_ATTRDEF_CLASSES |
Definition at line 1533 of file ShardOps.cpp.
#define GET_ATTRDEF_LIST |
#define GET_OP_CLASSES |
Definition at line 1530 of file ShardOps.cpp.
#define GET_OP_LIST |
#define GET_TYPEDEF_CLASSES |
Definition at line 1536 of file ShardOps.cpp.
#define GET_TYPEDEF_LIST |
|
static |
Definition at line 148 of file ShardOps.cpp.
References mlir::Operation::emitError(), mlir::shard::getGridOrNull(), and mlir::FlatSymbolRefAttr::getValue().
Referenced by getGridAndVerifyAxes().
|
static |
Definition at line 200 of file ShardOps.cpp.
References mlir::remark::failed(), getGridAndVerify(), mlir::OpState::getLoc(), mlir::Op< ConcreteType, Traits >::getOperation(), and verifyGridAxes().
bool isUnique | ( | It | begin, |
It | end | ||
) |
Definition at line 161 of file ShardOps.cpp.
Referenced by mlir::sparse_tensor::SparseTensorType::isCOOType(), and verifyGridAxes().
|
static |
Definition at line 299 of file ShardOps.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), mlir::Value::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPointAfterValue().
Referenced by mlir::shard::maybeInsertTargetShardingAnnotation().
|
static |
Definition at line 69 of file ShardOps.cpp.
|
static |
Definition at line 62 of file ShardOps.cpp.
|
static |
Definition at line 1014 of file ShardOps.cpp.
Referenced by product().
|
static |
Definition at line 1021 of file ShardOps.cpp.
References product().
|
static |
Definition at line 214 of file ShardOps.cpp.
Referenced by mlir::shard::shardShapedType().
|
static |
Definition at line 1151 of file ShardOps.cpp.
References mlir::shard::collectiveProcessGroupSize().
|
static |
Definition at line 1068 of file ShardOps.cpp.
References mlir::shard::collectiveProcessGroupSize(), mlir::remark::failed(), mlir::Value::getLoc(), mlir::Value::getType(), and verifyDimensionCompatibility().
|
static |
Definition at line 1025 of file ShardOps.cpp.
References mlir::emitError().
Referenced by verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
|
static |
Definition at line 1041 of file ShardOps.cpp.
References mlir::shard::collectiveProcessGroupSize(), mlir::emitError(), mlir::remark::failed(), mlir::Value::getLoc(), mlir::Value::getType(), and verifyDimensionCompatibility().
Definition at line 177 of file ShardOps.cpp.
References mlir::emitError(), and isUnique().
Referenced by getGridAndVerifyAxes().
|
static |
Definition at line 987 of file ShardOps.cpp.
References mlir::emitError().
|
static |
Definition at line 1113 of file ShardOps.cpp.
References mlir::shard::collectiveProcessGroupSize(), mlir::emitError(), mlir::remark::failed(), mlir::Value::getLoc(), mlir::Value::getType(), and verifyDimensionCompatibility().