16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "tosa-sharding-impl"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
32 struct MatMulOpSharding
33 :
public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
40 utils::IteratorType::parallel);
41 types[tensorType.getRank()] = utils::IteratorType::reduction;
46 getReductionLoopIteratorKinds(
Operation *op)
const {
65 struct NegateOpSharding
66 :
public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
69 auto type = dyn_cast<RankedTensorType>(val.
getType());
73 utils::IteratorType::parallel);
80 auto type = dyn_cast<RankedTensorType>(val.
getType());
83 int64_t rank = type.getRank();
98 resultShardings, spmdizationMap,
99 symbolTable, builder);
104 template <
typename OpType>
105 static void registerElemwiseOne(
MLIRContext *ctx) {
106 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
110 template <
typename... OpTypes>
111 static void registerElemwiseAll(
MLIRContext *ctx) {
112 (registerElemwiseOne<OpTypes>(ctx), ...);
122 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
123 BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
124 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
125 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
126 LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
127 GreaterOp, GreaterEqualOp>(ctx);
129 MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
130 NegateOp::attachInterface<NegateOpSharding>(*ctx);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.