MLIR  21.0.0git
ShardingInterface.cpp
Go to the documentation of this file.
1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <utility>
22 
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 
26 using namespace mlir;
27 using namespace mlir::mesh;
28 
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // common util functions
33 //===----------------------------------------------------------------------===//
34 
35 static LogicalResult
37  SmallVectorImpl<bool> &seenIds) {
38  switch (expr.getKind()) {
39  case AffineExprKind::Add: {
40  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41  AffineExpr lhs = binOpExpr.getLHS();
42  AffineExpr rhs = binOpExpr.getRHS();
43  if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44  return failure();
45  if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46  return failure();
47  return success();
48  }
49  case AffineExprKind::Mul: {
50  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51  AffineExpr lhs = binOpExpr.getLHS();
52  AffineExpr rhs = binOpExpr.getRHS();
53  AffineExpr dimExpr;
54  if (lhs.getKind() == AffineExprKind::DimId &&
56  dimExpr = lhs;
57  } else if (rhs.getKind() == AffineExprKind::DimId &&
59  dimExpr = rhs;
60  } else {
61  return failure();
62  }
63  unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
64  if ((size_t)position >= seenIds.size() || seenIds[position])
65  return failure();
66  seenIds[position] = true;
67  return success();
68  }
69  case AffineExprKind::DimId: {
70  unsigned position = cast<AffineDimExpr>(expr).getPosition();
71  if ((size_t)position >= seenIds.size() || seenIds[position])
72  return failure();
73  seenIds[position] = true;
74  return success();
75  }
76  default:
77  return failure();
78  }
79 }
80 
81 static FailureOr<llvm::SmallSet<unsigned, 2>>
82 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
83  SmallVector<bool> seenIds(numDims, false);
84  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
85  return failure();
86 
87  llvm::SmallSet<unsigned, 2> positions;
88  for (auto it : llvm::enumerate(seenIds)) {
89  if (it.value())
90  positions.insert((unsigned)it.index());
91  }
92  return positions;
93 }
94 
95 template <typename T>
99  for (const auto &v : vec) {
100  res.emplace_back(MeshAxesAttr::get(ctxt, v));
101  }
102  return res;
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // mesh::getMeshSharding
107 //===----------------------------------------------------------------------===//
108 
109 FailureOr<std::pair<bool, MeshSharding>>
111  Value val = cast<Value>(result);
112  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
113  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
114  if (!shardOp)
115  return false;
116  return !shardOp.getAnnotateForUsers();
117  });
118 
119  if (anyShardedForDef) {
120  // expected to have exact one use if it has a use of `mesh.shard` without
121  // unit attr annotate_for_users
122  if (!val.hasOneUse())
123  return failure();
124  auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
125  return std::make_pair(false, MeshSharding(shardOp.getSharding()));
126  }
127 
128  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
129  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
130  if (!shardOp)
131  return false;
132  return shardOp.getAnnotateForUsers();
133  });
134  if (anyShardedForUsers) {
135  SmallVector<ShardOp> shardOps;
136  for (Operation *user : val.getUsers()) {
137  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
138  if (shardOp)
139  shardOps.push_back(shardOp);
140  }
141  MeshSharding shardForDef = shardOps[0].getSharding();
142  for (size_t i = 1; i < shardOps.size(); ++i) {
143  // TODO: Deduce a reasonable mesh sharding attr for def when they are
144  // different
145  assert(shardForDef == shardOps[i].getSharding() &&
146  "only support all shard ops have the same mesh sharding attr");
147  }
148  return std::make_pair(true, shardForDef);
149  }
150  return failure();
151 }
152 
153 FailureOr<std::pair<bool, MeshSharding>>
155  Value val = opOperand.get();
156  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
157  return std::make_pair(shardOp.getAnnotateForUsers(),
158  MeshSharding(shardOp.getSharding()));
159 
160  return failure();
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // ShardingInterface::verifyShardingInterfaceImpl
165 //===----------------------------------------------------------------------===//
166 
167 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
168  Operation *op = getOperation();
169 
170  // check operands and results type
171  for (Type type : op->getOperandTypes())
172  if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
173  return failure();
174  for (Type type : op->getResultTypes())
175  if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
176  return failure();
177 
178  // check maps
179  SmallVector<AffineMap> maps = getIndexingMaps();
180  if (maps.empty())
181  return failure();
182  unsigned numOperands = op->getNumOperands();
183  unsigned numResults = op->getNumResults();
184  if (numOperands + numResults != maps.size())
185  return failure();
186 
187  for (OpResult result : op->getResults()) {
188  auto resultType = dyn_cast<RankedTensorType>(result.getType());
189  if (!resultType)
190  return failure();
191  AffineMap map = maps[numOperands + result.getResultNumber()];
192  if (!map.isProjectedPermutation()) {
193  return failure();
194  }
195  }
196 
197  return success();
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // ShardingInterface::printLoopTypesAndIndexingMaps
202 //===----------------------------------------------------------------------===//
203 
204 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
205  os << "print loop types and indexing maps for: \n";
206  getOperation()->print(os);
207  os << "\n";
208  os << "loop types: [";
209  for (utils::IteratorType type : getLoopIteratorTypes()) {
210  os << stringifyEnum(type) << " ";
211  }
212  os << "]\n";
213  os << "indexing maps: \n";
214  for (AffineMap map : getIndexingMaps())
215  os << map << "\n";
216  os << "\n";
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // detail::defaultGetShardingOption
221 //===----------------------------------------------------------------------===//
222 
223 namespace {
224 
225 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
226 static LogicalResult fillShardingOption(Operation *op,
227  ShardingOption &shardingOption,
228  FlatSymbolRefAttr mesh,
229  ArrayRef<MeshAxis> meshAxes,
230  unsigned loopIdx) {
231  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
232  (!shardingOption.shardingArray[loopIdx].empty() &&
233  shardingOption.shardingArray[loopIdx] != meshAxes)) {
234  LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
235  << loopIdx << "\n");
236  return failure();
237  }
238  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
239  if (i == loopIdx)
240  continue;
241 
242  for (MeshAxis axis : meshAxes) {
243  if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
244  LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
245  << axis << " duplicate");
246  return failure();
247  }
248  }
249  }
250  if (mesh)
251  shardingOption.mesh = mesh;
252  if (shardingOption.shardingArray[loopIdx].empty())
253  shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
254  meshAxes.end());
255  return success();
256 }
257 
258 } // namespace
259 
260 FailureOr<ShardingOption>
262  ArrayRef<MeshSharding> operandShardings,
263  ArrayRef<MeshSharding> resultShardings) {
264  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
265  ShardingOption shardingOption;
266 
267  if (failed(shardingOp.verifyShardingInterfaceImpl()))
268  return op->emitOpError() << "invalid sharding interface implementation";
270  shardingOp.getLoopIteratorTypes();
271  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
272  unsigned numOperands = op->getNumOperands();
273  shardingOption.shardingArray.resize(loopTypes.size());
274  llvm::SmallVector<MeshAxis> partialMeshAxes;
275  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
276  bool anyShardingInResultsOrOperands = false;
277 
278  // 1. Fill sharding option based on op results
279  for (auto shardingIt : llvm::enumerate(resultShardings)) {
280  MeshSharding shardAttr = shardingIt.value();
281  if (!shardAttr)
282  continue;
283  AffineMap map = maps[numOperands + shardingIt.index()];
284  anyShardingInResultsOrOperands = true;
285  if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
286  shardingOption.mesh = shardAttr.getMeshAttr();
287  } else {
288  // Handle the split axes: calculate the corresponding loop index for each
289  // split axes sub-array, and then store the sub-array to
290  // shardingOption[index]
291  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
292  AffineExpr expr = std::get<0>(it);
293  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
294  auto dim = cast<AffineDimExpr>(expr);
295  unsigned index = dim.getPosition();
296  visitedLoopIndices.insert(index);
297  if (failed(fillShardingOption(op, shardingOption,
298  shardAttr.getMeshAttr(), axes, index)))
299  return failure();
300  }
301  }
302 
303  // Handle the partial axes: at this stage, the exact loop index/indices
304  // cannot be decided because there could be multiple reduction loops.
305  ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
306  if (!partialAxes.empty()) {
307  if (!partialMeshAxes.empty())
308  return op->emitOpError() << "at most one result with partial axes is "
309  "supported at present";
310  partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
311  // Add all the reduction loop indices to `visitedLoopIndices` if
312  // `partialAxes` is not empty
313  for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
314  if (isReductionLoop(loopTypes[loopIdx]))
315  visitedLoopIndices.insert(loopIdx);
316  }
317  }
318  }
319 
320  // 2. Fill sharding option based on operands
321  for (auto shardingIt : llvm::enumerate(operandShardings)) {
322  MeshSharding shardAttr = shardingIt.value();
323  if (!shardAttr)
324  continue;
325 
326  anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
327  AffineMap map = maps[shardingIt.index()];
328  unsigned numDims = map.getNumDims();
329 
330  // Handle the split axes. Partial axes don't need to be handled because they
331  // only affect the defining op of the operand.
332  //
333  // TODO: Change to process the operands with single loop index first and
334  // then the operands with multiple loop indices.
335  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
336  AffineExpr expr = std::get<0>(it);
337  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
338  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
339  checkOperandAffineExpr(expr, numDims);
340  if (failed(loopIndices))
341  return op->emitOpError()
342  << "operand's affine expression is restricted to const_i * "
343  "dim_i + const_j + dim_j + ...";
344  if (loopIndices->empty())
345  continue;
346  if (loopIndices->size() == 1) {
347  unsigned loopIdx = *loopIndices->begin();
348  visitedLoopIndices.insert(loopIdx);
349  if (failed(fillShardingOption(op, shardingOption,
350  shardAttr.getMeshAttr(), axes, loopIdx)))
351  return failure();
352  }
353  // If multiple loop indices correspond to a dimension of an operand, it is
354  // difficult to infer which loop indices are responsible for sharding.
355  // Therefore, the exact loop index must be specified by others.
356  if (loopIndices->size() > 1) {
357  bool seenLoopIndices = false;
358  for (unsigned loopIdx : *loopIndices) {
359  if (visitedLoopIndices.contains(loopIdx)) {
360  seenLoopIndices = true;
361  break;
362  }
363  }
364  if (!seenLoopIndices)
365  return op->emitOpError()
366  << "the operand " << shardingIt.index()
367  << " has multiple loop indices in a dimension, but none of "
368  "them could be found in the exactly specified annotation "
369  "of op results or operands.";
370  }
371  }
372  }
373 
374  // 3. Finalize sharding option
375  if (!partialMeshAxes.empty()) {
376  bool anyNonEmptyReductionLoop = llvm::any_of(
377  llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
378  SmallVector<MeshAxis> &subArray = it.value();
379  int64_t idx = it.index();
380  return isReductionLoop(loopTypes[idx]) && !subArray.empty();
381  });
382  if (!anyNonEmptyReductionLoop) {
383  bool filled = false;
384  for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
385  if (isReductionLoop(loopTypes[idx])) {
386  std::ignore = fillShardingOption(op, shardingOption, nullptr,
387  partialMeshAxes, idx);
388  filled = true;
389  break;
390  }
391  }
392  if (!filled)
393  return op->emitOpError() << "no matched reduction loop found for the "
394  "result's partial type";
395  }
396  }
398  if (!anyShardingInResultsOrOperands)
399  shardingOption.empty = true;
400  return shardingOption;
401 }
402 
403 // Get the sharding attributed for the given result and sharding option.
404 MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
406  ArrayRef<ReductionKind> reductionLoopKinds) {
407  auto resultType = cast<RankedTensorType>(result.getType());
408  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
409  SmallVector<MeshAxis> partialAxes;
410 
411  // process the split axes
412  for (auto it : llvm::enumerate(map.getResults())) {
413  AffineExpr expr = it.value();
414  // `expr` must be an `AffineDimExpr` because `map` is verified by
415  // isProjectedPermutation
416  auto dim = cast<AffineDimExpr>(expr);
417  unsigned loopIdx = dim.getPosition();
418  if (loopIdx < shardingOption.shardingArray.size())
419  splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
420  }
421 
422  // process the partial axes
423  // partialType will be ignored if partialAxes is empty
424  ReductionKind partialType = ReductionKind::Sum;
425  size_t reductionLoopKindsIdx = 0;
426  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
427  utils::IteratorType iType = std::get<0>(it);
428  if (isReductionLoop(iType)) {
429  ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430  ++reductionLoopKindsIdx;
431  if (!partialAxes.empty())
432  assert(partialType == curPartialType &&
433  "Only one reduction type is supported");
434  partialType = curPartialType;
435  const SmallVector<MeshAxis> &axis = std::get<1>(it);
436  partialAxes.append(axis);
437  }
438  }
439 
440  removeTrailingEmptySubArray(splitAxes);
441  return MeshSharding::get(shardingOption.mesh,
442  fromArrayOfVector(result.getContext(), splitAxes),
443  partialAxes, partialType);
444 }
445 
446 static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
447  const ShardingOption &shardingOption,
448  AffineMap map) {
449  Value operandValue = opOperand.get();
450  auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
451  if (!operandType) {
452  if (operandValue.getType().isIntOrIndexOrFloat())
453  return MeshSharding();
454  return failure();
455  }
456  // 0d tensors cannot be sharded and must get replicated
457  if (operandType.getRank() == 0) {
458  return MeshSharding(shardingOption.mesh);
459  }
460  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
461  unsigned numDims = map.getNumDims();
462  for (auto it : llvm::enumerate(map.getResults())) {
463  int64_t idx = it.index();
464  AffineExpr expr = it.value();
465  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
466  checkOperandAffineExpr(expr, numDims);
467  if (failed(loopIndices))
468  return failure();
469  SmallVector<unsigned> shardedLoopIndices;
470  for (unsigned loopIdx : *loopIndices) {
471  if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
472  !shardingOption.shardingArray[loopIdx].empty())
473  shardedLoopIndices.push_back(loopIdx);
474  }
475  // mostly one sharded loop index is accepted
476  if (shardedLoopIndices.size() > 1)
477  return failure();
478  if (shardedLoopIndices.size() == 1) {
479  splitAxes[idx].append(
480  shardingOption.shardingArray[shardedLoopIndices[0]]);
481  }
482  }
483 
484  removeTrailingEmptySubArray(splitAxes);
485  return MeshSharding::get(
486  shardingOption.mesh,
487  fromArrayOfVector(opOperand.get().getContext(), splitAxes));
488 }
489 
490 FailureOr<std::vector<MeshSharding>>
492  Operation *op, const ShardingOption &shardingOption) {
493  std::vector<MeshSharding> res;
494 
495  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
497  shardingOp.getLoopIteratorTypes();
498  SmallVector<ReductionKind> reductionKinds =
499  shardingOp.getReductionLoopIteratorKinds();
500  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
501  unsigned numOperands = op->getNumOperands();
502 
503  for (OpOperand &opOperand : op->getOpOperands()) {
504  FailureOr<MeshSharding> shardingAttr = getSharding(
505  opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
506  if (failed(shardingAttr))
507  return failure();
508  res.push_back(*shardingAttr);
509  }
510 
511  for (OpResult result : op->getResults()) {
512  res.push_back(getSharding(result, shardingOption,
513  maps[numOperands + result.getResultNumber()],
514  loopTypes, reductionKinds));
515  }
516 
517  return res;
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // detail::defaultAddShardingAnnotations
522 //===----------------------------------------------------------------------===//
523 
524 // To add a `mesh.shard` op for the given result, based on the details provided
525 // in `shardingOption`, `map`, and `loopTypes`.
526 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
527  const ShardingOption &shardingOption,
528  AffineMap map,
530  ArrayRef<ReductionKind> reductionLoopKinds) {
531  MeshSharding sharding =
532  getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
533  maybeInsertTargetShardingAnnotation(sharding, result, b);
534 
535  return success();
536 }
537 
538 // To add a `mesh.shard` op for the given operand, based on the details provided
539 // in `shardingOption`, `map`, and `loopTypes`.
540 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
541  const ShardingOption &shardingOption,
542  AffineMap map) {
543 
544  FailureOr<MeshSharding> sharding =
545  getSharding(opOperand, shardingOption, map);
546  if (failed(sharding)) {
547  return failure();
548  }
549  OpBuilder::InsertionGuard guard(b);
550  maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
551 
552  return success();
553 }
554 
556  Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
557  assert(!shardingOption.empty && shardingOption.mesh);
558 
559  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
561  shardingOp.getLoopIteratorTypes();
562  SmallVector<ReductionKind> reductionKinds =
563  shardingOp.getReductionLoopIteratorKinds();
564  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
565  unsigned numOperands = op->getNumOperands();
566 
567  // 1. add mesh.shard ops for all op results
568  for (OpResult result : op->getResults()) {
569  if (failed(addShardOp(b, result, shardingOption,
570  maps[numOperands + result.getResultNumber()],
571  loopTypes, reductionKinds)))
572  return failure();
573  }
574 
575  // 2. add mesh.shard ops for all operands
576  for (OpOperand &opOperand : op->getOpOperands()) {
577  if (failed(addShardOp(b, opOperand, shardingOption,
578  maps[opOperand.getOperandNumber()])))
579  return failure();
580  }
581 
582  return success();
583 }
584 
585 #ifndef NDEBUG
586 static bool
588  MeshSharding sharding) {
589  if (isa<RankedTensorType>(value.getType())) {
590  return isFullReplication(sharding);
591  }
592 
593  return !sharding;
594 }
595 
596 template <typename ValueRange, typename MeshShardingRage>
597 static bool
599  MeshShardingRage &&shardings) {
600  if (std::size(values) != std::size(shardings)) {
601  return false;
602  }
603  return llvm::all_of(
604  llvm::zip_equal(std::forward<ValueRange>(values),
605  std::forward<MeshShardingRage>(shardings)),
606  [](auto valueAndSharding) {
608  std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
609  });
610 }
611 #endif // NDEBUG
612 
614  Operation &op, ArrayRef<Value> spmdizedOperands,
615  ArrayRef<MeshSharding> operandShardings,
616  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
617  SymbolTableCollection &symbolTable, OpBuilder &builder) {
618  assert(spmdizedOperands.size() == operandShardings.size());
620  operandShardings));
622  resultShardings));
623  // `clone` will populate the mapping of old to new results.
624  builder.clone(op, spmdizationMap);
625 }
626 
628  ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
629  SmallVector<std::optional<SmallVector<MeshAxis>>>
630  &meshAxesAssignmentForLoopIterators) {
631  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
632  unsigned loopIteratorIdx = affineDimExpr.getPosition();
633  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
634  assert(llvm::equal(meshAxesAssignmentForTensorAxis,
635  *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
636  } else {
637  meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
638  llvm::to_vector(meshAxesAssignmentForTensorAxis);
639  }
640 }
641 
643  ArrayRef<MeshSharding> operandShardings,
644  ArrayRef<MeshSharding> resultShardings,
645  ArrayRef<utils::IteratorType> loopIteratorTypes,
646  ArrayRef<AffineMap> indexingMaps) {
648  meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
649  std::vector<MeshSharding> operatorAndResultShardings;
650  operatorAndResultShardings.reserve(operandShardings.size() +
651  resultShardings.size());
652  llvm::append_range(operatorAndResultShardings, operandShardings);
653  for (auto [sharding, affineMap] :
654  llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
655  if (!sharding) {
656  continue;
657  }
658  for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
659  llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
661  meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
662  meshAxisAssignmentForLoopIterators);
663  }
664  // Missing trailing split axes means replication on those tensor dimensions.
665  for (unsigned i = sharding.getSplitAxes().size();
666  i < affineMap.getNumResults(); ++i) {
668  {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
669  }
670  }
671 
672  ShardingArray res;
673  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
674  [](std::optional<SmallVector<MeshAxis>> &axes) {
675  if (!axes) {
676  return SmallVector<MeshAxis>();
677  };
678  return std::move(*axes);
679  });
680  return res;
681 }
682 
684  ArrayRef<utils::IteratorType> loopIteratorTypes,
685  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
686  for (auto [loopIteratorType, meshAxisAssignment] :
687  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
688  if (loopIteratorType == utils::IteratorType::reduction &&
689  !meshAxisAssignment.empty()) {
690  return true;
691  }
692  }
693  return false;
694 }
695 
697  ArrayRef<utils::IteratorType> loopIteratorTypes,
698  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
699  SmallVector<MeshAxis> meshAxes;
700  for (auto [loopIteratorType, meshAxisAssignment] :
701  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
702  if (loopIteratorType == utils::IteratorType::reduction) {
703  llvm::append_range(meshAxes, meshAxisAssignment);
704  }
705  }
706  return meshAxes;
707 }
708 
710  Operation &op, ArrayRef<Value> spmdizedOperands,
711  ArrayRef<MeshSharding> operandShardings,
712  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
713  SymbolTableCollection &symbolTable, OpBuilder &builder) {
714  // `clone` will populate the mapping of old to new results.
715  Operation *newOp = builder.clone(op, spmdizationMap);
716  // Set the result types to the sharded counterparts.
717  for (auto [oldResult, newResult, sharding] :
718  llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
719  newResult.setType(shardType(
720  newResult.getType(),
721  getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
722  }
723 }
SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
#define DBGS()
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:348
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
This class represents an operand of an operation.
Definition: Value.h:243
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:346
operand_type_range getOperandTypes()
Definition: Operation.h:397
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:191
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:780
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:120
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:329
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:270
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
Definition: MeshOps.cpp:278
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:100
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:112
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:26
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:106
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.