MLIR  21.0.0git
MeshShardingExtensions.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "llvm/Support/Debug.h"
15 
16 #define DEBUG_TYPE "tensor-sharding-impl"
17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 using namespace mlir::mesh;
22 
23 namespace {
24 
25 // Sharding of tensor.empty/tensor.splat
26 template <typename OpTy>
27 struct CreatorOpShardingInterface
28  : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
29  OpTy> {
30  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31  auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
33  utils::IteratorType::parallel);
34  }
35 
36  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
37  MLIRContext *ctx = op->getContext();
38  Value val = op->getResult(0);
39  auto type = dyn_cast<RankedTensorType>(val.getType());
40  if (!type)
41  return {};
43  op->getNumOperands() + op->getNumResults(),
44  {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
45  }
46 
47  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
48  ArrayRef<MeshSharding> operandShardings,
49  ArrayRef<MeshSharding> resultShardings,
50  IRMapping &spmdizationMap,
51  SymbolTableCollection &symbolTable,
52  OpBuilder &builder) const {
53  assert(resultShardings.size() == 1);
54  auto resType = cast<RankedTensorType>(op->getResult(0).getType());
55  mlir::mesh::MeshOp mesh;
56  ShapedType shardType;
57  if (resType.getRank() > 0) {
58  mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
59  shardType =
60  cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
61  } else {
62  shardType = resType;
63  }
64  Operation *newOp = nullptr;
65  // if the sharding introduces a new dynamic dimension, we take it from
66  // the dynamic sharding info. For now bail out if it's not
67  // provided.
68  if (!shardType.hasStaticShape()) {
69  assert(op->getResult(0).hasOneUse());
70  SmallVector<Value> newOperands;
71  auto oldType = cast<ShapedType>(resType);
72  assert(oldType.getRank() == shardType.getRank());
73  int currOldOprndNum = -1;
74  mesh::ShardShapeOp shapeForDevice;
75  ValueRange device;
76  Operation *newSharding = nullptr;
77  for (auto i = 0; i < oldType.getRank(); ++i) {
78  if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
79  if (!newSharding) {
80  newSharding =
81  builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
82  device =
83  builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
84  .getResults();
85  shapeForDevice = builder.create<mesh::ShardShapeOp>(
86  op->getLoc(), oldType.getShape(), spmdizedOperands,
87  newSharding->getResult(0), device);
88  }
89  newOperands.emplace_back(shapeForDevice.getResult()[i]);
90  } else if (oldType.isDynamicDim(i)) {
91  assert(shardType.isDynamicDim(i));
92  newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
93  }
94  }
95  newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
96  spmdizationMap.map(op->getResult(0), newOp->getResult(0));
97  } else {
98  // `clone` will populate the mapping of old to new results.
99  newOp = builder.clone(*op, spmdizationMap);
100  }
101  newOp->getResult(0).setType(shardType);
102 
103  return success();
104  }
105 };
106 } // namespace
107 
109  DialectRegistry &registry) {
110 
111  registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
112  EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
113  *ctx);
114  SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
115  *ctx);
116  });
117 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:191
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:270
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:128
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.