MLIR  22.0.0git
Partition.cpp
Go to the documentation of this file.
1 //===- Partition.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/IR/Value.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LLVM.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Support/Casting.h"
32 #include <iterator>
33 #include <optional>
34 #include <tuple>
35 
36 namespace mlir::shard {
37 
38 template <typename SourceAxes, typename TargetAxes>
39 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
40  const TargetAxes &targetAxes) {
41  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
42  return sourceAxes.contains(targetAxis);
43  });
44 }
45 
47  Sharding sourceSharding,
48  int64_t splitTensorAxis,
49  GridAxis splitGridAxis) {
50  SmallVector<GridAxesAttr> targetShardingSplitAxes =
51  llvm::to_vector(sourceSharding.getSplitAxes());
52  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
53  splitTensorAxis) {
54  targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
55  }
56  auto targetSplitAxes =
57  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
58  targetSplitAxes.push_back(splitGridAxis);
59  targetShardingSplitAxes[splitTensorAxis] =
60  GridAxesAttr::get(ctx, targetSplitAxes);
61  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
62 }
63 
64 // Split a replicated tensor along a grid axis.
65 // E.g. [[0, 1]] -> [[0, 1, 2]].
66 // Returns the partitioned target value with its sharding.
67 static std::tuple<TypedValue<ShapedType>, Sharding>
69  Sharding sourceSharding,
70  TypedValue<ShapedType> sourceShard, GridOp grid,
71  int64_t splitTensorAxis, GridAxis splitGridAxis) {
72  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
73  AllSliceOp::create(builder, sourceShard, grid,
74  ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
75  .getResult());
76  Sharding targetSharding = targetShardingInSplitLastAxis(
77  builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
78  return {targetShard, targetSharding};
79 }
80 
81 // Detect if the resharding is of type e.g.
82 // [[0, 1]] -> [[0, 1, 2]].
83 // If detected, returns the corresponding tensor axis grid axis pair.
84 // Does not detect insertions like
85 // [[0, 1]] -> [[0, 2, 1]].
86 static std::optional<std::tuple<int64_t, GridAxis>>
88  Sharding targetSharding) {
89  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
90  ++tensorAxis) {
91  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
92  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
93  targetSharding.getSplitAxes()[tensorAxis].size()) {
94  continue;
95  }
96  if (!llvm::equal(
97  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
98  llvm::make_range(
99  targetSharding.getSplitAxes()[tensorAxis]
100  .asArrayRef()
101  .begin(),
102  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
103  1))) {
104  continue;
105  }
106  } else {
107  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
108  continue;
109  }
110  }
111  return std::make_tuple(
112  tensorAxis,
113  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
114  }
115  return std::nullopt;
116 }
117 
118 static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
120  Sharding sourceSharding, Sharding targetSharding,
121  TypedValue<ShapedType> sourceShard) {
122  if (auto detectRes =
123  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
124  auto [tensorAxis, gridAxis] = detectRes.value();
125  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid,
126  tensorAxis, gridAxis);
127  }
128 
129  return std::nullopt;
130 }
131 
132 // Detect if the resharding is of type e.g.
133 // [[0, 1, 2]] -> [[0, 1]].
134 // If detected, returns the corresponding tensor axis grid axis pair.
135 static std::optional<std::tuple<int64_t, GridAxis>>
137  Sharding targetSharding) {
138  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
139  ++tensorAxis) {
140  if (targetSharding.getSplitAxes().size() > tensorAxis) {
141  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
142  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
143  continue;
144  if (!llvm::equal(
145  llvm::make_range(
146  sourceSharding.getSplitAxes()[tensorAxis]
147  .asArrayRef()
148  .begin(),
149  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
150  1),
151  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
152  continue;
153  } else {
154  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
155  continue;
156  }
157  return std::make_tuple(
158  tensorAxis,
159  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
160  }
161  return std::nullopt;
162 }
163 
165  Sharding sourceSharding,
166  int64_t splitTensorAxis) {
167  SmallVector<GridAxesAttr> targetShardingSplitAxes =
168  llvm::to_vector(sourceSharding.getSplitAxes());
169  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
170  splitTensorAxis);
171  auto targetSplitAxes =
172  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
173 
174  targetSplitAxes.pop_back();
175  targetShardingSplitAxes[splitTensorAxis] =
176  GridAxesAttr::get(ctx, targetSplitAxes);
177  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
178 }
179 
181  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
182  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
183  targetShape[splitTensorAxis] =
184  gatherDimension(targetShape[splitTensorAxis], splitCount);
185  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
186 }
187 
188 static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
189  ImplicitLocOpBuilder &builder, Sharding sourceSharding,
190  ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
191  GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
192  MLIRContext *ctx = builder.getContext();
193  builder.setInsertionPointAfterValue(sourceShard);
194 
195  Sharding targetSharding =
196  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
197  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
198  sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
199  Value allGatherResult = AllGatherOp::create(
200  builder,
201  RankedTensorType::get(allGatherResultShape.getShape(),
202  allGatherResultShape.getElementType()),
203  grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
204  APInt(64, splitTensorAxis));
205  ShapedType targetShape =
206  shardShapedType(sourceUnshardedShape, grid, targetSharding);
207  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
208  tensor::CastOp::create(builder, targetShape, allGatherResult)
209  .getResult());
210  return {targetShard, targetSharding};
211 }
212 
213 static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
215  Sharding sourceSharding, Sharding targetSharding,
216  ShapedType sourceUnshardedShape,
217  TypedValue<ShapedType> sourceShard) {
218  if (auto detectRes =
219  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
220  auto [tensorAxis, gridAxis] = detectRes.value();
221  return unsplitLastAxisInResharding(builder, sourceSharding,
222  sourceUnshardedShape, sourceShard, grid,
223  tensorAxis, gridAxis);
224  }
225 
226  return std::nullopt;
227 }
228 
229 // Detect if the resharding is of type e.g.
230 // [[0, 1], [2]] -> [[0], [1, 2]].
231 // Only moving the last axis counts.
232 // If detected, returns the corresponding (source_tensor_axis,
233 // target_tensor_axis, grid_axis) tuple.
234 static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
236  Sharding targetSharding) {
237  for (size_t sourceTensorAxis = 0;
238  sourceTensorAxis < sourceSharding.getSplitAxes().size();
239  ++sourceTensorAxis) {
240  for (size_t targetTensorAxis = 0;
241  targetTensorAxis < targetSharding.getSplitAxes().size();
242  ++targetTensorAxis) {
243  if (sourceTensorAxis == targetTensorAxis)
244  continue;
245  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
246  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
247  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
248  targetSharding.getSplitAxes()[targetTensorAxis]
249  .asArrayRef()
250  .back())
251  continue;
252  if (!llvm::equal(
253  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
254  .asArrayRef()
255  .begin(),
256  sourceSharding.getSplitAxes()[sourceTensorAxis]
257  .asArrayRef()
258  .end() -
259  1),
260  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
261  .asArrayRef()
262  .begin(),
263  targetSharding.getSplitAxes()[targetTensorAxis]
264  .asArrayRef()
265  .end() -
266  1)))
267  continue;
268  return std::make_tuple(
269  sourceTensorAxis, targetTensorAxis,
270  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
271  }
272  }
273  return std::nullopt;
274 }
275 
277  Sharding sourceSharding,
278  int64_t sourceTensorAxis,
279  int64_t targetTensorAxis) {
280  SmallVector<GridAxesAttr> targetShardingSplitAxes =
281  llvm::to_vector(sourceSharding.getSplitAxes());
282  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
283  targetTensorAxis) {
284  targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
285  }
286 
287  auto sourceSplitAxes =
288  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
289  assert(!sourceSplitAxes.empty());
290  auto gridAxis = sourceSplitAxes.back();
291  sourceSplitAxes.pop_back();
292  targetShardingSplitAxes[sourceTensorAxis] =
293  GridAxesAttr::get(ctx, sourceSplitAxes);
294 
295  auto targetSplitAxes =
296  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
297  targetSplitAxes.push_back(gridAxis);
298  targetShardingSplitAxes[targetTensorAxis] =
299  GridAxesAttr::get(ctx, targetSplitAxes);
300 
301  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
302 }
303 
304 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
305  int64_t splitCount,
306  int64_t sourceTensorAxis,
307  int64_t targetTensorAxis) {
308  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
309  targetShape[sourceTensorAxis] =
310  gatherDimension(targetShape[sourceTensorAxis], splitCount);
311  targetShape[targetTensorAxis] =
312  shardDimension(targetShape[targetTensorAxis], splitCount);
313  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
314 }
315 
316 static std::tuple<TypedValue<ShapedType>, Sharding>
318  Sharding sourceSharding,
319  ShapedType sourceUnshardedShape,
320  TypedValue<ShapedType> sourceShard,
321  int64_t sourceTensorAxis,
322  int64_t targetTensorAxis, GridAxis gridAxis) {
323  MLIRContext *ctx = builder.getContext();
324  builder.setInsertionPointAfterValue(sourceShard);
325 
326  Sharding targetSharding = targetShardingInMoveLastAxis(
327  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
328  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
329  sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
330  targetTensorAxis);
331  Value allToAllResult = AllToAllOp::create(
332  builder,
333  RankedTensorType::get(allToAllResultShape.getShape(),
334  allToAllResultShape.getElementType()),
335  grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard,
336  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
337  ShapedType targetShape =
338  shardShapedType(sourceUnshardedShape, grid, targetSharding);
339  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
340  tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
341  return {targetShard, targetSharding};
342 }
343 
344 static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
346  Sharding sourceSharding,
347  Sharding targetSharding,
348  ShapedType sourceUnshardedShape,
349  TypedValue<ShapedType> sourceShard) {
350  if (auto detectRes =
351  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
352  auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
354  builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
355  sourceTensorAxis, targetTensorAxis, gridAxis);
356  }
357 
358  return std::nullopt;
359 }
360 
361 // Detect a change in the halo size (only) and create necessary operations if
362 // needed. A changed halo sizes requires copying the "core" of the source tensor
363 // into the "core" of the destination tensor followed by an update halo
364 // operation.
365 static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
367  Sharding sourceSharding, Sharding targetSharding,
368  ShapedType sourceUnshardedShape,
369  TypedValue<ShapedType> sourceShard) {
370  // Currently handles only cases where halo sizes differ but everything else
371  // stays the same (from source to destination sharding).
372  if (!sourceSharding.equalSplitAxes(targetSharding) ||
373  !sourceSharding.getStaticShardedDimsOffsets().empty() ||
374  !targetSharding.getStaticShardedDimsOffsets().empty() ||
375  sourceSharding.equalHaloSizes(targetSharding)) {
376  return std::nullopt;
377  }
378 
379  auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
380  auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
381  assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
382  assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
383  ShapedType::isStaticShape(tgtHaloSizes) &&
384  sourceShard.getType().hasStaticShape()) &&
385  "dynamic shapes/halos are not supported yet for shard-partition");
386  auto rank = sourceShard.getType().getRank();
387  auto splitAxes = sourceSharding.getSplitAxes();
388  SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
389  strides(rank, 1), outShape(sourceShard.getType().getShape()),
390  coreShape(sourceShard.getType().getShape());
391 
392  // Determine "core" of source and destination.
393  // The core is the local part of the shard excluding halo regions.
394  for (auto i = 0u; i < rank; ++i) {
395  if (i < splitAxes.size() && !splitAxes[i].empty()) {
396  if (!srcHaloSizes.empty()) {
397  coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
398  srcCoreOffs[i] = srcHaloSizes[i * 2];
399  }
400  tgtCoreOffs[i] = tgtHaloSizes[i * 2];
401  outShape[i] =
402  coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
403  }
404  }
405 
406  // Extract core from source and copy into destination core.
407  auto noVals = ValueRange{};
408  auto initVal =
409  tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape,
410  sourceShard.getType().getElementType());
411  auto core = tensor::ExtractSliceOp::create(
412  builder, sourceShard.getLoc(),
413  RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
414  sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
415  auto initOprnd = tensor::InsertSliceOp::create(
416  builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
417  tgtCoreOffs, coreShape, strides);
418 
419  // Finally update the halo.
420  auto updateHaloResult =
421  UpdateHaloOp::create(
422  builder, sourceShard.getLoc(),
423  RankedTensorType::get(outShape,
424  sourceShard.getType().getElementType()),
425  initOprnd, grid.getSymName(),
427  sourceSharding.getSplitAxes()),
428  targetSharding.getDynamicHaloSizes(),
429  targetSharding.getStaticHaloSizes())
430  .getResult();
431  return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
432  targetSharding);
433 }
434 
435 // Handles only resharding on a 1D shard.
436 // Currently the sharded tensor axes must be exactly divisible by the single
437 // grid axis size.
440  Sharding sourceSharding, Sharding targetSharding,
441  TypedValue<ShapedType> sourceUnshardedValue,
442  TypedValue<ShapedType> sourceShard) {
443  assert(sourceShard.getType() ==
444  shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
445  [[maybe_unused]] ShapedType targetShardType =
446  shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
447  assert(sourceShard.getType().getRank() == targetShardType.getRank());
448  assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
449 
450  if (sourceSharding == targetSharding) {
451  return sourceShard;
452  }
453 
454  TypedValue<ShapedType> targetShard;
455  Sharding actualTargetSharding;
456  if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
457  targetSharding.getStaticShardedDimsOffsets().empty() &&
458  sourceSharding.getStaticHaloSizes().empty() &&
459  targetSharding.getStaticHaloSizes().empty()) {
460  if (auto tryRes = tryMoveLastSplitAxisInResharding(
461  builder, grid, sourceSharding, targetSharding,
462  sourceUnshardedValue.getType(), sourceShard)) {
463  std::tie(targetShard, actualTargetSharding) = tryRes.value();
464  } else if (auto tryRes =
465  trySplitLastAxisInResharding(builder, grid, sourceSharding,
466  targetSharding, sourceShard)) {
467  std::tie(targetShard, actualTargetSharding) = tryRes.value();
468  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
469  builder, grid, sourceSharding, targetSharding,
470  sourceUnshardedValue.getType(), sourceShard)) {
471  std::tie(targetShard, actualTargetSharding) = tryRes.value();
472  }
473  }
474  assert(targetShard && "Did not find any pattern to apply.");
475  assert(actualTargetSharding == targetSharding);
476  assert(targetShard.getType() == targetShardType);
477  return targetShard;
478 }
479 
481 reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding,
482  Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue,
483  TypedValue<ShapedType> sourceShard) {
484  // If source and destination sharding are the same, no need to do anything.
485  if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
486  isFullReplication(targetSharding))) {
487  return sourceShard;
488  }
489 
490  // Tries to handle the case where the resharding is needed because the halo
491  // sizes are different. Supports arbitrary grid dimensionality.
492  if (auto tryRes = tryUpdateHaloInResharding(
493  builder, grid, sourceSharding, targetSharding,
494  sourceUnshardedValue.getType(), sourceShard)) {
495  return std::get<0>(tryRes.value()); // targetShard
496  }
497 
498  // Resort to handling only 1D grids since the general case is complicated if
499  // it needs to be communication efficient in terms of minimizing the data
500  // transfered between devices.
501  return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
502  sourceUnshardedValue, sourceShard);
503 }
504 
505 TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
506  ShardOp target,
507  TypedValue<ShapedType> sourceShardValue) {
508  assert(source.getResult() == target.getSrc());
509  auto sourceSharding = source.getSharding();
510  auto targetSharding = target.getSharding();
511  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
512  return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
513  cast<TypedValue<ShapedType>>(source.getSrc()),
514  sourceShardValue);
515 }
516 
517 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
518  ShardOp target,
519  TypedValue<ShapedType> sourceShardValue,
520  SymbolTableCollection &symbolTableCollection) {
521  GridOp srcGrid = getGrid(source, symbolTableCollection);
522  assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
523  return reshard(builder, srcGrid, source, target, sourceShardValue);
524 }
525 
527  registry.insert<shard::ShardDialect, tensor::TensorDialect>();
528 }
529 
530 #define GEN_PASS_DEF_PARTITION
531 #include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
532 
534 
535 // Get the types of block arguments for an partitioned block.
536 // Reads the sharding annotations of the arguments to deduce the sharded types.
537 // Types that are not ranked tensors are left unchanged.
538 static SmallVector<Type>
540  SymbolTableCollection &symbolTableCollection) {
541  SmallVector<Type> res;
542  llvm::transform(
543  block.getArguments(), std::back_inserter(res),
544  [&symbolTableCollection](BlockArgument arg) {
545  auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
546  if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
547  return arg.getType();
548  }
549 
550  assert(rankedTensorArg.hasOneUse());
551  Operation *useOp = *rankedTensorArg.getUsers().begin();
552  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
553  assert(shardOp);
554  GridOp grid = getGrid(shardOp, symbolTableCollection);
555  return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
556  shardOp.getSharding()));
557  });
558  return res;
559 }
560 
561 static LogicalResult
563  ArrayRef<Sharding> operandShardings,
564  ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
565  SymbolTableCollection &symbolTableCollection,
566  OpBuilder &builder) {
567  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
568  if (!shardingInterface) {
569  // If there is no sharding interface we are conservative and assume that
570  // the op should be fully replicated no all devices.
571  partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
572  resultShardings, partitionMap,
573  symbolTableCollection, builder);
574  } else {
575  if (failed(shardingInterface.partition(
576  partitionedOperands, operandShardings, resultShardings,
577  partitionMap, symbolTableCollection, builder))) {
578  return failure();
579  }
580  }
581 
582  assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
583  return partitionMap.contains(result);
584  }));
585 
586  return success();
587 }
588 
589 // Retrieve the sharding annotations for the operands of the given operation.
590 // If the type is not a ranked tensor it is not require to have an annotation.
591 static std::vector<Sharding> getOperandShardings(Operation &op) {
592  std::vector<Sharding> res;
593  res.reserve(op.getNumOperands());
594  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
595  TypedValue<RankedTensorType> rankedTensor =
596  dyn_cast<TypedValue<RankedTensorType>>(operand);
597  if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
598  return Sharding();
599  }
600 
601  Operation *definingOp = operand.getDefiningOp();
602  assert(definingOp);
603  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
604  return Sharding(shardOp.getSharding());
605  });
606  return res;
607 }
608 
609 // Retrieve the sharding annotations for the results of the given operation.
610 // If the type is not a ranked tensor it is not require to have an annotation.
611 static std::vector<Sharding> getResultShardings(Operation &op) {
612  std::vector<Sharding> res;
613  res.reserve(op.getNumResults());
614  llvm::transform(
615  op.getResults(), std::back_inserter(res), [&op](OpResult result) {
616  if (!result.hasOneUse() || result.use_empty()) {
617  return Sharding();
618  }
619  TypedValue<RankedTensorType> rankedTensor =
620  dyn_cast<TypedValue<RankedTensorType>>(result);
621  if (!rankedTensor) {
622  return Sharding();
623  }
624  Operation *userOp = *result.getUsers().begin();
625  ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
626  if (shardOp) {
627  return Sharding(shardOp.getSharding());
628  }
629  if (rankedTensor.getType().getRank() == 0) {
630  // This is a 0d tensor result without explicit sharding.
631  // Find grid symbol from operands, if any.
632  // Shardings without grid are not always fully supported yet.
633  for (auto operand : op.getOperands()) {
634  if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
635  return Sharding(sharding.getGridAttr());
636  }
637  }
638  }
639  return Sharding();
640  });
641  return res;
642 }
643 
644 static LogicalResult
645 partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
646  SymbolTableCollection &symbolTableCollection,
647  OpBuilder &builder) {
648  Value targetPartitionValue;
649 
650  // Check if 2 shard ops are chained. If not there is no need for resharding
651  // as the source and target shared the same sharding.
652  ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
653  if (!srcShardOp) {
654  targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
655  } else {
656  // Insert resharding.
657  TypedValue<ShapedType> srcPartitionValue =
658  cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
659  targetPartitionValue = reshard(builder, srcShardOp, shardOp,
660  srcPartitionValue, symbolTableCollection);
661  }
662 
663  assert(!partitionMap.contains(shardOp.getResult()));
664  partitionMap.map(shardOp.getResult(), targetPartitionValue);
665  return success();
666 }
667 
668 static LogicalResult
670  SymbolTableCollection &symbolTableCollection,
671  OpBuilder &builder) {
672  if (isa<ShardingOp>(op)) {
673  return success();
674  }
675  if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
676  auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
677  if (!shardOp) {
678  return op.emitError("expected a shard op as source of get_sharding");
679  }
680  auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
681  partitionMap.map(op.getResult(0), newSharding->getResult(0));
682  return success();
683  }
684 
685  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
686  if (shardOp) {
687  return partitionOperation(shardOp, partitionMap, symbolTableCollection,
688  builder);
689  }
690 
691  SmallVector<Value> partitionedOperands;
692  llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
693  [&partitionMap](Value operand) {
694  assert(partitionMap.contains(operand));
695  return partitionMap.lookup(operand);
696  });
697  return partitionOperation(op, partitionedOperands, getOperandShardings(op),
698  getResultShardings(op), partitionMap,
699  symbolTableCollection, builder);
700 }
701 
702 static LogicalResult
703 partitionBlock(Block &block, IRMapping &partitionMap,
704  SymbolTableCollection &symbolTableCollection,
705  OpBuilder &builder) {
706 
707  SmallVector<Location> argLocations;
708  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
709  [](BlockArgument arg) { return arg.getLoc(); });
710  Block *newBlock = builder.createBlock(
711  block.getParent(), {},
712  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
713  for (auto [unshardedBlockArg, partitionedBlockArg] :
714  llvm::zip(block.getArguments(), newBlock->getArguments())) {
715  partitionMap.map(unshardedBlockArg, partitionedBlockArg);
716  }
717 
718  OpBuilder::InsertionGuard insertionGuard(builder);
719  builder.setInsertionPointToEnd(newBlock);
720  for (Operation &op : block.getOperations()) {
721  if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
722  builder))) {
723  return failure();
724  }
725  }
726 
727  return success();
728 }
729 
730 static LogicalResult
731 partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
732  SymbolTableCollection &symbolTableCollection) {
733  OpBuilder builder(op.getFunctionBody());
734 
735  // Snapshot the original blocks to not mess up the iteration when adding new
736  // blocks.
737  SmallVector<Block *> originalBlocks;
738  for (Block &b : op.getBlocks()) {
739  if (llvm::any_of(b.getOperations(),
740  [](Operation &op) { return isa<ShardOp>(op); })) {
741  originalBlocks.push_back(&b);
742  }
743  }
744 
745  for (Block *block : originalBlocks) {
746  if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
747  builder))) {
748  return failure();
749  }
750  }
751 
752  for (Block *block : originalBlocks) {
753  block->erase();
754  }
755 
756  // Find a return op and change the function results signature to its operands
757  // signature.
758  Operation *returnOp = nullptr;
759  for (Block &block : op.getFunctionBody()) {
760  if (block.empty()) {
761  continue;
762  }
763 
764  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
765  returnOp = &block.back();
766  break;
767  }
768  }
769  if (returnOp) {
770  op.setType(FunctionType::get(
771  op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
772  returnOp->getOperandTypes()));
773  }
774 
775  return success();
776 }
777 
778 namespace {
779 
780 struct Partition : public impl::PartitionBase<Partition> {
781  void runOnOperation() override {
782  IRMapping partitionMap;
783  SymbolTableCollection symbolTableCollection;
784  if (failed(partitionFuncOp(getOperation(), partitionMap,
785  symbolTableCollection))) {
786  return signalPassFailure();
787  }
788  }
789 
790  void getDependentDialects(DialectRegistry &registry) const override {
792  registry.insert<shard::ShardDialect>();
793  }
794 };
795 
796 } // namespace
797 
798 } // namespace mlir::shard
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
MLIRContext * getContext() const
Definition: Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:656
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:421
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: ShardOps.cpp:770
bool equalSplitAxes(const Sharding &rhs) const
Definition: ShardOps.cpp:689
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition: ShardOps.h:60
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: ShardOps.h:64
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: ShardOps.h:63
ArrayRef< Value > getDynamicHaloSizes() const
Definition: ShardOps.h:67
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition: ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition: ShardOps.cpp:728
shard::Sharding Sharding
shard::GridOp GridOp
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
int16_t GridAxis
Definition: ShardOps.h:26
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
Definition: Partition.cpp:87
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition: ShardOps.cpp:281
static LogicalResult partitionOperation(Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
Definition: Partition.cpp:669
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
Definition: Partition.cpp:180
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Partition.cpp:119
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
Definition: Partition.cpp:214
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
Definition: Partition.cpp:235
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
Definition: Partition.cpp:731
static TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
Definition: Partition.cpp:481
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
Definition: Partition.cpp:366
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition: Partition.cpp:68
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
Definition: Partition.cpp:345
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Partition.cpp:39
bool isFullReplication(Sharding sharding)
Definition: ShardOps.h:106
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition: Partition.cpp:188
static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
Definition: Partition.cpp:439
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
Definition: Partition.cpp:703
static std::vector< Sharding > getOperandShardings(Operation &op)
Definition: Partition.cpp:591
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
Definition: Partition.cpp:317
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
Definition: Partition.cpp:539
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
Definition: Partition.cpp:276
static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
Definition: Partition.cpp:164
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
Definition: Partition.cpp:136
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: ShardOps.h:168
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
Definition: Partition.cpp:505
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition: Partition.cpp:46
void reshardingRegisterDependentDialects(DialectRegistry &registry)
Definition: Partition.cpp:526
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
Definition: Partition.cpp:304
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:121
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: ShardOps.h:177
static std::vector< Sharding > getResultShardings(Operation &op)
Definition: Partition.cpp:611
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".