MLIR  19.0.0git
Spmdization.cpp
Go to the documentation of this file.
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
36 #include <iterator>
37 #include <optional>
38 #include <tuple>
39 #include <type_traits>
40 
41 namespace mlir::mesh {
42 
43 template <typename SourceAxes, typename TargetAxes>
44 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
45  const TargetAxes &targetAxes) {
46  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
47  return sourceAxes.contains(targetAxis);
48  });
49 }
50 
51 // Return the reduced value and its corresponding sharding.
52 // Example:
53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
54 // targetSharding = <@mesh_1d, [[]]>
55 // Then will apply all-reduce on the source value
56 // and return it with the sharding <@mesh_1d, [[0]]>.
57 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
59  MeshShardingAttr sourceSharding,
60  MeshShardingAttr targetSharding,
61  TypedValue<ShapedType> sourceShard) {
62  if (sourceSharding.getPartialAxes().empty() &&
63  targetSharding.getPartialAxes().empty()) {
64  return {sourceShard, sourceSharding};
65  }
66  assert(targetSharding.getPartialAxes().empty() ||
67  (!sourceSharding.getPartialAxes().empty() &&
68  sourceSharding.getPartialType() == targetSharding.getPartialType()));
69  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70  using AxisSet = llvm::SmallDenseSet<Axis>;
71  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72  sourceSharding.getPartialAxes().end());
73  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74  targetSharding.getPartialAxes().end());
75  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
76  targetShardingPartialAxesSet));
77  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
78  llvm::copy_if(sourceShardingPartialAxesSet,
79  std::back_inserter(allReduceMeshAxes),
80  [&targetShardingPartialAxesSet](Axis a) {
81  return !targetShardingPartialAxesSet.contains(a);
82  });
83  if (allReduceMeshAxes.empty()) {
84  return {sourceShard, sourceSharding};
85  }
86 
87  builder.setInsertionPointAfterValue(sourceShard);
88  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
89  builder
90  .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91  sourceSharding.getMesh().getLeafReference(),
92  allReduceMeshAxes, sourceShard,
93  sourceSharding.getPartialType())
94  .getResult());
95 
96  llvm::SmallVector<MeshAxis> remainingPartialAxes;
97  llvm::copy_if(sourceShardingPartialAxesSet,
98  std::back_inserter(allReduceMeshAxes),
99  [&targetShardingPartialAxesSet](Axis a) {
100  return targetShardingPartialAxesSet.contains(a);
101  });
102  MeshShardingAttr resultSharding =
103  MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
104  sourceSharding.getSplitAxes(), remainingPartialAxes,
105  sourceSharding.getPartialType());
106  return {resultValue, resultSharding};
107 }
108 
109 static MeshShardingAttr
111  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
112  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
113  llvm::to_vector(sourceSharding.getSplitAxes());
114  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
115  splitTensorAxis) {
116  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
117  }
118  auto targetSplitAxes =
119  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120  targetSplitAxes.push_back(splitMeshAxis);
121  targetShardingSplitAxes[splitTensorAxis] =
122  MeshAxesAttr::get(ctx, targetSplitAxes);
123  return MeshShardingAttr::get(
124  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
125  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
126 }
127 
128 // Split a replicated tensor along a mesh axis.
129 // e.g. [[0, 1]] -> [[0, 1, 2]].
130 // Returns the spmdized target value with its sharding.
131 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
133  MeshShardingAttr sourceSharding,
134  TypedValue<ShapedType> sourceShard, MeshOp mesh,
135  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
137  builder
138  .create<AllSliceOp>(sourceShard, mesh,
139  ArrayRef<MeshAxis>(splitMeshAxis),
140  splitTensorAxis)
141  .getResult());
143  builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144  return {targetShard, targetSharding};
145 }
146 
147 // Detect if the resharding is of type e.g.
148 // [[0, 1]] -> [[0, 1, 2]].
149 // If detected, returns the corresponding tensor axis mesh axis pair.
150 // Does not detect insertions like
151 // [[0, 1]] -> [[0, 2, 1]].
152 static std::optional<std::tuple<int64_t, MeshAxis>>
154  MeshShardingAttr targetSharding) {
155  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156  ++tensorAxis) {
157  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159  targetSharding.getSplitAxes()[tensorAxis].size()) {
160  continue;
161  }
162  if (!llvm::equal(
163  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
164  llvm::make_range(
165  targetSharding.getSplitAxes()[tensorAxis]
166  .asArrayRef()
167  .begin(),
168  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
169  1))) {
170  continue;
171  }
172  } else {
173  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
174  continue;
175  }
176  }
177  return std::make_tuple(
178  tensorAxis,
179  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
180  }
181  return std::nullopt;
182 }
183 
184 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
186  MeshShardingAttr sourceSharding,
187  MeshShardingAttr targetSharding,
188  TypedValue<ShapedType> sourceShard) {
189  if (auto detectRes =
190  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
191  auto [tensorAxis, meshAxis] = detectRes.value();
192  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
193  tensorAxis, meshAxis);
194  }
195 
196  return std::nullopt;
197 }
198 
199 // Detect if the resharding is of type e.g.
200 // [[0, 1, 2]] -> [[0, 1]].
201 // If detected, returns the corresponding tensor axis mesh axis pair.
202 static std::optional<std::tuple<int64_t, MeshAxis>>
204  MeshShardingAttr targetSharding) {
205  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206  ++tensorAxis) {
207  if (targetSharding.getSplitAxes().size() > tensorAxis) {
208  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
210  continue;
211  if (!llvm::equal(
212  llvm::make_range(
213  sourceSharding.getSplitAxes()[tensorAxis]
214  .asArrayRef()
215  .begin(),
216  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
217  1),
218  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
219  continue;
220  } else {
221  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
222  continue;
223  }
224  return std::make_tuple(
225  tensorAxis,
226  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
227  }
228  return std::nullopt;
229 }
230 
231 static MeshShardingAttr
233  MeshShardingAttr sourceSharding,
234  int64_t splitTensorAxis) {
235  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
236  llvm::to_vector(sourceSharding.getSplitAxes());
237  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
238  splitTensorAxis);
239  auto targetSplitAxes =
240  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
241 
242  targetSplitAxes.pop_back();
243  targetShardingSplitAxes[splitTensorAxis] =
244  MeshAxesAttr::get(ctx, targetSplitAxes);
245  return MeshShardingAttr::get(
246  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
247  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
248 }
249 
251  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
252  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
253  targetShape[splitTensorAxis] =
254  gatherDimension(targetShape[splitTensorAxis], splitCount);
255  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
256 }
257 
258 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
260  MeshShardingAttr sourceSharding,
261  ShapedType sourceUnshardedShape,
262  TypedValue<ShapedType> sourceShard, MeshOp mesh,
263  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
264  MLIRContext *ctx = builder.getContext();
265  builder.setInsertionPointAfterValue(sourceShard);
266 
267  MeshShardingAttr targetSharding =
268  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
269  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
270  sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
271  Value allGatherResult = builder.create<AllGatherOp>(
272  RankedTensorType::get(allGatherResultShape.getShape(),
273  allGatherResultShape.getElementType()),
274  mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
275  APInt(64, splitTensorAxis));
276  ShapedType targetShape =
277  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
278  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
279  builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
280  return {targetShard, targetSharding};
281 }
282 
283 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
285  MeshShardingAttr sourceSharding,
286  MeshShardingAttr targetSharding,
287  ShapedType sourceUnshardedShape,
288  TypedValue<ShapedType> sourceShard) {
289  if (auto detectRes =
290  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
291  auto [tensorAxis, meshAxis] = detectRes.value();
292  return unsplitLastAxisInResharding(builder, sourceSharding,
293  sourceUnshardedShape, sourceShard, mesh,
294  tensorAxis, meshAxis);
295  }
296 
297  return std::nullopt;
298 }
299 
300 // Detect if the resharding is of type e.g.
301 // [[0, 1], [2]] -> [[0], [1, 2]].
302 // Only moving the last axis counts.
303 // If detected, returns the corresponding (source_tensor_axis,
304 // target_tensor_axis, mesh_axis) tuple.
305 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
307  MeshShardingAttr targetSharding) {
308  for (size_t sourceTensorAxis = 0;
309  sourceTensorAxis < sourceSharding.getSplitAxes().size();
310  ++sourceTensorAxis) {
311  for (size_t targetTensorAxis = 0;
312  targetTensorAxis < targetSharding.getSplitAxes().size();
313  ++targetTensorAxis) {
314  if (sourceTensorAxis == targetTensorAxis)
315  continue;
316  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
317  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
318  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
319  targetSharding.getSplitAxes()[targetTensorAxis]
320  .asArrayRef()
321  .back())
322  continue;
323  if (!llvm::equal(
324  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
325  .asArrayRef()
326  .begin(),
327  sourceSharding.getSplitAxes()[sourceTensorAxis]
328  .asArrayRef()
329  .end() -
330  1),
331  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
332  .asArrayRef()
333  .begin(),
334  targetSharding.getSplitAxes()[targetTensorAxis]
335  .asArrayRef()
336  .end() -
337  1)))
338  continue;
339  return std::make_tuple(
340  sourceTensorAxis, targetTensorAxis,
341  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
342  }
343  }
344  return std::nullopt;
345 }
346 
347 static MeshShardingAttr
349  int64_t sourceTensorAxis,
350  int64_t targetTensorAxis) {
351  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
352  llvm::to_vector(sourceSharding.getSplitAxes());
353  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
354  targetTensorAxis) {
355  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
356  }
357 
358  auto sourceSplitAxes =
359  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
360  assert(!sourceSplitAxes.empty());
361  auto meshAxis = sourceSplitAxes.back();
362  sourceSplitAxes.pop_back();
363  targetShardingSplitAxes[sourceTensorAxis] =
364  MeshAxesAttr::get(ctx, sourceSplitAxes);
365 
366  auto targetSplitAxes =
367  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
368  targetSplitAxes.push_back(meshAxis);
369  targetShardingSplitAxes[targetTensorAxis] =
370  MeshAxesAttr::get(ctx, targetSplitAxes);
371 
372  return MeshShardingAttr::get(
373  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
374  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
375 }
376 
377 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
378  int64_t splitCount,
379  int64_t sourceTensorAxis,
380  int64_t targetTensorAxis) {
381  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
382  targetShape[sourceTensorAxis] =
383  gatherDimension(targetShape[sourceTensorAxis], splitCount);
384  targetShape[targetTensorAxis] =
385  shardDimension(targetShape[targetTensorAxis], splitCount);
386  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
387 }
388 
389 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
391  MeshShardingAttr sourceSharding,
392  ShapedType sourceUnshardedShape,
393  TypedValue<ShapedType> sourceShard,
394  int64_t sourceTensorAxis,
395  int64_t targetTensorAxis, MeshAxis meshAxis) {
396  MLIRContext *ctx = builder.getContext();
397  builder.setInsertionPointAfterValue(sourceShard);
398 
400  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
401  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
402  sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
403  targetTensorAxis);
404  Value allToAllResult = builder.create<AllToAllOp>(
405  RankedTensorType::get(allToAllResultShape.getShape(),
406  allToAllResultShape.getElementType()),
407  mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
408  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
409  ShapedType targetShape =
410  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
411  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
412  builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
413  return {targetShard, targetSharding};
414 }
415 
416 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
418  MeshShardingAttr sourceSharding,
419  MeshShardingAttr targetSharding,
420  ShapedType sourceUnshardedShape,
421  TypedValue<ShapedType> sourceShard) {
422  if (auto detectRes =
423  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
424  auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
426  builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
427  sourceTensorAxis, targetTensorAxis, meshAxis);
428  }
429 
430  return std::nullopt;
431 }
432 
433 // Handles only resharding on a 1D mesh.
434 // Currently the sharded tensor axes must be exactly divisible by the single
435 // mesh axis size.
438  MeshShardingAttr sourceSharding,
439  MeshShardingAttr targetSharding,
440  TypedValue<ShapedType> sourceUnshardedValue,
441  TypedValue<ShapedType> sourceShard) {
442  assert(sourceShard.getType() ==
443  shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
444  [[maybe_unused]] ShapedType targetShardType =
445  shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
446  assert(sourceShard.getType().getRank() == targetShardType.getRank());
447  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
448 
449  auto [reducedSourceShard, reducedSourceSharding] =
450  handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
451  sourceShard);
452 
453  if (reducedSourceSharding == targetSharding) {
454  return reducedSourceShard;
455  }
456 
457  TypedValue<ShapedType> targetShard;
458  MeshShardingAttr actualTargetSharding;
459  if (auto tryRes = tryMoveLastSplitAxisInResharding(
460  builder, mesh, reducedSourceSharding, targetSharding,
461  sourceUnshardedValue.getType(), reducedSourceShard)) {
462  std::tie(targetShard, actualTargetSharding) = tryRes.value();
463  } else if (auto tryRes = trySplitLastAxisInResharding(
464  builder, mesh, reducedSourceSharding, targetSharding,
465  reducedSourceShard)) {
466  std::tie(targetShard, actualTargetSharding) = tryRes.value();
467  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
468  builder, mesh, reducedSourceSharding, targetSharding,
469  sourceUnshardedValue.getType(), reducedSourceShard)) {
470  std::tie(targetShard, actualTargetSharding) = tryRes.value();
471  } else {
472  assert(false && "Did not find any pattern to apply.");
473  }
474 
475  assert(actualTargetSharding == targetSharding);
476  assert(targetShard.getType() == targetShardType);
477  return targetShard;
478 }
479 
481  MeshShardingAttr sourceSharding,
482  MeshShardingAttr targetSharding,
483  TypedValue<ShapedType> sourceUnshardedValue,
484  TypedValue<ShapedType> sourceShard) {
485  // Resort to handling only 1D meshes since the general case is complicated if
486  // it needs to be communication efficient in terms of minimizing the data
487  // transfered between devices.
488  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
489  sourceUnshardedValue, sourceShard);
490 }
491 
492 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
493  ShardOp target,
494  TypedValue<ShapedType> sourceShardValue) {
495  assert(source.getResult() == target.getOperand());
496  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
497  return reshard(
498  implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
499  cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
500 }
501 
502 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
503  ShardOp target,
504  TypedValue<ShapedType> sourceShardValue,
505  SymbolTableCollection &symbolTableCollection) {
506  MeshOp srcMesh = getMesh(source, symbolTableCollection);
507  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
508  return reshard(builder, srcMesh, source, target, sourceShardValue);
509 }
510 
512  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
513 }
514 
515 #define GEN_PASS_DEF_SPMDIZATION
516 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
517 
519 
520 // Get the types of block arguments for an spmdized block.
521 // Reads the sharding annotations of the arguments to deduce the sharded types.
522 // Types that are not ranked tensors are left unchanged.
525  SymbolTableCollection &symbolTableCollection) {
526  SmallVector<Type> res;
527  llvm::transform(
528  block.getArguments(), std::back_inserter(res),
529  [&symbolTableCollection](BlockArgument arg) {
530  auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
531  if (!rankedTensorArg) {
532  return arg.getType();
533  }
534 
535  assert(rankedTensorArg.hasOneUse());
536  Operation *useOp = *rankedTensorArg.getUsers().begin();
537  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
538  assert(shardOp);
539  MeshOp mesh = getMesh(shardOp, symbolTableCollection);
540  return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
541  shardOp.getShardAttr()));
542  });
543  return res;
544 }
545 
546 static LogicalResult spmdizeOperation(
547  Operation &op, ArrayRef<Value> spmdizedOperands,
548  ArrayRef<MeshShardingAttr> operandShardings,
549  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
550  SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
551  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
552  if (!shardingInterface) {
553  // If there is no sharding interface we are conservative and assume that
554  // the op should be fully replicated no all devices.
555  spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
556  resultShardings, spmdizationMap,
557  symbolTableCollection, builder);
558  } else {
559  if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
560  resultShardings, spmdizationMap,
561  symbolTableCollection, builder))) {
562  return failure();
563  }
564  }
565 
566  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
567  return spmdizationMap.contains(result);
568  }));
569 
570  return success();
571 }
572 
573 // Retrieve the sharding annotations for the operands of the given operation.
574 // If the type is not a ranked tensor it is not require to have an annotation.
577  res.reserve(op.getNumOperands());
578  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
579  TypedValue<RankedTensorType> rankedTensor =
580  dyn_cast<TypedValue<RankedTensorType>>(operand);
581  if (!rankedTensor) {
582  return MeshShardingAttr();
583  }
584 
585  Operation *definingOp = operand.getDefiningOp();
586  assert(definingOp);
587  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
588  return shardOp.getShard();
589  });
590  return res;
591 }
592 
593 // Retrieve the sharding annotations for the results of the given operation.
594 // If the type is not a ranked tensor it is not require to have an annotation.
597  res.reserve(op.getNumResults());
598  llvm::transform(op.getResults(), std::back_inserter(res),
599  [](OpResult result) {
600  TypedValue<RankedTensorType> rankedTensor =
601  dyn_cast<TypedValue<RankedTensorType>>(result);
602  if (!rankedTensor) {
603  return MeshShardingAttr();
604  }
605 
606  assert(result.hasOneUse());
607  Operation *userOp = *result.getUsers().begin();
608  ShardOp shardOp = llvm::cast<ShardOp>(userOp);
609  return shardOp.getShard();
610  });
611  return res;
612 }
613 
614 static LogicalResult
615 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
616  SymbolTableCollection &symbolTableCollection,
617  OpBuilder &builder) {
618  Value targetSpmdValue;
619 
620  // Check if 2 shard ops are chained. If not there is no need for resharding
621  // as the source and target shared the same sharding.
622  ShardOp srcShardOp =
623  dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
624  if (!srcShardOp) {
625  targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
626  } else {
627  // Insert resharding.
628  TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
629  spmdizationMap.lookup(srcShardOp.getOperand()));
630  targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
631  symbolTableCollection);
632  }
633 
634  assert(!spmdizationMap.contains(shardOp.getResult()));
635  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
636  return success();
637 }
638 
639 static LogicalResult
640 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
641  SymbolTableCollection &symbolTableCollection,
642  OpBuilder &builder) {
643  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
644  if (shardOp) {
645  return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
646  builder);
647  }
648 
649  SmallVector<Value> spmdizedOperands;
650  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
651  [&spmdizationMap](Value operand) {
652  assert(spmdizationMap.contains(operand));
653  return spmdizationMap.lookup(operand);
654  });
655  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
656  getResultShardings(op), spmdizationMap,
657  symbolTableCollection, builder);
658 }
659 
660 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
661  SymbolTableCollection &symbolTableCollection,
662  OpBuilder &builder) {
663  SmallVector<Location> argLocations;
664  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
665  [](BlockArgument arg) { return arg.getLoc(); });
666  Block *newBlock = builder.createBlock(
667  block.getParent(), {},
668  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
669  for (auto [unshardedBlockArg, spmdizedBlockArg] :
670  llvm::zip(block.getArguments(), newBlock->getArguments())) {
671  spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
672  }
673 
674  OpBuilder::InsertionGuard insertionGuard(builder);
675  builder.setInsertionPointToEnd(newBlock);
676  for (Operation &op : block.getOperations()) {
677  if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
678  builder))) {
679  return failure();
680  }
681  }
682 
683  return success();
684 }
685 
686 static LogicalResult
687 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
688  SymbolTableCollection &symbolTableCollection) {
689  OpBuilder builder(op.getFunctionBody());
690 
691  // Snapshot the original blocks to not mess up the iteration when adding new
692  // blocks.
693  SmallVector<Block *> originalBlocks;
694  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
695  [](Block &b) { return &b; });
696 
697  for (Block *block : originalBlocks) {
698  if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
699  builder))) {
700  return failure();
701  }
702  }
703 
704  for (Block *block : originalBlocks) {
705  block->erase();
706  }
707 
708  // Find a return op and change the function results signature to its operands
709  // signature.
710  Operation *returnOp = nullptr;
711  for (Block &block : op.getFunctionBody()) {
712  if (block.empty()) {
713  continue;
714  }
715 
716  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
717  returnOp = &block.back();
718  break;
719  }
720  }
721  assert(returnOp);
722  op.setType(FunctionType::get(op->getContext(),
723  op.getFunctionBody().front().getArgumentTypes(),
724  returnOp->getOperandTypes()));
725 
726  return success();
727 }
728 
729 namespace {
730 
731 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
732  void runOnOperation() override {
733  IRMapping spmdizationMap;
734  SymbolTableCollection symbolTableCollection;
735  if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
736  symbolTableCollection))) {
737  return signalPassFailure();
738  }
739  }
740 
741  void getDependentDialects(DialectRegistry &registry) const override {
743  registry.insert<mesh::MeshDialect>();
744  }
745 };
746 
747 } // namespace
748 
749 } // namespace mlir::mesh
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
mesh::MeshShardingAttr MeshShardingAttr
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > handlePartialAxesDuringResharding(OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Spmdization.cpp:58
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Spmdization.cpp:44
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:121
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:67
static SmallVector< MeshShardingAttr > getOperandShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:112
static MeshShardingAttr targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis)
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static SmallVector< MeshShardingAttr > getResultShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:163
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static MeshShardingAttr targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static MeshShardingAttr targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".