MLIR  22.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 using namespace mlir;
16 using namespace mlir::arith;
17 using namespace mlir::mesh;
18 
19 namespace {
20 
21 // Sharding of arith.constant
22 // RankedTensor constants can be sharded like any other tensor.
23 // %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
24 // %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
25 // Scalar constants are always replicated and need no sharding annotation.
26 
27 struct ConstantShardingInterface
28  : public ShardingInterface::ExternalModel<ConstantShardingInterface,
29  ConstantOp> {
30  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31  auto ndims = 0;
32  if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
33  ndims = type.getRank();
34  }
36  utils::IteratorType::parallel);
37  }
38 
39  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
40  if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
42  type.getRank(), op->getContext())});
43  }
44  return {};
45  }
46 
47  // Indicate failure if no result sharding exists.
48  // Otherwise mirror result sharding if it is a tensor constant.
49  // Otherwise return replication option.
50  FailureOr<ShardingOption>
51  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
52  ArrayRef<MeshSharding> resultShardings) const {
53  assert(resultShardings.size() == 1 &&
54  "Expecting exactly one result sharding for arith.constant");
55  auto resultSharding = resultShardings[0];
56  if (!resultSharding) {
57  return failure();
58  }
59  if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
60  ShardingArray axesArray(resultSharding.getSplitAxes().size());
61  for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
62  axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
63  }
64  return ShardingOption(axesArray, resultSharding.getMeshAttr());
65  }
66  return ShardingOption({}, resultSharding.getMeshAttr());
67  }
68 
69  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
70  ArrayRef<MeshSharding> operandShardings,
71  ArrayRef<MeshSharding> resultShardings,
72  IRMapping &spmdizationMap,
73  SymbolTableCollection &symbolTable,
74  OpBuilder &builder) const {
75  auto cOp = cast<ConstantOp>(op);
76  if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
77  if (!value.isSplat() || !resultShardings[0]) {
78  // Currently non-splat constants are not supported.
79  return failure();
80  }
81  auto sharding = resultShardings[0];
82  auto newType = cast<RankedTensorType>(shardType(
83  cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
84  sharding));
85  auto newValue = value.resizeSplat(newType);
86  auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
87  spmdizationMap.map(op->getResult(0), newOp.getResult());
88  spmdizationMap.map(op, newOp.getOperation());
89  } else {
90  // `clone` will populate the mapping of old to new results.
91  (void)builder.clone(*op, spmdizationMap);
92  }
93  return success();
94  }
95 };
96 } // namespace
97 
99  DialectRegistry &registry) {
100 
101  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
102  ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
103  });
104 }
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Type getType() const
Return the type of this value.
Definition: Value.h:105
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:292
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:121
Include the generated interface declarations.