MLIR  19.0.0git
Spmdization.cpp
Go to the documentation of this file.
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/Support/Casting.h"
37 #include <iterator>
38 #include <optional>
39 #include <tuple>
40 #include <type_traits>
41 
42 namespace mlir::mesh {
43 
44 template <typename SourceAxes, typename TargetAxes>
45 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
46  const TargetAxes &targetAxes) {
47  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
48  return sourceAxes.contains(targetAxis);
49  });
50 }
51 
52 // Return the reduced value and its corresponding sharding.
53 // Example:
54 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
55 // targetSharding = <@mesh_1d, [[]]>
56 // Then will apply all-reduce on the source value
57 // and return it with the sharding <@mesh_1d, [[0]]>.
58 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
60  MeshShardingAttr sourceSharding,
61  MeshShardingAttr targetSharding,
62  TypedValue<ShapedType> sourceShard) {
63  if (sourceSharding.getPartialAxes().empty() &&
64  targetSharding.getPartialAxes().empty()) {
65  return {sourceShard, sourceSharding};
66  }
67  assert(targetSharding.getPartialAxes().empty() ||
68  (!sourceSharding.getPartialAxes().empty() &&
69  sourceSharding.getPartialType() == targetSharding.getPartialType()));
70  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
71  using AxisSet = llvm::SmallDenseSet<Axis>;
72  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
73  sourceSharding.getPartialAxes().end());
74  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
75  targetSharding.getPartialAxes().end());
76  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
77  targetShardingPartialAxesSet));
78  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
79  llvm::copy_if(sourceShardingPartialAxesSet,
80  std::back_inserter(allReduceMeshAxes),
81  [&targetShardingPartialAxesSet](Axis a) {
82  return !targetShardingPartialAxesSet.contains(a);
83  });
84  if (allReduceMeshAxes.empty()) {
85  return {sourceShard, sourceSharding};
86  }
87 
88  builder.setInsertionPointAfterValue(sourceShard);
89  TypedValue<ShapedType> resultValue =
90  builder
91  .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
92  sourceSharding.getMesh().getLeafReference(),
93  allReduceMeshAxes, sourceShard,
94  sourceSharding.getPartialType())
95  .getResult()
96  .cast<TypedValue<ShapedType>>();
97 
98  llvm::SmallVector<MeshAxis> remainingPartialAxes;
99  llvm::copy_if(sourceShardingPartialAxesSet,
100  std::back_inserter(allReduceMeshAxes),
101  [&targetShardingPartialAxesSet](Axis a) {
102  return targetShardingPartialAxesSet.contains(a);
103  });
104  MeshShardingAttr resultSharding =
105  MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
106  sourceSharding.getSplitAxes(), remainingPartialAxes,
107  sourceSharding.getPartialType());
108  return {resultValue, resultSharding};
109 }
110 
111 static MeshShardingAttr
113  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
114  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
115  llvm::to_vector(sourceSharding.getSplitAxes());
116  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
117  splitTensorAxis) {
118  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
119  }
120  auto targetSplitAxes =
121  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
122  targetSplitAxes.push_back(splitMeshAxis);
123  targetShardingSplitAxes[splitTensorAxis] =
124  MeshAxesAttr::get(ctx, targetSplitAxes);
125  return MeshShardingAttr::get(
126  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
127  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
128 }
129 
130 // Split a replicated tensor along a mesh axis.
131 // e.g. [[0, 1]] -> [[0, 1, 2]].
132 // Returns the spmdized target value with its sharding.
133 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
135  MeshShardingAttr sourceSharding,
136  TypedValue<ShapedType> sourceShard, MeshOp mesh,
137  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
138  TypedValue<ShapedType> targetShard =
139  builder
140  .create<AllSliceOp>(sourceShard, mesh,
141  ArrayRef<MeshAxis>(splitMeshAxis),
142  splitTensorAxis)
143  .getResult()
144  .cast<TypedValue<ShapedType>>();
146  builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
147  return {targetShard, targetSharding};
148 }
149 
150 // Detect if the resharding is of type e.g.
151 // [[0, 1]] -> [[0, 1, 2]].
152 // If detected, returns the corresponding tensor axis mesh axis pair.
153 // Does not detect insertions like
154 // [[0, 1]] -> [[0, 2, 1]].
155 static std::optional<std::tuple<int64_t, MeshAxis>>
157  MeshShardingAttr targetSharding) {
158  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
159  ++tensorAxis) {
160  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
161  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
162  targetSharding.getSplitAxes()[tensorAxis].size()) {
163  continue;
164  }
165  if (!llvm::equal(
166  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
167  llvm::make_range(
168  targetSharding.getSplitAxes()[tensorAxis]
169  .asArrayRef()
170  .begin(),
171  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
172  1))) {
173  continue;
174  }
175  } else {
176  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
177  continue;
178  }
179  }
180  return std::make_tuple(
181  tensorAxis,
182  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
183  }
184  return std::nullopt;
185 }
186 
187 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
189  MeshShardingAttr sourceSharding,
190  MeshShardingAttr targetSharding,
191  TypedValue<ShapedType> sourceShard) {
192  if (auto detectRes =
193  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
194  auto [tensorAxis, meshAxis] = detectRes.value();
195  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
196  tensorAxis, meshAxis);
197  }
198 
199  return std::nullopt;
200 }
201 
202 // Detect if the resharding is of type e.g.
203 // [[0, 1, 2]] -> [[0, 1]].
204 // If detected, returns the corresponding tensor axis mesh axis pair.
205 static std::optional<std::tuple<int64_t, MeshAxis>>
207  MeshShardingAttr targetSharding) {
208  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
209  ++tensorAxis) {
210  if (targetSharding.getSplitAxes().size() > tensorAxis) {
211  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
212  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
213  continue;
214  if (!llvm::equal(
215  llvm::make_range(
216  sourceSharding.getSplitAxes()[tensorAxis]
217  .asArrayRef()
218  .begin(),
219  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
220  1),
221  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
222  continue;
223  } else {
224  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
225  continue;
226  }
227  return std::make_tuple(
228  tensorAxis,
229  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
230  }
231  return std::nullopt;
232 }
233 
234 static MeshShardingAttr
236  MeshShardingAttr sourceSharding,
237  int64_t splitTensorAxis) {
238  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
239  llvm::to_vector(sourceSharding.getSplitAxes());
240  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
241  splitTensorAxis);
242  auto targetSplitAxes =
243  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
244 
245  targetSplitAxes.pop_back();
246  targetShardingSplitAxes[splitTensorAxis] =
247  MeshAxesAttr::get(ctx, targetSplitAxes);
248  return MeshShardingAttr::get(
249  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
250  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
251 }
252 
254  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
255  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
256  targetShape[splitTensorAxis] =
257  gatherDimension(targetShape[splitTensorAxis], splitCount);
258  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
259 }
260 
261 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
263  MeshShardingAttr sourceSharding,
264  ShapedType sourceUnshardedShape,
265  TypedValue<ShapedType> sourceShard, MeshOp mesh,
266  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
267  MLIRContext *ctx = builder.getContext();
268  builder.setInsertionPointAfterValue(sourceShard);
269 
270  MeshShardingAttr targetSharding =
271  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
272  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
273  sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
274  Value allGatherResult = builder.create<AllGatherOp>(
275  RankedTensorType::get(allGatherResultShape.getShape(),
276  allGatherResultShape.getElementType()),
277  mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
278  APInt(64, splitTensorAxis));
279  ShapedType targetShape =
280  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
281  TypedValue<ShapedType> targetShard =
282  builder.create<tensor::CastOp>(targetShape, allGatherResult)
283  .getResult()
284  .cast<TypedValue<ShapedType>>();
285  return {targetShard, targetSharding};
286 }
287 
288 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
290  MeshShardingAttr sourceSharding,
291  MeshShardingAttr targetSharding,
292  ShapedType sourceUnshardedShape,
293  TypedValue<ShapedType> sourceShard) {
294  if (auto detectRes =
295  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
296  auto [tensorAxis, meshAxis] = detectRes.value();
297  return unsplitLastAxisInResharding(builder, sourceSharding,
298  sourceUnshardedShape, sourceShard, mesh,
299  tensorAxis, meshAxis);
300  }
301 
302  return std::nullopt;
303 }
304 
305 // Detect if the resharding is of type e.g.
306 // [[0, 1], [2]] -> [[0], [1, 2]].
307 // Only moving the last axis counts.
308 // If detected, returns the corresponding (source_tensor_axis,
309 // target_tensor_axis, mesh_axis) tuple.
310 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
312  MeshShardingAttr targetSharding) {
313  for (size_t sourceTensorAxis = 0;
314  sourceTensorAxis < sourceSharding.getSplitAxes().size();
315  ++sourceTensorAxis) {
316  for (size_t targetTensorAxis = 0;
317  targetTensorAxis < targetSharding.getSplitAxes().size();
318  ++targetTensorAxis) {
319  if (sourceTensorAxis == targetTensorAxis)
320  continue;
321  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
322  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
323  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
324  targetSharding.getSplitAxes()[targetTensorAxis]
325  .asArrayRef()
326  .back())
327  continue;
328  if (!llvm::equal(
329  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
330  .asArrayRef()
331  .begin(),
332  sourceSharding.getSplitAxes()[sourceTensorAxis]
333  .asArrayRef()
334  .end() -
335  1),
336  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
337  .asArrayRef()
338  .begin(),
339  targetSharding.getSplitAxes()[targetTensorAxis]
340  .asArrayRef()
341  .end() -
342  1)))
343  continue;
344  return std::make_tuple(
345  sourceTensorAxis, targetTensorAxis,
346  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
347  }
348  }
349  return std::nullopt;
350 }
351 
352 static MeshShardingAttr
354  int64_t sourceTensorAxis,
355  int64_t targetTensorAxis) {
356  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
357  llvm::to_vector(sourceSharding.getSplitAxes());
358  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
359  targetTensorAxis) {
360  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
361  }
362 
363  auto sourceSplitAxes =
364  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
365  assert(!sourceSplitAxes.empty());
366  auto meshAxis = sourceSplitAxes.back();
367  sourceSplitAxes.pop_back();
368  targetShardingSplitAxes[sourceTensorAxis] =
369  MeshAxesAttr::get(ctx, sourceSplitAxes);
370 
371  auto targetSplitAxes =
372  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
373  targetSplitAxes.push_back(meshAxis);
374  targetShardingSplitAxes[targetTensorAxis] =
375  MeshAxesAttr::get(ctx, targetSplitAxes);
376 
377  return MeshShardingAttr::get(
378  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
379  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
380 }
381 
382 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
383  int64_t splitCount,
384  int64_t sourceTensorAxis,
385  int64_t targetTensorAxis) {
386  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
387  targetShape[sourceTensorAxis] =
388  gatherDimension(targetShape[sourceTensorAxis], splitCount);
389  targetShape[targetTensorAxis] =
390  shardDimension(targetShape[targetTensorAxis], splitCount);
391  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
392 }
393 
394 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
396  MeshShardingAttr sourceSharding,
397  ShapedType sourceUnshardedShape,
398  TypedValue<ShapedType> sourceShard,
399  int64_t sourceTensorAxis,
400  int64_t targetTensorAxis, MeshAxis meshAxis) {
401  MLIRContext *ctx = builder.getContext();
402  builder.setInsertionPointAfterValue(sourceShard);
403 
405  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
406  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
407  sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
408  targetTensorAxis);
409  Value allToAllResult = builder.create<AllToAllOp>(
410  RankedTensorType::get(allToAllResultShape.getShape(),
411  allToAllResultShape.getElementType()),
412  mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
413  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
414  ShapedType targetShape =
415  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
416  TypedValue<ShapedType> targetShard =
417  builder.create<tensor::CastOp>(targetShape, allToAllResult)
418  .getResult()
419  .cast<TypedValue<ShapedType>>();
420  return {targetShard, targetSharding};
421 }
422 
423 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
425  MeshShardingAttr sourceSharding,
426  MeshShardingAttr targetSharding,
427  ShapedType sourceUnshardedShape,
428  TypedValue<ShapedType> sourceShard) {
429  if (auto detectRes =
430  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
431  auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
433  builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
434  sourceTensorAxis, targetTensorAxis, meshAxis);
435  }
436 
437  return std::nullopt;
438 }
439 
440 // Handles only resharding on a 1D mesh.
441 // Currently the sharded tensor axes must be exactly divisible by the single
442 // mesh axis size.
445  MeshShardingAttr sourceSharding,
446  MeshShardingAttr targetSharding,
447  TypedValue<ShapedType> sourceUnshardedValue,
448  TypedValue<ShapedType> sourceShard) {
449  assert(sourceShard.getType() ==
450  shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
451  [[maybe_unused]] ShapedType targetShardType =
452  shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
453  assert(sourceShard.getType().getRank() == targetShardType.getRank());
454  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
455 
456  auto [reducedSourceShard, reducedSourceSharding] =
457  handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
458  sourceShard);
459 
460  if (reducedSourceSharding == targetSharding) {
461  return reducedSourceShard;
462  }
463 
464  TypedValue<ShapedType> targetShard;
465  MeshShardingAttr actualTargetSharding;
466  if (auto tryRes = tryMoveLastSplitAxisInResharding(
467  builder, mesh, reducedSourceSharding, targetSharding,
468  sourceUnshardedValue.getType(), reducedSourceShard)) {
469  std::tie(targetShard, actualTargetSharding) = tryRes.value();
470  } else if (auto tryRes = trySplitLastAxisInResharding(
471  builder, mesh, reducedSourceSharding, targetSharding,
472  reducedSourceShard)) {
473  std::tie(targetShard, actualTargetSharding) = tryRes.value();
474  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
475  builder, mesh, reducedSourceSharding, targetSharding,
476  sourceUnshardedValue.getType(), reducedSourceShard)) {
477  std::tie(targetShard, actualTargetSharding) = tryRes.value();
478  } else {
479  assert(false && "Did not find any pattern to apply.");
480  }
481 
482  assert(actualTargetSharding == targetSharding);
483  assert(targetShard.getType() == targetShardType);
484  return targetShard;
485 }
486 
488  MeshShardingAttr sourceSharding,
489  MeshShardingAttr targetSharding,
490  TypedValue<ShapedType> sourceUnshardedValue,
491  TypedValue<ShapedType> sourceShard) {
492  // Resort to handling only 1D meshes since the general case is complicated if
493  // it needs to be communication efficient in terms of minimizing the data
494  // transfered between devices.
495  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
496  sourceUnshardedValue, sourceShard);
497 }
498 
499 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
500  ShardOp target,
501  TypedValue<ShapedType> sourceShardValue) {
502  assert(!source.getAnnotateForUsers());
503  assert(target.getAnnotateForUsers());
504  assert(source.getResult() == target.getOperand());
505  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
506  return reshard(
507  implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
508  source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
509 }
510 
511 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
512  ShardOp target,
513  TypedValue<ShapedType> sourceShardValue,
514  SymbolTableCollection &symbolTableCollection) {
515  MeshOp srcMesh = getMesh(source, symbolTableCollection);
516  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
517  return reshard(builder, srcMesh, source, target, sourceShardValue);
518 }
519 
521  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
522 }
523 
524 #define GEN_PASS_DEF_SPMDIZATION
525 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
526 
528 
529 // Get the types of block arguments for an spmdized block.
530 // Reads the sharding annotations of the arguments to deduce the sharded types.
531 // Types that are not ranked tensors are left unchanged.
534  SymbolTableCollection &symbolTableCollection) {
535  SmallVector<Type> res;
536  llvm::transform(block.getArguments(), std::back_inserter(res),
537  [&symbolTableCollection](BlockArgument arg) {
538  auto rankedTensorArg =
539  arg.dyn_cast<TypedValue<RankedTensorType>>();
540  if (!rankedTensorArg) {
541  return arg.getType();
542  }
543 
544  assert(rankedTensorArg.hasOneUse());
545  Operation *useOp = *rankedTensorArg.getUsers().begin();
546  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
547  assert(shardOp);
548  MeshOp mesh = getMesh(shardOp, symbolTableCollection);
549  return shardShapedType(rankedTensorArg.getType(), mesh,
550  shardOp.getShardAttr())
551  .cast<Type>();
552  });
553  return res;
554 }
555 
557  Operation &op, ArrayRef<Value> spmdizedOperands,
558  ArrayRef<MeshShardingAttr> operandShardings,
559  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
560  SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
561  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
562  if (!shardingInterface) {
563  // If there is no sharding interface we are conservative and assume that
564  // the op should be fully replicated no all devices.
565  spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
566  resultShardings, spmdizationMap,
567  symbolTableCollection, builder);
568  } else {
569  if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
570  resultShardings, spmdizationMap,
571  symbolTableCollection, builder))) {
572  return failure();
573  }
574  }
575 
576  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
577  return spmdizationMap.contains(result);
578  }));
579 
580  return success();
581 }
582 
583 // Retrieve the sharding annotations for the operands of the given operation.
584 // If the type is not a ranked tensor it is not require to have an annotation.
587  res.reserve(op.getNumOperands());
588  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
589  TypedValue<RankedTensorType> rankedTensor =
590  operand.dyn_cast<TypedValue<RankedTensorType>>();
591  if (!rankedTensor) {
592  return MeshShardingAttr();
593  }
594 
595  Operation *definingOp = operand.getDefiningOp();
596  assert(definingOp);
597  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
598  return shardOp.getShard();
599  });
600  return res;
601 }
602 
603 // Retrieve the sharding annotations for the results of the given operation.
604 // If the type is not a ranked tensor it is not require to have an annotation.
607  res.reserve(op.getNumResults());
608  llvm::transform(op.getResults(), std::back_inserter(res),
609  [](OpResult result) {
610  TypedValue<RankedTensorType> rankedTensor =
611  result.dyn_cast<TypedValue<RankedTensorType>>();
612  if (!rankedTensor) {
613  return MeshShardingAttr();
614  }
615 
616  assert(result.hasOneUse());
617  Operation *userOp = *result.getUsers().begin();
618  ShardOp shardOp = llvm::cast<ShardOp>(userOp);
619  return shardOp.getShard();
620  });
621  return res;
622 }
623 
624 static LogicalResult
625 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
626  SymbolTableCollection &symbolTableCollection,
627  OpBuilder &builder) {
628  Value targetSpmdValue;
629 
630  // Check if 2 shard ops are chained. If not there is no need for resharding
631  // as the source and target shared the same sharding.
632  ShardOp srcShardOp =
633  dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
634  if (!srcShardOp) {
635  targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
636  } else {
637  // Insert resharding.
638  assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
639  TypedValue<ShapedType> srcSpmdValue =
640  spmdizationMap.lookup(srcShardOp.getOperand())
641  .cast<TypedValue<ShapedType>>();
642  targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
643  symbolTableCollection);
644  }
645 
646  assert(!spmdizationMap.contains(shardOp.getResult()));
647  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
648  return success();
649 }
650 
651 static LogicalResult
652 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
653  SymbolTableCollection &symbolTableCollection,
654  OpBuilder &builder) {
655  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
656  if (shardOp) {
657  return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
658  builder);
659  }
660 
661  SmallVector<Value> spmdizedOperands;
662  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
663  [&spmdizationMap](Value operand) {
664  assert(spmdizationMap.contains(operand));
665  return spmdizationMap.lookup(operand);
666  });
667  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
668  getResultShardings(op), spmdizationMap,
669  symbolTableCollection, builder);
670 }
671 
672 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
673  SymbolTableCollection &symbolTableCollection,
674  OpBuilder &builder) {
675  SmallVector<Location> argLocations;
676  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
677  [](BlockArgument arg) { return arg.getLoc(); });
678  Block *newBlock = builder.createBlock(
679  block.getParent(), {},
680  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
681  for (auto [unshardedBlockArg, spmdizedBlockArg] :
682  llvm::zip(block.getArguments(), newBlock->getArguments())) {
683  spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
684  }
685 
686  OpBuilder::InsertionGuard insertionGuard(builder);
687  builder.setInsertionPointToEnd(newBlock);
688  for (Operation &op : block.getOperations()) {
689  if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
690  builder))) {
691  return failure();
692  }
693  }
694 
695  return success();
696 }
697 
698 static LogicalResult
699 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
700  SymbolTableCollection &symbolTableCollection) {
701  OpBuilder builder(op.getFunctionBody());
702 
703  // Snapshot the original blocks to not mess up the iteration when adding new
704  // blocks.
705  SmallVector<Block *> originalBlocks;
706  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
707  [](Block &b) { return &b; });
708 
709  for (Block *block : originalBlocks) {
710  if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
711  builder))) {
712  return failure();
713  }
714  }
715 
716  for (Block *block : originalBlocks) {
717  block->erase();
718  }
719 
720  // Find a return op and change the function results signature to its operands
721  // signature.
722  Operation *returnOp = nullptr;
723  for (Block &block : op.getFunctionBody()) {
724  if (block.empty()) {
725  continue;
726  }
727 
728  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
729  returnOp = &block.back();
730  break;
731  }
732  }
733  assert(returnOp);
734  op.setType(FunctionType::get(op->getContext(),
735  op.getFunctionBody().front().getArgumentTypes(),
736  returnOp->getOperandTypes()));
737 
738  return success();
739 }
740 
741 namespace {
742 
743 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
744  void runOnOperation() override {
745  IRMapping spmdizationMap;
746  SymbolTableCollection symbolTableCollection;
747  if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
748  symbolTableCollection))) {
749  return signalPassFailure();
750  }
751  }
752 
753  void getDependentDialects(DialectRegistry &registry) const override {
755  registry.insert<mesh::MeshDialect>();
756  }
757 };
758 
759 } // namespace
760 
761 } // namespace mlir::mesh
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:224
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
mesh::MeshShardingAttr MeshShardingAttr
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > handlePartialAxesDuringResharding(OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Spmdization.cpp:59
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Spmdization.cpp:45
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:110
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
static SmallVector< MeshShardingAttr > getOperandShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:101
static MeshShardingAttr targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis)
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static SmallVector< MeshShardingAttr > getResultShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:162
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static MeshShardingAttr targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static MeshShardingAttr targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This trait indicates that a terminator operation is "return-like".