MLIR  20.0.0git
Spmdization.cpp
Go to the documentation of this file.
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
36 #include <iterator>
37 #include <optional>
38 #include <tuple>
39 #include <type_traits>
40 
41 namespace mlir::mesh {
42 
43 template <typename SourceAxes, typename TargetAxes>
44 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
45  const TargetAxes &targetAxes) {
46  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
47  return sourceAxes.contains(targetAxis);
48  });
49 }
50 
51 // Return the reduced value and its corresponding sharding.
52 // Example:
53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
54 // targetSharding = <@mesh_1d, [[]]>
55 // Then will apply all-reduce on the source value
56 // and return it with the sharding <@mesh_1d, [[0]]>.
57 static std::tuple<TypedValue<ShapedType>, MeshSharding>
59  MeshSharding sourceSharding,
60  MeshSharding targetSharding,
61  TypedValue<ShapedType> sourceShard) {
62  if (sourceSharding.getPartialAxes().empty() &&
63  targetSharding.getPartialAxes().empty()) {
64  return {sourceShard, sourceSharding};
65  }
66  assert(targetSharding.getPartialAxes().empty() ||
67  (!sourceSharding.getPartialAxes().empty() &&
68  sourceSharding.getPartialType() == targetSharding.getPartialType()));
69  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70  using AxisSet = llvm::SmallDenseSet<Axis>;
71  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72  sourceSharding.getPartialAxes().end());
73  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74  targetSharding.getPartialAxes().end());
75  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
76  targetShardingPartialAxesSet));
77  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
78  llvm::copy_if(sourceShardingPartialAxesSet,
79  std::back_inserter(allReduceMeshAxes),
80  [&targetShardingPartialAxesSet](Axis a) {
81  return !targetShardingPartialAxesSet.contains(a);
82  });
83  if (allReduceMeshAxes.empty()) {
84  return {sourceShard, sourceSharding};
85  }
86 
87  builder.setInsertionPointAfterValue(sourceShard);
88  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
89  builder
90  .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91  sourceSharding.getMeshAttr().getLeafReference(),
92  allReduceMeshAxes, sourceShard,
93  sourceSharding.getPartialType())
94  .getResult());
95 
96  llvm::SmallVector<MeshAxis> remainingPartialAxes;
97  llvm::copy_if(sourceShardingPartialAxesSet,
98  std::back_inserter(allReduceMeshAxes),
99  [&targetShardingPartialAxesSet](Axis a) {
100  return targetShardingPartialAxesSet.contains(a);
101  });
102  MeshSharding resultSharding = MeshSharding::get(
103  sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104  remainingPartialAxes, sourceSharding.getPartialType());
105  return {resultValue, resultSharding};
106 }
107 
109  MeshSharding sourceSharding,
110  int64_t splitTensorAxis,
111  MeshAxis splitMeshAxis) {
112  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
113  llvm::to_vector(sourceSharding.getSplitAxes());
114  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
115  splitTensorAxis) {
116  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
117  }
118  auto targetSplitAxes =
119  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120  targetSplitAxes.push_back(splitMeshAxis);
121  targetShardingSplitAxes[splitTensorAxis] =
122  MeshAxesAttr::get(ctx, targetSplitAxes);
123  return MeshSharding::get(
124  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
125  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
126 }
127 
128 // Split a replicated tensor along a mesh axis.
129 // e.g. [[0, 1]] -> [[0, 1, 2]].
130 // Returns the spmdized target value with its sharding.
131 static std::tuple<TypedValue<ShapedType>, MeshSharding>
133  MeshSharding sourceSharding,
134  TypedValue<ShapedType> sourceShard, MeshOp mesh,
135  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
137  builder
138  .create<AllSliceOp>(sourceShard, mesh,
139  ArrayRef<MeshAxis>(splitMeshAxis),
140  splitTensorAxis)
141  .getResult());
143  builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144  return {targetShard, targetSharding};
145 }
146 
147 // Detect if the resharding is of type e.g.
148 // [[0, 1]] -> [[0, 1, 2]].
149 // If detected, returns the corresponding tensor axis mesh axis pair.
150 // Does not detect insertions like
151 // [[0, 1]] -> [[0, 2, 1]].
152 static std::optional<std::tuple<int64_t, MeshAxis>>
154  MeshSharding targetSharding) {
155  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156  ++tensorAxis) {
157  if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158  if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159  targetSharding.getSplitAxes()[tensorAxis].size()) {
160  continue;
161  }
162  if (!llvm::equal(
163  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
164  llvm::make_range(
165  targetSharding.getSplitAxes()[tensorAxis]
166  .asArrayRef()
167  .begin(),
168  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
169  1))) {
170  continue;
171  }
172  } else {
173  if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
174  continue;
175  }
176  }
177  return std::make_tuple(
178  tensorAxis,
179  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
180  }
181  return std::nullopt;
182 }
183 
184 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
186  MeshSharding sourceSharding,
187  MeshSharding targetSharding,
188  TypedValue<ShapedType> sourceShard) {
189  if (auto detectRes =
190  detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
191  auto [tensorAxis, meshAxis] = detectRes.value();
192  return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
193  tensorAxis, meshAxis);
194  }
195 
196  return std::nullopt;
197 }
198 
199 // Detect if the resharding is of type e.g.
200 // [[0, 1, 2]] -> [[0, 1]].
201 // If detected, returns the corresponding tensor axis mesh axis pair.
202 static std::optional<std::tuple<int64_t, MeshAxis>>
204  MeshSharding targetSharding) {
205  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206  ++tensorAxis) {
207  if (targetSharding.getSplitAxes().size() > tensorAxis) {
208  if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209  targetSharding.getSplitAxes()[tensorAxis].size() + 1)
210  continue;
211  if (!llvm::equal(
212  llvm::make_range(
213  sourceSharding.getSplitAxes()[tensorAxis]
214  .asArrayRef()
215  .begin(),
216  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
217  1),
218  targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
219  continue;
220  } else {
221  if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
222  continue;
223  }
224  return std::make_tuple(
225  tensorAxis,
226  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
227  }
228  return std::nullopt;
229 }
230 
232  MeshSharding sourceSharding,
233  int64_t splitTensorAxis) {
234  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
235  llvm::to_vector(sourceSharding.getSplitAxes());
236  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
237  splitTensorAxis);
238  auto targetSplitAxes =
239  llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
240 
241  targetSplitAxes.pop_back();
242  targetShardingSplitAxes[splitTensorAxis] =
243  MeshAxesAttr::get(ctx, targetSplitAxes);
244  return MeshSharding::get(
245  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
246  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
247 }
248 
250  ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
251  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
252  targetShape[splitTensorAxis] =
253  gatherDimension(targetShape[splitTensorAxis], splitCount);
254  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
255 }
256 
257 static std::tuple<TypedValue<ShapedType>, MeshSharding>
259  MeshSharding sourceSharding,
260  ShapedType sourceUnshardedShape,
261  TypedValue<ShapedType> sourceShard, MeshOp mesh,
262  int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
263  MLIRContext *ctx = builder.getContext();
264  builder.setInsertionPointAfterValue(sourceShard);
265 
266  MeshSharding targetSharding =
267  targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
268  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
269  sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270  Value allGatherResult = builder.create<AllGatherOp>(
271  RankedTensorType::get(allGatherResultShape.getShape(),
272  allGatherResultShape.getElementType()),
273  mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
274  APInt(64, splitTensorAxis));
275  ShapedType targetShape =
276  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
277  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
278  builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279  return {targetShard, targetSharding};
280 }
281 
282 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
284  MeshSharding sourceSharding,
285  MeshSharding targetSharding,
286  ShapedType sourceUnshardedShape,
287  TypedValue<ShapedType> sourceShard) {
288  if (auto detectRes =
289  detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
290  auto [tensorAxis, meshAxis] = detectRes.value();
291  return unsplitLastAxisInResharding(builder, sourceSharding,
292  sourceUnshardedShape, sourceShard, mesh,
293  tensorAxis, meshAxis);
294  }
295 
296  return std::nullopt;
297 }
298 
299 // Detect if the resharding is of type e.g.
300 // [[0, 1], [2]] -> [[0], [1, 2]].
301 // Only moving the last axis counts.
302 // If detected, returns the corresponding (source_tensor_axis,
303 // target_tensor_axis, mesh_axis) tuple.
304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
306  MeshSharding targetSharding) {
307  for (size_t sourceTensorAxis = 0;
308  sourceTensorAxis < sourceSharding.getSplitAxes().size();
309  ++sourceTensorAxis) {
310  for (size_t targetTensorAxis = 0;
311  targetTensorAxis < targetSharding.getSplitAxes().size();
312  ++targetTensorAxis) {
313  if (sourceTensorAxis == targetTensorAxis)
314  continue;
315  if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
316  targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
317  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
318  targetSharding.getSplitAxes()[targetTensorAxis]
319  .asArrayRef()
320  .back())
321  continue;
322  if (!llvm::equal(
323  llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
324  .asArrayRef()
325  .begin(),
326  sourceSharding.getSplitAxes()[sourceTensorAxis]
327  .asArrayRef()
328  .end() -
329  1),
330  llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
331  .asArrayRef()
332  .begin(),
333  targetSharding.getSplitAxes()[targetTensorAxis]
334  .asArrayRef()
335  .end() -
336  1)))
337  continue;
338  return std::make_tuple(
339  sourceTensorAxis, targetTensorAxis,
340  sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
341  }
342  }
343  return std::nullopt;
344 }
345 
347  MeshSharding sourceSharding,
348  int64_t sourceTensorAxis,
349  int64_t targetTensorAxis) {
350  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
351  llvm::to_vector(sourceSharding.getSplitAxes());
352  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
353  targetTensorAxis) {
354  targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
355  }
356 
357  auto sourceSplitAxes =
358  llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359  assert(!sourceSplitAxes.empty());
360  auto meshAxis = sourceSplitAxes.back();
361  sourceSplitAxes.pop_back();
362  targetShardingSplitAxes[sourceTensorAxis] =
363  MeshAxesAttr::get(ctx, sourceSplitAxes);
364 
365  auto targetSplitAxes =
366  llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367  targetSplitAxes.push_back(meshAxis);
368  targetShardingSplitAxes[targetTensorAxis] =
369  MeshAxesAttr::get(ctx, targetSplitAxes);
370 
371  return MeshSharding::get(
372  sourceSharding.getMeshAttr(), targetShardingSplitAxes,
373  sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
374 }
375 
376 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
377  int64_t splitCount,
378  int64_t sourceTensorAxis,
379  int64_t targetTensorAxis) {
380  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
381  targetShape[sourceTensorAxis] =
382  gatherDimension(targetShape[sourceTensorAxis], splitCount);
383  targetShape[targetTensorAxis] =
384  shardDimension(targetShape[targetTensorAxis], splitCount);
385  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
386 }
387 
388 static std::tuple<TypedValue<ShapedType>, MeshSharding>
390  MeshSharding sourceSharding,
391  ShapedType sourceUnshardedShape,
392  TypedValue<ShapedType> sourceShard,
393  int64_t sourceTensorAxis,
394  int64_t targetTensorAxis, MeshAxis meshAxis) {
395  MLIRContext *ctx = builder.getContext();
396  builder.setInsertionPointAfterValue(sourceShard);
397 
399  ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
400  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
401  sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
402  targetTensorAxis);
403  Value allToAllResult = builder.create<AllToAllOp>(
404  RankedTensorType::get(allToAllResultShape.getShape(),
405  allToAllResultShape.getElementType()),
406  mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
407  APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408  ShapedType targetShape =
409  shardShapedType(sourceUnshardedShape, mesh, targetSharding);
410  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
411  builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412  return {targetShard, targetSharding};
413 }
414 
415 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
417  MeshSharding sourceSharding,
418  MeshSharding targetSharding,
419  ShapedType sourceUnshardedShape,
420  TypedValue<ShapedType> sourceShard) {
421  if (auto detectRes =
422  detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
423  auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
425  builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426  sourceTensorAxis, targetTensorAxis, meshAxis);
427  }
428 
429  return std::nullopt;
430 }
431 
432 // Handles only resharding on a 1D mesh.
433 // Currently the sharded tensor axes must be exactly divisible by the single
434 // mesh axis size.
437  MeshSharding sourceSharding, MeshSharding targetSharding,
438  TypedValue<ShapedType> sourceUnshardedValue,
439  TypedValue<ShapedType> sourceShard) {
440  assert(sourceShard.getType() ==
441  shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
442  [[maybe_unused]] ShapedType targetShardType =
443  shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
444  assert(sourceShard.getType().getRank() == targetShardType.getRank());
445  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
446 
447  auto [reducedSourceShard, reducedSourceSharding] =
448  handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
449  sourceShard);
450 
451  if (reducedSourceSharding == targetSharding) {
452  return reducedSourceShard;
453  }
454 
455  TypedValue<ShapedType> targetShard;
456  MeshSharding actualTargetSharding;
457  if (reducedSourceSharding.getStaticHaloSizes().empty() &&
458  targetSharding.getStaticHaloSizes().empty() &&
459  reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
460  targetSharding.getStaticShardedDimsSizes().empty()) {
461  if (auto tryRes = tryMoveLastSplitAxisInResharding(
462  builder, mesh, reducedSourceSharding, targetSharding,
463  sourceUnshardedValue.getType(), reducedSourceShard)) {
464  std::tie(targetShard, actualTargetSharding) = tryRes.value();
465  } else if (auto tryRes = trySplitLastAxisInResharding(
466  builder, mesh, reducedSourceSharding, targetSharding,
467  reducedSourceShard)) {
468  std::tie(targetShard, actualTargetSharding) = tryRes.value();
469  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
470  builder, mesh, reducedSourceSharding, targetSharding,
471  sourceUnshardedValue.getType(), reducedSourceShard)) {
472  std::tie(targetShard, actualTargetSharding) = tryRes.value();
473  }
474  }
475  assert(targetShard && "Did not find any pattern to apply.");
476  assert(actualTargetSharding == targetSharding);
477  assert(targetShard.getType() == targetShardType);
478  return targetShard;
479 }
480 
482  MeshSharding sourceSharding,
483  MeshSharding targetSharding,
484  TypedValue<ShapedType> sourceUnshardedValue,
485  TypedValue<ShapedType> sourceShard) {
486  // Resort to handling only 1D meshes since the general case is complicated if
487  // it needs to be communication efficient in terms of minimizing the data
488  // transfered between devices.
489  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
490  sourceUnshardedValue, sourceShard);
491 }
492 
493 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
494  ShardOp target,
495  TypedValue<ShapedType> sourceShardValue) {
496  assert(source.getResult() == target.getSrc());
497  auto sourceSharding = source.getSharding();
498  auto targetSharding = target.getSharding();
499  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
500  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
501  cast<TypedValue<ShapedType>>(source.getSrc()),
502  sourceShardValue);
503 }
504 
505 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
506  ShardOp target,
507  TypedValue<ShapedType> sourceShardValue,
508  SymbolTableCollection &symbolTableCollection) {
509  MeshOp srcMesh = getMesh(source, symbolTableCollection);
510  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
511  return reshard(builder, srcMesh, source, target, sourceShardValue);
512 }
513 
515  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
516 }
517 
518 #define GEN_PASS_DEF_SPMDIZATION
519 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
520 
522 
523 // Get the types of block arguments for an spmdized block.
524 // Reads the sharding annotations of the arguments to deduce the sharded types.
525 // Types that are not ranked tensors are left unchanged.
528  SymbolTableCollection &symbolTableCollection) {
529  SmallVector<Type> res;
530  llvm::transform(
531  block.getArguments(), std::back_inserter(res),
532  [&symbolTableCollection](BlockArgument arg) {
533  auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
534  if (!rankedTensorArg) {
535  return arg.getType();
536  }
537 
538  assert(rankedTensorArg.hasOneUse());
539  Operation *useOp = *rankedTensorArg.getUsers().begin();
540  ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
541  assert(shardOp);
542  MeshOp mesh = getMesh(shardOp, symbolTableCollection);
543  return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
544  shardOp.getSharding()));
545  });
546  return res;
547 }
548 
549 void spmdizeTriviallyShardableOperation(Operation &op,
550  ArrayRef<Value> spmdizedOperands,
551  ArrayRef<MeshSharding> operandShardings,
552  ArrayRef<MeshSharding> resultShardings,
553  IRMapping &spmdizationMap,
554  SymbolTableCollection &symbolTable,
555  OpBuilder &builder);
556 
557 static LogicalResult spmdizeOperation(
558  Operation &op, ArrayRef<Value> spmdizedOperands,
559  ArrayRef<MeshSharding> operandShardings,
560  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
561  SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
562  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
563  if (!shardingInterface) {
564  // If there is no sharding interface we are conservative and assume that
565  // the op should be fully replicated no all devices.
566  spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
567  resultShardings, spmdizationMap,
568  symbolTableCollection, builder);
569  } else {
570  if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
571  resultShardings, spmdizationMap,
572  symbolTableCollection, builder))) {
573  return failure();
574  }
575  }
576 
577  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
578  return spmdizationMap.contains(result);
579  }));
580 
581  return success();
582 }
583 
584 // Retrieve the sharding annotations for the operands of the given operation.
585 // If the type is not a ranked tensor it is not require to have an annotation.
586 static std::vector<MeshSharding> getOperandShardings(Operation &op) {
587  std::vector<MeshSharding> res;
588  res.reserve(op.getNumOperands());
589  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
590  TypedValue<RankedTensorType> rankedTensor =
591  dyn_cast<TypedValue<RankedTensorType>>(operand);
592  if (!rankedTensor) {
593  return MeshSharding();
594  }
595 
596  Operation *definingOp = operand.getDefiningOp();
597  assert(definingOp);
598  ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
599  return MeshSharding(shardOp.getSharding());
600  });
601  return res;
602 }
603 
604 // Retrieve the sharding annotations for the results of the given operation.
605 // If the type is not a ranked tensor it is not require to have an annotation.
606 static std::vector<MeshSharding> getResultShardings(Operation &op) {
607  std::vector<MeshSharding> res;
608  res.reserve(op.getNumResults());
609  llvm::transform(op.getResults(), std::back_inserter(res),
610  [](OpResult result) {
611  TypedValue<RankedTensorType> rankedTensor =
612  dyn_cast<TypedValue<RankedTensorType>>(result);
613  if (!rankedTensor) {
614  return MeshSharding();
615  }
616 
617  assert(result.hasOneUse());
618  Operation *userOp = *result.getUsers().begin();
619  ShardOp shardOp = llvm::cast<ShardOp>(userOp);
620  return MeshSharding(shardOp.getSharding());
621  });
622  return res;
623 }
624 
625 static LogicalResult
626 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
627  SymbolTableCollection &symbolTableCollection,
628  OpBuilder &builder) {
629  Value targetSpmdValue;
630 
631  // Check if 2 shard ops are chained. If not there is no need for resharding
632  // as the source and target shared the same sharding.
633  ShardOp srcShardOp =
634  dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
635  if (!srcShardOp) {
636  targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
637  } else {
638  // Insert resharding.
639  TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
640  spmdizationMap.lookup(srcShardOp.getSrc()));
641  targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
642  symbolTableCollection);
643  }
644 
645  assert(!spmdizationMap.contains(shardOp.getResult()));
646  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
647  return success();
648 }
649 
650 static LogicalResult
651 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
652  SymbolTableCollection &symbolTableCollection,
653  OpBuilder &builder) {
654  if (isa<ShardingOp>(op)) {
655  return success();
656  }
657 
658  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
659  if (shardOp) {
660  return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
661  builder);
662  }
663 
664  SmallVector<Value> spmdizedOperands;
665  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
666  [&spmdizationMap](Value operand) {
667  assert(spmdizationMap.contains(operand));
668  return spmdizationMap.lookup(operand);
669  });
670  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
671  getResultShardings(op), spmdizationMap,
672  symbolTableCollection, builder);
673 }
674 
675 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
676  SymbolTableCollection &symbolTableCollection,
677  OpBuilder &builder) {
678  SmallVector<Location> argLocations;
679  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
680  [](BlockArgument arg) { return arg.getLoc(); });
681  Block *newBlock = builder.createBlock(
682  block.getParent(), {},
683  shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
684  for (auto [unshardedBlockArg, spmdizedBlockArg] :
685  llvm::zip(block.getArguments(), newBlock->getArguments())) {
686  spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
687  }
688 
689  OpBuilder::InsertionGuard insertionGuard(builder);
690  builder.setInsertionPointToEnd(newBlock);
691  for (Operation &op : block.getOperations()) {
692  if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
693  builder))) {
694  return failure();
695  }
696  }
697 
698  return success();
699 }
700 
701 static LogicalResult
702 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
703  SymbolTableCollection &symbolTableCollection) {
704  OpBuilder builder(op.getFunctionBody());
705 
706  // Snapshot the original blocks to not mess up the iteration when adding new
707  // blocks.
708  SmallVector<Block *> originalBlocks;
709  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
710  [](Block &b) { return &b; });
711 
712  for (Block *block : originalBlocks) {
713  if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
714  builder))) {
715  return failure();
716  }
717  }
718 
719  for (Block *block : originalBlocks) {
720  block->erase();
721  }
722 
723  // Find a return op and change the function results signature to its operands
724  // signature.
725  Operation *returnOp = nullptr;
726  for (Block &block : op.getFunctionBody()) {
727  if (block.empty()) {
728  continue;
729  }
730 
731  if (block.back().hasTrait<OpTrait::ReturnLike>()) {
732  returnOp = &block.back();
733  break;
734  }
735  }
736  assert(returnOp);
737  op.setType(FunctionType::get(op->getContext(),
738  op.getFunctionBody().front().getArgumentTypes(),
739  returnOp->getOperandTypes()));
740 
741  return success();
742 }
743 
744 namespace {
745 
746 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
747  void runOnOperation() override {
748  IRMapping spmdizationMap;
749  SymbolTableCollection symbolTableCollection;
750  if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
751  symbolTableCollection))) {
752  return signalPassFailure();
753  }
754  }
755 
756  void getDependentDialects(DialectRegistry &registry) const override {
758  registry.insert<mesh::MeshDialect>();
759  }
760 };
761 
762 } // namespace
763 
764 } // namespace mlir::mesh
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:63
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:65
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
Definition: MeshOps.cpp:637
ReductionKind getPartialType() const
Definition: MeshOps.h:67
ArrayRef< int64_t > getStaticShardedDimsSizes() const
Definition: MeshOps.h:69
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:66
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:68
mesh::MeshSharding MeshSharding
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
static std::tuple< TypedValue< ShapedType >, MeshSharding > handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
Definition: Spmdization.cpp:58
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::vector< MeshSharding > getResultShardings(Operation &op)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition: Spmdization.cpp:44
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:179
static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
static std::tuple< TypedValue< ShapedType >, MeshSharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:253
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:123
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:170
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
static std::tuple< TypedValue< ShapedType >, MeshSharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static std::vector< MeshSharding > getOperandShardings(Operation &op)
int16_t MeshAxis
Definition: MeshOps.h:25
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::tuple< TypedValue< ShapedType >, MeshSharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".