17 #define DEBUG_TYPE "tosa-sharding-impl"
18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
31 struct MatMulOpSharding
32 :
public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
39 utils::IteratorType::parallel);
40 types[tensorType.getRank()] = utils::IteratorType::reduction;
45 getReductionLoopIteratorKinds(
Operation *op)
const {
64 struct NegateOpSharding
65 :
public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
68 auto type = dyn_cast<RankedTensorType>(val.
getType());
72 utils::IteratorType::parallel);
79 auto type = dyn_cast<RankedTensorType>(val.
getType());
82 int64_t rank = type.getRank();
97 operandShardings, resultShardings,
98 partitionMap, symbolTable, builder);
103 template <
typename OpType>
104 static void registerElemwiseOne(
MLIRContext *ctx) {
105 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
109 template <
typename... OpTypes>
110 static void registerElemwiseAll(
MLIRContext *ctx) {
111 (registerElemwiseOne<OpTypes>(ctx), ...);
121 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
122 BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
123 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
124 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
125 LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
126 GreaterOp, GreaterEqualOp>(ctx);
128 MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
129 NegateOp::attachInterface<NegateOpSharding>(*ctx);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.