MLIR  21.0.0git
Spmdization.cpp
Go to the documentation of this file.
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
36 #include <iterator>
37 #include <optional>
38 #include <tuple>
39 #include <type_traits>
40 
41 namespace mlir::mesh {
42 
43 template <typename SourceAxes, typename TargetAxes>
44 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
45  const TargetAxes &targetAxes) {
46  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
47  return sourceAxes.contains(targetAxis);
48  });
49 }
50 
51 // Return the reduced value and its corresponding sharding.
52 // Example:
53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
54 // targetSharding = <@mesh_1d, [[]]>
55 // Then will apply all-reduce on the source value
56 // and return it with the sharding <@mesh_1d, [[0]]>.
57 static std::tuple<TypedValue<ShapedType>, MeshSharding>
59  MeshSharding sourceSharding,
60  MeshSharding targetSharding,
61  TypedValue<ShapedType> sourceShard) {
62  if (sourceSharding.getPartialAxes().empty() &&
63  targetSharding.getPartialAxes().empty()) {
64  return {sourceShard, sourceSharding};
65  }
66  assert(targetSharding.getPartialAxes().empty() ||
67  (!sourceSharding.getPartialAxes().empty() &&
68  sourceSharding.getPartialType() == targetSharding.getPartialType()));
69  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70  using AxisSet = llvm::SmallDenseSet<Axis>;
71  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72  sourceSharding.getPartialAxes().end());
73  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74  targetSharding.getPartialAxes().end());
75  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
76  targetShardingPartialAxesSet));
77  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
78  llvm::copy_if(sourceShardingPartialAxesSet,
79  std::back_inserter(allReduceMeshAxes),
80  [&targetShardingPartialAxesSet](Axis a) {
81  return !targetShardingPartialAxesSet.contains(a);
82  });
83  if (allReduceMeshAxes.empty()) {
84  return {sourceShard, sourceSharding};
85  }
86 
87  builder.setInsertionPointAfterValue(sourceShard);
88  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
89  builder
90  .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91  sourceSharding.getMeshAttr().getLeafReference(),
92  allReduceMeshAxes, sourceShard,
93  sourceSharding.getPartialType())
94  .getResult());
95 
96  llvm::SmallVector<MeshAxis> remainingPartialAxes;
97  llvm::copy_if(sourceShardingPartialAxesSet,
98  std::back_inserter(allReduceMeshAxes),
99  [&targetShardingPartialAxesSet](Axis a) {
100  return targetShardingPartialAxesSet.contains(a);
101  });
102  MeshSharding resultSharding = MeshSharding::get(
103  sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104  remainingPartialAxes, sourceSharding.getPartialType());
105  return {resultValue, resultSharding};
106 }
107 
109  MeshSharding sourceSharding,
110  int64_t splitTensorAxis,
111  MeshAxis splitMeshAxis) {
112  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
113  llvm::to_vector(sourceSharding.getSplitAxes());
114  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
115  splitTensorAxis) {
116  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
117  }
118  auto targetSplitAxes =
119  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120  targetSplitAxes.push_back(splitMeshAxis);
121  targetShardingSplitAxes[splitTensorAxis] =
122  MeshAxesAttr::get(ctx, targetSplitAxes);
123  return MeshSharding::get(
124  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
125  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
126 }
127 
128 // Split a replicated tensor along a mesh axis.
129 // E.g. [[0, 1]] -> [[0, 1, 2]].
130 // Returns the spmdized target value with its sharding.
131 static std::tuple<TypedValue<ShapedType>, MeshSharding>
133  MeshSharding sourceSharding,
134  TypedValue<ShapedType> sourceShard, MeshOp mesh,
135  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
137  builder
138  .create<AllSliceOp>(sourceShard, mesh,
139  ArrayRef<MeshAxis>(splitMeshAxis),
140  splitTensorAxis)
141  .getResult());
143  builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144  return {targetShard, targetSharding};
145 }
146 
147 // Detect if the resharding is of type e.g.
148 // [[0, 1]] -> [[0, 1, 2]].
149 // If detected, returns the corresponding tensor axis mesh axis pair.
150 // Does not detect insertions like
151 // [[0, 1]] -> [[0, 2, 1]].
152 static std::optional<std::tuple<int64_t, MeshAxis>>
154  MeshSharding targetSharding) {
155  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156  ++tensorAxis) {
157  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159  targetSharding.getSplitAxes()[tensorAxis].size()) {
160  continue;
161  }
162  if (!llvm::equal(
163  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
164  llvm::make_range(
165  targetSharding.getSplitAxes()[tensorAxis]
166  .asArrayRef()
167  .begin(),
168  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
169  1))) {
170  continue;
171  }
172  } else {
173  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
174  continue;
175  }
176  }
177  return std::make_tuple(
178  tensorAxis,
179  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
180  }
181  return std::nullopt;
182 }
183 
184 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
186  MeshSharding sourceSharding,
187  MeshSharding targetSharding,
188  TypedValue<ShapedType> sourceShard) {
189  if (auto detectRes =
190  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
191  auto [tensorAxis, meshAxis] = detectRes.value();
192  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
193  tensorAxis, meshAxis);
194  }
195 
196  return std::nullopt;
197 }
198 
199 // Detect if the resharding is of type e.g.
200 // [[0, 1, 2]] -> [[0, 1]].
201 // If detected, returns the corresponding tensor axis mesh axis pair.
202 static std::optional<std::tuple<int64_t, MeshAxis>>
204  MeshSharding targetSharding) {
205  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206  ++tensorAxis) {
207  if (targetSharding.getSplitAxes().size() > tensorAxis) {
208  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
210  continue;
211  if (!llvm::equal(
212  llvm::make_range(
213  sourceSharding.getSplitAxes()[tensorAxis]
214  .asArrayRef()
215  .begin(),
216  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
217  1),
218  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
219  continue;
220  } else {
221  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
222  continue;
223  }
224  return std::make_tuple(
225  tensorAxis,
226  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
227  }
228  return std::nullopt;
229 }
230 
232  MeshSharding sourceSharding,
233  int64_t splitTensorAxis) {
234  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
235  llvm::to_vector(sourceSharding.getSplitAxes());
236  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
237  splitTensorAxis);
238  auto targetSplitAxes =
239  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
240 
241  targetSplitAxes.pop_back();
242  targetShardingSplitAxes[splitTensorAxis] =
243  MeshAxesAttr::get(ctx, targetSplitAxes);
244  return MeshSharding::get(
245  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
246  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
247 }
248 
250  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
251  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
252  targetShape[splitTensorAxis] =
253  gatherDimension(targetShape[splitTensorAxis], splitCount);
254  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
255 }
256 
257 static std::tuple<TypedValue<ShapedType>, MeshSharding>
259  MeshSharding sourceSharding,
260  ShapedType sourceUnshardedShape,
261  TypedValue<ShapedType> sourceShard, MeshOp mesh,
262  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
263  MLIRContext *ctx = builder.getContext();
264  builder.setInsertionPointAfterValue(sourceShard);
265 
266  MeshSharding targetSharding =
267  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
268  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
269  sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270  Value allGatherResult = builder.create<AllGatherOp>(
271  RankedTensorType::get(allGatherResultShape.getShape(),
272  allGatherResultShape.getElementType()),
273  mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
274  APInt(64, splitTensorAxis));
275  ShapedType targetShape =
276  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
277  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
278  builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279  return {targetShard, targetSharding};
280 }
281 
282 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
284  MeshSharding sourceSharding,
285  MeshSharding targetSharding,
286  ShapedType sourceUnshardedShape,
287  TypedValue<ShapedType> sourceShard) {
288  if (auto detectRes =
289  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
290  auto [tensorAxis, meshAxis] = detectRes.value();
291  return unsplitLastAxisInResharding(builder, sourceSharding,
292  sourceUnshardedShape, sourceShard, mesh,
293  tensorAxis, meshAxis);
294  }
295 
296  return std::nullopt;
297 }
298 
299 // Detect if the resharding is of type e.g.
300 // [[0, 1], [2]] -> [[0], [1, 2]].
301 // Only moving the last axis counts.
302 // If detected, returns the corresponding (source_tensor_axis,
303 // target_tensor_axis, mesh_axis) tuple.
304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
306  MeshSharding targetSharding) {
307  for (size_t sourceTensorAxis = 0;
308  sourceTensorAxis < sourceSharding.getSplitAxes().size();
309  ++sourceTensorAxis) {
310  for (size_t targetTensorAxis = 0;
311  targetTensorAxis < targetSharding.getSplitAxes().size();
312  ++targetTensorAxis) {
313  if (sourceTensorAxis == targetTensorAxis)
314  continue;
315  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
316  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
317  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
318  targetSharding.getSplitAxes()[targetTensorAxis]
319  .asArrayRef()
320  .back())
321  continue;
322  if (!llvm::equal(
323  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
324  .asArrayRef()
325  .begin(),
326  sourceSharding.getSplitAxes()[sourceTensorAxis]
327  .asArrayRef()
328  .end() -
329  1),
330  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
331  .asArrayRef()
332  .begin(),
333  targetSharding.getSplitAxes()[targetTensorAxis]
334  .asArrayRef()
335  .end() -
336  1)))
337  continue;
338  return std::make_tuple(
339  sourceTensorAxis, targetTensorAxis,
340  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
341  }
342  }
343  return std::nullopt;
344 }
345 
347  MeshSharding sourceSharding,
348  int64_t sourceTensorAxis,
349  int64_t targetTensorAxis) {
350  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
351  llvm::to_vector(sourceSharding.getSplitAxes());
352  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
353  targetTensorAxis) {
354  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
355  }
356 
357  auto sourceSplitAxes =
358  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359  assert(!sourceSplitAxes.empty());
360  auto meshAxis = sourceSplitAxes.back();
361  sourceSplitAxes.pop_back();
362  targetShardingSplitAxes[sourceTensorAxis] =
363  MeshAxesAttr::get(ctx, sourceSplitAxes);
364 
365  auto targetSplitAxes =
366  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367  targetSplitAxes.push_back(meshAxis);
368  targetShardingSplitAxes[targetTensorAxis] =
369  MeshAxesAttr::get(ctx, targetSplitAxes);
370 
371  return MeshSharding::get(
372  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
373  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
374 }
375 
376 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
377  int64_t splitCount,
378  int64_t sourceTensorAxis,
379  int64_t targetTensorAxis) {
380  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
381  targetShape[sourceTensorAxis] =
382  gatherDimension(targetShape[sourceTensorAxis], splitCount);
383  targetShape[targetTensorAxis] =
384  shardDimension(targetShape[targetTensorAxis], splitCount);
385  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
386 }
387 
388 static std::tuple<TypedValue<ShapedType>, MeshSharding>
390  MeshSharding sourceSharding,
391  ShapedType sourceUnshardedShape,
392  TypedValue<ShapedType> sourceShard,
393  int64_t sourceTensorAxis,
394  int64_t targetTensorAxis, MeshAxis meshAxis) {
395  MLIRContext *ctx = builder.getContext();
396  builder.setInsertionPointAfterValue(sourceShard);
397 
399  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
400  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
401  sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
402  targetTensorAxis);
403  Value allToAllResult = builder.create<AllToAllOp>(
404  RankedTensorType::get(allToAllResultShape.getShape(),
405  allToAllResultShape.getElementType()),
406  mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
407  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408  ShapedType targetShape =
409  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
410  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
411  builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412  return {targetShard, targetSharding};
413 }
414 
415 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
417  MeshSharding sourceSharding,
418  MeshSharding targetSharding,
419  ShapedType sourceUnshardedShape,
420  TypedValue<ShapedType> sourceShard) {
421  if (auto detectRes =
422  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
423  auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
425  builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426  sourceTensorAxis, targetTensorAxis, meshAxis);
427  }
428 
429  return std::nullopt;
430 }
431 
432 // Detect a change in the halo size (only) and create necessary operations if
433 // needed. A changed halo sizes requires copying the "core" of the source tensor
434 // into the "core" of the destination tensor followed by an update halo
435 // operation.
436 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
438  MeshSharding sourceSharding,
439  MeshSharding targetSharding,
440  ShapedType sourceUnshardedShape,
441  TypedValue<ShapedType> sourceShard) {
442  // Currently handles only cases where halo sizes differ but everything else
443  // stays the same (from source to destination sharding).
444  if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
445  !sourceSharding.getPartialAxes().empty() ||
446  !targetSharding.getPartialAxes().empty() ||
447  !sourceSharding.getStaticShardedDimsOffsets().empty() ||
448  !targetSharding.getStaticShardedDimsOffsets().empty() ||
449  sourceSharding.equalHaloSizes(targetSharding)) {
450  return std::nullopt;
451  }
452 
453  auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
454  auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
455  assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456  assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457  !ShapedType::isDynamicShape(tgtHaloSizes) &&
458  sourceShard.getType().hasStaticShape()) &&
459  "dynamic shapes/halos are not supported yet for mesh-spmdization");
460  auto rank = sourceShard.getType().getRank();
461  auto splitAxes = sourceSharding.getSplitAxes();
462  SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
463  strides(rank, 1), outShape(sourceShard.getType().getShape()),
464  coreShape(sourceShard.getType().getShape());
465 
466  // Determine "core" of source and destination.
467  // The core is the local part of the shard excluding halo regions.
468  for (auto i = 0u; i < rank; ++i) {
469  if (i < splitAxes.size() && !splitAxes[i].empty()) {
470  if (!srcHaloSizes.empty()) {
471  coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472  srcCoreOffs[i] = srcHaloSizes[i * 2];
473  }
474  tgtCoreOffs[i] = tgtHaloSizes[i * 2];
475  outShape[i] =
476  coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
477  }
478  }
479 
480  // Extract core from source and copy into destination core.
481  auto noVals = ValueRange{};
482  auto initVal = builder.create<tensor::EmptyOp>(
483  sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484  auto core = builder.create<tensor::ExtractSliceOp>(
485  sourceShard.getLoc(),
486  RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
487  sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488  auto initOprnd = builder.create<tensor::InsertSliceOp>(
489  sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
490  coreShape, strides);
491 
492  // Finally update the halo.
493  auto updateHaloResult =
494  builder
495  .create<UpdateHaloOp>(
496  sourceShard.getLoc(),
497  RankedTensorType::get(outShape,
498  sourceShard.getType().getElementType()),
499  initOprnd, mesh.getSymName(),
501  sourceSharding.getSplitAxes()),
502  targetSharding.getDynamicHaloSizes(),
503  targetSharding.getStaticHaloSizes())
504  .getResult();
505  return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
506  targetSharding);
507 }
508 
509 // Handles only resharding on a 1D mesh.
510 // Currently the sharded tensor axes must be exactly divisible by the single
511 // mesh axis size.
514  MeshSharding sourceSharding, MeshSharding targetSharding,
515  TypedValue<ShapedType> sourceUnshardedValue,
516  TypedValue<ShapedType> sourceShard) {
517  assert(sourceShard.getType() ==
518  shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
519  [[maybe_unused]] ShapedType targetShardType =
520  shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
521  assert(sourceShard.getType().getRank() == targetShardType.getRank());
522  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
523 
524  auto [reducedSourceShard, reducedSourceSharding] =
525  handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
526  sourceShard);
527 
528  if (reducedSourceSharding == targetSharding) {
529  return reducedSourceShard;
530  }
531 
532  TypedValue<ShapedType> targetShard;
533  MeshSharding actualTargetSharding;
534  if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
535  targetSharding.getStaticShardedDimsOffsets().empty() &&
536  reducedSourceSharding.getStaticHaloSizes().empty() &&
537  targetSharding.getStaticHaloSizes().empty()) {
538  if (auto tryRes = tryMoveLastSplitAxisInResharding(
539  builder, mesh, reducedSourceSharding, targetSharding,
540  sourceUnshardedValue.getType(), reducedSourceShard)) {
541  std::tie(targetShard, actualTargetSharding) = tryRes.value();
542  } else if (auto tryRes = trySplitLastAxisInResharding(
543  builder, mesh, reducedSourceSharding, targetSharding,
544  reducedSourceShard)) {
545  std::tie(targetShard, actualTargetSharding) = tryRes.value();
546  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
547  builder, mesh, reducedSourceSharding, targetSharding,
548  sourceUnshardedValue.getType(), reducedSourceShard)) {
549  std::tie(targetShard, actualTargetSharding) = tryRes.value();
550  }
551  }
552  assert(targetShard && "Did not find any pattern to apply.");
553  assert(actualTargetSharding == targetSharding);
554  assert(targetShard.getType() == targetShardType);
555  return targetShard;
556 }
557 
559  MeshSharding sourceSharding,
560  MeshSharding targetSharding,
561  TypedValue<ShapedType> sourceUnshardedValue,
562  TypedValue<ShapedType> sourceShard) {
563  // If source and destination sharding are the same, no need to do anything.
564  if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
565  isFullReplication(targetSharding))) {
566  return sourceShard;
567  }
568 
569  // Tries to handle the case where the resharding is needed because the halo
570  // sizes are different. Supports arbitrary mesh dimensionality.
571  if (auto tryRes = tryUpdateHaloInResharding(
572  builder, mesh, sourceSharding, targetSharding,
573  sourceUnshardedValue.getType(), sourceShard)) {
574  return std::get<0>(tryRes.value()); // targetShard
575  }
576 
577  // Resort to handling only 1D meshes since the general case is complicated if
578  // it needs to be communication efficient in terms of minimizing the data
579  // transfered between devices.
580  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
581  sourceUnshardedValue, sourceShard);
582 }
583 
584 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
585  ShardOp target,
586  TypedValue<ShapedType> sourceShardValue) {
587  assert(source.getResult() == target.getSrc());
588  auto sourceSharding = source.getSharding();
589  auto targetSharding = target.getSharding();
590  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
591  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
592  cast<TypedValue<ShapedType>>(source.getSrc()),
593  sourceShardValue);
594 }
595 
596 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
597  ShardOp target,
598  TypedValue<ShapedType> sourceShardValue,
599  SymbolTableCollection &symbolTableCollection) {
600  MeshOp srcMesh = getMesh(source, symbolTableCollection);
601  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
602  return reshard(builder, srcMesh, source, target, sourceShardValue);
603 }
604 
606  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
607 }
608 
609 #define GEN_PASS_DEF_SPMDIZATION
610 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
611 
613 
614 // Get the types of block arguments for an spmdized block.
615 // Reads the sharding annotations of the arguments to deduce the sharded types.
616 // Types that are not ranked tensors are left unchanged.
619  SymbolTableCollection &symbolTableCollection) {
620  SmallVector<Type> res;
621  llvm::transform(
622  block.getArguments(), std::back_inserter(res),
623  [&symbolTableCollection](BlockArgument arg) {
624  auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
625  if (!rankedTensorArg) {
626  return arg.getType();
627  }
628 
629  assert(rankedTensorArg.hasOneUse());
630  Operation *useOp = *rankedTensorArg.getUsers().begin();
631  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
632  assert(shardOp);
633  MeshOp mesh = getMesh(shardOp, symbolTableCollection);
634  return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
635  shardOp.getSharding()));
636  });
637  return res;
638 }
639 
640 static LogicalResult spmdizeOperation(
641  Operation &op, ArrayRef<Value> spmdizedOperands,
642  ArrayRef<MeshSharding> operandShardings,
643  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
644  SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
645  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
646  if (!shardingInterface) {
647  // If there is no sharding interface we are conservative and assume that
648  // the op should be fully replicated no all devices.
649  spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
650  resultShardings, spmdizationMap,
651  symbolTableCollection, builder);
652  } else {
653  if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
654  resultShardings, spmdizationMap,
655  symbolTableCollection, builder))) {
656  return failure();
657  }
658  }
659 
660  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
661  return spmdizationMap.contains(result);
662  }));
663 
664  return success();
665 }
666 
667 // Retrieve the sharding annotations for the operands of the given operation.
668 // If the type is not a ranked tensor it is not require to have an annotation.
669 static std::vector<MeshSharding> getOperandShardings(Operation &op) {
670  std::vector<MeshSharding> res;
671  res.reserve(op.getNumOperands());
672  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
673  TypedValue<RankedTensorType> rankedTensor =
674  dyn_cast<TypedValue<RankedTensorType>>(operand);
675  if (!rankedTensor) {
676  return MeshSharding();
677  }
678 
679  Operation *definingOp = operand.getDefiningOp();
680  assert(definingOp);
681  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
682  return MeshSharding(shardOp.getSharding());
683  });
684  return res;
685 }
686 
687 // Retrieve the sharding annotations for the results of the given operation.
688 // If the type is not a ranked tensor it is not require to have an annotation.
689 static std::vector<MeshSharding> getResultShardings(Operation &op) {
690  std::vector<MeshSharding> res;
691  res.reserve(op.getNumResults());
692  llvm::transform(op.getResults(), std::back_inserter(res),
693  [](OpResult result) {
694  TypedValue<RankedTensorType> rankedTensor =
695  dyn_cast<TypedValue<RankedTensorType>>(result);
696  if (!rankedTensor) {
697  return MeshSharding();
698  }
699  if (!result.hasOneUse()) {
700  return MeshSharding();
701  }
702  Operation *userOp = *result.getUsers().begin();
703  ShardOp shardOp = llvm::cast<ShardOp>(userOp);
704  return MeshSharding(shardOp.getSharding());
705  });
706  return res;
707 }
708 
709 static LogicalResult
710 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
711  SymbolTableCollection &symbolTableCollection,
712  OpBuilder &builder) {
713  Value targetSpmdValue;
714 
715  // Check if 2 shard ops are chained. If not there is no need for resharding
716  // as the source and target shared the same sharding.
717  ShardOp srcShardOp =
718  dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
719  if (!srcShardOp) {
720  targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
721  } else {
722  // Insert resharding.
723  TypedValue<ShapedType> srcSpmdValue =
724  cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
725  targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
726  symbolTableCollection);
727  }
728 
729  assert(!spmdizationMap.contains(shardOp.getResult()));
730  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
731  return success();
732 }
733 
734 static LogicalResult
735 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
736  SymbolTableCollection &symbolTableCollection,
737  OpBuilder &builder) {
738  if (isa<ShardingOp>(op)) {
739  return success();
740  }
741  if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
742  auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
743  if (!shardOp) {
744  return op.emitError("expected a shard op as source of get_sharding");
745  }
746  auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
747  spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
748  return success();
749  }
750 
751  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
752  if (shardOp) {
753  return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
754  builder);
755  }
756 
757  SmallVector<Value> spmdizedOperands;
758  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
759  [&spmdizationMap](Value operand) {
760  assert(spmdizationMap.contains(operand));
761  return spmdizationMap.lookup(operand);
762  });
763  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
764  getResultShardings(op), spmdizationMap,
765  symbolTableCollection, builder);
766 }
767 
768 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
769  SymbolTableCollection &symbolTableCollection,
770  OpBuilder &builder) {
771 
772  SmallVector<Location> argLocations;
773  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
774  [](BlockArgument arg) { return arg.getLoc(); });
775  Block *newBlock = builder.createBlock(
776  block.getParent(), {},
777  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
778  for (auto [unshardedBlockArg, spmdizedBlockArg] :
779  llvm::zip(block.getArguments(), newBlock->getArguments())) {
780  spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
781  }
782 
783  OpBuilder::InsertionGuard insertionGuard(builder);
784  builder.setInsertionPointToEnd(newBlock);
785  for (Operation &op : block.getOperations()) {
786  if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
787  builder))) {
788  return failure();
789  }
790  }
791 
792  return success();
793 }
794 
795 static LogicalResult
796 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
797  SymbolTableCollection &symbolTableCollection) {
798  OpBuilder builder(op.getFunctionBody());
799 
800  // Snapshot the original blocks to not mess up the iteration when adding new
801  // blocks.
802  SmallVector<Block *> originalBlocks;
803  for (Block &b : op.getBlocks()) {
804  if (llvm::any_of(b.getOperations(),
805  [](Operation &op) { return isa<ShardOp>(op); })) {
806  originalBlocks.push_back(&b);
807  }
808  }
809 
810  for (Block *block : originalBlocks) {
811  if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
812  builder))) {
813  return failure();
814  }
815  }
816 
817  for (Block *block : originalBlocks) {
818  block->erase();
819  }
820 
821  // Find a return op and change the function results signature to its operands
822  // signature.
823  Operation *returnOp = nullptr;
824  for (Block &block : op.getFunctionBody()) {
825  if (block.empty()) {
826  continue;
827  }
828 
829  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
830  returnOp = &block.back();
831  break;
832  }
833  }
834  if (returnOp) {
835  op.setType(FunctionType::get(
836  op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
837  returnOp->getOperandTypes()));
838  }
839 
840  return success();
841 }
842 
843 namespace {
844 
845 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
846  void runOnOperation() override {
847  IRMapping spmdizationMap;
848  SymbolTableCollection symbolTableCollection;
849  if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
850  symbolTableCollection))) {
851  return signalPassFailure();
852  }
853  }
854 
855  void getDependentDialects(DialectRegistry &registry) const override {
857  registry.insert<mesh::MeshDialect>();
858  }
859 };
860 
861 } // namespace
862 
863 } // namespace mlir::mesh
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
MLIRContext * getContext() const
Definition: Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:689
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:744
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:796
mesh::MeshSharding MeshSharding
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
static std::tuple< TypedValue< ShapedType >, MeshSharding > handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Spmdization.cpp:58
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::vector< MeshSharding > getResultShardings(Operation &op)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Spmdization.cpp:44
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:182
static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
static std::tuple< TypedValue< ShapedType >, MeshSharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:260
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:173
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:112
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::tuple< TypedValue< ShapedType >, MeshSharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static std::vector< MeshSharding > getOperandShardings(Operation &op)
int16_t MeshAxis
Definition: MeshOps.h:26
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshSharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".