MLIR 22.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14
15using namespace mlir;
16using namespace mlir::arith;
17using namespace mlir::shard;
18
19namespace {
20
21// Sharding of arith.constant
22// RankedTensor constants can be sharded like any other tensor.
23// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
24// %sharding = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
25// Scalar constants are always replicated and need no sharding annotation.
26
27struct ConstantShardingInterface
28 : public ShardingInterface::ExternalModel<ConstantShardingInterface,
29 ConstantOp> {
30 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31 auto ndims = 0;
32 if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
33 ndims = type.getRank();
34 }
35 return SmallVector<utils::IteratorType>(ndims,
36 utils::IteratorType::parallel);
37 }
38
39 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
40 if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
41 return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
42 type.getRank(), op->getContext())});
43 }
44 return {};
45 }
46
47 // Indicate failure if no result sharding exists.
48 // Otherwise mirror result sharding if it is a tensor constant.
49 // Otherwise return replication option.
50 FailureOr<ShardingOption>
51 getShardingOption(Operation *op, ArrayRef<Sharding> operandShardings,
52 ArrayRef<Sharding> resultShardings) const {
53 assert(resultShardings.size() == 1 &&
54 "Expecting exactly one result sharding for arith.constant");
55 auto resultSharding = resultShardings[0];
56 if (!resultSharding) {
57 return failure();
58 }
59 if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
60 ShardingArray axesArray(resultSharding.getSplitAxes().size());
61 for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
62 axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
63 }
64 return ShardingOption(axesArray, resultSharding.getGridAttr());
65 }
66 return ShardingOption({}, resultSharding.getGridAttr());
67 }
68
69 LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
70 ArrayRef<Sharding> operandShardings,
71 ArrayRef<Sharding> resultShardings,
72 IRMapping &partitionMap,
73 SymbolTableCollection &symbolTable,
74 OpBuilder &builder) const {
75 auto cOp = cast<ConstantOp>(op);
76 if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
77 if (!value.isSplat() || !resultShardings[0]) {
78 // Currently non-splat constants are not supported.
79 return failure();
80 }
81 auto sharding = resultShardings[0];
82 auto newType = cast<RankedTensorType>(shardType(
83 cOp.getType(), getGrid(op, sharding.getGridAttr(), symbolTable),
84 sharding));
85 auto newValue = value.resizeSplat(newType);
86 auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue);
87 partitionMap.map(op->getResult(0), newOp.getResult());
88 partitionMap.map(op, newOp.getOperation());
89 } else {
90 // `clone` will populate the mapping of old to new results.
91 (void)builder.clone(*op, partitionMap);
92 }
93 return success();
94 }
95};
96} // namespace
97
99 DialectRegistry &registry) {
100
101 registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
102 ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
103 });
104}
return success()
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
Type getType() const
Return the type of this value.
Definition Value.h:105
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
SmallVector< SmallVector< GridAxis > > ShardingArray
Include the generated interface declarations.