MLIR  19.0.0git
Spmdization.cpp
Go to the documentation of this file.
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/Support/Casting.h"
37 #include <iterator>
38 #include <optional>
39 #include <tuple>
40 #include <type_traits>
41 
42 namespace mlir::mesh {
43 
44 template <typename SourceAxes, typename TargetAxes>
45 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
46  const TargetAxes &targetAxes) {
47  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
48  return sourceAxes.contains(targetAxis);
49  });
50 }
51 
52 // Return the reduced value and its corresponding sharding.
53 // Example:
54 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
55 // targetSharding = <@mesh_1d, [[]]>
56 // Then will apply all-reduce on the source value
57 // and return it with the sharding <@mesh_1d, [[0]]>.
58 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
60  MeshShardingAttr sourceSharding,
61  MeshShardingAttr targetSharding,
62  TypedValue<ShapedType> sourceShard) {
63  if (sourceSharding.getPartialAxes().empty() &&
64  targetSharding.getPartialAxes().empty()) {
65  return {sourceShard, sourceSharding};
66  }
67  assert(targetSharding.getPartialAxes().empty() ||
68  (!sourceSharding.getPartialAxes().empty() &&
69  sourceSharding.getPartialType() == targetSharding.getPartialType()));
70  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
71  using AxisSet = llvm::SmallDenseSet<Axis>;
72  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
73  sourceSharding.getPartialAxes().end());
74  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
75  targetSharding.getPartialAxes().end());
76  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
77  targetShardingPartialAxesSet));
78  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
79  llvm::copy_if(sourceShardingPartialAxesSet,
80  std::back_inserter(allReduceMeshAxes),
81  [&targetShardingPartialAxesSet](Axis a) {
82  return !targetShardingPartialAxesSet.contains(a);
83  });
84  if (allReduceMeshAxes.empty()) {
85  return {sourceShard, sourceSharding};
86  }
87 
88  builder.setInsertionPointAfterValue(sourceShard);
89  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
90  builder
91  .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
92  sourceSharding.getMesh().getLeafReference(),
93  allReduceMeshAxes, sourceShard,
94  sourceSharding.getPartialType())
95  .getResult());
96 
97  llvm::SmallVector<MeshAxis> remainingPartialAxes;
98  llvm::copy_if(sourceShardingPartialAxesSet,
99  std::back_inserter(allReduceMeshAxes),
100  [&targetShardingPartialAxesSet](Axis a) {
101  return targetShardingPartialAxesSet.contains(a);
102  });
103  MeshShardingAttr resultSharding =
104  MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
105  sourceSharding.getSplitAxes(), remainingPartialAxes,
106  sourceSharding.getPartialType());
107  return {resultValue, resultSharding};
108 }
109 
110 static MeshShardingAttr
112  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
113  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
114  llvm::to_vector(sourceSharding.getSplitAxes());
115  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
116  splitTensorAxis) {
117  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
118  }
119  auto targetSplitAxes =
120  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
121  targetSplitAxes.push_back(splitMeshAxis);
122  targetShardingSplitAxes[splitTensorAxis] =
123  MeshAxesAttr::get(ctx, targetSplitAxes);
124  return MeshShardingAttr::get(
125  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
126  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
127 }
128 
129 // Split a replicated tensor along a mesh axis.
130 // e.g. [[0, 1]] -> [[0, 1, 2]].
131 // Returns the spmdized target value with its sharding.
132 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
134  MeshShardingAttr sourceSharding,
135  TypedValue<ShapedType> sourceShard, MeshOp mesh,
136  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
137  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
138  builder
139  .create<AllSliceOp>(sourceShard, mesh,
140  ArrayRef<MeshAxis>(splitMeshAxis),
141  splitTensorAxis)
142  .getResult());
144  builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
145  return {targetShard, targetSharding};
146 }
147 
148 // Detect if the resharding is of type e.g.
149 // [[0, 1]] -> [[0, 1, 2]].
150 // If detected, returns the corresponding tensor axis mesh axis pair.
151 // Does not detect insertions like
152 // [[0, 1]] -> [[0, 2, 1]].
153 static std::optional<std::tuple<int64_t, MeshAxis>>
155  MeshShardingAttr targetSharding) {
156  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
157  ++tensorAxis) {
158  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
159  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
160  targetSharding.getSplitAxes()[tensorAxis].size()) {
161  continue;
162  }
163  if (!llvm::equal(
164  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
165  llvm::make_range(
166  targetSharding.getSplitAxes()[tensorAxis]
167  .asArrayRef()
168  .begin(),
169  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
170  1))) {
171  continue;
172  }
173  } else {
174  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
175  continue;
176  }
177  }
178  return std::make_tuple(
179  tensorAxis,
180  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
181  }
182  return std::nullopt;
183 }
184 
185 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
187  MeshShardingAttr sourceSharding,
188  MeshShardingAttr targetSharding,
189  TypedValue<ShapedType> sourceShard) {
190  if (auto detectRes =
191  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
192  auto [tensorAxis, meshAxis] = detectRes.value();
193  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
194  tensorAxis, meshAxis);
195  }
196 
197  return std::nullopt;
198 }
199 
200 // Detect if the resharding is of type e.g.
201 // [[0, 1, 2]] -> [[0, 1]].
202 // If detected, returns the corresponding tensor axis mesh axis pair.
203 static std::optional<std::tuple<int64_t, MeshAxis>>
205  MeshShardingAttr targetSharding) {
206  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
207  ++tensorAxis) {
208  if (targetSharding.getSplitAxes().size() > tensorAxis) {
209  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
210  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
211  continue;
212  if (!llvm::equal(
213  llvm::make_range(
214  sourceSharding.getSplitAxes()[tensorAxis]
215  .asArrayRef()
216  .begin(),
217  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
218  1),
219  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
220  continue;
221  } else {
222  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
223  continue;
224  }
225  return std::make_tuple(
226  tensorAxis,
227  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
228  }
229  return std::nullopt;
230 }
231 
232 static MeshShardingAttr
234  MeshShardingAttr sourceSharding,
235  int64_t splitTensorAxis) {
236  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
237  llvm::to_vector(sourceSharding.getSplitAxes());
238  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
239  splitTensorAxis);
240  auto targetSplitAxes =
241  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
242 
243  targetSplitAxes.pop_back();
244  targetShardingSplitAxes[splitTensorAxis] =
245  MeshAxesAttr::get(ctx, targetSplitAxes);
246  return MeshShardingAttr::get(
247  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
248  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
249 }
250 
252  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
253  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
254  targetShape[splitTensorAxis] =
255  gatherDimension(targetShape[splitTensorAxis], splitCount);
256  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
257 }
258 
259 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
261  MeshShardingAttr sourceSharding,
262  ShapedType sourceUnshardedShape,
263  TypedValue<ShapedType> sourceShard, MeshOp mesh,
264  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
265  MLIRContext *ctx = builder.getContext();
266  builder.setInsertionPointAfterValue(sourceShard);
267 
268  MeshShardingAttr targetSharding =
269  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
270  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
271  sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
272  Value allGatherResult = builder.create<AllGatherOp>(
273  RankedTensorType::get(allGatherResultShape.getShape(),
274  allGatherResultShape.getElementType()),
275  mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
276  APInt(64, splitTensorAxis));
277  ShapedType targetShape =
278  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
279  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
280  builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
281  return {targetShard, targetSharding};
282 }
283 
284 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
286  MeshShardingAttr sourceSharding,
287  MeshShardingAttr targetSharding,
288  ShapedType sourceUnshardedShape,
289  TypedValue<ShapedType> sourceShard) {
290  if (auto detectRes =
291  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
292  auto [tensorAxis, meshAxis] = detectRes.value();
293  return unsplitLastAxisInResharding(builder, sourceSharding,
294  sourceUnshardedShape, sourceShard, mesh,
295  tensorAxis, meshAxis);
296  }
297 
298  return std::nullopt;
299 }
300 
301 // Detect if the resharding is of type e.g.
302 // [[0, 1], [2]] -> [[0], [1, 2]].
303 // Only moving the last axis counts.
304 // If detected, returns the corresponding (source_tensor_axis,
305 // target_tensor_axis, mesh_axis) tuple.
306 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
308  MeshShardingAttr targetSharding) {
309  for (size_t sourceTensorAxis = 0;
310  sourceTensorAxis < sourceSharding.getSplitAxes().size();
311  ++sourceTensorAxis) {
312  for (size_t targetTensorAxis = 0;
313  targetTensorAxis < targetSharding.getSplitAxes().size();
314  ++targetTensorAxis) {
315  if (sourceTensorAxis == targetTensorAxis)
316  continue;
317  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
318  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
319  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
320  targetSharding.getSplitAxes()[targetTensorAxis]
321  .asArrayRef()
322  .back())
323  continue;
324  if (!llvm::equal(
325  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
326  .asArrayRef()
327  .begin(),
328  sourceSharding.getSplitAxes()[sourceTensorAxis]
329  .asArrayRef()
330  .end() -
331  1),
332  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
333  .asArrayRef()
334  .begin(),
335  targetSharding.getSplitAxes()[targetTensorAxis]
336  .asArrayRef()
337  .end() -
338  1)))
339  continue;
340  return std::make_tuple(
341  sourceTensorAxis, targetTensorAxis,
342  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
343  }
344  }
345  return std::nullopt;
346 }
347 
348 static MeshShardingAttr
350  int64_t sourceTensorAxis,
351  int64_t targetTensorAxis) {
352  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
353  llvm::to_vector(sourceSharding.getSplitAxes());
354  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
355  targetTensorAxis) {
356  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
357  }
358 
359  auto sourceSplitAxes =
360  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
361  assert(!sourceSplitAxes.empty());
362  auto meshAxis = sourceSplitAxes.back();
363  sourceSplitAxes.pop_back();
364  targetShardingSplitAxes[sourceTensorAxis] =
365  MeshAxesAttr::get(ctx, sourceSplitAxes);
366 
367  auto targetSplitAxes =
368  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
369  targetSplitAxes.push_back(meshAxis);
370  targetShardingSplitAxes[targetTensorAxis] =
371  MeshAxesAttr::get(ctx, targetSplitAxes);
372 
373  return MeshShardingAttr::get(
374  ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
375  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
376 }
377 
378 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
379  int64_t splitCount,
380  int64_t sourceTensorAxis,
381  int64_t targetTensorAxis) {
382  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
383  targetShape[sourceTensorAxis] =
384  gatherDimension(targetShape[sourceTensorAxis], splitCount);
385  targetShape[targetTensorAxis] =
386  shardDimension(targetShape[targetTensorAxis], splitCount);
387  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
388 }
389 
390 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
392  MeshShardingAttr sourceSharding,
393  ShapedType sourceUnshardedShape,
394  TypedValue<ShapedType> sourceShard,
395  int64_t sourceTensorAxis,
396  int64_t targetTensorAxis, MeshAxis meshAxis) {
397  MLIRContext *ctx = builder.getContext();
398  builder.setInsertionPointAfterValue(sourceShard);
399 
401  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
402  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
403  sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
404  targetTensorAxis);
405  Value allToAllResult = builder.create<AllToAllOp>(
406  RankedTensorType::get(allToAllResultShape.getShape(),
407  allToAllResultShape.getElementType()),
408  mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
409  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
410  ShapedType targetShape =
411  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
412  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
413  builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
414  return {targetShard, targetSharding};
415 }
416 
417 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
419  MeshShardingAttr sourceSharding,
420  MeshShardingAttr targetSharding,
421  ShapedType sourceUnshardedShape,
422  TypedValue<ShapedType> sourceShard) {
423  if (auto detectRes =
424  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
425  auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
427  builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
428  sourceTensorAxis, targetTensorAxis, meshAxis);
429  }
430 
431  return std::nullopt;
432 }
433 
434 // Handles only resharding on a 1D mesh.
435 // Currently the sharded tensor axes must be exactly divisible by the single
436 // mesh axis size.
439  MeshShardingAttr sourceSharding,
440  MeshShardingAttr targetSharding,
441  TypedValue<ShapedType> sourceUnshardedValue,
442  TypedValue<ShapedType> sourceShard) {
443  assert(sourceShard.getType() ==
444  shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
445  [[maybe_unused]] ShapedType targetShardType =
446  shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
447  assert(sourceShard.getType().getRank() == targetShardType.getRank());
448  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
449 
450  auto [reducedSourceShard, reducedSourceSharding] =
451  handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
452  sourceShard);
453 
454  if (reducedSourceSharding == targetSharding) {
455  return reducedSourceShard;
456  }
457 
458  TypedValue<ShapedType> targetShard;
459  MeshShardingAttr actualTargetSharding;
460  if (auto tryRes = tryMoveLastSplitAxisInResharding(
461  builder, mesh, reducedSourceSharding, targetSharding,
462  sourceUnshardedValue.getType(), reducedSourceShard)) {
463  std::tie(targetShard, actualTargetSharding) = tryRes.value();
464  } else if (auto tryRes = trySplitLastAxisInResharding(
465  builder, mesh, reducedSourceSharding, targetSharding,
466  reducedSourceShard)) {
467  std::tie(targetShard, actualTargetSharding) = tryRes.value();
468  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
469  builder, mesh, reducedSourceSharding, targetSharding,
470  sourceUnshardedValue.getType(), reducedSourceShard)) {
471  std::tie(targetShard, actualTargetSharding) = tryRes.value();
472  } else {
473  assert(false && "Did not find any pattern to apply.");
474  }
475 
476  assert(actualTargetSharding == targetSharding);
477  assert(targetShard.getType() == targetShardType);
478  return targetShard;
479 }
480 
482  MeshShardingAttr sourceSharding,
483  MeshShardingAttr targetSharding,
484  TypedValue<ShapedType> sourceUnshardedValue,
485  TypedValue<ShapedType> sourceShard) {
486  // Resort to handling only 1D meshes since the general case is complicated if
487  // it needs to be communication efficient in terms of minimizing the data
488  // transfered between devices.
489  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
490  sourceUnshardedValue, sourceShard);
491 }
492 
493 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
494  ShardOp target,
495  TypedValue<ShapedType> sourceShardValue) {
496  assert(!source.getAnnotateForUsers());
497  assert(target.getAnnotateForUsers());
498  assert(source.getResult() == target.getOperand());
499  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
500  return reshard(
501  implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
502  cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
503 }
504 
505 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
506  ShardOp target,
507  TypedValue<ShapedType> sourceShardValue,
508  SymbolTableCollection &symbolTableCollection) {
509  MeshOp srcMesh = getMesh(source, symbolTableCollection);
510  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
511  return reshard(builder, srcMesh, source, target, sourceShardValue);
512 }
513 
515  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
516 }
517 
518 #define GEN_PASS_DEF_SPMDIZATION
519 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
520 
522 
523 // Get the types of block arguments for an spmdized block.
524 // Reads the sharding annotations of the arguments to deduce the sharded types.
525 // Types that are not ranked tensors are left unchanged.
528  SymbolTableCollection &symbolTableCollection) {
529  SmallVector<Type> res;
530  llvm::transform(
531  block.getArguments(), std::back_inserter(res),
532  [&symbolTableCollection](BlockArgument arg) {
533  auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
534  if (!rankedTensorArg) {
535  return arg.getType();
536  }
537 
538  assert(rankedTensorArg.hasOneUse());
539  Operation *useOp = *rankedTensorArg.getUsers().begin();
540  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
541  assert(shardOp);
542  MeshOp mesh = getMesh(shardOp, symbolTableCollection);
543  return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
544  shardOp.getShardAttr()));
545  });
546  return res;
547 }
548 
550  Operation &op, ArrayRef<Value> spmdizedOperands,
551  ArrayRef<MeshShardingAttr> operandShardings,
552  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
553  SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
554  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
555  if (!shardingInterface) {
556  // If there is no sharding interface we are conservative and assume that
557  // the op should be fully replicated no all devices.
558  spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
559  resultShardings, spmdizationMap,
560  symbolTableCollection, builder);
561  } else {
562  if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
563  resultShardings, spmdizationMap,
564  symbolTableCollection, builder))) {
565  return failure();
566  }
567  }
568 
569  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
570  return spmdizationMap.contains(result);
571  }));
572 
573  return success();
574 }
575 
576 // Retrieve the sharding annotations for the operands of the given operation.
577 // If the type is not a ranked tensor it is not require to have an annotation.
580  res.reserve(op.getNumOperands());
581  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
582  TypedValue<RankedTensorType> rankedTensor =
583  dyn_cast<TypedValue<RankedTensorType>>(operand);
584  if (!rankedTensor) {
585  return MeshShardingAttr();
586  }
587 
588  Operation *definingOp = operand.getDefiningOp();
589  assert(definingOp);
590  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
591  return shardOp.getShard();
592  });
593  return res;
594 }
595 
596 // Retrieve the sharding annotations for the results of the given operation.
597 // If the type is not a ranked tensor it is not require to have an annotation.
600  res.reserve(op.getNumResults());
601  llvm::transform(op.getResults(), std::back_inserter(res),
602  [](OpResult result) {
603  TypedValue<RankedTensorType> rankedTensor =
604  dyn_cast<TypedValue<RankedTensorType>>(result);
605  if (!rankedTensor) {
606  return MeshShardingAttr();
607  }
608 
609  assert(result.hasOneUse());
610  Operation *userOp = *result.getUsers().begin();
611  ShardOp shardOp = llvm::cast<ShardOp>(userOp);
612  return shardOp.getShard();
613  });
614  return res;
615 }
616 
617 static LogicalResult
618 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
619  SymbolTableCollection &symbolTableCollection,
620  OpBuilder &builder) {
621  Value targetSpmdValue;
622 
623  // Check if 2 shard ops are chained. If not there is no need for resharding
624  // as the source and target shared the same sharding.
625  ShardOp srcShardOp =
626  dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
627  if (!srcShardOp) {
628  targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
629  } else {
630  // Insert resharding.
631  assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
632  TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
633  spmdizationMap.lookup(srcShardOp.getOperand()));
634  targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
635  symbolTableCollection);
636  }
637 
638  assert(!spmdizationMap.contains(shardOp.getResult()));
639  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
640  return success();
641 }
642 
643 static LogicalResult
644 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
645  SymbolTableCollection &symbolTableCollection,
646  OpBuilder &builder) {
647  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
648  if (shardOp) {
649  return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
650  builder);
651  }
652 
653  SmallVector<Value> spmdizedOperands;
654  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
655  [&spmdizationMap](Value operand) {
656  assert(spmdizationMap.contains(operand));
657  return spmdizationMap.lookup(operand);
658  });
659  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
660  getResultShardings(op), spmdizationMap,
661  symbolTableCollection, builder);
662 }
663 
664 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
665  SymbolTableCollection &symbolTableCollection,
666  OpBuilder &builder) {
667  SmallVector<Location> argLocations;
668  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
669  [](BlockArgument arg) { return arg.getLoc(); });
670  Block *newBlock = builder.createBlock(
671  block.getParent(), {},
672  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
673  for (auto [unshardedBlockArg, spmdizedBlockArg] :
674  llvm::zip(block.getArguments(), newBlock->getArguments())) {
675  spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
676  }
677 
678  OpBuilder::InsertionGuard insertionGuard(builder);
679  builder.setInsertionPointToEnd(newBlock);
680  for (Operation &op : block.getOperations()) {
681  if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
682  builder))) {
683  return failure();
684  }
685  }
686 
687  return success();
688 }
689 
690 static LogicalResult
691 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
692  SymbolTableCollection &symbolTableCollection) {
693  OpBuilder builder(op.getFunctionBody());
694 
695  // Snapshot the original blocks to not mess up the iteration when adding new
696  // blocks.
697  SmallVector<Block *> originalBlocks;
698  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
699  [](Block &b) { return &b; });
700 
701  for (Block *block : originalBlocks) {
702  if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
703  builder))) {
704  return failure();
705  }
706  }
707 
708  for (Block *block : originalBlocks) {
709  block->erase();
710  }
711 
712  // Find a return op and change the function results signature to its operands
713  // signature.
714  Operation *returnOp = nullptr;
715  for (Block &block : op.getFunctionBody()) {
716  if (block.empty()) {
717  continue;
718  }
719 
720  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
721  returnOp = &block.back();
722  break;
723  }
724  }
725  assert(returnOp);
726  op.setType(FunctionType::get(op->getContext(),
727  op.getFunctionBody().front().getArgumentTypes(),
728  returnOp->getOperandTypes()));
729 
730  return success();
731 }
732 
733 namespace {
734 
735 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
736  void runOnOperation() override {
737  IRMapping spmdizationMap;
738  SymbolTableCollection symbolTableCollection;
739  if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
740  symbolTableCollection))) {
741  return signalPassFailure();
742  }
743  }
744 
745  void getDependentDialects(DialectRegistry &registry) const override {
747  registry.insert<mesh::MeshDialect>();
748  }
749 };
750 
751 } // namespace
752 
753 } // namespace mlir::mesh
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
mesh::MeshShardingAttr MeshShardingAttr
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > handlePartialAxesDuringResharding(OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Spmdization.cpp:59
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Spmdization.cpp:45
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:110
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
static SmallVector< MeshShardingAttr > getOperandShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:101
static MeshShardingAttr targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis)
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static SmallVector< MeshShardingAttr > getResultShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:162
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static MeshShardingAttr targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static MeshShardingAttr targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This trait indicates that a terminator operation is "return-like".