MLIR 22.0.0git
ShardingExtensions.cpp
Go to the documentation of this file.
1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14
15using namespace mlir;
16using namespace mlir::tensor;
17using namespace mlir::shard;
18
19namespace {
20
21// Sharding of tensor.empty/tensor.splat
22template <typename OpTy>
23struct CreatorOpShardingInterface
24 : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
25 OpTy> {
26 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
27 auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
28 return SmallVector<utils::IteratorType>(ndims,
29 utils::IteratorType::parallel);
30 }
31
32 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
33 MLIRContext *ctx = op->getContext();
34 Value val = op->getResult(0);
35 auto type = dyn_cast<RankedTensorType>(val.getType());
36 if (!type)
37 return {};
38 return SmallVector<AffineMap>(
39 op->getNumOperands() + op->getNumResults(),
40 {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
41 }
42
43 LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
44 ArrayRef<Sharding> operandShardings,
45 ArrayRef<Sharding> resultShardings,
46 IRMapping &partitionMap,
47 SymbolTableCollection &symbolTable,
48 OpBuilder &builder) const {
49 assert(resultShardings.size() == 1);
50 auto resType = cast<RankedTensorType>(op->getResult(0).getType());
51 mlir::shard::GridOp grid;
52 ShapedType shardType;
53 if (resType.getRank() > 0) {
54 grid = shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable);
55 shardType =
56 cast<ShapedType>(shard::shardType(resType, grid, resultShardings[0]));
57 } else {
58 shardType = resType;
59 }
60 Operation *newOp = nullptr;
61 // if the sharding introduces a new dynamic dimension, we take it from
62 // the dynamic sharding info. For now bail out if it's not
63 // provided.
64 if (!shardType.hasStaticShape()) {
65 assert(op->getResult(0).hasOneUse());
66 SmallVector<Value> newOperands;
67 auto oldType = cast<ShapedType>(resType);
68 assert(oldType.getRank() == shardType.getRank());
69 int currOldOprndNum = -1;
70 shard::ShardShapeOp shapeForDevice;
71 ValueRange device;
72 Operation *newSharding = nullptr;
73 for (auto i = 0; i < oldType.getRank(); ++i) {
74 if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
75 if (!newSharding) {
76 newSharding =
77 ShardingOp::create(builder, op->getLoc(), resultShardings[0]);
78 device =
79 shard::ProcessMultiIndexOp::create(builder, op->getLoc(), grid)
80 .getResults();
81 shapeForDevice = shard::ShardShapeOp::create(
82 builder, op->getLoc(), oldType.getShape(), partitionedOperands,
83 newSharding->getResult(0), device);
84 }
85 newOperands.emplace_back(shapeForDevice.getResult()[i]);
86 } else if (oldType.isDynamicDim(i)) {
87 assert(shardType.isDynamicDim(i));
88 newOperands.emplace_back(partitionedOperands[++currOldOprndNum]);
89 }
90 }
91 newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands);
92 partitionMap.map(op->getResult(0), newOp->getResult(0));
93 } else {
94 // `clone` will populate the mapping of old to new results.
95 newOp = builder.clone(*op, partitionMap);
96 }
97 newOp->getResult(0).setType(shardType);
98
99 return success();
100 }
101};
102} // namespace
103
105 DialectRegistry &registry) {
106
107 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
108 EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
109 *ctx);
110 SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
111 *ctx);
112 });
113}
return success()
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.