MLIR  22.0.0git
ShardingExtensions.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 using namespace mlir;
16 using namespace mlir::tensor;
17 using namespace mlir::shard;
18 
19 namespace {
20 
21 // Sharding of tensor.empty/tensor.splat
22 template <typename OpTy>
23 struct CreatorOpShardingInterface
24  : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
25  OpTy> {
26  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
27  auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
29  utils::IteratorType::parallel);
30  }
31 
32  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
33  MLIRContext *ctx = op->getContext();
34  Value val = op->getResult(0);
35  auto type = dyn_cast<RankedTensorType>(val.getType());
36  if (!type)
37  return {};
39  op->getNumOperands() + op->getNumResults(),
40  {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
41  }
42 
43  LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
44  ArrayRef<Sharding> operandShardings,
45  ArrayRef<Sharding> resultShardings,
46  IRMapping &partitionMap,
47  SymbolTableCollection &symbolTable,
48  OpBuilder &builder) const {
49  assert(resultShardings.size() == 1);
50  auto resType = cast<RankedTensorType>(op->getResult(0).getType());
52  ShapedType shardType;
53  if (resType.getRank() > 0) {
54  grid = shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable);
55  shardType =
56  cast<ShapedType>(shard::shardType(resType, grid, resultShardings[0]));
57  } else {
58  shardType = resType;
59  }
60  Operation *newOp = nullptr;
61  // if the sharding introduces a new dynamic dimension, we take it from
62  // the dynamic sharding info. For now bail out if it's not
63  // provided.
64  if (!shardType.hasStaticShape()) {
65  assert(op->getResult(0).hasOneUse());
66  SmallVector<Value> newOperands;
67  auto oldType = cast<ShapedType>(resType);
68  assert(oldType.getRank() == shardType.getRank());
69  int currOldOprndNum = -1;
70  shard::ShardShapeOp shapeForDevice;
71  ValueRange device;
72  Operation *newSharding = nullptr;
73  for (auto i = 0; i < oldType.getRank(); ++i) {
74  if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
75  if (!newSharding) {
76  newSharding =
77  ShardingOp::create(builder, op->getLoc(), resultShardings[0]);
78  device =
79  shard::ProcessMultiIndexOp::create(builder, op->getLoc(), grid)
80  .getResults();
81  shapeForDevice = shard::ShardShapeOp::create(
82  builder, op->getLoc(), oldType.getShape(), partitionedOperands,
83  newSharding->getResult(0), device);
84  }
85  newOperands.emplace_back(shapeForDevice.getResult()[i]);
86  } else if (oldType.isDynamicDim(i)) {
87  assert(shardType.isDynamicDim(i));
88  newOperands.emplace_back(partitionedOperands[++currOldOprndNum]);
89  }
90  }
91  newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands);
92  partitionMap.map(op->getResult(0), newOp->getResult(0));
93  } else {
94  // `clone` will populate the mapping of old to new results.
95  newOp = builder.clone(*op, partitionMap);
96  }
97  newOp->getResult(0).setType(shardType);
98 
99  return success();
100  }
101 };
102 } // namespace
103 
105  DialectRegistry &registry) {
106 
107  registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
108  EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
109  *ctx);
110  SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
111  *ctx);
112  });
113 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
shard::GridOp GridOp
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition: ShardOps.cpp:291
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:121
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.