MLIR  20.0.0git
ShardingInterface.cpp
Go to the documentation of this file.
1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <utility>
22 
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 
26 using namespace mlir;
27 using namespace mlir::mesh;
28 
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // common util functions
33 //===----------------------------------------------------------------------===//
34 
35 static LogicalResult
37  SmallVectorImpl<bool> &seenIds) {
38  switch (expr.getKind()) {
39  case AffineExprKind::Add: {
40  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41  AffineExpr lhs = binOpExpr.getLHS();
42  AffineExpr rhs = binOpExpr.getRHS();
43  if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44  return failure();
45  if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46  return failure();
47  return success();
48  }
49  case AffineExprKind::Mul: {
50  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51  AffineExpr lhs = binOpExpr.getLHS();
52  AffineExpr rhs = binOpExpr.getRHS();
53  AffineExpr dimExpr;
54  if (lhs.getKind() == AffineExprKind::DimId &&
56  dimExpr = lhs;
57  } else if (rhs.getKind() == AffineExprKind::DimId &&
59  dimExpr = rhs;
60  } else
61  return failure();
62  unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63  if ((size_t)position >= seenIds.size() || seenIds[position])
64  return failure();
65  seenIds[position] = true;
66  return success();
67  }
68  case AffineExprKind::DimId: {
69  unsigned position = cast<AffineDimExpr>(expr).getPosition();
70  if ((size_t)position >= seenIds.size() || seenIds[position])
71  return failure();
72  seenIds[position] = true;
73  return success();
74  }
75  default:
76  return failure();
77  }
78 }
79 
80 static FailureOr<llvm::SmallSet<unsigned, 2>>
81 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
82  SmallVector<bool> seenIds(numDims, false);
83  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
84  return failure();
85 
86  llvm::SmallSet<unsigned, 2> positions;
87  for (auto it : llvm::enumerate(seenIds)) {
88  if (it.value())
89  positions.insert((unsigned)it.index());
90  }
91  return positions;
92 }
93 
94 template <typename T>
98  for (const auto &v : vec) {
99  res.emplace_back(MeshAxesAttr::get(ctxt, v));
100  }
101  return res;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // mesh::getMeshSharding
106 //===----------------------------------------------------------------------===//
107 
108 FailureOr<std::pair<bool, MeshSharding>>
110  Value val = cast<Value>(result);
111  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
112  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
113  if (!shardOp)
114  return false;
115  return !shardOp.getAnnotateForUsers();
116  });
117 
118  if (anyShardedForDef) {
119  // expected to have exact one use if it has a use of `mesh.shard` without
120  // unit attr annotate_for_users
121  if (!val.hasOneUse())
122  return failure();
123  auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
124  return std::make_pair(false, MeshSharding(shardOp.getSharding()));
125  }
126 
127  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
128  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
129  if (!shardOp)
130  return false;
131  return shardOp.getAnnotateForUsers();
132  });
133  if (anyShardedForUsers) {
134  SmallVector<ShardOp> shardOps;
135  for (Operation *user : val.getUsers()) {
136  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
137  if (shardOp)
138  shardOps.push_back(shardOp);
139  }
140  MeshSharding shardForDef = shardOps[0].getSharding();
141  for (size_t i = 1; i < shardOps.size(); ++i) {
142  // TODO: Deduce a reasonable mesh sharding attr for def when they are
143  // different
144  assert(shardForDef == shardOps[i].getSharding() &&
145  "only support all shard ops have the same mesh sharding attr");
146  }
147  return std::make_pair(true, shardForDef);
148  }
149  return failure();
150 }
151 
152 FailureOr<std::pair<bool, MeshSharding>>
154  Value val = opOperand.get();
155  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
156  return std::make_pair(shardOp.getAnnotateForUsers(),
157  MeshSharding(shardOp.getSharding()));
158 
159  return failure();
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // ShardingInterface::verifyShardingInterfaceImpl
164 //===----------------------------------------------------------------------===//
165 
166 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
167  Operation *op = getOperation();
168 
169  // check operands and results type
170  for (Type type : op->getOperandTypes())
171  if (!llvm::isa<RankedTensorType>(type))
172  return failure();
173  for (Type type : op->getResultTypes())
174  if (!llvm::isa<RankedTensorType>(type))
175  return failure();
176 
177  // check loop types
178  SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
179  if (loopTypes.empty())
180  return failure();
181 
182  // check maps
183  SmallVector<AffineMap> maps = getIndexingMaps();
184  if (maps.empty())
185  return failure();
186  unsigned numOperands = op->getNumOperands();
187  unsigned numResults = op->getNumResults();
188  if (numOperands + numResults != maps.size())
189  return failure();
190 
191  for (OpResult result : op->getResults()) {
192  auto resultType = dyn_cast<RankedTensorType>(result.getType());
193  if (!resultType)
194  return failure();
195  AffineMap map = maps[numOperands + result.getResultNumber()];
196  if (!map.isProjectedPermutation()) {
197  return failure();
198  }
199  }
200 
201  return success();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // ShardingInterface::printLoopTypesAndIndexingMaps
206 //===----------------------------------------------------------------------===//
207 
208 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
209  os << "print loop types and indexing maps for: \n";
210  getOperation()->print(os);
211  os << "\n";
212  os << "loop types: [";
213  for (utils::IteratorType type : getLoopIteratorTypes()) {
214  os << stringifyEnum(type) << " ";
215  }
216  os << "]\n";
217  os << "indexing maps: \n";
218  for (AffineMap map : getIndexingMaps())
219  os << map << "\n";
220  os << "\n";
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // detail::defaultGetShardingOption
225 //===----------------------------------------------------------------------===//
226 
227 namespace {
228 
229 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
230 static LogicalResult fillShardingOption(Operation *op,
231  ShardingOption &shardingOption,
232  FlatSymbolRefAttr mesh,
233  ArrayRef<MeshAxis> meshAxes,
234  unsigned loopIdx) {
235  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
236  (!shardingOption.shardingArray[loopIdx].empty() &&
237  shardingOption.shardingArray[loopIdx] != meshAxes)) {
238  LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
239  << loopIdx << "\n");
240  return failure();
241  }
242  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
243  if (i == loopIdx)
244  continue;
245 
246  for (MeshAxis axis : meshAxes) {
247  if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
248  LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
249  << axis << " duplicate");
250  return failure();
251  }
252  }
253  }
254  if (mesh)
255  shardingOption.mesh = mesh;
256  if (shardingOption.shardingArray[loopIdx].empty())
257  shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
258  meshAxes.end());
259  return success();
260 }
261 
262 } // namespace
263 
264 FailureOr<ShardingOption>
266  ArrayRef<MeshSharding> operandShardings,
267  ArrayRef<MeshSharding> resultShardings) {
268  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
269  ShardingOption shardingOption;
270 
271  if (failed(shardingOp.verifyShardingInterfaceImpl()))
272  return op->emitOpError() << "invalid sharding interface implementation";
274  shardingOp.getLoopIteratorTypes();
275  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
276  unsigned numOperands = op->getNumOperands();
277  shardingOption.shardingArray.resize(loopTypes.size());
278  llvm::SmallVector<MeshAxis> partialMeshAxes;
279  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
280  bool anyShardingInResultsOrOperands = false;
281 
282  // 1. Fill sharding option based on op results
283  for (auto shardingIt : llvm::enumerate(resultShardings)) {
284  MeshSharding shardAttr = shardingIt.value();
285  if (!shardAttr)
286  continue;
287  AffineMap map = maps[numOperands + shardingIt.index()];
288  anyShardingInResultsOrOperands = true;
289  // Handle the split axes: calculate the corresponding loop index for each
290  // split axes sub-array, and then store the sub-array to
291  // shardingOption[index]
292  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
293  AffineExpr expr = std::get<0>(it);
294  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
295  auto dim = cast<AffineDimExpr>(expr);
296  unsigned index = dim.getPosition();
297  visitedLoopIndices.insert(index);
298  if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
299  axes, index)))
300  return failure();
301  }
302 
303  // Handle the partial axes: at this stage, the exact loop index/indices
304  // cannot be decided because there could be multiple reduction loops.
305  ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
306  if (!partialAxes.empty()) {
307  if (!partialMeshAxes.empty())
308  return op->emitOpError() << "at most one result with partial axes is "
309  "supported at present";
310  partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
311  // Add all the reduction loop indices to `visitedLoopIndices` if
312  // `partialAxes` is not empty
313  for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
314  if (isReductionLoop(loopTypes[loopIdx]))
315  visitedLoopIndices.insert(loopIdx);
316  }
317  }
318  }
319 
320  // 2. Fill sharding option based on operands
321  for (auto shardingIt : llvm::enumerate(operandShardings)) {
322  MeshSharding shardAttr = shardingIt.value();
323  if (!shardAttr)
324  continue;
325 
326  anyShardingInResultsOrOperands = true;
327  AffineMap map = maps[shardingIt.index()];
328  unsigned numDims = map.getNumDims();
329 
330  // Handle the split axes. Partial axes don't need to be handled because they
331  // only affect the defining op of the operand.
332  //
333  // TODO: Change to process the operands with single loop index first and
334  // then the operands with multiple loop indices.
335  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
336  AffineExpr expr = std::get<0>(it);
337  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
338  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
339  checkOperandAffineExpr(expr, numDims);
340  if (failed(loopIndices))
341  return op->emitOpError()
342  << "operand's affine expression is restricted to const_i * "
343  "dim_i + const_j + dim_j + ...";
344  if (loopIndices->empty())
345  continue;
346  if (loopIndices->size() == 1) {
347  unsigned loopIdx = *loopIndices->begin();
348  visitedLoopIndices.insert(loopIdx);
349  if (failed(fillShardingOption(op, shardingOption,
350  shardAttr.getMeshAttr(), axes, loopIdx)))
351  return failure();
352  }
353  // If multiple loop indices correspond to a dimension of an operand, it is
354  // difficult to infer which loop indices are responsible for sharding.
355  // Therefore, the exact loop index must be specified by others.
356  if (loopIndices->size() > 1) {
357  bool seenLoopIndices = false;
358  for (unsigned loopIdx : *loopIndices) {
359  if (visitedLoopIndices.contains(loopIdx)) {
360  seenLoopIndices = true;
361  break;
362  }
363  }
364  if (!seenLoopIndices)
365  return op->emitOpError()
366  << "the operand " << shardingIt.index()
367  << " has multiple loop indices in a dimension, but none of "
368  "them could be found in the exactly specified annotation "
369  "of op results or operands.";
370  }
371  }
372  }
373 
374  // 3. Finalize sharding option
375  if (!partialMeshAxes.empty()) {
376  bool anyNonEmptyReductionLoop = llvm::any_of(
377  llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
378  SmallVector<MeshAxis> &subArray = it.value();
379  int64_t idx = it.index();
380  return isReductionLoop(loopTypes[idx]) && !subArray.empty();
381  });
382  if (!anyNonEmptyReductionLoop) {
383  bool filled = false;
384  for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
385  if (isReductionLoop(loopTypes[idx])) {
386  std::ignore = fillShardingOption(op, shardingOption, nullptr,
387  partialMeshAxes, idx);
388  filled = true;
389  break;
390  }
391  }
392  if (!filled)
393  return op->emitOpError() << "no matched reduction loop found for the "
394  "result's partial type";
395  }
396  }
398  if (!anyShardingInResultsOrOperands)
399  shardingOption.empty = true;
400  return shardingOption;
401 }
402 
403 // Get the sharding attributed for the given result and sharding option.
404 MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
406  ArrayRef<ReductionKind> reductionLoopKinds) {
407  auto resultType = cast<RankedTensorType>(result.getType());
408  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
409  SmallVector<MeshAxis> partialAxes;
410 
411  // process the split axes
412  for (auto it : llvm::enumerate(map.getResults())) {
413  SmallVector<MeshAxis> tmp_axes;
414  AffineExpr expr = it.value();
415  // `expr` must be an `AffineDimExpr` because `map` is verified by
416  // isProjectedPermutation
417  auto dim = cast<AffineDimExpr>(expr);
418  unsigned loopIdx = dim.getPosition();
419  if (loopIdx < shardingOption.shardingArray.size())
420  splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
421  }
422 
423  // process the partial axes
424  // partialType will be ignored if partialAxes is empty
425  ReductionKind partialType = ReductionKind::Sum;
426  size_t reductionLoopKindsIdx = 0;
427  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
428  utils::IteratorType iType = std::get<0>(it);
429  if (isReductionLoop(iType)) {
430  ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
431  ++reductionLoopKindsIdx;
432  if (!partialAxes.empty())
433  assert(partialType == curPartialType &&
434  "Only one reduction type is supported");
435  partialType = curPartialType;
436  const SmallVector<MeshAxis> &axis = std::get<1>(it);
437  partialAxes.append(axis);
438  }
439  }
440 
441  removeTrailingEmptySubArray(splitAxes);
442  return MeshSharding::get(shardingOption.mesh,
443  fromArrayOfVector(result.getContext(), splitAxes),
444  partialAxes, partialType);
445 }
446 
447 static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
448  const ShardingOption &shardingOption,
449  AffineMap map) {
450  Value operandValue = opOperand.get();
451  auto operandType = cast<RankedTensorType>(operandValue.getType());
452  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
453  unsigned numDims = map.getNumDims();
454  for (auto it : llvm::enumerate(map.getResults())) {
455  int64_t idx = it.index();
456  AffineExpr expr = it.value();
457  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
458  checkOperandAffineExpr(expr, numDims);
459  if (failed(loopIndices))
460  return failure();
461  SmallVector<unsigned> shardedLoopIndices;
462  for (unsigned loopIdx : *loopIndices) {
463  if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
464  !shardingOption.shardingArray[loopIdx].empty())
465  shardedLoopIndices.push_back(loopIdx);
466  }
467  // mostly one sharded loop index is accepted
468  if (shardedLoopIndices.size() > 1)
469  return failure();
470  if (shardedLoopIndices.size() == 1) {
471  splitAxes[idx].append(
472  shardingOption.shardingArray[shardedLoopIndices[0]]);
473  }
474  }
475 
476  removeTrailingEmptySubArray(splitAxes);
477  return MeshSharding::get(
478  shardingOption.mesh,
479  fromArrayOfVector(opOperand.get().getContext(), splitAxes));
480 }
481 
482 FailureOr<std::vector<MeshSharding>>
484  Operation *op, const ShardingOption &shardingOption) {
485  std::vector<MeshSharding> res;
486 
487  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
489  shardingOp.getLoopIteratorTypes();
490  SmallVector<ReductionKind> reductionKinds =
491  shardingOp.getReductionLoopIteratorKinds();
492  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
493  unsigned numOperands = op->getNumOperands();
494 
495  for (OpOperand &opOperand : op->getOpOperands()) {
496  FailureOr<MeshSharding> shardingAttr = getSharding(
497  opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
498  if (failed(shardingAttr))
499  return failure();
500  res.push_back(*shardingAttr);
501  }
502 
503  for (OpResult result : op->getResults()) {
504  res.push_back(getSharding(result, shardingOption,
505  maps[numOperands + result.getResultNumber()],
506  loopTypes, reductionKinds));
507  }
508 
509  return res;
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // detail::defaultAddShardingAnnotations
514 //===----------------------------------------------------------------------===//
515 
516 // To add a `mesh.shard` op for the given result, based on the details provided
517 // in `shardingOption`, `map`, and `loopTypes`.
518 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
519  const ShardingOption &shardingOption,
520  AffineMap map,
522  ArrayRef<ReductionKind> reductionLoopKinds) {
523  MeshSharding sharding =
524  getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
525  maybeInsertTargetShardingAnnotation(sharding, result, b);
526 
527  return success();
528 }
529 
530 // To add a `mesh.shard` op for the given operand, based on the details provided
531 // in `shardingOption`, `map`, and `loopTypes`.
532 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
533  const ShardingOption &shardingOption,
534  AffineMap map) {
535 
536  FailureOr<MeshSharding> sharding =
537  getSharding(opOperand, shardingOption, map);
538  if (failed(sharding)) {
539  return failure();
540  }
541  OpBuilder::InsertionGuard guard(b);
542  maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
543 
544  return success();
545 }
546 
548  Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
549  assert(!shardingOption.empty && shardingOption.mesh);
550 
551  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
553  shardingOp.getLoopIteratorTypes();
554  SmallVector<ReductionKind> reductionKinds =
555  shardingOp.getReductionLoopIteratorKinds();
556  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
557  unsigned numOperands = op->getNumOperands();
558 
559  // 1. add mesh.shard ops for all op results
560  for (OpResult result : op->getResults()) {
561  if (failed(addShardOp(b, result, shardingOption,
562  maps[numOperands + result.getResultNumber()],
563  loopTypes, reductionKinds)))
564  return failure();
565  }
566 
567  // 2. add mesh.shard ops for all operands
568  for (OpOperand &opOperand : op->getOpOperands()) {
569  if (failed(addShardOp(b, opOperand, shardingOption,
570  maps[opOperand.getOperandNumber()])))
571  return failure();
572  }
573 
574  return success();
575 }
576 
577 #ifndef NDEBUG
578 static bool
580  MeshSharding sharding) {
581  if (isa<RankedTensorType>(value.getType())) {
582  return sharding && isFullReplication(sharding);
583  }
584 
585  return !sharding;
586 }
587 
588 template <typename ValueRange, typename MeshShardingRage>
589 static bool
591  MeshShardingRage &&shardings) {
592  if (std::size(values) != std::size(shardings)) {
593  return false;
594  }
595  return llvm::all_of(
596  llvm::zip_equal(std::forward<ValueRange>(values),
597  std::forward<MeshShardingRage>(shardings)),
598  [](auto valueAndSharding) {
600  std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
601  });
602 }
603 #endif // NDEBUG
604 
606  Operation &op, ArrayRef<Value> spmdizedOperands,
607  ArrayRef<MeshSharding> operandShardings,
608  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
609  SymbolTableCollection &symbolTable, OpBuilder &builder) {
610  assert(spmdizedOperands.size() == operandShardings.size());
612  operandShardings));
614  resultShardings));
615  // `clone` will populate the mapping of old to new results.
616  builder.clone(op, spmdizationMap);
617 }
618 
620  ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
621  SmallVector<std::optional<SmallVector<MeshAxis>>>
622  &meshAxesAssignmentForLoopIterators) {
623  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
624  unsigned loopIteratorIdx = affineDimExpr.getPosition();
625  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
626  assert(llvm::equal(meshAxesAssignmentForTensorAxis,
627  *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
628  } else {
629  meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
630  llvm::to_vector(meshAxesAssignmentForTensorAxis);
631  }
632 }
633 
635  ArrayRef<MeshSharding> operandShardings,
636  ArrayRef<MeshSharding> resultShardings,
637  ArrayRef<utils::IteratorType> loopIteratorTypes,
638  ArrayRef<AffineMap> indexingMaps) {
640  meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
641  std::vector<MeshSharding> operatorAndResultShardings;
642  operatorAndResultShardings.reserve(operandShardings.size() +
643  resultShardings.size());
644  llvm::append_range(operatorAndResultShardings, operandShardings);
645  for (auto [sharding, affineMap] :
646  llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
647  if (!sharding) {
648  continue;
649  }
650  for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
651  llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
653  meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
654  meshAxisAssignmentForLoopIterators);
655  }
656  // Missing trailing split axes means replication on those tensor dimensions.
657  for (unsigned i = sharding.getSplitAxes().size();
658  i < affineMap.getNumResults(); ++i) {
660  {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
661  }
662  }
663 
664  ShardingArray res;
665  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
666  [](std::optional<SmallVector<MeshAxis>> &axes) {
667  if (!axes) {
668  return SmallVector<MeshAxis>();
669  };
670  return std::move(*axes);
671  });
672  return res;
673 }
674 
676  ArrayRef<utils::IteratorType> loopIteratorTypes,
677  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
678  for (auto [loopIteratorType, meshAxisAssignment] :
679  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
680  if (loopIteratorType == utils::IteratorType::reduction &&
681  !meshAxisAssignment.empty()) {
682  return true;
683  }
684  }
685  return false;
686 }
687 
689  ArrayRef<utils::IteratorType> loopIteratorTypes,
690  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
691  SmallVector<MeshAxis> meshAxes;
692  for (auto [loopIteratorType, meshAxisAssignment] :
693  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
694  if (loopIteratorType == utils::IteratorType::reduction) {
695  llvm::append_range(meshAxes, meshAxisAssignment);
696  }
697  }
698  return meshAxes;
699 }
700 
702  Operation &op, ArrayRef<Value> spmdizedOperands,
703  ArrayRef<MeshSharding> operandShardings,
704  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
705  SymbolTableCollection &symbolTable, OpBuilder &builder) {
706  // `clone` will populate the mapping of old to new results.
707  Operation *newOp = builder.clone(op, spmdizationMap);
708  // Set the result types to the sharded counterparts.
709  for (auto [oldResult, newResult, sharding] :
710  llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
711  newResult.setType(
712  shardType(newResult.getType(),
713  getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
714  }
715 }
SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
#define DBGS()
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
unsigned getPosition() const
Definition: AffineExpr.cpp:348
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:595
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:567
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:341
operand_type_range getOperandTypes()
Definition: Operation.h:392
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:63
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:65
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
Definition: MeshOps.cpp:637
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:66
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:313
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:263
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:122
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:97
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:108
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:25
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:271
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:102
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.