14 #include "llvm/Support/Debug.h"
16 #define DEBUG_TYPE "tensor-sharding-impl"
17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
26 struct EmptyOpShardingInterface
27 :
public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
32 utils::IteratorType::parallel);
38 auto type = dyn_cast<RankedTensorType>(val.
getType());
52 mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
58 assert(resultShardings.size() == 1);
63 assert(oldType.getRank() ==
shardType.getRank());
64 int currOldOprndNum = -1;
65 mesh::ShardShapeOp shapeForDevice;
68 for (
auto i = 0; i < oldType.getRank(); ++i) {
69 if (!oldType.isDynamicDim(i) &&
shardType.isDynamicDim(i)) {
72 builder.
create<ShardingOp>(op->
getLoc(), resultShardings[0]);
73 device = builder.
create<mesh::ProcessLinearIndexOp>(
74 op->
getLoc(), resultShardings[0].getMesh());
75 shapeForDevice = builder.
create<mesh::ShardShapeOp>(
79 newOperands.emplace_back(shapeForDevice.getResult()[i]);
80 }
else if (oldType.isDynamicDim(i)) {
82 newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
90 newOp = builder.
clone(*op, spmdizationMap);
103 EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.