14 #include "llvm/Support/Debug.h"
28 struct ConstantShardingInterface
29 :
public ShardingInterface::ExternalModel<ConstantShardingInterface,
34 ndims = type.getRank();
37 utils::IteratorType::parallel);
51 FailureOr<ShardingOption>
54 assert(resultShardings.size() == 1 &&
55 "Expecting exactly one result sharding for arith.constant");
56 auto resultSharding = resultShardings[0];
57 if (!resultSharding) {
61 ShardingArray axesArray(resultSharding.getSplitAxes().size());
63 axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
76 auto cOp = cast<ConstantOp>(op);
77 if (
auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
78 if (!value.isSplat() || !resultShardings[0]) {
82 auto sharding = resultShardings[0];
83 auto newType = cast<RankedTensorType>(
shardType(
84 cOp.getType(),
getMesh(op, sharding.getMeshAttr(), symbolTable),
86 auto newValue = value.resizeSplat(newType);
87 auto newOp = builder.
create<ConstantOp>(op->
getLoc(), newType, newValue);
89 spmdizationMap.
map(op, newOp.getOperation());
92 (void)builder.
clone(*op, spmdizationMap);
103 ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents a collection of SymbolTables.
Type getType() const
Return the type of this value.
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.