27 struct ConstantShardingInterface
28 :
public ShardingInterface::ExternalModel<ConstantShardingInterface,
33 ndims = type.getRank();
36 utils::IteratorType::parallel);
50 FailureOr<ShardingOption>
53 assert(resultShardings.size() == 1 &&
54 "Expecting exactly one result sharding for arith.constant");
55 auto resultSharding = resultShardings[0];
56 if (!resultSharding) {
60 ShardingArray axesArray(resultSharding.getSplitAxes().size());
62 axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
75 auto cOp = cast<ConstantOp>(op);
76 if (
auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
77 if (!value.isSplat() || !resultShardings[0]) {
81 auto sharding = resultShardings[0];
82 auto newType = cast<RankedTensorType>(
shardType(
83 cOp.getType(),
getGrid(op, sharding.getGridAttr(), symbolTable),
85 auto newValue = value.resizeSplat(newType);
86 auto newOp = ConstantOp::create(builder, op->
getLoc(), newType, newValue);
88 partitionMap.
map(op, newOp.getOperation());
91 (void)builder.
clone(*op, partitionMap);
102 ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents a collection of SymbolTables.
Type getType() const
Return the type of this value.
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static GridOp getGrid(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, SymbolTableCollection &symbolTable)
Type shardType(Type type, GridOp grid, Sharding sharding)
Include the generated interface declarations.