27struct ConstantShardingInterface
28 :
public ShardingInterface::ExternalModel<ConstantShardingInterface,
30 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op)
const {
33 ndims = type.getRank();
35 return SmallVector<utils::IteratorType>(ndims,
36 utils::IteratorType::parallel);
39 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
50 FailureOr<ShardingOption>
51 getShardingOption(Operation *op, ArrayRef<Sharding> operandShardings,
52 ArrayRef<Sharding> resultShardings)
const {
53 assert(resultShardings.size() == 1 &&
54 "Expecting exactly one result sharding for arith.constant");
55 auto resultSharding = resultShardings[0];
56 if (!resultSharding) {
60 ShardingArray axesArray(resultSharding.getSplitAxes().size());
61 for (
auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
62 axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
64 return ShardingOption(axesArray, resultSharding.getGridAttr());
66 return ShardingOption({}, resultSharding.getGridAttr());
69 LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
70 ArrayRef<Sharding> operandShardings,
71 ArrayRef<Sharding> resultShardings,
72 IRMapping &partitionMap,
73 SymbolTableCollection &symbolTable,
74 OpBuilder &builder)
const {
75 auto cOp = cast<ConstantOp>(op);
76 if (
auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
77 if (!value.isSplat() || !resultShardings[0]) {
81 auto sharding = resultShardings[0];
82 auto newType = cast<RankedTensorType>(
shardType(
83 cOp.getType(),
getGrid(op, sharding.getGridAttr(), symbolTable),
85 auto newValue = value.resizeSplat(newType);
86 auto newOp = ConstantOp::create(builder, op->
getLoc(), newType, newValue);
88 partitionMap.
map(op, newOp.getOperation());
91 (void)builder.
clone(*op, partitionMap);
102 ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
Type getType() const
Return the type of this value.
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
Type shardType(Type type, GridOp grid, Sharding sharding)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
SmallVector< SmallVector< GridAxis > > ShardingArray
Include the generated interface declarations.