MLIR  20.0.0git
MeshShardingExtensions.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "llvm/Support/Debug.h"
15 
16 #define DEBUG_TYPE "tensor-sharding-impl"
17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 using namespace mlir::mesh;
22 
23 namespace {
24 
25 // Sharding of tensor.empty
26 struct EmptyOpShardingInterface
27  : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
28  tensor::EmptyOp> {
29  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30  auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
32  utils::IteratorType::parallel);
33  }
34 
35  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
36  MLIRContext *ctx = op->getContext();
37  Value val = op->getResult(0);
38  auto type = dyn_cast<RankedTensorType>(val.getType());
39  if (!type)
40  return {};
41  return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
42  }
43 
44  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
45  ArrayRef<MeshSharding> operandShardings,
46  ArrayRef<MeshSharding> resultShardings,
47  IRMapping &spmdizationMap,
48  SymbolTableCollection &symbolTable,
49  OpBuilder &builder) const {
50  auto shardType = cast<ShapedType>(mesh::shardType(
51  op->getResult(0).getType(),
52  mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
53  resultShardings[0]));
54  Operation *newOp = nullptr;
55  // if the sharding introduces a new dynamic dimension, we take it from
56  // the dynamic sharding info. For now bail out if it's not
57  // provided.
58  assert(resultShardings.size() == 1);
59  if (!shardType.hasStaticShape()) {
60  assert(op->getResult(0).hasOneUse());
61  SmallVector<Value> newOperands;
62  auto oldType = cast<ShapedType>(op->getResult(0).getType());
63  assert(oldType.getRank() == shardType.getRank());
64  int currOldOprndNum = -1;
65  mesh::ShardShapeOp shapeForDevice;
66  Value device;
67  Operation *newSharding = nullptr;
68  for (auto i = 0; i < oldType.getRank(); ++i) {
69  if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
70  if (!newSharding) {
71  newSharding =
72  builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
73  device = builder.create<mesh::ProcessLinearIndexOp>(
74  op->getLoc(), resultShardings[0].getMesh());
75  shapeForDevice = builder.create<mesh::ShardShapeOp>(
76  op->getLoc(), oldType.getShape(), newSharding->getResult(0),
77  device);
78  }
79  newOperands.emplace_back(shapeForDevice.getResult()[i]);
80  } else if (oldType.isDynamicDim(i)) {
81  assert(shardType.isDynamicDim(i));
82  newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
83  }
84  }
85  newOp =
86  builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
87  spmdizationMap.map(op->getResult(0), newOp->getResult(0));
88  } else {
89  // `clone` will populate the mapping of old to new results.
90  newOp = builder.clone(*op, spmdizationMap);
91  }
92  newOp->getResult(0).setType(shardType);
93 
94  return success();
95  }
96 };
97 } // namespace
98 
100  DialectRegistry &registry) {
101 
102  registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
103  EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
104  });
105 }
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:264
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.