22 template <
typename OpTy>
23 struct CreatorOpShardingInterface
24 :
public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
29 utils::IteratorType::parallel);
35 auto type = dyn_cast<RankedTensorType>(val.
getType());
40 {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
49 assert(resultShardings.size() == 1);
53 if (resType.getRank() > 0) {
54 grid =
shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable);
67 auto oldType = cast<ShapedType>(resType);
68 assert(oldType.getRank() ==
shardType.getRank());
69 int currOldOprndNum = -1;
70 shard::ShardShapeOp shapeForDevice;
73 for (
auto i = 0; i < oldType.getRank(); ++i) {
74 if (!oldType.isDynamicDim(i) &&
shardType.isDynamicDim(i)) {
77 ShardingOp::create(builder, op->
getLoc(), resultShardings[0]);
79 shard::ProcessMultiIndexOp::create(builder, op->
getLoc(), grid)
81 shapeForDevice = shard::ShardShapeOp::create(
82 builder, op->
getLoc(), oldType.getShape(), partitionedOperands,
85 newOperands.emplace_back(shapeForDevice.getResult()[i]);
86 }
else if (oldType.isDynamicDim(i)) {
88 newOperands.emplace_back(partitionedOperands[++currOldOprndNum]);
95 newOp = builder.
clone(*op, partitionMap);
108 EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
110 SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Type shardType(Type type, GridOp grid, Sharding sharding)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.