16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "tosa-sharding-impl"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
32 struct MatMulOpSharding
33 :
public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
40 utils::IteratorType::parallel);
41 types[tensorType.getRank()] = utils::IteratorType::reduction;
46 getReductionLoopIteratorKinds(
Operation *op)
const {
63 template <
typename OpType>
65 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
69 template <
typename... OpTypes>
71 (registerElemwiseOne<OpTypes>(ctx), ...);
81 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
82 BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
83 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
84 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
85 LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
86 GreaterOp, GreaterEqualOp>(ctx);
88 MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Type getType() const
Return the type of this value.
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.