MLIR  20.0.0git
Spmdization.cpp
Go to the documentation of this file.
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
36 #include <iterator>
37 #include <optional>
38 #include <tuple>
39 #include <type_traits>
40 
41 namespace mlir::mesh {
42 
43 template <typename SourceAxes, typename TargetAxes>
44 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
45  const TargetAxes &targetAxes) {
46  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
47  return sourceAxes.contains(targetAxis);
48  });
49 }
50 
51 // Return the reduced value and its corresponding sharding.
52 // Example:
53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
54 // targetSharding = <@mesh_1d, [[]]>
55 // Then will apply all-reduce on the source value
56 // and return it with the sharding <@mesh_1d, [[0]]>.
57 static std::tuple<TypedValue<ShapedType>, MeshSharding>
59  MeshSharding sourceSharding,
60  MeshSharding targetSharding,
61  TypedValue<ShapedType> sourceShard) {
62  if (sourceSharding.getPartialAxes().empty() &&
63  targetSharding.getPartialAxes().empty()) {
64  return {sourceShard, sourceSharding};
65  }
66  assert(targetSharding.getPartialAxes().empty() ||
67  (!sourceSharding.getPartialAxes().empty() &&
68  sourceSharding.getPartialType() == targetSharding.getPartialType()));
69  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70  using AxisSet = llvm::SmallDenseSet<Axis>;
71  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72  sourceSharding.getPartialAxes().end());
73  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74  targetSharding.getPartialAxes().end());
75  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
76  targetShardingPartialAxesSet));
77  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
78  llvm::copy_if(sourceShardingPartialAxesSet,
79  std::back_inserter(allReduceMeshAxes),
80  [&targetShardingPartialAxesSet](Axis a) {
81  return !targetShardingPartialAxesSet.contains(a);
82  });
83  if (allReduceMeshAxes.empty()) {
84  return {sourceShard, sourceSharding};
85  }
86 
87  builder.setInsertionPointAfterValue(sourceShard);
88  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
89  builder
90  .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91  sourceSharding.getMeshAttr().getLeafReference(),
92  allReduceMeshAxes, sourceShard,
93  sourceSharding.getPartialType())
94  .getResult());
95 
96  llvm::SmallVector<MeshAxis> remainingPartialAxes;
97  llvm::copy_if(sourceShardingPartialAxesSet,
98  std::back_inserter(allReduceMeshAxes),
99  [&targetShardingPartialAxesSet](Axis a) {
100  return targetShardingPartialAxesSet.contains(a);
101  });
102  MeshSharding resultSharding = MeshSharding::get(
103  sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104  remainingPartialAxes, sourceSharding.getPartialType());
105  return {resultValue, resultSharding};
106 }
107 
109  MeshSharding sourceSharding,
110  int64_t splitTensorAxis,
111  MeshAxis splitMeshAxis) {
112  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
113  llvm::to_vector(sourceSharding.getSplitAxes());
114  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
115  splitTensorAxis) {
116  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
117  }
118  auto targetSplitAxes =
119  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120  targetSplitAxes.push_back(splitMeshAxis);
121  targetShardingSplitAxes[splitTensorAxis] =
122  MeshAxesAttr::get(ctx, targetSplitAxes);
123  return MeshSharding::get(
124  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
125  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
126 }
127 
128 // Split a replicated tensor along a mesh axis.
129 // E.g. [[0, 1]] -> [[0, 1, 2]].
130 // Returns the spmdized target value with its sharding.
131 static std::tuple<TypedValue<ShapedType>, MeshSharding>
133  MeshSharding sourceSharding,
134  TypedValue<ShapedType> sourceShard, MeshOp mesh,
135  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
137  builder
138  .create<AllSliceOp>(sourceShard, mesh,
139  ArrayRef<MeshAxis>(splitMeshAxis),
140  splitTensorAxis)
141  .getResult());
143  builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144  return {targetShard, targetSharding};
145 }
146 
147 // Detect if the resharding is of type e.g.
148 // [[0, 1]] -> [[0, 1, 2]].
149 // If detected, returns the corresponding tensor axis mesh axis pair.
150 // Does not detect insertions like
151 // [[0, 1]] -> [[0, 2, 1]].
152 static std::optional<std::tuple<int64_t, MeshAxis>>
154  MeshSharding targetSharding) {
155  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156  ++tensorAxis) {
157  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159  targetSharding.getSplitAxes()[tensorAxis].size()) {
160  continue;
161  }
162  if (!llvm::equal(
163  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
164  llvm::make_range(
165  targetSharding.getSplitAxes()[tensorAxis]
166  .asArrayRef()
167  .begin(),
168  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
169  1))) {
170  continue;
171  }
172  } else {
173  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
174  continue;
175  }
176  }
177  return std::make_tuple(
178  tensorAxis,
179  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
180  }
181  return std::nullopt;
182 }
183 
184 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
186  MeshSharding sourceSharding,
187  MeshSharding targetSharding,
188  TypedValue<ShapedType> sourceShard) {
189  if (auto detectRes =
190  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
191  auto [tensorAxis, meshAxis] = detectRes.value();
192  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
193  tensorAxis, meshAxis);
194  }
195 
196  return std::nullopt;
197 }
198 
199 // Detect if the resharding is of type e.g.
200 // [[0, 1, 2]] -> [[0, 1]].
201 // If detected, returns the corresponding tensor axis mesh axis pair.
202 static std::optional<std::tuple<int64_t, MeshAxis>>
204  MeshSharding targetSharding) {
205  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206  ++tensorAxis) {
207  if (targetSharding.getSplitAxes().size() > tensorAxis) {
208  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
210  continue;
211  if (!llvm::equal(
212  llvm::make_range(
213  sourceSharding.getSplitAxes()[tensorAxis]
214  .asArrayRef()
215  .begin(),
216  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
217  1),
218  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
219  continue;
220  } else {
221  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
222  continue;
223  }
224  return std::make_tuple(
225  tensorAxis,
226  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
227  }
228  return std::nullopt;
229 }
230 
232  MeshSharding sourceSharding,
233  int64_t splitTensorAxis) {
234  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
235  llvm::to_vector(sourceSharding.getSplitAxes());
236  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
237  splitTensorAxis);
238  auto targetSplitAxes =
239  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
240 
241  targetSplitAxes.pop_back();
242  targetShardingSplitAxes[splitTensorAxis] =
243  MeshAxesAttr::get(ctx, targetSplitAxes);
244  return MeshSharding::get(
245  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
246  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
247 }
248 
250  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
251  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
252  targetShape[splitTensorAxis] =
253  gatherDimension(targetShape[splitTensorAxis], splitCount);
254  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
255 }
256 
257 static std::tuple<TypedValue<ShapedType>, MeshSharding>
259  MeshSharding sourceSharding,
260  ShapedType sourceUnshardedShape,
261  TypedValue<ShapedType> sourceShard, MeshOp mesh,
262  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
263  MLIRContext *ctx = builder.getContext();
264  builder.setInsertionPointAfterValue(sourceShard);
265 
266  MeshSharding targetSharding =
267  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
268  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
269  sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270  Value allGatherResult = builder.create<AllGatherOp>(
271  RankedTensorType::get(allGatherResultShape.getShape(),
272  allGatherResultShape.getElementType()),
273  mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
274  APInt(64, splitTensorAxis));
275  ShapedType targetShape =
276  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
277  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
278  builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279  return {targetShard, targetSharding};
280 }
281 
282 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
284  MeshSharding sourceSharding,
285  MeshSharding targetSharding,
286  ShapedType sourceUnshardedShape,
287  TypedValue<ShapedType> sourceShard) {
288  if (auto detectRes =
289  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
290  auto [tensorAxis, meshAxis] = detectRes.value();
291  return unsplitLastAxisInResharding(builder, sourceSharding,
292  sourceUnshardedShape, sourceShard, mesh,
293  tensorAxis, meshAxis);
294  }
295 
296  return std::nullopt;
297 }
298 
299 // Detect if the resharding is of type e.g.
300 // [[0, 1], [2]] -> [[0], [1, 2]].
301 // Only moving the last axis counts.
302 // If detected, returns the corresponding (source_tensor_axis,
303 // target_tensor_axis, mesh_axis) tuple.
304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
306  MeshSharding targetSharding) {
307  for (size_t sourceTensorAxis = 0;
308  sourceTensorAxis < sourceSharding.getSplitAxes().size();
309  ++sourceTensorAxis) {
310  for (size_t targetTensorAxis = 0;
311  targetTensorAxis < targetSharding.getSplitAxes().size();
312  ++targetTensorAxis) {
313  if (sourceTensorAxis == targetTensorAxis)
314  continue;
315  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
316  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
317  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
318  targetSharding.getSplitAxes()[targetTensorAxis]
319  .asArrayRef()
320  .back())
321  continue;
322  if (!llvm::equal(
323  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
324  .asArrayRef()
325  .begin(),
326  sourceSharding.getSplitAxes()[sourceTensorAxis]
327  .asArrayRef()
328  .end() -
329  1),
330  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
331  .asArrayRef()
332  .begin(),
333  targetSharding.getSplitAxes()[targetTensorAxis]
334  .asArrayRef()
335  .end() -
336  1)))
337  continue;
338  return std::make_tuple(
339  sourceTensorAxis, targetTensorAxis,
340  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
341  }
342  }
343  return std::nullopt;
344 }
345 
347  MeshSharding sourceSharding,
348  int64_t sourceTensorAxis,
349  int64_t targetTensorAxis) {
350  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
351  llvm::to_vector(sourceSharding.getSplitAxes());
352  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
353  targetTensorAxis) {
354  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
355  }
356 
357  auto sourceSplitAxes =
358  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359  assert(!sourceSplitAxes.empty());
360  auto meshAxis = sourceSplitAxes.back();
361  sourceSplitAxes.pop_back();
362  targetShardingSplitAxes[sourceTensorAxis] =
363  MeshAxesAttr::get(ctx, sourceSplitAxes);
364 
365  auto targetSplitAxes =
366  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367  targetSplitAxes.push_back(meshAxis);
368  targetShardingSplitAxes[targetTensorAxis] =
369  MeshAxesAttr::get(ctx, targetSplitAxes);
370 
371  return MeshSharding::get(
372  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
373  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
374 }
375 
376 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
377  int64_t splitCount,
378  int64_t sourceTensorAxis,
379  int64_t targetTensorAxis) {
380  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
381  targetShape[sourceTensorAxis] =
382  gatherDimension(targetShape[sourceTensorAxis], splitCount);
383  targetShape[targetTensorAxis] =
384  shardDimension(targetShape[targetTensorAxis], splitCount);
385  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
386 }
387 
388 static std::tuple<TypedValue<ShapedType>, MeshSharding>
390  MeshSharding sourceSharding,
391  ShapedType sourceUnshardedShape,
392  TypedValue<ShapedType> sourceShard,
393  int64_t sourceTensorAxis,
394  int64_t targetTensorAxis, MeshAxis meshAxis) {
395  MLIRContext *ctx = builder.getContext();
396  builder.setInsertionPointAfterValue(sourceShard);
397 
399  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
400  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
401  sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
402  targetTensorAxis);
403  Value allToAllResult = builder.create<AllToAllOp>(
404  RankedTensorType::get(allToAllResultShape.getShape(),
405  allToAllResultShape.getElementType()),
406  mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
407  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408  ShapedType targetShape =
409  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
410  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
411  builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412  return {targetShard, targetSharding};
413 }
414 
415 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
417  MeshSharding sourceSharding,
418  MeshSharding targetSharding,
419  ShapedType sourceUnshardedShape,
420  TypedValue<ShapedType> sourceShard) {
421  if (auto detectRes =
422  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
423  auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
425  builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426  sourceTensorAxis, targetTensorAxis, meshAxis);
427  }
428 
429  return std::nullopt;
430 }
431 
432 // Detect a change in the halo size (only) and create necessary operations if
433 // needed. A changed halo sizes requires copying the "core" of the source tensor
434 // into the "core" of the destination tensor followed by an update halo
435 // operation.
436 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
438  MeshSharding sourceSharding,
439  MeshSharding targetSharding,
440  ShapedType sourceUnshardedShape,
441  TypedValue<ShapedType> sourceShard) {
442  // Currently handles only cases where halo sizes differ but everything else
443  // stays the same (from source to destination sharding).
444  if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
445  !sourceSharding.getPartialAxes().empty() ||
446  !targetSharding.getPartialAxes().empty() ||
447  !sourceSharding.getStaticShardedDimsOffsets().empty() ||
448  !targetSharding.getStaticShardedDimsOffsets().empty() ||
449  sourceSharding.equalHaloSizes(targetSharding)) {
450  return std::nullopt;
451  }
452 
453  auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
454  auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
455  assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456  assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457  !ShapedType::isDynamicShape(tgtHaloSizes) &&
458  sourceShard.getType().hasStaticShape()) &&
459  "dynamic shapes/halos are not supported yet for mesh-spmdization");
460  auto rank = sourceShard.getType().getRank();
461  auto splitAxes = sourceSharding.getSplitAxes();
462  SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
463  strides(rank, 1), outShape(sourceShard.getType().getShape()),
464  coreShape(sourceShard.getType().getShape());
465 
466  // Determine "core" of source and destination.
467  // The core is the local part of the shard excluding halo regions.
468  for (auto i = 0u; i < rank; ++i) {
469  if (i < splitAxes.size() && !splitAxes[i].empty()) {
470  if (!srcHaloSizes.empty()) {
471  coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472  srcCoreOffs[i] = srcHaloSizes[i * 2];
473  }
474  tgtCoreOffs[i] = tgtHaloSizes[i * 2];
475  outShape[i] =
476  coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
477  }
478  }
479 
480  // Extract core from source and copy into destination core.
481  auto noVals = ValueRange{};
482  auto initVal = builder.create<tensor::EmptyOp>(
483  sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484  auto core = builder.create<tensor::ExtractSliceOp>(
485  sourceShard.getLoc(),
486  RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
487  sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488  auto initOprnd = builder.create<tensor::InsertSliceOp>(
489  sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
490  coreShape, strides);
491 
492  // Finally update the halo.
493  auto updateHaloResult =
494  builder
495  .create<UpdateHaloOp>(
496  sourceShard.getLoc(),
497  RankedTensorType::get(outShape,
498  sourceShard.getType().getElementType()),
499  sourceShard, initOprnd, mesh.getSymName(),
501  sourceSharding.getSplitAxes()),
502  sourceSharding.getDynamicHaloSizes(),
503  sourceSharding.getStaticHaloSizes(),
504  targetSharding.getDynamicHaloSizes(),
505  targetSharding.getStaticHaloSizes())
506  .getResult();
507  return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
508  targetSharding);
509 }
510 
511 // Handles only resharding on a 1D mesh.
512 // Currently the sharded tensor axes must be exactly divisible by the single
513 // mesh axis size.
516  MeshSharding sourceSharding, MeshSharding targetSharding,
517  TypedValue<ShapedType> sourceUnshardedValue,
518  TypedValue<ShapedType> sourceShard) {
519  assert(sourceShard.getType() ==
520  shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
521  [[maybe_unused]] ShapedType targetShardType =
522  shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
523  assert(sourceShard.getType().getRank() == targetShardType.getRank());
524  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
525 
526  auto [reducedSourceShard, reducedSourceSharding] =
527  handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
528  sourceShard);
529 
530  if (reducedSourceSharding == targetSharding) {
531  return reducedSourceShard;
532  }
533 
534  TypedValue<ShapedType> targetShard;
535  MeshSharding actualTargetSharding;
536  if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
537  targetSharding.getStaticShardedDimsOffsets().empty() &&
538  reducedSourceSharding.getStaticHaloSizes().empty() &&
539  targetSharding.getStaticHaloSizes().empty()) {
540  if (auto tryRes = tryMoveLastSplitAxisInResharding(
541  builder, mesh, reducedSourceSharding, targetSharding,
542  sourceUnshardedValue.getType(), reducedSourceShard)) {
543  std::tie(targetShard, actualTargetSharding) = tryRes.value();
544  } else if (auto tryRes = trySplitLastAxisInResharding(
545  builder, mesh, reducedSourceSharding, targetSharding,
546  reducedSourceShard)) {
547  std::tie(targetShard, actualTargetSharding) = tryRes.value();
548  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
549  builder, mesh, reducedSourceSharding, targetSharding,
550  sourceUnshardedValue.getType(), reducedSourceShard)) {
551  std::tie(targetShard, actualTargetSharding) = tryRes.value();
552  }
553  }
554  assert(targetShard && "Did not find any pattern to apply.");
555  assert(actualTargetSharding == targetSharding);
556  assert(targetShard.getType() == targetShardType);
557  return targetShard;
558 }
559 
561  MeshSharding sourceSharding,
562  MeshSharding targetSharding,
563  TypedValue<ShapedType> sourceUnshardedValue,
564  TypedValue<ShapedType> sourceShard) {
565  // If source and destination sharding are the same, no need to do anything.
566  if (sourceSharding == targetSharding) {
567  return sourceShard;
568  }
569 
570  // Tries to handle the case where the resharding is needed because the halo
571  // sizes are different. Supports arbitrary mesh dimensionality.
572  if (auto tryRes = tryUpdateHaloInResharding(
573  builder, mesh, sourceSharding, targetSharding,
574  sourceUnshardedValue.getType(), sourceShard)) {
575  return std::get<0>(tryRes.value()); // targetShard
576  }
577 
578  // Resort to handling only 1D meshes since the general case is complicated if
579  // it needs to be communication efficient in terms of minimizing the data
580  // transfered between devices.
581  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
582  sourceUnshardedValue, sourceShard);
583 }
584 
585 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
586  ShardOp target,
587  TypedValue<ShapedType> sourceShardValue) {
588  assert(source.getResult() == target.getSrc());
589  auto sourceSharding = source.getSharding();
590  auto targetSharding = target.getSharding();
591  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
592  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
593  cast<TypedValue<ShapedType>>(source.getSrc()),
594  sourceShardValue);
595 }
596 
597 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
598  ShardOp target,
599  TypedValue<ShapedType> sourceShardValue,
600  SymbolTableCollection &symbolTableCollection) {
601  MeshOp srcMesh = getMesh(source, symbolTableCollection);
602  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
603  return reshard(builder, srcMesh, source, target, sourceShardValue);
604 }
605 
607  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
608 }
609 
610 #define GEN_PASS_DEF_SPMDIZATION
611 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
612 
614 
615 // Get the types of block arguments for an spmdized block.
616 // Reads the sharding annotations of the arguments to deduce the sharded types.
617 // Types that are not ranked tensors are left unchanged.
620  SymbolTableCollection &symbolTableCollection) {
621  SmallVector<Type> res;
622  llvm::transform(
623  block.getArguments(), std::back_inserter(res),
624  [&symbolTableCollection](BlockArgument arg) {
625  auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
626  if (!rankedTensorArg) {
627  return arg.getType();
628  }
629 
630  assert(rankedTensorArg.hasOneUse());
631  Operation *useOp = *rankedTensorArg.getUsers().begin();
632  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
633  assert(shardOp);
634  MeshOp mesh = getMesh(shardOp, symbolTableCollection);
635  return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
636  shardOp.getSharding()));
637  });
638  return res;
639 }
640 
641 void spmdizeTriviallyShardableOperation(Operation &op,
642  ArrayRef<Value> spmdizedOperands,
643  ArrayRef<MeshSharding> operandShardings,
644  ArrayRef<MeshSharding> resultShardings,
645  IRMapping &spmdizationMap,
646  SymbolTableCollection &symbolTable,
647  OpBuilder &builder);
648 
649 static LogicalResult spmdizeOperation(
650  Operation &op, ArrayRef<Value> spmdizedOperands,
651  ArrayRef<MeshSharding> operandShardings,
652  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
653  SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
654  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
655  if (!shardingInterface) {
656  // If there is no sharding interface we are conservative and assume that
657  // the op should be fully replicated no all devices.
658  spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
659  resultShardings, spmdizationMap,
660  symbolTableCollection, builder);
661  } else {
662  if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
663  resultShardings, spmdizationMap,
664  symbolTableCollection, builder))) {
665  return failure();
666  }
667  }
668 
669  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
670  return spmdizationMap.contains(result);
671  }));
672 
673  return success();
674 }
675 
676 // Retrieve the sharding annotations for the operands of the given operation.
677 // If the type is not a ranked tensor it is not require to have an annotation.
678 static std::vector<MeshSharding> getOperandShardings(Operation &op) {
679  std::vector<MeshSharding> res;
680  res.reserve(op.getNumOperands());
681  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
682  TypedValue<RankedTensorType> rankedTensor =
683  dyn_cast<TypedValue<RankedTensorType>>(operand);
684  if (!rankedTensor) {
685  return MeshSharding();
686  }
687 
688  Operation *definingOp = operand.getDefiningOp();
689  assert(definingOp);
690  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
691  return MeshSharding(shardOp.getSharding());
692  });
693  return res;
694 }
695 
696 // Retrieve the sharding annotations for the results of the given operation.
697 // If the type is not a ranked tensor it is not require to have an annotation.
698 static std::vector<MeshSharding> getResultShardings(Operation &op) {
699  std::vector<MeshSharding> res;
700  res.reserve(op.getNumResults());
701  llvm::transform(op.getResults(), std::back_inserter(res),
702  [](OpResult result) {
703  TypedValue<RankedTensorType> rankedTensor =
704  dyn_cast<TypedValue<RankedTensorType>>(result);
705  if (!rankedTensor) {
706  return MeshSharding();
707  }
708 
709  assert(result.hasOneUse());
710  Operation *userOp = *result.getUsers().begin();
711  ShardOp shardOp = llvm::cast<ShardOp>(userOp);
712  return MeshSharding(shardOp.getSharding());
713  });
714  return res;
715 }
716 
717 static LogicalResult
718 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
719  SymbolTableCollection &symbolTableCollection,
720  OpBuilder &builder) {
721  Value targetSpmdValue;
722 
723  // Check if 2 shard ops are chained. If not there is no need for resharding
724  // as the source and target shared the same sharding.
725  ShardOp srcShardOp =
726  dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
727  if (!srcShardOp) {
728  targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
729  } else {
730  // Insert resharding.
731  TypedValue<ShapedType> srcSpmdValue =
732  cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
733  targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
734  symbolTableCollection);
735  }
736 
737  assert(!spmdizationMap.contains(shardOp.getResult()));
738  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
739  return success();
740 }
741 
742 static LogicalResult
743 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
744  SymbolTableCollection &symbolTableCollection,
745  OpBuilder &builder) {
746  if (isa<ShardingOp>(op)) {
747  return success();
748  }
749 
750  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
751  if (shardOp) {
752  return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
753  builder);
754  }
755 
756  SmallVector<Value> spmdizedOperands;
757  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
758  [&spmdizationMap](Value operand) {
759  assert(spmdizationMap.contains(operand));
760  return spmdizationMap.lookup(operand);
761  });
762  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
763  getResultShardings(op), spmdizationMap,
764  symbolTableCollection, builder);
765 }
766 
767 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
768  SymbolTableCollection &symbolTableCollection,
769  OpBuilder &builder) {
770  SmallVector<Location> argLocations;
771  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
772  [](BlockArgument arg) { return arg.getLoc(); });
773  Block *newBlock = builder.createBlock(
774  block.getParent(), {},
775  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
776  for (auto [unshardedBlockArg, spmdizedBlockArg] :
777  llvm::zip(block.getArguments(), newBlock->getArguments())) {
778  spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
779  }
780 
781  OpBuilder::InsertionGuard insertionGuard(builder);
782  builder.setInsertionPointToEnd(newBlock);
783  for (Operation &op : block.getOperations()) {
784  if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
785  builder))) {
786  return failure();
787  }
788  }
789 
790  return success();
791 }
792 
793 static LogicalResult
794 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
795  SymbolTableCollection &symbolTableCollection) {
796  OpBuilder builder(op.getFunctionBody());
797 
798  // Snapshot the original blocks to not mess up the iteration when adding new
799  // blocks.
800  SmallVector<Block *> originalBlocks;
801  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
802  [](Block &b) { return &b; });
803 
804  for (Block *block : originalBlocks) {
805  if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
806  builder))) {
807  return failure();
808  }
809  }
810 
811  for (Block *block : originalBlocks) {
812  block->erase();
813  }
814 
815  // Find a return op and change the function results signature to its operands
816  // signature.
817  Operation *returnOp = nullptr;
818  for (Block &block : op.getFunctionBody()) {
819  if (block.empty()) {
820  continue;
821  }
822 
823  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
824  returnOp = &block.back();
825  break;
826  }
827  }
828  assert(returnOp);
829  op.setType(FunctionType::get(op->getContext(),
830  op.getFunctionBody().front().getArgumentTypes(),
831  returnOp->getOperandTypes()));
832 
833  return success();
834 }
835 
836 namespace {
837 
838 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
839  void runOnOperation() override {
840  IRMapping spmdizationMap;
841  SymbolTableCollection symbolTableCollection;
842  if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
843  symbolTableCollection))) {
844  return signalPassFailure();
845  }
846  }
847 
848  void getDependentDialects(DialectRegistry &registry) const override {
850  registry.insert<mesh::MeshDialect>();
851  }
852 };
853 
854 } // namespace
855 
856 } // namespace mlir::mesh
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:341
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:623
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:678
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:722
mesh::MeshSharding MeshSharding
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
static std::tuple< TypedValue< ShapedType >, MeshSharding > handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Spmdization.cpp:58
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::vector< MeshSharding > getResultShardings(Operation &op)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Spmdization.cpp:44
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:182
static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
static std::tuple< TypedValue< ShapedType >, MeshSharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:254
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:173
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::tuple< TypedValue< ShapedType >, MeshSharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static std::vector< MeshSharding > getOperandShardings(Operation &op)
int16_t MeshAxis
Definition: MeshOps.h:26
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshSharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".