MLIR 23.0.0git
Partition.cpp
Go to the documentation of this file.
1//===- Partition.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
32#include <iterator>
33#include <optional>
34#include <tuple>
35#include <utility>
36
37namespace mlir::shard {
39template <typename SourceAxes, typename TargetAxes>
40static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
41 const TargetAxes &targetAxes) {
42 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
43 return sourceAxes.contains(targetAxis);
44 });
45}
48 const Sharding &sourceSharding,
49 int64_t splitTensorAxis,
50 GridAxis splitGridAxis) {
51 SmallVector<GridAxesAttr> targetShardingSplitAxes =
52 llvm::to_vector(sourceSharding.getSplitAxes());
53 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
54 splitTensorAxis) {
55 targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
56 }
57 auto targetSplitAxes =
58 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
59 targetSplitAxes.push_back(splitGridAxis);
60 targetShardingSplitAxes[splitTensorAxis] =
61 GridAxesAttr::get(ctx, targetSplitAxes);
62 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
63}
64
65// Split a replicated tensor along a grid axis.
66// E.g. [[0, 1]] -> [[0, 1, 2]].
67// Returns the partitioned target value with its sharding.
68static std::tuple<TypedValue<ShapedType>, Sharding>
70 Sharding sourceSharding,
71 TypedValue<ShapedType> sourceShard, GridOp grid,
72 int64_t splitTensorAxis, GridAxis splitGridAxis) {
73 TypedValue<ShapedType> targetShard =
74 AllSliceOp::create(builder, sourceShard, grid,
75 ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
76 .getResult();
78 builder.getContext(), std::move(sourceSharding), splitTensorAxis,
79 splitGridAxis);
80 return {targetShard, targetSharding};
81}
82
83// Detect if the resharding is of type e.g.
84// [[0, 1]] -> [[0, 1, 2]].
85// If detected, returns the corresponding tensor axis grid axis pair.
86// Does not detect insertions like
87// [[0, 1]] -> [[0, 2, 1]].
88static std::optional<std::tuple<int64_t, GridAxis>>
90 const Sharding &targetSharding) {
91 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
92 ++tensorAxis) {
93 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
94 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
95 targetSharding.getSplitAxes()[tensorAxis].size()) {
96 continue;
97 }
98 if (!llvm::equal(
99 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
100 llvm::make_range(
101 targetSharding.getSplitAxes()[tensorAxis]
102 .asArrayRef()
103 .begin(),
104 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
105 1))) {
106 continue;
107 }
108 } else {
109 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
110 continue;
111 }
112 }
113 return std::make_tuple(
114 tensorAxis,
115 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
116 }
117 return std::nullopt;
118}
119
120static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
122 const Sharding &sourceSharding,
123 Sharding targetSharding,
124 TypedValue<ShapedType> sourceShard) {
125 if (auto detectRes = detectSplitLastAxisInResharding(
126 sourceSharding, std::move(targetSharding))) {
127 auto [tensorAxis, gridAxis] = detectRes.value();
128 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid,
129 tensorAxis, gridAxis);
130 }
131
132 return std::nullopt;
133}
134
135// Detect if the resharding removes trailing split Axes along a tensor
136// dimension, e.g.
137// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
138// If detected, returns the corresponding (tensor dim, grid axes) pair, where
139// the "grid axes" are the removed trailing split axes.
140static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
142 const Sharding &tgtSharding) {
143 size_t dimOff = 0;
144 size_t srcSize = srcSharding.getSplitAxes().size();
145 for (size_t tensorDim = 0; tensorDim < srcSize; ++tensorDim) {
146 auto srcSplitAxes = srcSharding.getSplitAxes()[tensorDim].asArrayRef();
147 if (tgtSharding.getSplitAxes().size() > tensorDim) {
148 auto tgtSplitAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
149 // No match if the target sharding does not have less split axes than the
150 // source sharding along the current tensor dimension.
151 if (srcSplitAxes.size() <= tgtSplitAxes.size())
152 continue;
153 // No match if the split axes of the target sharding are different from
154 // the first split axes of the source sharding.
155 if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(),
156 srcSplitAxes.begin()))
157 continue;
158 dimOff = tgtSplitAxes.size();
159 } else {
160 // Here the target dimension is replicated; there is nothing to do if the
161 // source dimension is also replicated.
162 if (srcSplitAxes.size() == 0)
163 continue;
164 dimOff = 0;
165 }
166 // This is a match. Return the current tensor dimension and the trailing
167 // grid axis of the source sharding along this dimension.
168 ArrayRef<GridAxis> trailingAxes = srcSplitAxes.drop_front(dimOff);
169 SmallVector<GridAxis> unsplitAxes(trailingAxes.begin(), trailingAxes.end());
170 return std::make_tuple(tensorDim, unsplitAxes);
171 }
172 return std::nullopt;
173}
174
175// Return the resulting Sharding if the unsplit last axes resharding is applied.
177 const Sharding &sourceSharding,
178 int64_t splitTensorDim,
179 size_t numUnsplitAxes) {
180 SmallVector<GridAxesAttr> resSplitAxes =
181 llvm::to_vector(sourceSharding.getSplitAxes());
182 assert(static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
183 ArrayRef<GridAxis> srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef();
184 assert(srcSplitAxes.size() >= numUnsplitAxes);
185 size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
186 SmallVector<GridAxis> newSplitAxes(srcSplitAxes.begin(),
187 srcSplitAxes.begin() + numSplitAxes);
188 resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes);
189 return Sharding::get(sourceSharding.getGridAttr(), resSplitAxes);
190}
191
192// Return the resulting Tensor type after applying the unsplit last axes
193// resharding.
195 ShapedType sourceType, int64_t splitTensorDim, ArrayRef<int64_t> gridShape,
196 ArrayRef<GridAxis> unsplitAxes) {
197 SmallVector<int64_t> targetShape = llvm::to_vector(sourceType.getShape());
198 for (GridAxis gridAxis : unsplitAxes)
199 targetShape[splitTensorDim] =
200 gatherDimension(targetShape[splitTensorDim], gridShape[gridAxis]);
201 return sourceType.cloneWith(targetShape, sourceType.getElementType());
202}
203
204// Perform the resharding for the unsplit last axes case.
205// This basically performs an all-gather along the unsplit grid axes.
206static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxesInResharding(
207 ImplicitLocOpBuilder &builder, Sharding sourceSharding,
208 ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
209 GridOp grid, int64_t splitTensorDim, ArrayRef<GridAxis> unsplitAxes) {
210 MLIRContext *ctx = builder.getContext();
211 builder.setInsertionPointAfterValue(sourceShard);
212
214 ctx, std::move(sourceSharding), splitTensorDim, unsplitAxes.size());
215 ShapedType allGatherResultType = allGatherResultTypeInUnsplitLastAxes(
216 sourceShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
217 Value allGatherResult = AllGatherOp::create(
218 builder,
219 RankedTensorType::get(allGatherResultType.getShape(),
220 allGatherResultType.getElementType()),
221 grid.getSymName(), unsplitAxes, sourceShard, APInt(64, splitTensorDim));
222 ShapedType targetType =
223 shardShapedType(sourceUnshardedShape, grid, targetSharding);
224 TypedValue<ShapedType> targetShard =
225 tensor::CastOp::create(builder, targetType, allGatherResult).getResult();
226 return {targetShard, targetSharding};
227}
228
229static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
231 const Sharding &sourceSharding,
232 Sharding targetSharding,
233 ShapedType sourceUnshardedShape,
234 TypedValue<ShapedType> sourceShard) {
235 if (auto detectRes = detectUnsplitLastAxesInResharding(
236 sourceSharding, std::move(targetSharding))) {
237 auto [tensorDim, gridAxes] = detectRes.value();
238 return unsplitLastAxesInResharding(builder, sourceSharding,
239 sourceUnshardedShape, sourceShard, grid,
240 tensorDim, gridAxes);
241 }
242
243 return std::nullopt;
244}
245
246// Detect if the resharding is of type e.g.
247// [[0, 1], [2]] -> [[0], [1, 2]].
248// Only moving the last axis counts.
249// If detected, returns the corresponding (source_tensor_axis,
250// target_tensor_axis, grid_axis) tuple.
251static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
253 const Sharding &targetSharding) {
254 for (size_t sourceTensorAxis = 0;
255 sourceTensorAxis < sourceSharding.getSplitAxes().size();
256 ++sourceTensorAxis) {
257 for (size_t targetTensorAxis = 0;
258 targetTensorAxis < targetSharding.getSplitAxes().size();
259 ++targetTensorAxis) {
260 if (sourceTensorAxis == targetTensorAxis)
261 continue;
262 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
263 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
264 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
265 targetSharding.getSplitAxes()[targetTensorAxis]
266 .asArrayRef()
267 .back())
268 continue;
269 if (!llvm::equal(
270 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
271 .asArrayRef()
272 .begin(),
273 sourceSharding.getSplitAxes()[sourceTensorAxis]
274 .asArrayRef()
275 .end() -
276 1),
277 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
278 .asArrayRef()
279 .begin(),
280 targetSharding.getSplitAxes()[targetTensorAxis]
281 .asArrayRef()
282 .end() -
283 1)))
284 continue;
285 return std::make_tuple(
286 sourceTensorAxis, targetTensorAxis,
287 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
288 }
289 }
290 return std::nullopt;
291}
292
294 const Sharding &sourceSharding,
295 int64_t sourceTensorAxis,
296 int64_t targetTensorAxis) {
297 SmallVector<GridAxesAttr> targetShardingSplitAxes =
298 llvm::to_vector(sourceSharding.getSplitAxes());
299 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
300 targetTensorAxis) {
301 targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
302 }
303
304 auto sourceSplitAxes =
305 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
306 assert(!sourceSplitAxes.empty());
307 auto gridAxis = sourceSplitAxes.back();
308 sourceSplitAxes.pop_back();
309 targetShardingSplitAxes[sourceTensorAxis] =
310 GridAxesAttr::get(ctx, sourceSplitAxes);
311
312 auto targetSplitAxes =
313 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
314 targetSplitAxes.push_back(gridAxis);
315 targetShardingSplitAxes[targetTensorAxis] =
316 GridAxesAttr::get(ctx, targetSplitAxes);
317
318 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
319}
320
321static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
322 int64_t splitCount,
323 int64_t sourceTensorAxis,
324 int64_t targetTensorAxis) {
325 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
326 targetShape[sourceTensorAxis] =
327 gatherDimension(targetShape[sourceTensorAxis], splitCount);
328 targetShape[targetTensorAxis] =
329 shardDimension(targetShape[targetTensorAxis], splitCount);
330 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
331}
332
333static std::tuple<TypedValue<ShapedType>, Sharding>
335 Sharding sourceSharding,
336 ShapedType sourceUnshardedShape,
337 TypedValue<ShapedType> sourceShard,
338 int64_t sourceTensorAxis,
339 int64_t targetTensorAxis, GridAxis gridAxis) {
340 MLIRContext *ctx = builder.getContext();
341 builder.setInsertionPointAfterValue(sourceShard);
342
343 Sharding targetSharding = targetShardingInMoveLastAxis(
344 ctx, std::move(sourceSharding), sourceTensorAxis, targetTensorAxis);
345 ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
346 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
347 targetTensorAxis);
348 Value allToAllResult = AllToAllOp::create(
349 builder,
350 RankedTensorType::get(allToAllResultShape.getShape(),
351 allToAllResultShape.getElementType()),
352 grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard,
353 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
354 ShapedType targetShape =
355 shardShapedType(sourceUnshardedShape, grid, targetSharding);
356 TypedValue<ShapedType> targetShard =
357 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
358 return {targetShard, targetSharding};
359}
360
361static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
363 const Sharding &sourceSharding,
364 Sharding targetSharding,
365 ShapedType sourceUnshardedShape,
366 TypedValue<ShapedType> sourceShard) {
367 if (auto detectRes = detectMoveLastSplitAxisInResharding(
368 sourceSharding, std::move(targetSharding))) {
369 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
371 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
372 sourceTensorAxis, targetTensorAxis, gridAxis);
373 }
374
375 return std::nullopt;
376}
377
378// Detect a change in the halo size (only) and create necessary operations if
379// needed. A changed halo sizes requires copying the "core" of the source tensor
380// into the "core" of the destination tensor followed by an update halo
381// operation.
382static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
384 const Sharding &sourceSharding,
385 const Sharding &targetSharding,
386 ShapedType sourceUnshardedShape,
387 TypedValue<ShapedType> sourceShard) {
388 // Currently handles only cases where halo sizes differ but everything else
389 // stays the same (from source to destination sharding).
390 if (!sourceSharding.equalSplitAxes(targetSharding) ||
391 !sourceSharding.getStaticShardedDimsOffsets().empty() ||
392 !targetSharding.getStaticShardedDimsOffsets().empty() ||
393 sourceSharding.equalHaloSizes(targetSharding)) {
394 return std::nullopt;
395 }
396
397 auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
398 auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
399 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
400 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
401 ShapedType::isStaticShape(tgtHaloSizes) &&
402 sourceShard.getType().hasStaticShape()) &&
403 "dynamic shapes/halos are not supported yet for shard-partition");
404 auto rank = sourceShard.getType().getRank();
405 auto splitAxes = sourceSharding.getSplitAxes();
406 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
407 strides(rank, 1), outShape(sourceShard.getType().getShape()),
408 coreShape(sourceShard.getType().getShape());
409
410 // Determine "core" of source and destination.
411 // The core is the local part of the shard excluding halo regions.
412 for (auto i = 0u; i < rank; ++i) {
413 if (i < splitAxes.size() && !splitAxes[i].empty()) {
414 if (!srcHaloSizes.empty()) {
415 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
416 srcCoreOffs[i] = srcHaloSizes[i * 2];
417 }
418 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
419 outShape[i] =
420 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
421 }
422 }
423
424 // Extract core from source and copy into destination core.
425 auto noVals = ValueRange{};
426 auto initVal =
427 tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape,
428 sourceShard.getType().getElementType());
429 auto core = tensor::ExtractSliceOp::create(
430 builder, sourceShard.getLoc(),
431 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
432 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
433 auto initOprnd = tensor::InsertSliceOp::create(
434 builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
435 tgtCoreOffs, coreShape, strides);
436
437 // Finally update the halo.
438 auto updateHaloResult =
439 UpdateHaloOp::create(
440 builder, sourceShard.getLoc(),
441 RankedTensorType::get(outShape,
442 sourceShard.getType().getElementType()),
443 initOprnd, grid.getSymName(),
444 GridAxesArrayAttr::get(builder.getContext(),
445 sourceSharding.getSplitAxes()),
446 targetSharding.getDynamicHaloSizes(),
447 targetSharding.getStaticHaloSizes())
448 .getResult();
449 return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
450 targetSharding);
451}
452
453// In most cases the sharded tensor axes must be exactly divisible by the single
454// grid axis size. Only halo size changes can deal with non-divisible cases.
456reshard(ImplicitLocOpBuilder &builder, GridOp grid,
457 const Sharding &sourceSharding, const Sharding &targetSharding,
458 TypedValue<ShapedType> sourceUnshardedValue,
459 TypedValue<ShapedType> sourceShard) {
460 // If source and destination sharding are the same, no need to do anything.
461 if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
462 isFullReplication(targetSharding))) {
463 return sourceShard;
464 }
465
466 // Tries to handle the case where the resharding is needed because the halo
467 // sizes are different. Supports arbitrary grid dimensionality.
468 if (auto tryRes = tryUpdateHaloInResharding(
469 builder, grid, sourceSharding, targetSharding,
470 sourceUnshardedValue.getType(), sourceShard)) {
471 return std::get<0>(tryRes.value()); // targetShard
472 }
473
474 assert(sourceShard.getType() ==
475 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
476 [[maybe_unused]] ShapedType targetShardType =
477 shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
478 assert(sourceShard.getType().getRank() == targetShardType.getRank());
479
480 TypedValue<ShapedType> targetShard;
481 Sharding actualTargetSharding;
482 if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
483 targetSharding.getStaticShardedDimsOffsets().empty() &&
484 sourceSharding.getStaticHaloSizes().empty() &&
485 targetSharding.getStaticHaloSizes().empty()) {
486 if (auto tryRes = tryMoveLastSplitAxisInResharding(
487 builder, grid, sourceSharding, targetSharding,
488 sourceUnshardedValue.getType(), sourceShard)) {
489 std::tie(targetShard, actualTargetSharding) = tryRes.value();
490 } else if (auto tryRes =
491 trySplitLastAxisInResharding(builder, grid, sourceSharding,
492 targetSharding, sourceShard)) {
493 std::tie(targetShard, actualTargetSharding) = tryRes.value();
494 } else if (auto tryRes = tryUnsplitLastAxesInResharding(
495 builder, grid, sourceSharding, targetSharding,
496 sourceUnshardedValue.getType(), sourceShard)) {
497 std::tie(targetShard, actualTargetSharding) = tryRes.value();
498 }
499 }
500
501 assert(targetShard && "Did not find any pattern to apply.");
502 assert(actualTargetSharding == targetSharding);
503 assert(targetShard.getType() == targetShardType);
504 return targetShard;
505}
506
507TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
508 ShardOp target,
509 TypedValue<ShapedType> sourceShardValue) {
510 assert(source.getResult() == target.getSrc());
511 auto sourceSharding = source.getSharding();
512 auto targetSharding = target.getSharding();
513 ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
514 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
515 source.getSrc(), sourceShardValue);
516}
517
518TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
519 ShardOp target,
520 TypedValue<ShapedType> sourceShardValue,
521 SymbolTableCollection &symbolTableCollection) {
522 GridOp srcGrid = getGrid(source, symbolTableCollection);
523 assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
524 return reshard(builder, srcGrid, source, target, sourceShardValue);
525}
526
528 registry.insert<shard::ShardDialect, tensor::TensorDialect>();
529}
530
531#define GEN_PASS_DEF_PARTITION
532#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
533
535
536// Get the types of block arguments for an partitioned block.
537// Reads the sharding annotations of the arguments to deduce the sharded types.
538// Types that are not ranked tensors are left unchanged.
541 SymbolTableCollection &symbolTableCollection) {
543 llvm::transform(
544 block.getArguments(), std::back_inserter(res),
545 [&symbolTableCollection](BlockArgument arg) {
546 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
547 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
548 return arg.getType();
549 }
550
551 assert(rankedTensorArg.hasOneUse());
552 Operation *useOp = *rankedTensorArg.getUsers().begin();
553 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
554 assert(shardOp);
555 GridOp grid = getGrid(shardOp, symbolTableCollection);
556 return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
557 shardOp.getSharding()));
558 });
559 return res;
560}
561
562static LogicalResult
564 ArrayRef<Sharding> operandShardings,
565 ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
566 SymbolTableCollection &symbolTableCollection,
567 OpBuilder &builder) {
568 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
569 if (!shardingInterface) {
570 // If there is no sharding interface we are conservative and assume that
571 // the op should be fully replicated no all devices.
572 partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
573 resultShardings, partitionMap,
574 symbolTableCollection, builder);
575 } else {
576 if (failed(shardingInterface.partition(
577 partitionedOperands, operandShardings, resultShardings,
578 partitionMap, symbolTableCollection, builder))) {
579 return failure();
580 }
581 }
582
583 assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
584 return partitionMap.contains(result);
585 }));
586
587 return success();
588}
589
590// Retrieve the sharding annotations for the operands of the given operation.
591// If the type is not a ranked tensor it is not require to have an annotation.
592static std::vector<Sharding> getOperandShardings(Operation &op) {
593 std::vector<Sharding> res;
594 res.reserve(op.getNumOperands());
595 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
596 TypedValue<RankedTensorType> rankedTensor =
597 dyn_cast<TypedValue<RankedTensorType>>(operand);
598 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
599 return Sharding();
600 }
601
602 Operation *definingOp = operand.getDefiningOp();
603 assert(definingOp);
604 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
605 return Sharding(shardOp.getSharding());
606 });
607 return res;
608}
609
610// Retrieve the sharding annotations for the results of the given operation.
611// If the type is not a ranked tensor it is not require to have an annotation.
612static std::vector<Sharding> getResultShardings(Operation &op) {
613 std::vector<Sharding> res;
614 res.reserve(op.getNumResults());
615 llvm::transform(
616 op.getResults(), std::back_inserter(res), [&op](OpResult result) {
617 if (!result.hasOneUse() || result.use_empty()) {
618 return Sharding();
619 }
620 TypedValue<RankedTensorType> rankedTensor =
622 if (!rankedTensor) {
623 return Sharding();
624 }
625 Operation *userOp = *result.getUsers().begin();
626 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
627 if (shardOp) {
628 return Sharding(shardOp.getSharding());
629 }
630 if (rankedTensor.getType().getRank() == 0) {
631 // This is a 0d tensor result without explicit sharding.
632 // Find grid symbol from operands, if any.
633 // Shardings without grid are not always fully supported yet.
634 for (auto operand : op.getOperands()) {
635 if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
636 return Sharding(sharding.getGridAttr());
637 }
638 }
639 }
640 return Sharding();
641 });
642 return res;
643}
644
645static LogicalResult
646partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
647 SymbolTableCollection &symbolTableCollection,
648 OpBuilder &builder) {
649 Value targetPartitionValue;
650
651 // Check if 2 shard ops are chained. If not there is no need for resharding
652 // as the source and target shared the same sharding.
653 ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
654 if (!srcShardOp) {
655 targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
656 } else {
657 // Insert resharding.
658 TypedValue<ShapedType> srcPartitionValue =
659 cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
660 targetPartitionValue = reshard(builder, srcShardOp, shardOp,
661 srcPartitionValue, symbolTableCollection);
662 }
663
664 assert(!partitionMap.contains(shardOp.getResult()));
665 partitionMap.map(shardOp.getResult(), targetPartitionValue);
666 return success();
667}
668
669// Check if the block args are correctly annotated with sharding information:
670// - non-tensor and 0d-tensor args are ignored
671// - each tensor arg must have exactly one use, which must be a shard.shard
672// operation
673static LogicalResult checkFullyAnnotated(Block &block) {
674 for (const BlockArgument &arg : block.getArguments()) {
675 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
676 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
677 continue;
678
679 if (rankedTensorArg.getNumUses() > 1)
680 return emitError(block.getParent()->getLoc())
681 << "Cannot partition: expected a single use for block argument "
682 << arg.getArgNumber() << " in block "
683 << block.computeBlockNumber();
684 Operation *useOp = *rankedTensorArg.getUsers().begin();
685 auto shardOp = dyn_cast<ShardOp>(useOp);
686 if (!shardOp)
687 return emitError(block.getParent()->getLoc())
688 << "Cannot partition: expected a shard.shard op for block "
689 << "argument " << arg.getArgNumber() << " in block "
690 << block.computeBlockNumber();
691 }
692 return success();
693}
694
695// Check if the operation is correctly and fully annotated with sharding
696// information:
697// - Operation results must have exactly one use (e.g. the shard operation).
698// - All operands and all results must be annotated, e.g. they must be
699// produced by/consumed by a shard.shard operation.
700// - Result annotations must not include the 'annotate_for_users' attribute.
701// - Operand annotations must include the 'annotate_for_users' attribute.
702// raises an error if the operation is not correctly and fully annotated.
703static LogicalResult checkFullyAnnotated(Operation *op) {
704 // constant ops do not need to have sharding annotations
706 return success();
707
708 for (OpOperand &operand : op->getOpOperands()) {
709 // non-tensor and 0d-tensor operands are ignored
710 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
711 if (!rankedTT || rankedTT.getRank() == 0)
712 continue;
713
714 auto shard = operand.get().getDefiningOp<ShardOp>();
715 if (!shard)
716 return op->emitError() << "Cannot partition: tensor operand "
717 << operand.getOperandNumber()
718 << " must be defined by a shard.shard operation.";
719 if (!shard.getAnnotateForUsers())
720 return op->emitError()
721 << "Cannot partition: shard.shard for operand "
722 << operand.getOperandNumber() << " must set 'annotate_for_users'.";
723 }
724 for (const OpResult &result : op->getResults()) {
725 if (!result.hasOneUse())
726 return op->emitError()
727 << "Cannot partition: result " << result.getResultNumber()
728 << " must have exactly one use.";
729 auto shard = dyn_cast<ShardOp>(*result.user_begin());
730 if (!shard)
731 return op->emitError()
732 << "Cannot partition: user of result " << result.getResultNumber()
733 << " must be shard.shard operation.";
734 if (shard.getAnnotateForUsers())
735 return op->emitError() << "Cannot partition: shard.shard for result "
736 << result.getResultNumber()
737 << " must not set 'annotate_for_users'.";
738 }
739 return success();
740}
741
742static LogicalResult
744 SymbolTableCollection &symbolTableCollection,
745 OpBuilder &builder) {
746 if (isa<ShardingOp>(op)) {
747 return success();
748 }
749
750 if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
751 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
752 if (!shardOp) {
753 return op.emitError("expected a shard op as source of get_sharding");
754 }
755 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
756 partitionMap.map(op.getResult(0), newSharding->getResult(0));
757 return success();
758 }
759
760 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
761 if (shardOp) {
762 return partitionOperation(shardOp, partitionMap, symbolTableCollection,
763 builder);
764 }
765
766 // Check if operation is correctly and fully annotated.
767 if (failed(checkFullyAnnotated(&op)))
768 return failure();
769
770 SmallVector<Value> partitionedOperands;
771 llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
772 [&partitionMap](Value operand) {
773 assert(partitionMap.contains(operand));
774 return partitionMap.lookup(operand);
775 });
776 return partitionOperation(op, partitionedOperands, getOperandShardings(op),
777 getResultShardings(op), partitionMap,
778 symbolTableCollection, builder);
779}
780
781static LogicalResult
782partitionBlock(Block &block, IRMapping &partitionMap,
783 SymbolTableCollection &symbolTableCollection,
784 OpBuilder &builder) {
785
786 if (failed(checkFullyAnnotated(block)))
787 return failure();
788
789 SmallVector<Location> argLocations;
790 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
791 [](BlockArgument arg) { return arg.getLoc(); });
792 Block *newBlock = builder.createBlock(
793 block.getParent(), {},
794 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
795 for (auto [unshardedBlockArg, partitionedBlockArg] :
796 llvm::zip(block.getArguments(), newBlock->getArguments())) {
797 partitionMap.map(unshardedBlockArg, partitionedBlockArg);
798 }
799
800 OpBuilder::InsertionGuard insertionGuard(builder);
801 builder.setInsertionPointToEnd(newBlock);
802 for (Operation &op : block.getOperations()) {
803 if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
804 builder))) {
805 return failure();
806 }
807 }
808
809 return success();
810}
811
812static LogicalResult
813partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
814 SymbolTableCollection &symbolTableCollection) {
815 OpBuilder builder(op.getFunctionBody());
816
817 // Snapshot the original blocks to not mess up the iteration when adding new
818 // blocks.
819 SmallVector<Block *> originalBlocks;
820 for (Block &b : op.getBlocks()) {
821 if (llvm::any_of(b.getOperations(),
822 [](Operation &op) { return isa<ShardOp>(op); })) {
823 originalBlocks.push_back(&b);
824 }
825 }
826
827 for (Block *block : originalBlocks) {
828 if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
829 builder))) {
830 return failure();
831 }
832 }
833
834 for (Block *block : originalBlocks) {
835 block->erase();
836 }
837
838 // Find a return op and change the function results signature to its operands
839 // signature.
840 Operation *returnOp = nullptr;
841 for (Block &block : op.getFunctionBody()) {
842 if (block.empty()) {
843 continue;
844 }
845
846 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
847 returnOp = &block.back();
848 break;
849 }
850 }
851 if (returnOp) {
852 op.setType(FunctionType::get(
853 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
854 returnOp->getOperandTypes()));
855 }
856
857 return success();
858}
859
860namespace {
861
862struct Partition : public impl::PartitionBase<Partition> {
863 void runOnOperation() override {
864 IRMapping partitionMap;
865 SymbolTableCollection symbolTableCollection;
866 if (failed(partitionFuncOp(getOperation(), partitionMap,
867 symbolTableCollection))) {
868 return signalPassFailure();
869 }
870 }
871
872 void getDependentDialects(DialectRegistry &registry) const override {
874 registry.insert<shard::ShardDialect>();
875 }
876};
877
878} // namespace
879
880} // namespace mlir::shard
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
BlockArgListType getArguments()
Definition Block.h:97
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
Definition Block.cpp:144
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:769
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:688
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:63
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:60
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:67
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:64
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:727
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorDim, ArrayRef< GridAxis > unsplitAxes)
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:69
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultTypeInUnsplitLastAxes(ShapedType sourceType, int64_t splitTensorDim, ArrayRef< int64_t > gridShape, ArrayRef< GridAxis > unsplitAxes)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult checkFullyAnnotated(Block &block)
static std::optional< std::tuple< int64_t, SmallVector< GridAxis > > > detectUnsplitLastAxesInResharding(const Sharding &srcSharding, const Sharding &tgtSharding)
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:47
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition Partition.cpp:40
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:106
int16_t GridAxis
Definition ShardOps.h:26
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::vector< Sharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:168
static Sharding targetShardingInUnsplitLastAxes(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorDim, size_t numUnsplitAxes)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
Definition Partition.cpp:89
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:177
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
This trait indicates that a terminator operation is "return-like".