MLIR  20.0.0git
Spmdization.cpp
Go to the documentation of this file.
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
36 #include <iterator>
37 #include <optional>
38 #include <tuple>
39 #include <type_traits>
40 
41 namespace mlir::mesh {
42 
43 template <typename SourceAxes, typename TargetAxes>
44 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
45  const TargetAxes &targetAxes) {
46  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
47  return sourceAxes.contains(targetAxis);
48  });
49 }
50 
51 // Return the reduced value and its corresponding sharding.
52 // Example:
53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
54 // targetSharding = <@mesh_1d, [[]]>
55 // Then will apply all-reduce on the source value
56 // and return it with the sharding <@mesh_1d, [[0]]>.
57 static std::tuple<TypedValue<ShapedType>, MeshSharding>
59  MeshSharding sourceSharding,
60  MeshSharding targetSharding,
61  TypedValue<ShapedType> sourceShard) {
62  if (sourceSharding.getPartialAxes().empty() &&
63  targetSharding.getPartialAxes().empty()) {
64  return {sourceShard, sourceSharding};
65  }
66  assert(targetSharding.getPartialAxes().empty() ||
67  (!sourceSharding.getPartialAxes().empty() &&
68  sourceSharding.getPartialType() == targetSharding.getPartialType()));
69  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70  using AxisSet = llvm::SmallDenseSet<Axis>;
71  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72  sourceSharding.getPartialAxes().end());
73  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74  targetSharding.getPartialAxes().end());
75  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
76  targetShardingPartialAxesSet));
77  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
78  llvm::copy_if(sourceShardingPartialAxesSet,
79  std::back_inserter(allReduceMeshAxes),
80  [&targetShardingPartialAxesSet](Axis a) {
81  return !targetShardingPartialAxesSet.contains(a);
82  });
83  if (allReduceMeshAxes.empty()) {
84  return {sourceShard, sourceSharding};
85  }
86 
87  builder.setInsertionPointAfterValue(sourceShard);
88  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
89  builder
90  .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91  sourceSharding.getMeshAttr().getLeafReference(),
92  allReduceMeshAxes, sourceShard,
93  sourceSharding.getPartialType())
94  .getResult());
95 
96  llvm::SmallVector<MeshAxis> remainingPartialAxes;
97  llvm::copy_if(sourceShardingPartialAxesSet,
98  std::back_inserter(allReduceMeshAxes),
99  [&targetShardingPartialAxesSet](Axis a) {
100  return targetShardingPartialAxesSet.contains(a);
101  });
102  MeshSharding resultSharding = MeshSharding::get(
103  sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104  remainingPartialAxes, sourceSharding.getPartialType());
105  return {resultValue, resultSharding};
106 }
107 
109  MeshSharding sourceSharding,
110  int64_t splitTensorAxis,
111  MeshAxis splitMeshAxis) {
112  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
113  llvm::to_vector(sourceSharding.getSplitAxes());
114  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
115  splitTensorAxis) {
116  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
117  }
118  auto targetSplitAxes =
119  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120  targetSplitAxes.push_back(splitMeshAxis);
121  targetShardingSplitAxes[splitTensorAxis] =
122  MeshAxesAttr::get(ctx, targetSplitAxes);
123  return MeshSharding::get(
124  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
125  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
126 }
127 
128 // Split a replicated tensor along a mesh axis.
129 // E.g. [[0, 1]] -> [[0, 1, 2]].
130 // Returns the spmdized target value with its sharding.
131 static std::tuple<TypedValue<ShapedType>, MeshSharding>
133  MeshSharding sourceSharding,
134  TypedValue<ShapedType> sourceShard, MeshOp mesh,
135  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
137  builder
138  .create<AllSliceOp>(sourceShard, mesh,
139  ArrayRef<MeshAxis>(splitMeshAxis),
140  splitTensorAxis)
141  .getResult());
143  builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144  return {targetShard, targetSharding};
145 }
146 
147 // Detect if the resharding is of type e.g.
148 // [[0, 1]] -> [[0, 1, 2]].
149 // If detected, returns the corresponding tensor axis mesh axis pair.
150 // Does not detect insertions like
151 // [[0, 1]] -> [[0, 2, 1]].
152 static std::optional<std::tuple<int64_t, MeshAxis>>
154  MeshSharding targetSharding) {
155  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156  ++tensorAxis) {
157  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159  targetSharding.getSplitAxes()[tensorAxis].size()) {
160  continue;
161  }
162  if (!llvm::equal(
163  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
164  llvm::make_range(
165  targetSharding.getSplitAxes()[tensorAxis]
166  .asArrayRef()
167  .begin(),
168  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
169  1))) {
170  continue;
171  }
172  } else {
173  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
174  continue;
175  }
176  }
177  return std::make_tuple(
178  tensorAxis,
179  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
180  }
181  return std::nullopt;
182 }
183 
184 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
186  MeshSharding sourceSharding,
187  MeshSharding targetSharding,
188  TypedValue<ShapedType> sourceShard) {
189  if (auto detectRes =
190  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
191  auto [tensorAxis, meshAxis] = detectRes.value();
192  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
193  tensorAxis, meshAxis);
194  }
195 
196  return std::nullopt;
197 }
198 
199 // Detect if the resharding is of type e.g.
200 // [[0, 1, 2]] -> [[0, 1]].
201 // If detected, returns the corresponding tensor axis mesh axis pair.
202 static std::optional<std::tuple<int64_t, MeshAxis>>
204  MeshSharding targetSharding) {
205  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206  ++tensorAxis) {
207  if (targetSharding.getSplitAxes().size() > tensorAxis) {
208  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
210  continue;
211  if (!llvm::equal(
212  llvm::make_range(
213  sourceSharding.getSplitAxes()[tensorAxis]
214  .asArrayRef()
215  .begin(),
216  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
217  1),
218  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
219  continue;
220  } else {
221  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
222  continue;
223  }
224  return std::make_tuple(
225  tensorAxis,
226  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
227  }
228  return std::nullopt;
229 }
230 
232  MeshSharding sourceSharding,
233  int64_t splitTensorAxis) {
234  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
235  llvm::to_vector(sourceSharding.getSplitAxes());
236  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
237  splitTensorAxis);
238  auto targetSplitAxes =
239  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
240 
241  targetSplitAxes.pop_back();
242  targetShardingSplitAxes[splitTensorAxis] =
243  MeshAxesAttr::get(ctx, targetSplitAxes);
244  return MeshSharding::get(
245  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
246  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
247 }
248 
250  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
251  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
252  targetShape[splitTensorAxis] =
253  gatherDimension(targetShape[splitTensorAxis], splitCount);
254  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
255 }
256 
257 static std::tuple<TypedValue<ShapedType>, MeshSharding>
259  MeshSharding sourceSharding,
260  ShapedType sourceUnshardedShape,
261  TypedValue<ShapedType> sourceShard, MeshOp mesh,
262  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
263  MLIRContext *ctx = builder.getContext();
264  builder.setInsertionPointAfterValue(sourceShard);
265 
266  MeshSharding targetSharding =
267  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
268  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
269  sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270  Value allGatherResult = builder.create<AllGatherOp>(
271  RankedTensorType::get(allGatherResultShape.getShape(),
272  allGatherResultShape.getElementType()),
273  mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
274  APInt(64, splitTensorAxis));
275  ShapedType targetShape =
276  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
277  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
278  builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279  return {targetShard, targetSharding};
280 }
281 
282 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
284  MeshSharding sourceSharding,
285  MeshSharding targetSharding,
286  ShapedType sourceUnshardedShape,
287  TypedValue<ShapedType> sourceShard) {
288  if (auto detectRes =
289  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
290  auto [tensorAxis, meshAxis] = detectRes.value();
291  return unsplitLastAxisInResharding(builder, sourceSharding,
292  sourceUnshardedShape, sourceShard, mesh,
293  tensorAxis, meshAxis);
294  }
295 
296  return std::nullopt;
297 }
298 
299 // Detect if the resharding is of type e.g.
300 // [[0, 1], [2]] -> [[0], [1, 2]].
301 // Only moving the last axis counts.
302 // If detected, returns the corresponding (source_tensor_axis,
303 // target_tensor_axis, mesh_axis) tuple.
304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
306  MeshSharding targetSharding) {
307  for (size_t sourceTensorAxis = 0;
308  sourceTensorAxis < sourceSharding.getSplitAxes().size();
309  ++sourceTensorAxis) {
310  for (size_t targetTensorAxis = 0;
311  targetTensorAxis < targetSharding.getSplitAxes().size();
312  ++targetTensorAxis) {
313  if (sourceTensorAxis == targetTensorAxis)
314  continue;
315  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
316  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
317  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
318  targetSharding.getSplitAxes()[targetTensorAxis]
319  .asArrayRef()
320  .back())
321  continue;
322  if (!llvm::equal(
323  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
324  .asArrayRef()
325  .begin(),
326  sourceSharding.getSplitAxes()[sourceTensorAxis]
327  .asArrayRef()
328  .end() -
329  1),
330  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
331  .asArrayRef()
332  .begin(),
333  targetSharding.getSplitAxes()[targetTensorAxis]
334  .asArrayRef()
335  .end() -
336  1)))
337  continue;
338  return std::make_tuple(
339  sourceTensorAxis, targetTensorAxis,
340  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
341  }
342  }
343  return std::nullopt;
344 }
345 
347  MeshSharding sourceSharding,
348  int64_t sourceTensorAxis,
349  int64_t targetTensorAxis) {
350  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
351  llvm::to_vector(sourceSharding.getSplitAxes());
352  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
353  targetTensorAxis) {
354  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
355  }
356 
357  auto sourceSplitAxes =
358  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359  assert(!sourceSplitAxes.empty());
360  auto meshAxis = sourceSplitAxes.back();
361  sourceSplitAxes.pop_back();
362  targetShardingSplitAxes[sourceTensorAxis] =
363  MeshAxesAttr::get(ctx, sourceSplitAxes);
364 
365  auto targetSplitAxes =
366  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367  targetSplitAxes.push_back(meshAxis);
368  targetShardingSplitAxes[targetTensorAxis] =
369  MeshAxesAttr::get(ctx, targetSplitAxes);
370 
371  return MeshSharding::get(
372  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
373  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
374 }
375 
376 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
377  int64_t splitCount,
378  int64_t sourceTensorAxis,
379  int64_t targetTensorAxis) {
380  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
381  targetShape[sourceTensorAxis] =
382  gatherDimension(targetShape[sourceTensorAxis], splitCount);
383  targetShape[targetTensorAxis] =
384  shardDimension(targetShape[targetTensorAxis], splitCount);
385  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
386 }
387 
388 static std::tuple<TypedValue<ShapedType>, MeshSharding>
390  MeshSharding sourceSharding,
391  ShapedType sourceUnshardedShape,
392  TypedValue<ShapedType> sourceShard,
393  int64_t sourceTensorAxis,
394  int64_t targetTensorAxis, MeshAxis meshAxis) {
395  MLIRContext *ctx = builder.getContext();
396  builder.setInsertionPointAfterValue(sourceShard);
397 
399  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
400  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
401  sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
402  targetTensorAxis);
403  Value allToAllResult = builder.create<AllToAllOp>(
404  RankedTensorType::get(allToAllResultShape.getShape(),
405  allToAllResultShape.getElementType()),
406  mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
407  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408  ShapedType targetShape =
409  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
410  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
411  builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412  return {targetShard, targetSharding};
413 }
414 
415 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
417  MeshSharding sourceSharding,
418  MeshSharding targetSharding,
419  ShapedType sourceUnshardedShape,
420  TypedValue<ShapedType> sourceShard) {
421  if (auto detectRes =
422  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
423  auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
425  builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426  sourceTensorAxis, targetTensorAxis, meshAxis);
427  }
428 
429  return std::nullopt;
430 }
431 
432 // Detect a change in the halo size (only) and create necessary operations if
433 // needed. A changed halo sizes requires copying the "core" of the source tensor
434 // into the "core" of the destination tensor followed by an update halo
435 // operation.
436 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
438  MeshSharding sourceSharding,
439  MeshSharding targetSharding,
440  ShapedType sourceUnshardedShape,
441  TypedValue<ShapedType> sourceShard) {
442  // Currently handles only cases where halo sizes differ but everything else
443  // stays the same (from source to destination sharding).
444  if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
445  !sourceSharding.getPartialAxes().empty() ||
446  !targetSharding.getPartialAxes().empty() ||
447  !sourceSharding.getStaticShardedDimsOffsets().empty() ||
448  !targetSharding.getStaticShardedDimsOffsets().empty() ||
449  sourceSharding.equalHaloSizes(targetSharding)) {
450  return std::nullopt;
451  }
452 
453  auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
454  auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
455  assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456  assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457  !ShapedType::isDynamicShape(tgtHaloSizes) &&
458  sourceShard.getType().hasStaticShape()) &&
459  "dynamic shapes/halos are not supported yet for mesh-spmdization");
460  auto rank = sourceShard.getType().getRank();
461  auto splitAxes = sourceSharding.getSplitAxes();
462  SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
463  strides(rank, 1), outShape(sourceShard.getType().getShape()),
464  coreShape(sourceShard.getType().getShape());
465 
466  // Determine "core" of source and destination.
467  // The core is the local part of the shard excluding halo regions.
468  for (auto i = 0u; i < rank; ++i) {
469  if (i < splitAxes.size() && !splitAxes[i].empty()) {
470  if (!srcHaloSizes.empty()) {
471  coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472  srcCoreOffs[i] = srcHaloSizes[i * 2];
473  }
474  tgtCoreOffs[i] = tgtHaloSizes[i * 2];
475  outShape[i] =
476  coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
477  }
478  }
479 
480  // Extract core from source and copy into destination core.
481  auto noVals = ValueRange{};
482  auto initVal = builder.create<tensor::EmptyOp>(
483  sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484  auto core = builder.create<tensor::ExtractSliceOp>(
485  sourceShard.getLoc(),
486  RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
487  sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488  auto initOprnd = builder.create<tensor::InsertSliceOp>(
489  sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
490  coreShape, strides);
491 
492  // Finally update the halo.
493  auto updateHaloResult =
494  builder
495  .create<UpdateHaloOp>(
496  sourceShard.getLoc(),
497  RankedTensorType::get(outShape,
498  sourceShard.getType().getElementType()),
499  initOprnd, mesh.getSymName(),
501  sourceSharding.getSplitAxes()),
502  targetSharding.getDynamicHaloSizes(),
503  targetSharding.getStaticHaloSizes())
504  .getResult();
505  return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
506  targetSharding);
507 }
508 
509 // Handles only resharding on a 1D mesh.
510 // Currently the sharded tensor axes must be exactly divisible by the single
511 // mesh axis size.
514  MeshSharding sourceSharding, MeshSharding targetSharding,
515  TypedValue<ShapedType> sourceUnshardedValue,
516  TypedValue<ShapedType> sourceShard) {
517  assert(sourceShard.getType() ==
518  shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
519  [[maybe_unused]] ShapedType targetShardType =
520  shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
521  assert(sourceShard.getType().getRank() == targetShardType.getRank());
522  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
523 
524  auto [reducedSourceShard, reducedSourceSharding] =
525  handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
526  sourceShard);
527 
528  if (reducedSourceSharding == targetSharding) {
529  return reducedSourceShard;
530  }
531 
532  TypedValue<ShapedType> targetShard;
533  MeshSharding actualTargetSharding;
534  if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
535  targetSharding.getStaticShardedDimsOffsets().empty() &&
536  reducedSourceSharding.getStaticHaloSizes().empty() &&
537  targetSharding.getStaticHaloSizes().empty()) {
538  if (auto tryRes = tryMoveLastSplitAxisInResharding(
539  builder, mesh, reducedSourceSharding, targetSharding,
540  sourceUnshardedValue.getType(), reducedSourceShard)) {
541  std::tie(targetShard, actualTargetSharding) = tryRes.value();
542  } else if (auto tryRes = trySplitLastAxisInResharding(
543  builder, mesh, reducedSourceSharding, targetSharding,
544  reducedSourceShard)) {
545  std::tie(targetShard, actualTargetSharding) = tryRes.value();
546  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
547  builder, mesh, reducedSourceSharding, targetSharding,
548  sourceUnshardedValue.getType(), reducedSourceShard)) {
549  std::tie(targetShard, actualTargetSharding) = tryRes.value();
550  }
551  }
552  assert(targetShard && "Did not find any pattern to apply.");
553  assert(actualTargetSharding == targetSharding);
554  assert(targetShard.getType() == targetShardType);
555  return targetShard;
556 }
557 
559  MeshSharding sourceSharding,
560  MeshSharding targetSharding,
561  TypedValue<ShapedType> sourceUnshardedValue,
562  TypedValue<ShapedType> sourceShard) {
563  // If source and destination sharding are the same, no need to do anything.
564  if (sourceSharding == targetSharding) {
565  return sourceShard;
566  }
567 
568  // Tries to handle the case where the resharding is needed because the halo
569  // sizes are different. Supports arbitrary mesh dimensionality.
570  if (auto tryRes = tryUpdateHaloInResharding(
571  builder, mesh, sourceSharding, targetSharding,
572  sourceUnshardedValue.getType(), sourceShard)) {
573  return std::get<0>(tryRes.value()); // targetShard
574  }
575 
576  // Resort to handling only 1D meshes since the general case is complicated if
577  // it needs to be communication efficient in terms of minimizing the data
578  // transfered between devices.
579  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
580  sourceUnshardedValue, sourceShard);
581 }
582 
583 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
584  ShardOp target,
585  TypedValue<ShapedType> sourceShardValue) {
586  assert(source.getResult() == target.getSrc());
587  auto sourceSharding = source.getSharding();
588  auto targetSharding = target.getSharding();
589  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
590  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
591  cast<TypedValue<ShapedType>>(source.getSrc()),
592  sourceShardValue);
593 }
594 
595 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
596  ShardOp target,
597  TypedValue<ShapedType> sourceShardValue,
598  SymbolTableCollection &symbolTableCollection) {
599  MeshOp srcMesh = getMesh(source, symbolTableCollection);
600  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
601  return reshard(builder, srcMesh, source, target, sourceShardValue);
602 }
603 
605  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
606 }
607 
608 #define GEN_PASS_DEF_SPMDIZATION
609 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
610 
612 
613 // Get the types of block arguments for an spmdized block.
614 // Reads the sharding annotations of the arguments to deduce the sharded types.
615 // Types that are not ranked tensors are left unchanged.
618  SymbolTableCollection &symbolTableCollection) {
619  SmallVector<Type> res;
620  llvm::transform(
621  block.getArguments(), std::back_inserter(res),
622  [&symbolTableCollection](BlockArgument arg) {
623  auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
624  if (!rankedTensorArg) {
625  return arg.getType();
626  }
627 
628  assert(rankedTensorArg.hasOneUse());
629  Operation *useOp = *rankedTensorArg.getUsers().begin();
630  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
631  assert(shardOp);
632  MeshOp mesh = getMesh(shardOp, symbolTableCollection);
633  return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
634  shardOp.getSharding()));
635  });
636  return res;
637 }
638 
639 void spmdizeTriviallyShardableOperation(Operation &op,
640  ArrayRef<Value> spmdizedOperands,
641  ArrayRef<MeshSharding> operandShardings,
642  ArrayRef<MeshSharding> resultShardings,
643  IRMapping &spmdizationMap,
644  SymbolTableCollection &symbolTable,
645  OpBuilder &builder);
646 
647 static LogicalResult spmdizeOperation(
648  Operation &op, ArrayRef<Value> spmdizedOperands,
649  ArrayRef<MeshSharding> operandShardings,
650  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
651  SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
652  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
653  if (!shardingInterface) {
654  // If there is no sharding interface we are conservative and assume that
655  // the op should be fully replicated no all devices.
656  spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
657  resultShardings, spmdizationMap,
658  symbolTableCollection, builder);
659  } else {
660  if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
661  resultShardings, spmdizationMap,
662  symbolTableCollection, builder))) {
663  return failure();
664  }
665  }
666 
667  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
668  return spmdizationMap.contains(result);
669  }));
670 
671  return success();
672 }
673 
674 // Retrieve the sharding annotations for the operands of the given operation.
675 // If the type is not a ranked tensor it is not require to have an annotation.
676 static std::vector<MeshSharding> getOperandShardings(Operation &op) {
677  std::vector<MeshSharding> res;
678  res.reserve(op.getNumOperands());
679  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
680  TypedValue<RankedTensorType> rankedTensor =
681  dyn_cast<TypedValue<RankedTensorType>>(operand);
682  if (!rankedTensor) {
683  return MeshSharding();
684  }
685 
686  Operation *definingOp = operand.getDefiningOp();
687  assert(definingOp);
688  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
689  return MeshSharding(shardOp.getSharding());
690  });
691  return res;
692 }
693 
694 // Retrieve the sharding annotations for the results of the given operation.
695 // If the type is not a ranked tensor it is not require to have an annotation.
696 static std::vector<MeshSharding> getResultShardings(Operation &op) {
697  std::vector<MeshSharding> res;
698  res.reserve(op.getNumResults());
699  llvm::transform(op.getResults(), std::back_inserter(res),
700  [](OpResult result) {
701  TypedValue<RankedTensorType> rankedTensor =
702  dyn_cast<TypedValue<RankedTensorType>>(result);
703  if (!rankedTensor) {
704  return MeshSharding();
705  }
706 
707  assert(result.hasOneUse());
708  Operation *userOp = *result.getUsers().begin();
709  ShardOp shardOp = llvm::cast<ShardOp>(userOp);
710  return MeshSharding(shardOp.getSharding());
711  });
712  return res;
713 }
714 
715 static LogicalResult
716 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
717  SymbolTableCollection &symbolTableCollection,
718  OpBuilder &builder) {
719  Value targetSpmdValue;
720 
721  // Check if 2 shard ops are chained. If not there is no need for resharding
722  // as the source and target shared the same sharding.
723  ShardOp srcShardOp =
724  dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
725  if (!srcShardOp) {
726  targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
727  } else {
728  // Insert resharding.
729  TypedValue<ShapedType> srcSpmdValue =
730  cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
731  targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
732  symbolTableCollection);
733  }
734 
735  assert(!spmdizationMap.contains(shardOp.getResult()));
736  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
737  return success();
738 }
739 
740 static LogicalResult
741 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
742  SymbolTableCollection &symbolTableCollection,
743  OpBuilder &builder) {
744  if (isa<ShardingOp>(op)) {
745  return success();
746  }
747 
748  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
749  if (shardOp) {
750  return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
751  builder);
752  }
753 
754  SmallVector<Value> spmdizedOperands;
755  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
756  [&spmdizationMap](Value operand) {
757  assert(spmdizationMap.contains(operand));
758  return spmdizationMap.lookup(operand);
759  });
760  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
761  getResultShardings(op), spmdizationMap,
762  symbolTableCollection, builder);
763 }
764 
765 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
766  SymbolTableCollection &symbolTableCollection,
767  OpBuilder &builder) {
768  SmallVector<Location> argLocations;
769  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
770  [](BlockArgument arg) { return arg.getLoc(); });
771  Block *newBlock = builder.createBlock(
772  block.getParent(), {},
773  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
774  for (auto [unshardedBlockArg, spmdizedBlockArg] :
775  llvm::zip(block.getArguments(), newBlock->getArguments())) {
776  spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
777  }
778 
779  OpBuilder::InsertionGuard insertionGuard(builder);
780  builder.setInsertionPointToEnd(newBlock);
781  for (Operation &op : block.getOperations()) {
782  if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
783  builder))) {
784  return failure();
785  }
786  }
787 
788  return success();
789 }
790 
791 static LogicalResult
792 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
793  SymbolTableCollection &symbolTableCollection) {
794  OpBuilder builder(op.getFunctionBody());
795 
796  // Snapshot the original blocks to not mess up the iteration when adding new
797  // blocks.
798  SmallVector<Block *> originalBlocks;
799  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
800  [](Block &b) { return &b; });
801 
802  for (Block *block : originalBlocks) {
803  if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
804  builder))) {
805  return failure();
806  }
807  }
808 
809  for (Block *block : originalBlocks) {
810  block->erase();
811  }
812 
813  // Find a return op and change the function results signature to its operands
814  // signature.
815  Operation *returnOp = nullptr;
816  for (Block &block : op.getFunctionBody()) {
817  if (block.empty()) {
818  continue;
819  }
820 
821  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
822  returnOp = &block.back();
823  break;
824  }
825  }
826  assert(returnOp);
827  op.setType(FunctionType::get(op->getContext(),
828  op.getFunctionBody().front().getArgumentTypes(),
829  returnOp->getOperandTypes()));
830 
831  return success();
832 }
833 
834 namespace {
835 
836 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
837  void runOnOperation() override {
838  IRMapping spmdizationMap;
839  SymbolTableCollection symbolTableCollection;
840  if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
841  symbolTableCollection))) {
842  return signalPassFailure();
843  }
844  }
845 
846  void getDependentDialects(DialectRegistry &registry) const override {
848  registry.insert<mesh::MeshDialect>();
849  }
850 };
851 
852 } // namespace
853 
854 } // namespace mlir::mesh
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
MLIRContext * getContext() const
Definition: Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:430
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:346
operand_type_range getOperandTypes()
Definition: Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:623
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:678
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:722
mesh::MeshSharding MeshSharding
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
static std::tuple< TypedValue< ShapedType >, MeshSharding > handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Spmdization.cpp:58
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::vector< MeshSharding > getResultShardings(Operation &op)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Spmdization.cpp:44
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:182
static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
static std::tuple< TypedValue< ShapedType >, MeshSharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:254
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:173
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::tuple< TypedValue< ShapedType >, MeshSharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static std::vector< MeshSharding > getOperandShardings(Operation &op)
int16_t MeshAxis
Definition: MeshOps.h:26
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshSharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".