MLIR  21.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "llvm/Support/Debug.h"
15 
16 using namespace mlir;
17 using namespace mlir::arith;
18 using namespace mlir::mesh;
19 
20 namespace {
21 
22 // Sharding of arith.constant
23 // RankedTensor constants can be sharded like any other tensor.
24 // %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
25 // %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
26 // Scalar constants are always replicated and need no sharding annotation.
27 
28 struct ConstantShardingInterface
29  : public ShardingInterface::ExternalModel<ConstantShardingInterface,
30  ConstantOp> {
31  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
32  auto ndims = 0;
33  if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
34  ndims = type.getRank();
35  }
37  utils::IteratorType::parallel);
38  }
39 
40  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
41  if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
43  type.getRank(), op->getContext())});
44  }
45  return {};
46  }
47 
48  // Indicate failure if no result sharding exists.
49  // Otherwise mirror result sharding if it is a tensor constant.
50  // Otherwise return replication option.
51  FailureOr<ShardingOption>
52  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
53  ArrayRef<MeshSharding> resultShardings) const {
54  assert(resultShardings.size() == 1 &&
55  "Expecting exactly one result sharding for arith.constant");
56  auto resultSharding = resultShardings[0];
57  if (!resultSharding) {
58  return failure();
59  }
60  if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
61  ShardingArray axesArray(resultSharding.getSplitAxes().size());
62  for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
63  axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
64  }
65  return ShardingOption(axesArray, resultSharding.getMeshAttr());
66  }
67  return ShardingOption({}, resultSharding.getMeshAttr());
68  }
69 
70  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
71  ArrayRef<MeshSharding> operandShardings,
72  ArrayRef<MeshSharding> resultShardings,
73  IRMapping &spmdizationMap,
74  SymbolTableCollection &symbolTable,
75  OpBuilder &builder) const {
76  auto cOp = cast<ConstantOp>(op);
77  if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
78  if (!value.isSplat() || !resultShardings[0]) {
79  // Currently non-splat constants are not supported.
80  return failure();
81  }
82  auto sharding = resultShardings[0];
83  auto newType = cast<RankedTensorType>(shardType(
84  cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
85  sharding));
86  auto newValue = value.resizeSplat(newType);
87  auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
88  spmdizationMap.map(op->getResult(0), newOp.getResult());
89  spmdizationMap.map(op, newOp.getOperation());
90  } else {
91  // `clone` will populate the mapping of old to new results.
92  (void)builder.clone(*op, spmdizationMap);
93  }
94  return success();
95  }
96 };
97 } // namespace
98 
100  DialectRegistry &registry) {
101 
102  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
103  ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
104  });
105 }
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Type getType() const
Return the type of this value.
Definition: Value.h:105
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:270
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:128
Include the generated interface declarations.