MLIR 22.0.0git
Partition.cpp
Go to the documentation of this file.
1//===- Partition.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
32#include <iterator>
33#include <optional>
34#include <tuple>
36namespace mlir::shard {
37
38template <typename SourceAxes, typename TargetAxes>
39static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
40 const TargetAxes &targetAxes) {
41 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
42 return sourceAxes.contains(targetAxis);
43 });
44}
45
47 Sharding sourceSharding,
48 int64_t splitTensorAxis,
49 GridAxis splitGridAxis) {
50 SmallVector<GridAxesAttr> targetShardingSplitAxes =
51 llvm::to_vector(sourceSharding.getSplitAxes());
52 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
53 splitTensorAxis) {
54 targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
55 }
56 auto targetSplitAxes =
57 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
58 targetSplitAxes.push_back(splitGridAxis);
59 targetShardingSplitAxes[splitTensorAxis] =
60 GridAxesAttr::get(ctx, targetSplitAxes);
61 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
63
64// Split a replicated tensor along a grid axis.
65// E.g. [[0, 1]] -> [[0, 1, 2]].
66// Returns the partitioned target value with its sharding.
67static std::tuple<TypedValue<ShapedType>, Sharding>
69 Sharding sourceSharding,
70 TypedValue<ShapedType> sourceShard, GridOp grid,
71 int64_t splitTensorAxis, GridAxis splitGridAxis) {
72 TypedValue<ShapedType> targetShard =
73 AllSliceOp::create(builder, sourceShard, grid,
74 ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
75 .getResult();
77 builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
78 return {targetShard, targetSharding};
79}
81// Detect if the resharding is of type e.g.
82// [[0, 1]] -> [[0, 1, 2]].
83// If detected, returns the corresponding tensor axis grid axis pair.
84// Does not detect insertions like
85// [[0, 1]] -> [[0, 2, 1]].
86static std::optional<std::tuple<int64_t, GridAxis>>
88 Sharding targetSharding) {
89 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
90 ++tensorAxis) {
91 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
92 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
93 targetSharding.getSplitAxes()[tensorAxis].size()) {
94 continue;
95 }
96 if (!llvm::equal(
97 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
98 llvm::make_range(
99 targetSharding.getSplitAxes()[tensorAxis]
100 .asArrayRef()
101 .begin(),
102 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
103 1))) {
104 continue;
105 }
106 } else {
107 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
108 continue;
109 }
110 }
111 return std::make_tuple(
112 tensorAxis,
113 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
114 }
115 return std::nullopt;
116}
117
118static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
120 Sharding sourceSharding, Sharding targetSharding,
121 TypedValue<ShapedType> sourceShard) {
122 if (auto detectRes =
123 detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
124 auto [tensorAxis, gridAxis] = detectRes.value();
125 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid,
126 tensorAxis, gridAxis);
127 }
128
129 return std::nullopt;
130}
131
132// Detect if the resharding is of type e.g.
133// [[0, 1, 2]] -> [[0, 1]].
134// If detected, returns the corresponding tensor axis grid axis pair.
135static std::optional<std::tuple<int64_t, GridAxis>>
137 Sharding targetSharding) {
138 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
139 ++tensorAxis) {
140 if (targetSharding.getSplitAxes().size() > tensorAxis) {
141 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
142 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
143 continue;
144 if (!llvm::equal(
145 llvm::make_range(
146 sourceSharding.getSplitAxes()[tensorAxis]
147 .asArrayRef()
148 .begin(),
149 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
150 1),
151 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
152 continue;
153 } else {
154 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
155 continue;
156 }
157 return std::make_tuple(
158 tensorAxis,
159 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
160 }
161 return std::nullopt;
162}
163
165 Sharding sourceSharding,
166 int64_t splitTensorAxis) {
167 SmallVector<GridAxesAttr> targetShardingSplitAxes =
168 llvm::to_vector(sourceSharding.getSplitAxes());
169 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
170 splitTensorAxis);
171 auto targetSplitAxes =
172 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
173
174 targetSplitAxes.pop_back();
175 targetShardingSplitAxes[splitTensorAxis] =
176 GridAxesAttr::get(ctx, targetSplitAxes);
177 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
178}
179
181 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
182 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
183 targetShape[splitTensorAxis] =
184 gatherDimension(targetShape[splitTensorAxis], splitCount);
185 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
186}
187
188static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
189 ImplicitLocOpBuilder &builder, Sharding sourceSharding,
190 ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
191 GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
192 MLIRContext *ctx = builder.getContext();
193 builder.setInsertionPointAfterValue(sourceShard);
194
195 Sharding targetSharding =
196 targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
197 ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
198 sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
199 Value allGatherResult = AllGatherOp::create(
200 builder,
201 RankedTensorType::get(allGatherResultShape.getShape(),
202 allGatherResultShape.getElementType()),
203 grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
204 APInt(64, splitTensorAxis));
205 ShapedType targetShape =
206 shardShapedType(sourceUnshardedShape, grid, targetSharding);
207 TypedValue<ShapedType> targetShard =
208 tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
209 return {targetShard, targetSharding};
210}
211
212static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
214 Sharding sourceSharding, Sharding targetSharding,
215 ShapedType sourceUnshardedShape,
216 TypedValue<ShapedType> sourceShard) {
217 if (auto detectRes =
218 detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
219 auto [tensorAxis, gridAxis] = detectRes.value();
220 return unsplitLastAxisInResharding(builder, sourceSharding,
221 sourceUnshardedShape, sourceShard, grid,
222 tensorAxis, gridAxis);
223 }
224
225 return std::nullopt;
226}
227
228// Detect if the resharding is of type e.g.
229// [[0, 1], [2]] -> [[0], [1, 2]].
230// Only moving the last axis counts.
231// If detected, returns the corresponding (source_tensor_axis,
232// target_tensor_axis, grid_axis) tuple.
233static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
235 Sharding targetSharding) {
236 for (size_t sourceTensorAxis = 0;
237 sourceTensorAxis < sourceSharding.getSplitAxes().size();
238 ++sourceTensorAxis) {
239 for (size_t targetTensorAxis = 0;
240 targetTensorAxis < targetSharding.getSplitAxes().size();
241 ++targetTensorAxis) {
242 if (sourceTensorAxis == targetTensorAxis)
243 continue;
244 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
245 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
246 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
247 targetSharding.getSplitAxes()[targetTensorAxis]
248 .asArrayRef()
249 .back())
250 continue;
251 if (!llvm::equal(
252 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
253 .asArrayRef()
254 .begin(),
255 sourceSharding.getSplitAxes()[sourceTensorAxis]
256 .asArrayRef()
257 .end() -
258 1),
259 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
260 .asArrayRef()
261 .begin(),
262 targetSharding.getSplitAxes()[targetTensorAxis]
263 .asArrayRef()
264 .end() -
265 1)))
266 continue;
267 return std::make_tuple(
268 sourceTensorAxis, targetTensorAxis,
269 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
270 }
271 }
272 return std::nullopt;
273}
274
276 Sharding sourceSharding,
277 int64_t sourceTensorAxis,
278 int64_t targetTensorAxis) {
279 SmallVector<GridAxesAttr> targetShardingSplitAxes =
280 llvm::to_vector(sourceSharding.getSplitAxes());
281 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
282 targetTensorAxis) {
283 targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
284 }
285
286 auto sourceSplitAxes =
287 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
288 assert(!sourceSplitAxes.empty());
289 auto gridAxis = sourceSplitAxes.back();
290 sourceSplitAxes.pop_back();
291 targetShardingSplitAxes[sourceTensorAxis] =
292 GridAxesAttr::get(ctx, sourceSplitAxes);
293
294 auto targetSplitAxes =
295 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
296 targetSplitAxes.push_back(gridAxis);
297 targetShardingSplitAxes[targetTensorAxis] =
298 GridAxesAttr::get(ctx, targetSplitAxes);
299
300 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
301}
302
303static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
304 int64_t splitCount,
305 int64_t sourceTensorAxis,
306 int64_t targetTensorAxis) {
307 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
308 targetShape[sourceTensorAxis] =
309 gatherDimension(targetShape[sourceTensorAxis], splitCount);
310 targetShape[targetTensorAxis] =
311 shardDimension(targetShape[targetTensorAxis], splitCount);
312 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
313}
314
315static std::tuple<TypedValue<ShapedType>, Sharding>
317 Sharding sourceSharding,
318 ShapedType sourceUnshardedShape,
319 TypedValue<ShapedType> sourceShard,
320 int64_t sourceTensorAxis,
321 int64_t targetTensorAxis, GridAxis gridAxis) {
322 MLIRContext *ctx = builder.getContext();
323 builder.setInsertionPointAfterValue(sourceShard);
324
325 Sharding targetSharding = targetShardingInMoveLastAxis(
326 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
327 ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
328 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
329 targetTensorAxis);
330 Value allToAllResult = AllToAllOp::create(
331 builder,
332 RankedTensorType::get(allToAllResultShape.getShape(),
333 allToAllResultShape.getElementType()),
334 grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard,
335 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
336 ShapedType targetShape =
337 shardShapedType(sourceUnshardedShape, grid, targetSharding);
338 TypedValue<ShapedType> targetShard =
339 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
340 return {targetShard, targetSharding};
341}
342
343static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
345 Sharding sourceSharding,
346 Sharding targetSharding,
347 ShapedType sourceUnshardedShape,
348 TypedValue<ShapedType> sourceShard) {
349 if (auto detectRes =
350 detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
351 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
353 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
354 sourceTensorAxis, targetTensorAxis, gridAxis);
355 }
356
357 return std::nullopt;
358}
359
360// Detect a change in the halo size (only) and create necessary operations if
361// needed. A changed halo sizes requires copying the "core" of the source tensor
362// into the "core" of the destination tensor followed by an update halo
363// operation.
364static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
366 Sharding sourceSharding, Sharding targetSharding,
367 ShapedType sourceUnshardedShape,
368 TypedValue<ShapedType> sourceShard) {
369 // Currently handles only cases where halo sizes differ but everything else
370 // stays the same (from source to destination sharding).
371 if (!sourceSharding.equalSplitAxes(targetSharding) ||
372 !sourceSharding.getStaticShardedDimsOffsets().empty() ||
373 !targetSharding.getStaticShardedDimsOffsets().empty() ||
374 sourceSharding.equalHaloSizes(targetSharding)) {
375 return std::nullopt;
376 }
377
378 auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
379 auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
380 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
381 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
382 ShapedType::isStaticShape(tgtHaloSizes) &&
383 sourceShard.getType().hasStaticShape()) &&
384 "dynamic shapes/halos are not supported yet for shard-partition");
385 auto rank = sourceShard.getType().getRank();
386 auto splitAxes = sourceSharding.getSplitAxes();
387 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
388 strides(rank, 1), outShape(sourceShard.getType().getShape()),
389 coreShape(sourceShard.getType().getShape());
390
391 // Determine "core" of source and destination.
392 // The core is the local part of the shard excluding halo regions.
393 for (auto i = 0u; i < rank; ++i) {
394 if (i < splitAxes.size() && !splitAxes[i].empty()) {
395 if (!srcHaloSizes.empty()) {
396 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
397 srcCoreOffs[i] = srcHaloSizes[i * 2];
398 }
399 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
400 outShape[i] =
401 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
402 }
403 }
404
405 // Extract core from source and copy into destination core.
406 auto noVals = ValueRange{};
407 auto initVal =
408 tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape,
409 sourceShard.getType().getElementType());
410 auto core = tensor::ExtractSliceOp::create(
411 builder, sourceShard.getLoc(),
412 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
413 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
414 auto initOprnd = tensor::InsertSliceOp::create(
415 builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
416 tgtCoreOffs, coreShape, strides);
417
418 // Finally update the halo.
419 auto updateHaloResult =
420 UpdateHaloOp::create(
421 builder, sourceShard.getLoc(),
422 RankedTensorType::get(outShape,
423 sourceShard.getType().getElementType()),
424 initOprnd, grid.getSymName(),
425 GridAxesArrayAttr::get(builder.getContext(),
426 sourceSharding.getSplitAxes()),
427 targetSharding.getDynamicHaloSizes(),
428 targetSharding.getStaticHaloSizes())
429 .getResult();
430 return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
431 targetSharding);
432}
433
434// Handles only resharding on a 1D shard.
435// Currently the sharded tensor axes must be exactly divisible by the single
436// grid axis size.
439 Sharding sourceSharding, Sharding targetSharding,
440 TypedValue<ShapedType> sourceUnshardedValue,
441 TypedValue<ShapedType> sourceShard) {
442 assert(sourceShard.getType() ==
443 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
444 [[maybe_unused]] ShapedType targetShardType =
445 shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
446 assert(sourceShard.getType().getRank() == targetShardType.getRank());
447 assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
448
449 if (sourceSharding == targetSharding) {
450 return sourceShard;
451 }
452
453 TypedValue<ShapedType> targetShard;
454 Sharding actualTargetSharding;
455 if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
456 targetSharding.getStaticShardedDimsOffsets().empty() &&
457 sourceSharding.getStaticHaloSizes().empty() &&
458 targetSharding.getStaticHaloSizes().empty()) {
459 if (auto tryRes = tryMoveLastSplitAxisInResharding(
460 builder, grid, sourceSharding, targetSharding,
461 sourceUnshardedValue.getType(), sourceShard)) {
462 std::tie(targetShard, actualTargetSharding) = tryRes.value();
463 } else if (auto tryRes =
464 trySplitLastAxisInResharding(builder, grid, sourceSharding,
465 targetSharding, sourceShard)) {
466 std::tie(targetShard, actualTargetSharding) = tryRes.value();
467 } else if (auto tryRes = tryUnsplitLastAxisInResharding(
468 builder, grid, sourceSharding, targetSharding,
469 sourceUnshardedValue.getType(), sourceShard)) {
470 std::tie(targetShard, actualTargetSharding) = tryRes.value();
471 }
472 }
473 assert(targetShard && "Did not find any pattern to apply.");
474 assert(actualTargetSharding == targetSharding);
475 assert(targetShard.getType() == targetShardType);
476 return targetShard;
477}
478
480reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding,
481 Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue,
482 TypedValue<ShapedType> sourceShard) {
483 // If source and destination sharding are the same, no need to do anything.
484 if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
485 isFullReplication(targetSharding))) {
486 return sourceShard;
487 }
488
489 // Tries to handle the case where the resharding is needed because the halo
490 // sizes are different. Supports arbitrary grid dimensionality.
491 if (auto tryRes = tryUpdateHaloInResharding(
492 builder, grid, sourceSharding, targetSharding,
493 sourceUnshardedValue.getType(), sourceShard)) {
494 return std::get<0>(tryRes.value()); // targetShard
495 }
496
497 // Resort to handling only 1D grids since the general case is complicated if
498 // it needs to be communication efficient in terms of minimizing the data
499 // transfered between devices.
500 return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
501 sourceUnshardedValue, sourceShard);
502}
503
504TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
505 ShardOp target,
506 TypedValue<ShapedType> sourceShardValue) {
507 assert(source.getResult() == target.getSrc());
508 auto sourceSharding = source.getSharding();
509 auto targetSharding = target.getSharding();
510 ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
511 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
512 source.getSrc(), sourceShardValue);
513}
514
515TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
516 ShardOp target,
517 TypedValue<ShapedType> sourceShardValue,
518 SymbolTableCollection &symbolTableCollection) {
519 GridOp srcGrid = getGrid(source, symbolTableCollection);
520 assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
521 return reshard(builder, srcGrid, source, target, sourceShardValue);
522}
523
525 registry.insert<shard::ShardDialect, tensor::TensorDialect>();
526}
527
528#define GEN_PASS_DEF_PARTITION
529#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
530
532
533// Get the types of block arguments for an partitioned block.
534// Reads the sharding annotations of the arguments to deduce the sharded types.
535// Types that are not ranked tensors are left unchanged.
538 SymbolTableCollection &symbolTableCollection) {
540 llvm::transform(
541 block.getArguments(), std::back_inserter(res),
542 [&symbolTableCollection](BlockArgument arg) {
543 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
544 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
545 return arg.getType();
546 }
547
548 assert(rankedTensorArg.hasOneUse());
549 Operation *useOp = *rankedTensorArg.getUsers().begin();
550 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
551 assert(shardOp);
552 GridOp grid = getGrid(shardOp, symbolTableCollection);
553 return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
554 shardOp.getSharding()));
555 });
556 return res;
557}
558
559static LogicalResult
561 ArrayRef<Sharding> operandShardings,
562 ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
563 SymbolTableCollection &symbolTableCollection,
564 OpBuilder &builder) {
565 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
566 if (!shardingInterface) {
567 // If there is no sharding interface we are conservative and assume that
568 // the op should be fully replicated no all devices.
569 partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
570 resultShardings, partitionMap,
571 symbolTableCollection, builder);
572 } else {
573 if (failed(shardingInterface.partition(
574 partitionedOperands, operandShardings, resultShardings,
575 partitionMap, symbolTableCollection, builder))) {
576 return failure();
577 }
578 }
579
580 assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
581 return partitionMap.contains(result);
582 }));
583
584 return success();
585}
586
587// Retrieve the sharding annotations for the operands of the given operation.
588// If the type is not a ranked tensor it is not require to have an annotation.
589static std::vector<Sharding> getOperandShardings(Operation &op) {
590 std::vector<Sharding> res;
591 res.reserve(op.getNumOperands());
592 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
593 TypedValue<RankedTensorType> rankedTensor =
594 dyn_cast<TypedValue<RankedTensorType>>(operand);
595 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
596 return Sharding();
597 }
598
599 Operation *definingOp = operand.getDefiningOp();
600 assert(definingOp);
601 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
602 return Sharding(shardOp.getSharding());
603 });
604 return res;
605}
606
607// Retrieve the sharding annotations for the results of the given operation.
608// If the type is not a ranked tensor it is not require to have an annotation.
609static std::vector<Sharding> getResultShardings(Operation &op) {
610 std::vector<Sharding> res;
611 res.reserve(op.getNumResults());
612 llvm::transform(
613 op.getResults(), std::back_inserter(res), [&op](OpResult result) {
614 if (!result.hasOneUse() || result.use_empty()) {
615 return Sharding();
616 }
617 TypedValue<RankedTensorType> rankedTensor =
619 if (!rankedTensor) {
620 return Sharding();
621 }
622 Operation *userOp = *result.getUsers().begin();
623 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
624 if (shardOp) {
625 return Sharding(shardOp.getSharding());
626 }
627 if (rankedTensor.getType().getRank() == 0) {
628 // This is a 0d tensor result without explicit sharding.
629 // Find grid symbol from operands, if any.
630 // Shardings without grid are not always fully supported yet.
631 for (auto operand : op.getOperands()) {
632 if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
633 return Sharding(sharding.getGridAttr());
634 }
635 }
636 }
637 return Sharding();
638 });
639 return res;
640}
641
642static LogicalResult
643partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
644 SymbolTableCollection &symbolTableCollection,
645 OpBuilder &builder) {
646 Value targetPartitionValue;
647
648 // Check if 2 shard ops are chained. If not there is no need for resharding
649 // as the source and target shared the same sharding.
650 ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
651 if (!srcShardOp) {
652 targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
653 } else {
654 // Insert resharding.
655 TypedValue<ShapedType> srcPartitionValue =
656 cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
657 targetPartitionValue = reshard(builder, srcShardOp, shardOp,
658 srcPartitionValue, symbolTableCollection);
659 }
660
661 assert(!partitionMap.contains(shardOp.getResult()));
662 partitionMap.map(shardOp.getResult(), targetPartitionValue);
663 return success();
664}
665
666static LogicalResult
668 SymbolTableCollection &symbolTableCollection,
669 OpBuilder &builder) {
670 if (isa<ShardingOp>(op)) {
671 return success();
672 }
673 if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
674 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
675 if (!shardOp) {
676 return op.emitError("expected a shard op as source of get_sharding");
677 }
678 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
679 partitionMap.map(op.getResult(0), newSharding->getResult(0));
680 return success();
681 }
682
683 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
684 if (shardOp) {
685 return partitionOperation(shardOp, partitionMap, symbolTableCollection,
686 builder);
687 }
688
689 SmallVector<Value> partitionedOperands;
690 llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
691 [&partitionMap](Value operand) {
692 assert(partitionMap.contains(operand));
693 return partitionMap.lookup(operand);
694 });
695 return partitionOperation(op, partitionedOperands, getOperandShardings(op),
696 getResultShardings(op), partitionMap,
697 symbolTableCollection, builder);
698}
699
700static LogicalResult
701partitionBlock(Block &block, IRMapping &partitionMap,
702 SymbolTableCollection &symbolTableCollection,
703 OpBuilder &builder) {
704
705 SmallVector<Location> argLocations;
706 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
707 [](BlockArgument arg) { return arg.getLoc(); });
708 Block *newBlock = builder.createBlock(
709 block.getParent(), {},
710 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
711 for (auto [unshardedBlockArg, partitionedBlockArg] :
712 llvm::zip(block.getArguments(), newBlock->getArguments())) {
713 partitionMap.map(unshardedBlockArg, partitionedBlockArg);
714 }
715
716 OpBuilder::InsertionGuard insertionGuard(builder);
717 builder.setInsertionPointToEnd(newBlock);
718 for (Operation &op : block.getOperations()) {
719 if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
720 builder))) {
721 return failure();
722 }
723 }
724
725 return success();
726}
727
728static LogicalResult
729partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
730 SymbolTableCollection &symbolTableCollection) {
731 OpBuilder builder(op.getFunctionBody());
732
733 // Snapshot the original blocks to not mess up the iteration when adding new
734 // blocks.
735 SmallVector<Block *> originalBlocks;
736 for (Block &b : op.getBlocks()) {
737 if (llvm::any_of(b.getOperations(),
738 [](Operation &op) { return isa<ShardOp>(op); })) {
739 originalBlocks.push_back(&b);
740 }
741 }
742
743 for (Block *block : originalBlocks) {
744 if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
745 builder))) {
746 return failure();
747 }
748 }
749
750 for (Block *block : originalBlocks) {
751 block->erase();
752 }
753
754 // Find a return op and change the function results signature to its operands
755 // signature.
756 Operation *returnOp = nullptr;
757 for (Block &block : op.getFunctionBody()) {
758 if (block.empty()) {
759 continue;
760 }
761
762 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
763 returnOp = &block.back();
764 break;
765 }
766 }
767 if (returnOp) {
768 op.setType(FunctionType::get(
769 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
770 returnOp->getOperandTypes()));
771 }
772
773 return success();
774}
775
776namespace {
777
778struct Partition : public impl::PartitionBase<Partition> {
779 void runOnOperation() override {
780 IRMapping partitionMap;
781 SymbolTableCollection symbolTableCollection;
782 if (failed(partitionFuncOp(getOperation(), partitionMap,
783 symbolTableCollection))) {
784 return signalPassFailure();
785 }
786 }
787
788 void getDependentDialects(DialectRegistry &registry) const override {
790 registry.insert<shard::ShardDialect>();
791 }
792};
793
794} // namespace
795
796} // namespace mlir::shard
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
BlockArgListType getArguments()
Definition Block.h:87
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:770
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:689
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:63
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:60
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:67
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:64
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:728
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:68
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition Partition.cpp:39
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:106
int16_t GridAxis
Definition ShardOps.h:26
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Sharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
Definition Partition.cpp:87
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:168
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:46
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:177
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
This trait indicates that a terminator operation is "return-like".