22template <
typename OpTy>
23struct CreatorOpShardingInterface
24 :
public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
26 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op)
const {
28 return SmallVector<utils::IteratorType>(ndims,
29 utils::IteratorType::parallel);
32 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
35 auto type = dyn_cast<RankedTensorType>(val.
getType());
38 return SmallVector<AffineMap>(
40 {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
43 LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
44 ArrayRef<Sharding> operandShardings,
45 ArrayRef<Sharding> resultShardings,
46 IRMapping &partitionMap,
47 SymbolTableCollection &symbolTable,
48 OpBuilder &builder)
const {
49 assert(resultShardings.size() == 1);
51 mlir::shard::GridOp grid;
53 if (resType.getRank() > 0) {
54 grid =
shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable);
60 Operation *newOp =
nullptr;
66 SmallVector<Value> newOperands;
67 auto oldType = cast<ShapedType>(resType);
68 assert(oldType.getRank() ==
shardType.getRank());
69 int currOldOprndNum = -1;
70 shard::ShardShapeOp shapeForDevice;
72 Operation *newSharding =
nullptr;
73 for (
auto i = 0; i < oldType.getRank(); ++i) {
74 if (!oldType.isDynamicDim(i) &&
shardType.isDynamicDim(i)) {
77 ShardingOp::create(builder, op->
getLoc(), resultShardings[0]);
79 shard::ProcessMultiIndexOp::create(builder, op->
getLoc(), grid)
81 shapeForDevice = shard::ShardShapeOp::create(
82 builder, op->
getLoc(), oldType.getShape(), partitionedOperands,
85 newOperands.emplace_back(shapeForDevice.getResult()[i]);
86 }
else if (oldType.isDynamicDim(i)) {
88 newOperands.emplace_back(partitionedOperands[++currOldOprndNum]);
95 newOp = builder.
clone(*op, partitionMap);
108 EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
110 SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Type shardType(Type type, GridOp grid, Sharding sharding)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.