MLIR  21.0.0git
ShardingInterface.cpp
Go to the documentation of this file.
1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <utility>
22 
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 
26 using namespace mlir;
27 using namespace mlir::mesh;
28 
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // common util functions
33 //===----------------------------------------------------------------------===//
34 
35 static LogicalResult
37  SmallVectorImpl<bool> &seenIds) {
38  switch (expr.getKind()) {
39  case AffineExprKind::Add: {
40  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41  AffineExpr lhs = binOpExpr.getLHS();
42  AffineExpr rhs = binOpExpr.getRHS();
43  if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44  return failure();
45  if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46  return failure();
47  return success();
48  }
49  case AffineExprKind::Mul: {
50  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51  AffineExpr lhs = binOpExpr.getLHS();
52  AffineExpr rhs = binOpExpr.getRHS();
53  AffineExpr dimExpr;
54  if (lhs.getKind() == AffineExprKind::DimId &&
56  dimExpr = lhs;
57  } else if (rhs.getKind() == AffineExprKind::DimId &&
59  dimExpr = rhs;
60  } else
61  return failure();
62  unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63  if ((size_t)position >= seenIds.size() || seenIds[position])
64  return failure();
65  seenIds[position] = true;
66  return success();
67  }
68  case AffineExprKind::DimId: {
69  unsigned position = cast<AffineDimExpr>(expr).getPosition();
70  if ((size_t)position >= seenIds.size() || seenIds[position])
71  return failure();
72  seenIds[position] = true;
73  return success();
74  }
75  default:
76  return failure();
77  }
78 }
79 
80 static FailureOr<llvm::SmallSet<unsigned, 2>>
81 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
82  SmallVector<bool> seenIds(numDims, false);
83  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
84  return failure();
85 
86  llvm::SmallSet<unsigned, 2> positions;
87  for (auto it : llvm::enumerate(seenIds)) {
88  if (it.value())
89  positions.insert((unsigned)it.index());
90  }
91  return positions;
92 }
93 
94 template <typename T>
98  for (const auto &v : vec) {
99  res.emplace_back(MeshAxesAttr::get(ctxt, v));
100  }
101  return res;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // mesh::getMeshSharding
106 //===----------------------------------------------------------------------===//
107 
108 FailureOr<std::pair<bool, MeshSharding>>
110  Value val = cast<Value>(result);
111  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
112  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
113  if (!shardOp)
114  return false;
115  return !shardOp.getAnnotateForUsers();
116  });
117 
118  if (anyShardedForDef) {
119  // expected to have exact one use if it has a use of `mesh.shard` without
120  // unit attr annotate_for_users
121  if (!val.hasOneUse())
122  return failure();
123  auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
124  return std::make_pair(false, MeshSharding(shardOp.getSharding()));
125  }
126 
127  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
128  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
129  if (!shardOp)
130  return false;
131  return shardOp.getAnnotateForUsers();
132  });
133  if (anyShardedForUsers) {
134  SmallVector<ShardOp> shardOps;
135  for (Operation *user : val.getUsers()) {
136  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
137  if (shardOp)
138  shardOps.push_back(shardOp);
139  }
140  MeshSharding shardForDef = shardOps[0].getSharding();
141  for (size_t i = 1; i < shardOps.size(); ++i) {
142  // TODO: Deduce a reasonable mesh sharding attr for def when they are
143  // different
144  assert(shardForDef == shardOps[i].getSharding() &&
145  "only support all shard ops have the same mesh sharding attr");
146  }
147  return std::make_pair(true, shardForDef);
148  }
149  return failure();
150 }
151 
152 FailureOr<std::pair<bool, MeshSharding>>
154  Value val = opOperand.get();
155  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
156  return std::make_pair(shardOp.getAnnotateForUsers(),
157  MeshSharding(shardOp.getSharding()));
158 
159  return failure();
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // ShardingInterface::verifyShardingInterfaceImpl
164 //===----------------------------------------------------------------------===//
165 
166 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
167  Operation *op = getOperation();
168 
169  // check operands and results type
170  for (Type type : op->getOperandTypes())
171  if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
172  return failure();
173  for (Type type : op->getResultTypes())
174  if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
175  return failure();
176 
177  // check maps
178  SmallVector<AffineMap> maps = getIndexingMaps();
179  if (maps.empty())
180  return failure();
181  unsigned numOperands = op->getNumOperands();
182  unsigned numResults = op->getNumResults();
183  if (numOperands + numResults != maps.size())
184  return failure();
185 
186  for (OpResult result : op->getResults()) {
187  auto resultType = dyn_cast<RankedTensorType>(result.getType());
188  if (!resultType)
189  return failure();
190  AffineMap map = maps[numOperands + result.getResultNumber()];
191  if (!map.isProjectedPermutation()) {
192  return failure();
193  }
194  }
195 
196  return success();
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // ShardingInterface::printLoopTypesAndIndexingMaps
201 //===----------------------------------------------------------------------===//
202 
203 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
204  os << "print loop types and indexing maps for: \n";
205  getOperation()->print(os);
206  os << "\n";
207  os << "loop types: [";
208  for (utils::IteratorType type : getLoopIteratorTypes()) {
209  os << stringifyEnum(type) << " ";
210  }
211  os << "]\n";
212  os << "indexing maps: \n";
213  for (AffineMap map : getIndexingMaps())
214  os << map << "\n";
215  os << "\n";
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // detail::defaultGetShardingOption
220 //===----------------------------------------------------------------------===//
221 
222 namespace {
223 
224 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
225 static LogicalResult fillShardingOption(Operation *op,
226  ShardingOption &shardingOption,
227  FlatSymbolRefAttr mesh,
228  ArrayRef<MeshAxis> meshAxes,
229  unsigned loopIdx) {
230  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
231  (!shardingOption.shardingArray[loopIdx].empty() &&
232  shardingOption.shardingArray[loopIdx] != meshAxes)) {
233  LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
234  << loopIdx << "\n");
235  return failure();
236  }
237  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
238  if (i == loopIdx)
239  continue;
240 
241  for (MeshAxis axis : meshAxes) {
242  if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
243  LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
244  << axis << " duplicate");
245  return failure();
246  }
247  }
248  }
249  if (mesh)
250  shardingOption.mesh = mesh;
251  if (shardingOption.shardingArray[loopIdx].empty())
252  shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
253  meshAxes.end());
254  return success();
255 }
256 
257 } // namespace
258 
259 FailureOr<ShardingOption>
261  ArrayRef<MeshSharding> operandShardings,
262  ArrayRef<MeshSharding> resultShardings) {
263  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
264  ShardingOption shardingOption;
265 
266  if (failed(shardingOp.verifyShardingInterfaceImpl()))
267  return op->emitOpError() << "invalid sharding interface implementation";
269  shardingOp.getLoopIteratorTypes();
270  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
271  unsigned numOperands = op->getNumOperands();
272  shardingOption.shardingArray.resize(loopTypes.size());
273  llvm::SmallVector<MeshAxis> partialMeshAxes;
274  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
275  bool anyShardingInResultsOrOperands = false;
276 
277  // 1. Fill sharding option based on op results
278  for (auto shardingIt : llvm::enumerate(resultShardings)) {
279  MeshSharding shardAttr = shardingIt.value();
280  if (!shardAttr)
281  continue;
282  AffineMap map = maps[numOperands + shardingIt.index()];
283  anyShardingInResultsOrOperands = true;
284  if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
285  shardingOption.mesh = shardAttr.getMeshAttr();
286  } else {
287  // Handle the split axes: calculate the corresponding loop index for each
288  // split axes sub-array, and then store the sub-array to
289  // shardingOption[index]
290  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
291  AffineExpr expr = std::get<0>(it);
292  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
293  auto dim = cast<AffineDimExpr>(expr);
294  unsigned index = dim.getPosition();
295  visitedLoopIndices.insert(index);
296  if (failed(fillShardingOption(op, shardingOption,
297  shardAttr.getMeshAttr(), axes, index)))
298  return failure();
299  }
300  }
301 
302  // Handle the partial axes: at this stage, the exact loop index/indices
303  // cannot be decided because there could be multiple reduction loops.
304  ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
305  if (!partialAxes.empty()) {
306  if (!partialMeshAxes.empty())
307  return op->emitOpError() << "at most one result with partial axes is "
308  "supported at present";
309  partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
310  // Add all the reduction loop indices to `visitedLoopIndices` if
311  // `partialAxes` is not empty
312  for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
313  if (isReductionLoop(loopTypes[loopIdx]))
314  visitedLoopIndices.insert(loopIdx);
315  }
316  }
317  }
318 
319  // 2. Fill sharding option based on operands
320  for (auto shardingIt : llvm::enumerate(operandShardings)) {
321  MeshSharding shardAttr = shardingIt.value();
322  if (!shardAttr)
323  continue;
324 
325  anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
326  AffineMap map = maps[shardingIt.index()];
327  unsigned numDims = map.getNumDims();
328 
329  // Handle the split axes. Partial axes don't need to be handled because they
330  // only affect the defining op of the operand.
331  //
332  // TODO: Change to process the operands with single loop index first and
333  // then the operands with multiple loop indices.
334  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
335  AffineExpr expr = std::get<0>(it);
336  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
337  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
338  checkOperandAffineExpr(expr, numDims);
339  if (failed(loopIndices))
340  return op->emitOpError()
341  << "operand's affine expression is restricted to const_i * "
342  "dim_i + const_j + dim_j + ...";
343  if (loopIndices->empty())
344  continue;
345  if (loopIndices->size() == 1) {
346  unsigned loopIdx = *loopIndices->begin();
347  visitedLoopIndices.insert(loopIdx);
348  if (failed(fillShardingOption(op, shardingOption,
349  shardAttr.getMeshAttr(), axes, loopIdx)))
350  return failure();
351  }
352  // If multiple loop indices correspond to a dimension of an operand, it is
353  // difficult to infer which loop indices are responsible for sharding.
354  // Therefore, the exact loop index must be specified by others.
355  if (loopIndices->size() > 1) {
356  bool seenLoopIndices = false;
357  for (unsigned loopIdx : *loopIndices) {
358  if (visitedLoopIndices.contains(loopIdx)) {
359  seenLoopIndices = true;
360  break;
361  }
362  }
363  if (!seenLoopIndices)
364  return op->emitOpError()
365  << "the operand " << shardingIt.index()
366  << " has multiple loop indices in a dimension, but none of "
367  "them could be found in the exactly specified annotation "
368  "of op results or operands.";
369  }
370  }
371  }
372 
373  // 3. Finalize sharding option
374  if (!partialMeshAxes.empty()) {
375  bool anyNonEmptyReductionLoop = llvm::any_of(
376  llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
377  SmallVector<MeshAxis> &subArray = it.value();
378  int64_t idx = it.index();
379  return isReductionLoop(loopTypes[idx]) && !subArray.empty();
380  });
381  if (!anyNonEmptyReductionLoop) {
382  bool filled = false;
383  for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
384  if (isReductionLoop(loopTypes[idx])) {
385  std::ignore = fillShardingOption(op, shardingOption, nullptr,
386  partialMeshAxes, idx);
387  filled = true;
388  break;
389  }
390  }
391  if (!filled)
392  return op->emitOpError() << "no matched reduction loop found for the "
393  "result's partial type";
394  }
395  }
397  if (!anyShardingInResultsOrOperands)
398  shardingOption.empty = true;
399  return shardingOption;
400 }
401 
402 // Get the sharding attributed for the given result and sharding option.
403 MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
405  ArrayRef<ReductionKind> reductionLoopKinds) {
406  auto resultType = cast<RankedTensorType>(result.getType());
407  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
408  SmallVector<MeshAxis> partialAxes;
409 
410  // process the split axes
411  for (auto it : llvm::enumerate(map.getResults())) {
412  SmallVector<MeshAxis> tmp_axes;
413  AffineExpr expr = it.value();
414  // `expr` must be an `AffineDimExpr` because `map` is verified by
415  // isProjectedPermutation
416  auto dim = cast<AffineDimExpr>(expr);
417  unsigned loopIdx = dim.getPosition();
418  if (loopIdx < shardingOption.shardingArray.size())
419  splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
420  }
421 
422  // process the partial axes
423  // partialType will be ignored if partialAxes is empty
424  ReductionKind partialType = ReductionKind::Sum;
425  size_t reductionLoopKindsIdx = 0;
426  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
427  utils::IteratorType iType = std::get<0>(it);
428  if (isReductionLoop(iType)) {
429  ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430  ++reductionLoopKindsIdx;
431  if (!partialAxes.empty())
432  assert(partialType == curPartialType &&
433  "Only one reduction type is supported");
434  partialType = curPartialType;
435  const SmallVector<MeshAxis> &axis = std::get<1>(it);
436  partialAxes.append(axis);
437  }
438  }
439 
440  removeTrailingEmptySubArray(splitAxes);
441  return MeshSharding::get(shardingOption.mesh,
442  fromArrayOfVector(result.getContext(), splitAxes),
443  partialAxes, partialType);
444 }
445 
446 static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
447  const ShardingOption &shardingOption,
448  AffineMap map) {
449  Value operandValue = opOperand.get();
450  auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
451  if (!operandType) {
452  if (operandValue.getType().isIntOrIndexOrFloat())
453  return MeshSharding();
454  return failure();
455  }
456  // 0d tensors cannot be sharded and must get replicated
457  if (operandType.getRank() == 0) {
458  return MeshSharding(shardingOption.mesh);
459  }
460  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
461  unsigned numDims = map.getNumDims();
462  for (auto it : llvm::enumerate(map.getResults())) {
463  int64_t idx = it.index();
464  AffineExpr expr = it.value();
465  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
466  checkOperandAffineExpr(expr, numDims);
467  if (failed(loopIndices))
468  return failure();
469  SmallVector<unsigned> shardedLoopIndices;
470  for (unsigned loopIdx : *loopIndices) {
471  if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
472  !shardingOption.shardingArray[loopIdx].empty())
473  shardedLoopIndices.push_back(loopIdx);
474  }
475  // mostly one sharded loop index is accepted
476  if (shardedLoopIndices.size() > 1)
477  return failure();
478  if (shardedLoopIndices.size() == 1) {
479  splitAxes[idx].append(
480  shardingOption.shardingArray[shardedLoopIndices[0]]);
481  }
482  }
483 
484  removeTrailingEmptySubArray(splitAxes);
485  return MeshSharding::get(
486  shardingOption.mesh,
487  fromArrayOfVector(opOperand.get().getContext(), splitAxes));
488 }
489 
490 FailureOr<std::vector<MeshSharding>>
492  Operation *op, const ShardingOption &shardingOption) {
493  std::vector<MeshSharding> res;
494 
495  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
497  shardingOp.getLoopIteratorTypes();
498  SmallVector<ReductionKind> reductionKinds =
499  shardingOp.getReductionLoopIteratorKinds();
500  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
501  unsigned numOperands = op->getNumOperands();
502 
503  for (OpOperand &opOperand : op->getOpOperands()) {
504  FailureOr<MeshSharding> shardingAttr = getSharding(
505  opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
506  if (failed(shardingAttr))
507  return failure();
508  res.push_back(*shardingAttr);
509  }
510 
511  for (OpResult result : op->getResults()) {
512  res.push_back(getSharding(result, shardingOption,
513  maps[numOperands + result.getResultNumber()],
514  loopTypes, reductionKinds));
515  }
516 
517  return res;
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // detail::defaultAddShardingAnnotations
522 //===----------------------------------------------------------------------===//
523 
524 // To add a `mesh.shard` op for the given result, based on the details provided
525 // in `shardingOption`, `map`, and `loopTypes`.
526 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
527  const ShardingOption &shardingOption,
528  AffineMap map,
530  ArrayRef<ReductionKind> reductionLoopKinds) {
531  MeshSharding sharding =
532  getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
533  maybeInsertTargetShardingAnnotation(sharding, result, b);
534 
535  return success();
536 }
537 
538 // To add a `mesh.shard` op for the given operand, based on the details provided
539 // in `shardingOption`, `map`, and `loopTypes`.
540 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
541  const ShardingOption &shardingOption,
542  AffineMap map) {
543 
544  FailureOr<MeshSharding> sharding =
545  getSharding(opOperand, shardingOption, map);
546  if (failed(sharding)) {
547  return failure();
548  }
549  OpBuilder::InsertionGuard guard(b);
550  maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
551 
552  return success();
553 }
554 
556  Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
557  assert(!shardingOption.empty && shardingOption.mesh);
558 
559  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
561  shardingOp.getLoopIteratorTypes();
562  SmallVector<ReductionKind> reductionKinds =
563  shardingOp.getReductionLoopIteratorKinds();
564  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
565  unsigned numOperands = op->getNumOperands();
566 
567  // 1. add mesh.shard ops for all op results
568  for (OpResult result : op->getResults()) {
569  if (failed(addShardOp(b, result, shardingOption,
570  maps[numOperands + result.getResultNumber()],
571  loopTypes, reductionKinds)))
572  return failure();
573  }
574 
575  // 2. add mesh.shard ops for all operands
576  for (OpOperand &opOperand : op->getOpOperands()) {
577  if (failed(addShardOp(b, opOperand, shardingOption,
578  maps[opOperand.getOperandNumber()])))
579  return failure();
580  }
581 
582  return success();
583 }
584 
585 #ifndef NDEBUG
586 static bool
588  MeshSharding sharding) {
589  if (isa<RankedTensorType>(value.getType())) {
590  return isFullReplication(sharding);
591  }
592 
593  return !sharding;
594 }
595 
596 template <typename ValueRange, typename MeshShardingRage>
597 static bool
599  MeshShardingRage &&shardings) {
600  if (std::size(values) != std::size(shardings)) {
601  return false;
602  }
603  return llvm::all_of(
604  llvm::zip_equal(std::forward<ValueRange>(values),
605  std::forward<MeshShardingRage>(shardings)),
606  [](auto valueAndSharding) {
608  std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
609  });
610 }
611 #endif // NDEBUG
612 
614  Operation &op, ArrayRef<Value> spmdizedOperands,
615  ArrayRef<MeshSharding> operandShardings,
616  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
617  SymbolTableCollection &symbolTable, OpBuilder &builder) {
618  assert(spmdizedOperands.size() == operandShardings.size());
620  operandShardings));
622  resultShardings));
623  // `clone` will populate the mapping of old to new results.
624  builder.clone(op, spmdizationMap);
625 }
626 
628  ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
629  SmallVector<std::optional<SmallVector<MeshAxis>>>
630  &meshAxesAssignmentForLoopIterators) {
631  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
632  unsigned loopIteratorIdx = affineDimExpr.getPosition();
633  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
634  assert(llvm::equal(meshAxesAssignmentForTensorAxis,
635  *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
636  } else {
637  meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
638  llvm::to_vector(meshAxesAssignmentForTensorAxis);
639  }
640 }
641 
643  ArrayRef<MeshSharding> operandShardings,
644  ArrayRef<MeshSharding> resultShardings,
645  ArrayRef<utils::IteratorType> loopIteratorTypes,
646  ArrayRef<AffineMap> indexingMaps) {
648  meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
649  std::vector<MeshSharding> operatorAndResultShardings;
650  operatorAndResultShardings.reserve(operandShardings.size() +
651  resultShardings.size());
652  llvm::append_range(operatorAndResultShardings, operandShardings);
653  for (auto [sharding, affineMap] :
654  llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
655  if (!sharding) {
656  continue;
657  }
658  for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
659  llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
661  meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
662  meshAxisAssignmentForLoopIterators);
663  }
664  // Missing trailing split axes means replication on those tensor dimensions.
665  for (unsigned i = sharding.getSplitAxes().size();
666  i < affineMap.getNumResults(); ++i) {
668  {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
669  }
670  }
671 
672  ShardingArray res;
673  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
674  [](std::optional<SmallVector<MeshAxis>> &axes) {
675  if (!axes) {
676  return SmallVector<MeshAxis>();
677  };
678  return std::move(*axes);
679  });
680  return res;
681 }
682 
684  ArrayRef<utils::IteratorType> loopIteratorTypes,
685  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
686  for (auto [loopIteratorType, meshAxisAssignment] :
687  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
688  if (loopIteratorType == utils::IteratorType::reduction &&
689  !meshAxisAssignment.empty()) {
690  return true;
691  }
692  }
693  return false;
694 }
695 
697  ArrayRef<utils::IteratorType> loopIteratorTypes,
698  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
699  SmallVector<MeshAxis> meshAxes;
700  for (auto [loopIteratorType, meshAxisAssignment] :
701  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
702  if (loopIteratorType == utils::IteratorType::reduction) {
703  llvm::append_range(meshAxes, meshAxisAssignment);
704  }
705  }
706  return meshAxes;
707 }
708 
710  Operation &op, ArrayRef<Value> spmdizedOperands,
711  ArrayRef<MeshSharding> operandShardings,
712  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
713  SymbolTableCollection &symbolTable, OpBuilder &builder) {
714  // `clone` will populate the mapping of old to new results.
715  Operation *newOp = builder.clone(op, spmdizationMap);
716  // Set the result types to the sharded counterparts.
717  for (auto [oldResult, newResult, sharding] :
718  llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
719  newResult.setType(shardType(
720  newResult.getType(),
721  getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
722  }
723 }
SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
#define DBGS()
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:348
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
This class represents an operand of an operation.
Definition: Value.h:243
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:346
operand_type_range getOperandTypes()
Definition: Operation.h:397
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:191
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:780
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:120
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:329
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:270
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
Definition: MeshOps.cpp:278
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:100
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:112
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:26
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:106
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.