MLIR 23.0.0git
Partition.cpp
Go to the documentation of this file.
1//===- Partition.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
32#include <iterator>
33#include <optional>
34#include <tuple>
35#include <utility>
36
37namespace mlir::shard {
38
39template <typename SourceAxes, typename TargetAxes>
40static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
41 const TargetAxes &targetAxes) {
42 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
43 return sourceAxes.contains(targetAxis);
44 });
45}
46
48 const Sharding &sourceSharding,
49 int64_t splitTensorAxis,
50 GridAxis splitGridAxis) {
51 SmallVector<GridAxesAttr> targetShardingSplitAxes =
52 llvm::to_vector(sourceSharding.getSplitAxes());
53 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
54 splitTensorAxis) {
55 targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
56 }
57 auto targetSplitAxes =
58 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
59 targetSplitAxes.push_back(splitGridAxis);
60 targetShardingSplitAxes[splitTensorAxis] =
61 GridAxesAttr::get(ctx, targetSplitAxes);
62 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
63}
64
65// Split a replicated tensor along a grid axis.
66// E.g. [[0, 1]] -> [[0, 1, 2]].
67// Returns the partitioned target value with its sharding.
68static std::tuple<TypedValue<ShapedType>, Sharding>
70 Sharding sourceSharding,
71 TypedValue<ShapedType> sourceShard, GridOp grid,
72 int64_t splitTensorAxis, GridAxis splitGridAxis) {
73 TypedValue<ShapedType> targetShard =
74 AllSliceOp::create(builder, sourceShard, grid,
75 ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
76 .getResult();
78 builder.getContext(), std::move(sourceSharding), splitTensorAxis,
79 splitGridAxis);
80 return {targetShard, targetSharding};
81}
82
83// Detect if the resharding is of type e.g.
84// [[0, 1]] -> [[0, 1, 2]].
85// If detected, returns the corresponding tensor axis grid axis pair.
86// Does not detect insertions like
87// [[0, 1]] -> [[0, 2, 1]].
88static std::optional<std::tuple<int64_t, GridAxis>>
90 const Sharding &targetSharding) {
91 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
92 ++tensorAxis) {
93 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
94 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
95 targetSharding.getSplitAxes()[tensorAxis].size()) {
96 continue;
97 }
98 if (!llvm::equal(
99 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
100 llvm::make_range(
101 targetSharding.getSplitAxes()[tensorAxis]
102 .asArrayRef()
103 .begin(),
104 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
105 1))) {
106 continue;
107 }
108 } else {
109 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
110 continue;
111 }
112 }
113 return std::make_tuple(
114 tensorAxis,
115 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
116 }
117 return std::nullopt;
118}
119
120static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
122 const Sharding &sourceSharding,
123 Sharding targetSharding,
124 TypedValue<ShapedType> sourceShard) {
125 if (auto detectRes = detectSplitLastAxisInResharding(
126 sourceSharding, std::move(targetSharding))) {
127 auto [tensorAxis, gridAxis] = detectRes.value();
128 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid,
129 tensorAxis, gridAxis);
130 }
131
132 return std::nullopt;
133}
134
135// Detect if the resharding is of type e.g.
136// [[0, 1, 2]] -> [[0, 1]].
137// If detected, returns the corresponding tensor axis grid axis pair.
138static std::optional<std::tuple<int64_t, GridAxis>>
140 const Sharding &targetSharding) {
141 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
142 ++tensorAxis) {
143 if (targetSharding.getSplitAxes().size() > tensorAxis) {
144 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
145 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
146 continue;
147 if (!llvm::equal(
148 llvm::make_range(
149 sourceSharding.getSplitAxes()[tensorAxis]
150 .asArrayRef()
151 .begin(),
152 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
153 1),
154 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
155 continue;
156 } else {
157 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
158 continue;
159 }
160 return std::make_tuple(
161 tensorAxis,
162 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
163 }
164 return std::nullopt;
165}
166
168 const Sharding &sourceSharding,
169 int64_t splitTensorAxis) {
170 SmallVector<GridAxesAttr> targetShardingSplitAxes =
171 llvm::to_vector(sourceSharding.getSplitAxes());
172 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
173 splitTensorAxis);
174 auto targetSplitAxes =
175 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
176
177 targetSplitAxes.pop_back();
178 targetShardingSplitAxes[splitTensorAxis] =
179 GridAxesAttr::get(ctx, targetSplitAxes);
180 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
181}
182
184 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
185 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
186 targetShape[splitTensorAxis] =
187 gatherDimension(targetShape[splitTensorAxis], splitCount);
188 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
189}
190
191static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
192 ImplicitLocOpBuilder &builder, Sharding sourceSharding,
193 ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
194 GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
195 MLIRContext *ctx = builder.getContext();
196 builder.setInsertionPointAfterValue(sourceShard);
197
199 ctx, std::move(sourceSharding), splitTensorAxis);
200 ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
201 sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
202 Value allGatherResult = AllGatherOp::create(
203 builder,
204 RankedTensorType::get(allGatherResultShape.getShape(),
205 allGatherResultShape.getElementType()),
206 grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
207 APInt(64, splitTensorAxis));
208 ShapedType targetShape =
209 shardShapedType(sourceUnshardedShape, grid, targetSharding);
210 TypedValue<ShapedType> targetShard =
211 tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
212 return {targetShard, targetSharding};
213}
214
215static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
217 const Sharding &sourceSharding,
218 Sharding targetSharding,
219 ShapedType sourceUnshardedShape,
220 TypedValue<ShapedType> sourceShard) {
221 if (auto detectRes = detectUnsplitLastAxisInResharding(
222 sourceSharding, std::move(targetSharding))) {
223 auto [tensorAxis, gridAxis] = detectRes.value();
224 return unsplitLastAxisInResharding(builder, sourceSharding,
225 sourceUnshardedShape, sourceShard, grid,
226 tensorAxis, gridAxis);
227 }
228
229 return std::nullopt;
230}
231
232// Detect if the resharding is of type e.g.
233// [[0, 1], [2]] -> [[0], [1, 2]].
234// Only moving the last axis counts.
235// If detected, returns the corresponding (source_tensor_axis,
236// target_tensor_axis, grid_axis) tuple.
237static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
239 const Sharding &targetSharding) {
240 for (size_t sourceTensorAxis = 0;
241 sourceTensorAxis < sourceSharding.getSplitAxes().size();
242 ++sourceTensorAxis) {
243 for (size_t targetTensorAxis = 0;
244 targetTensorAxis < targetSharding.getSplitAxes().size();
245 ++targetTensorAxis) {
246 if (sourceTensorAxis == targetTensorAxis)
247 continue;
248 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
249 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
250 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
251 targetSharding.getSplitAxes()[targetTensorAxis]
252 .asArrayRef()
253 .back())
254 continue;
255 if (!llvm::equal(
256 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
257 .asArrayRef()
258 .begin(),
259 sourceSharding.getSplitAxes()[sourceTensorAxis]
260 .asArrayRef()
261 .end() -
262 1),
263 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
264 .asArrayRef()
265 .begin(),
266 targetSharding.getSplitAxes()[targetTensorAxis]
267 .asArrayRef()
268 .end() -
269 1)))
270 continue;
271 return std::make_tuple(
272 sourceTensorAxis, targetTensorAxis,
273 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
274 }
275 }
276 return std::nullopt;
277}
278
280 const Sharding &sourceSharding,
281 int64_t sourceTensorAxis,
282 int64_t targetTensorAxis) {
283 SmallVector<GridAxesAttr> targetShardingSplitAxes =
284 llvm::to_vector(sourceSharding.getSplitAxes());
285 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
286 targetTensorAxis) {
287 targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
288 }
289
290 auto sourceSplitAxes =
291 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
292 assert(!sourceSplitAxes.empty());
293 auto gridAxis = sourceSplitAxes.back();
294 sourceSplitAxes.pop_back();
295 targetShardingSplitAxes[sourceTensorAxis] =
296 GridAxesAttr::get(ctx, sourceSplitAxes);
297
298 auto targetSplitAxes =
299 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
300 targetSplitAxes.push_back(gridAxis);
301 targetShardingSplitAxes[targetTensorAxis] =
302 GridAxesAttr::get(ctx, targetSplitAxes);
303
304 return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
305}
306
307static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
308 int64_t splitCount,
309 int64_t sourceTensorAxis,
310 int64_t targetTensorAxis) {
311 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
312 targetShape[sourceTensorAxis] =
313 gatherDimension(targetShape[sourceTensorAxis], splitCount);
314 targetShape[targetTensorAxis] =
315 shardDimension(targetShape[targetTensorAxis], splitCount);
316 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
317}
318
319static std::tuple<TypedValue<ShapedType>, Sharding>
321 Sharding sourceSharding,
322 ShapedType sourceUnshardedShape,
323 TypedValue<ShapedType> sourceShard,
324 int64_t sourceTensorAxis,
325 int64_t targetTensorAxis, GridAxis gridAxis) {
326 MLIRContext *ctx = builder.getContext();
327 builder.setInsertionPointAfterValue(sourceShard);
328
329 Sharding targetSharding = targetShardingInMoveLastAxis(
330 ctx, std::move(sourceSharding), sourceTensorAxis, targetTensorAxis);
331 ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
332 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
333 targetTensorAxis);
334 Value allToAllResult = AllToAllOp::create(
335 builder,
336 RankedTensorType::get(allToAllResultShape.getShape(),
337 allToAllResultShape.getElementType()),
338 grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard,
339 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
340 ShapedType targetShape =
341 shardShapedType(sourceUnshardedShape, grid, targetSharding);
342 TypedValue<ShapedType> targetShard =
343 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
344 return {targetShard, targetSharding};
345}
346
347static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
349 const Sharding &sourceSharding,
350 Sharding targetSharding,
351 ShapedType sourceUnshardedShape,
352 TypedValue<ShapedType> sourceShard) {
353 if (auto detectRes = detectMoveLastSplitAxisInResharding(
354 sourceSharding, std::move(targetSharding))) {
355 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
357 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
358 sourceTensorAxis, targetTensorAxis, gridAxis);
359 }
360
361 return std::nullopt;
362}
363
364// Detect a change in the halo size (only) and create necessary operations if
365// needed. A changed halo sizes requires copying the "core" of the source tensor
366// into the "core" of the destination tensor followed by an update halo
367// operation.
368static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
370 const Sharding &sourceSharding,
371 const Sharding &targetSharding,
372 ShapedType sourceUnshardedShape,
373 TypedValue<ShapedType> sourceShard) {
374 // Currently handles only cases where halo sizes differ but everything else
375 // stays the same (from source to destination sharding).
376 if (!sourceSharding.equalSplitAxes(targetSharding) ||
377 !sourceSharding.getStaticShardedDimsOffsets().empty() ||
378 !targetSharding.getStaticShardedDimsOffsets().empty() ||
379 sourceSharding.equalHaloSizes(targetSharding)) {
380 return std::nullopt;
381 }
382
383 auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
384 auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
385 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
386 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
387 ShapedType::isStaticShape(tgtHaloSizes) &&
388 sourceShard.getType().hasStaticShape()) &&
389 "dynamic shapes/halos are not supported yet for shard-partition");
390 auto rank = sourceShard.getType().getRank();
391 auto splitAxes = sourceSharding.getSplitAxes();
392 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
393 strides(rank, 1), outShape(sourceShard.getType().getShape()),
394 coreShape(sourceShard.getType().getShape());
395
396 // Determine "core" of source and destination.
397 // The core is the local part of the shard excluding halo regions.
398 for (auto i = 0u; i < rank; ++i) {
399 if (i < splitAxes.size() && !splitAxes[i].empty()) {
400 if (!srcHaloSizes.empty()) {
401 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
402 srcCoreOffs[i] = srcHaloSizes[i * 2];
403 }
404 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
405 outShape[i] =
406 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
407 }
408 }
409
410 // Extract core from source and copy into destination core.
411 auto noVals = ValueRange{};
412 auto initVal =
413 tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape,
414 sourceShard.getType().getElementType());
415 auto core = tensor::ExtractSliceOp::create(
416 builder, sourceShard.getLoc(),
417 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
418 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
419 auto initOprnd = tensor::InsertSliceOp::create(
420 builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
421 tgtCoreOffs, coreShape, strides);
422
423 // Finally update the halo.
424 auto updateHaloResult =
425 UpdateHaloOp::create(
426 builder, sourceShard.getLoc(),
427 RankedTensorType::get(outShape,
428 sourceShard.getType().getElementType()),
429 initOprnd, grid.getSymName(),
430 GridAxesArrayAttr::get(builder.getContext(),
431 sourceSharding.getSplitAxes()),
432 targetSharding.getDynamicHaloSizes(),
433 targetSharding.getStaticHaloSizes())
434 .getResult();
435 return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
436 targetSharding);
437}
438
439// Handles only resharding on a 1D shard.
440// Currently the sharded tensor axes must be exactly divisible by the single
441// grid axis size.
444 const Sharding &sourceSharding, const Sharding &targetSharding,
445 TypedValue<ShapedType> sourceUnshardedValue,
446 TypedValue<ShapedType> sourceShard) {
447 assert(sourceShard.getType() ==
448 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
449 [[maybe_unused]] ShapedType targetShardType =
450 shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
451 assert(sourceShard.getType().getRank() == targetShardType.getRank());
452 assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
453
454 if (sourceSharding == targetSharding) {
455 return sourceShard;
456 }
457
458 TypedValue<ShapedType> targetShard;
459 Sharding actualTargetSharding;
460 if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
461 targetSharding.getStaticShardedDimsOffsets().empty() &&
462 sourceSharding.getStaticHaloSizes().empty() &&
463 targetSharding.getStaticHaloSizes().empty()) {
464 if (auto tryRes = tryMoveLastSplitAxisInResharding(
465 builder, grid, sourceSharding, targetSharding,
466 sourceUnshardedValue.getType(), sourceShard)) {
467 std::tie(targetShard, actualTargetSharding) = tryRes.value();
468 } else if (auto tryRes =
469 trySplitLastAxisInResharding(builder, grid, sourceSharding,
470 targetSharding, sourceShard)) {
471 std::tie(targetShard, actualTargetSharding) = tryRes.value();
472 } else if (auto tryRes = tryUnsplitLastAxisInResharding(
473 builder, grid, sourceSharding, targetSharding,
474 sourceUnshardedValue.getType(), sourceShard)) {
475 std::tie(targetShard, actualTargetSharding) = tryRes.value();
476 }
477 }
478 assert(targetShard && "Did not find any pattern to apply.");
479 assert(actualTargetSharding == targetSharding);
480 assert(targetShard.getType() == targetShardType);
481 return targetShard;
482}
483
485reshard(ImplicitLocOpBuilder &builder, GridOp grid,
486 const Sharding &sourceSharding, const Sharding &targetSharding,
487 TypedValue<ShapedType> sourceUnshardedValue,
488 TypedValue<ShapedType> sourceShard) {
489 // If source and destination sharding are the same, no need to do anything.
490 if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
491 isFullReplication(targetSharding))) {
492 return sourceShard;
493 }
494
495 // Tries to handle the case where the resharding is needed because the halo
496 // sizes are different. Supports arbitrary grid dimensionality.
497 if (auto tryRes = tryUpdateHaloInResharding(
498 builder, grid, sourceSharding, targetSharding,
499 sourceUnshardedValue.getType(), sourceShard)) {
500 return std::get<0>(tryRes.value()); // targetShard
501 }
502
503 // Resort to handling only 1D grids since the general case is complicated if
504 // it needs to be communication efficient in terms of minimizing the data
505 // transfered between devices.
506 return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
507 sourceUnshardedValue, sourceShard);
508}
509
510TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
511 ShardOp target,
512 TypedValue<ShapedType> sourceShardValue) {
513 assert(source.getResult() == target.getSrc());
514 auto sourceSharding = source.getSharding();
515 auto targetSharding = target.getSharding();
516 ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
517 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
518 source.getSrc(), sourceShardValue);
519}
520
521TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
522 ShardOp target,
523 TypedValue<ShapedType> sourceShardValue,
524 SymbolTableCollection &symbolTableCollection) {
525 GridOp srcGrid = getGrid(source, symbolTableCollection);
526 assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
527 return reshard(builder, srcGrid, source, target, sourceShardValue);
528}
529
531 registry.insert<shard::ShardDialect, tensor::TensorDialect>();
532}
533
534#define GEN_PASS_DEF_PARTITION
535#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
536
538
539// Get the types of block arguments for an partitioned block.
540// Reads the sharding annotations of the arguments to deduce the sharded types.
541// Types that are not ranked tensors are left unchanged.
544 SymbolTableCollection &symbolTableCollection) {
546 llvm::transform(
547 block.getArguments(), std::back_inserter(res),
548 [&symbolTableCollection](BlockArgument arg) {
549 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
550 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
551 return arg.getType();
552 }
553
554 assert(rankedTensorArg.hasOneUse());
555 Operation *useOp = *rankedTensorArg.getUsers().begin();
556 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
557 assert(shardOp);
558 GridOp grid = getGrid(shardOp, symbolTableCollection);
559 return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
560 shardOp.getSharding()));
561 });
562 return res;
563}
564
565static LogicalResult
567 ArrayRef<Sharding> operandShardings,
568 ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
569 SymbolTableCollection &symbolTableCollection,
570 OpBuilder &builder) {
571 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
572 if (!shardingInterface) {
573 // If there is no sharding interface we are conservative and assume that
574 // the op should be fully replicated no all devices.
575 partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
576 resultShardings, partitionMap,
577 symbolTableCollection, builder);
578 } else {
579 if (failed(shardingInterface.partition(
580 partitionedOperands, operandShardings, resultShardings,
581 partitionMap, symbolTableCollection, builder))) {
582 return failure();
583 }
584 }
585
586 assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
587 return partitionMap.contains(result);
588 }));
589
590 return success();
591}
592
593// Retrieve the sharding annotations for the operands of the given operation.
594// If the type is not a ranked tensor it is not require to have an annotation.
595static std::vector<Sharding> getOperandShardings(Operation &op) {
596 std::vector<Sharding> res;
597 res.reserve(op.getNumOperands());
598 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
599 TypedValue<RankedTensorType> rankedTensor =
600 dyn_cast<TypedValue<RankedTensorType>>(operand);
601 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
602 return Sharding();
603 }
604
605 Operation *definingOp = operand.getDefiningOp();
606 assert(definingOp);
607 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
608 return Sharding(shardOp.getSharding());
609 });
610 return res;
611}
612
613// Retrieve the sharding annotations for the results of the given operation.
614// If the type is not a ranked tensor it is not require to have an annotation.
615static std::vector<Sharding> getResultShardings(Operation &op) {
616 std::vector<Sharding> res;
617 res.reserve(op.getNumResults());
618 llvm::transform(
619 op.getResults(), std::back_inserter(res), [&op](OpResult result) {
620 if (!result.hasOneUse() || result.use_empty()) {
621 return Sharding();
622 }
623 TypedValue<RankedTensorType> rankedTensor =
625 if (!rankedTensor) {
626 return Sharding();
627 }
628 Operation *userOp = *result.getUsers().begin();
629 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
630 if (shardOp) {
631 return Sharding(shardOp.getSharding());
632 }
633 if (rankedTensor.getType().getRank() == 0) {
634 // This is a 0d tensor result without explicit sharding.
635 // Find grid symbol from operands, if any.
636 // Shardings without grid are not always fully supported yet.
637 for (auto operand : op.getOperands()) {
638 if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
639 return Sharding(sharding.getGridAttr());
640 }
641 }
642 }
643 return Sharding();
644 });
645 return res;
646}
647
648static LogicalResult
649partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
650 SymbolTableCollection &symbolTableCollection,
651 OpBuilder &builder) {
652 Value targetPartitionValue;
653
654 // Check if 2 shard ops are chained. If not there is no need for resharding
655 // as the source and target shared the same sharding.
656 ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
657 if (!srcShardOp) {
658 targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
659 } else {
660 // Insert resharding.
661 TypedValue<ShapedType> srcPartitionValue =
662 cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
663 targetPartitionValue = reshard(builder, srcShardOp, shardOp,
664 srcPartitionValue, symbolTableCollection);
665 }
666
667 assert(!partitionMap.contains(shardOp.getResult()));
668 partitionMap.map(shardOp.getResult(), targetPartitionValue);
669 return success();
670}
671
672// Check if the block args are correctly annotated with sharding information:
673// - non-tensor and 0d-tensor args are ignored
674// - each tensor arg must have exactly one use, which must be a shard.shard
675// operation
676static LogicalResult checkFullyAnnotated(Block &block) {
677 for (const BlockArgument &arg : block.getArguments()) {
678 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
679 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
680 continue;
681
682 if (rankedTensorArg.getNumUses() > 1)
683 return emitError(block.getParent()->getLoc())
684 << "Cannot partition: expected a single use for block argument "
685 << arg.getArgNumber() << " in block "
686 << block.computeBlockNumber();
687 Operation *useOp = *rankedTensorArg.getUsers().begin();
688 auto shardOp = dyn_cast<ShardOp>(useOp);
689 if (!shardOp)
690 return emitError(block.getParent()->getLoc())
691 << "Cannot partition: expected a shard.shard op for block "
692 << "argument " << arg.getArgNumber() << " in block "
693 << block.computeBlockNumber();
694 }
695 return success();
696}
697
698// Check if the operation is correctly and fully annotated with sharding
699// information:
700// - Operation results must have exactly one use (e.g. the shard operation).
701// - All operands and all results must be annotated, e.g. they must be
702// produced by/consumed by a shard.shard operation.
703// - Result annotations must not include the 'annotate_for_users' attribute.
704// - Operand annotations must include the 'annotate_for_users' attribute.
705// raises an error if the operation is not correctly and fully annotated.
706static LogicalResult checkFullyAnnotated(Operation *op) {
707 // constant ops do not need to have sharding annotations
709 return success();
710
711 for (OpOperand &operand : op->getOpOperands()) {
712 // non-tensor and 0d-tensor operands are ignored
713 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
714 if (!rankedTT || rankedTT.getRank() == 0)
715 continue;
716
717 auto shard = operand.get().getDefiningOp<ShardOp>();
718 if (!shard)
719 return op->emitError() << "Cannot partition: tensor operand "
720 << operand.getOperandNumber()
721 << " must be defined by a shard.shard operation.";
722 if (!shard.getAnnotateForUsers())
723 return op->emitError()
724 << "Cannot partition: shard.shard for operand "
725 << operand.getOperandNumber() << " must set 'annotate_for_users'.";
726 }
727 for (const OpResult &result : op->getResults()) {
728 if (!result.hasOneUse())
729 return op->emitError()
730 << "Cannot partition: result " << result.getResultNumber()
731 << " must have exactly one use.";
732 auto shard = dyn_cast<ShardOp>(*result.user_begin());
733 if (!shard)
734 return op->emitError()
735 << "Cannot partition: user of result " << result.getResultNumber()
736 << " must be shard.shard operation.";
737 if (shard.getAnnotateForUsers())
738 return op->emitError() << "Cannot partition: shard.shard for result "
739 << result.getResultNumber()
740 << " must not set 'annotate_for_users'.";
741 }
742 return success();
743}
744
745static LogicalResult
747 SymbolTableCollection &symbolTableCollection,
748 OpBuilder &builder) {
749 if (isa<ShardingOp>(op)) {
750 return success();
751 }
752
753 if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
754 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
755 if (!shardOp) {
756 return op.emitError("expected a shard op as source of get_sharding");
757 }
758 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
759 partitionMap.map(op.getResult(0), newSharding->getResult(0));
760 return success();
761 }
762
763 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
764 if (shardOp) {
765 return partitionOperation(shardOp, partitionMap, symbolTableCollection,
766 builder);
767 }
768
769 // Check if operation is correctly and fully annotated.
770 if (failed(checkFullyAnnotated(&op)))
771 return failure();
772
773 SmallVector<Value> partitionedOperands;
774 llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
775 [&partitionMap](Value operand) {
776 assert(partitionMap.contains(operand));
777 return partitionMap.lookup(operand);
778 });
779 return partitionOperation(op, partitionedOperands, getOperandShardings(op),
780 getResultShardings(op), partitionMap,
781 symbolTableCollection, builder);
782}
783
784static LogicalResult
785partitionBlock(Block &block, IRMapping &partitionMap,
786 SymbolTableCollection &symbolTableCollection,
787 OpBuilder &builder) {
788
789 if (failed(checkFullyAnnotated(block)))
790 return failure();
791
792 SmallVector<Location> argLocations;
793 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
794 [](BlockArgument arg) { return arg.getLoc(); });
795 Block *newBlock = builder.createBlock(
796 block.getParent(), {},
797 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
798 for (auto [unshardedBlockArg, partitionedBlockArg] :
799 llvm::zip(block.getArguments(), newBlock->getArguments())) {
800 partitionMap.map(unshardedBlockArg, partitionedBlockArg);
801 }
802
803 OpBuilder::InsertionGuard insertionGuard(builder);
804 builder.setInsertionPointToEnd(newBlock);
805 for (Operation &op : block.getOperations()) {
806 if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
807 builder))) {
808 return failure();
809 }
810 }
811
812 return success();
813}
814
815static LogicalResult
816partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
817 SymbolTableCollection &symbolTableCollection) {
818 OpBuilder builder(op.getFunctionBody());
819
820 // Snapshot the original blocks to not mess up the iteration when adding new
821 // blocks.
822 SmallVector<Block *> originalBlocks;
823 for (Block &b : op.getBlocks()) {
824 if (llvm::any_of(b.getOperations(),
825 [](Operation &op) { return isa<ShardOp>(op); })) {
826 originalBlocks.push_back(&b);
827 }
828 }
829
830 for (Block *block : originalBlocks) {
831 if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
832 builder))) {
833 return failure();
834 }
835 }
836
837 for (Block *block : originalBlocks) {
838 block->erase();
839 }
840
841 // Find a return op and change the function results signature to its operands
842 // signature.
843 Operation *returnOp = nullptr;
844 for (Block &block : op.getFunctionBody()) {
845 if (block.empty()) {
846 continue;
847 }
848
849 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
850 returnOp = &block.back();
851 break;
852 }
853 }
854 if (returnOp) {
855 op.setType(FunctionType::get(
856 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
857 returnOp->getOperandTypes()));
858 }
859
860 return success();
861}
862
863namespace {
864
865struct Partition : public impl::PartitionBase<Partition> {
866 void runOnOperation() override {
867 IRMapping partitionMap;
868 SymbolTableCollection symbolTableCollection;
869 if (failed(partitionFuncOp(getOperation(), partitionMap,
870 symbolTableCollection))) {
871 return signalPassFailure();
872 }
873 }
874
875 void getDependentDialects(DialectRegistry &registry) const override {
877 registry.insert<shard::ShardDialect>();
878 }
879};
880
881} // namespace
882
883} // namespace mlir::shard
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
BlockArgListType getArguments()
Definition Block.h:97
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
Definition Block.cpp:144
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:769
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:688
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:63
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:60
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:67
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:64
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:727
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:69
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static LogicalResult checkFullyAnnotated(Block &block)
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:47
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition Partition.cpp:40
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:106
int16_t GridAxis
Definition ShardOps.h:26
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::vector< Sharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:168
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
Definition Partition.cpp:89
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:177
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
This trait indicates that a terminator operation is "return-like".