14 #include "llvm/Support/Debug.h"
16 #define DEBUG_TYPE "tensor-sharding-impl"
17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
26 template <
typename OpTy>
27 struct CreatorOpShardingInterface
28 :
public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
33 utils::IteratorType::parallel);
39 auto type = dyn_cast<RankedTensorType>(val.
getType());
44 {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
53 assert(resultShardings.size() == 1);
57 if (resType.getRank() > 0) {
58 mesh =
mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
71 auto oldType = cast<ShapedType>(resType);
72 assert(oldType.getRank() ==
shardType.getRank());
73 int currOldOprndNum = -1;
74 mesh::ShardShapeOp shapeForDevice;
77 for (
auto i = 0; i < oldType.getRank(); ++i) {
78 if (!oldType.isDynamicDim(i) &&
shardType.isDynamicDim(i)) {
81 builder.
create<ShardingOp>(op->
getLoc(), resultShardings[0]);
83 builder.
create<mesh::ProcessMultiIndexOp>(op->
getLoc(), mesh)
85 shapeForDevice = builder.
create<mesh::ShardShapeOp>(
86 op->
getLoc(), oldType.getShape(), spmdizedOperands,
89 newOperands.emplace_back(shapeForDevice.getResult()[i]);
90 }
else if (oldType.isDynamicDim(i)) {
92 newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
99 newOp = builder.
clone(*op, spmdizationMap);
112 EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
114 SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.