MLIR  20.0.0git
ShardingInterface.cpp
Go to the documentation of this file.
1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <utility>
22 
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 
26 using namespace mlir;
27 using namespace mlir::mesh;
28 
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // common util functions
33 //===----------------------------------------------------------------------===//
34 
35 static LogicalResult
37  SmallVectorImpl<bool> &seenIds) {
38  switch (expr.getKind()) {
39  case AffineExprKind::Add: {
40  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41  AffineExpr lhs = binOpExpr.getLHS();
42  AffineExpr rhs = binOpExpr.getRHS();
43  if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44  return failure();
45  if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46  return failure();
47  return success();
48  }
49  case AffineExprKind::Mul: {
50  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51  AffineExpr lhs = binOpExpr.getLHS();
52  AffineExpr rhs = binOpExpr.getRHS();
53  AffineExpr dimExpr;
54  if (lhs.getKind() == AffineExprKind::DimId &&
56  dimExpr = lhs;
57  } else if (rhs.getKind() == AffineExprKind::DimId &&
59  dimExpr = rhs;
60  } else
61  return failure();
62  unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63  if ((size_t)position >= seenIds.size() || seenIds[position])
64  return failure();
65  seenIds[position] = true;
66  return success();
67  }
68  case AffineExprKind::DimId: {
69  unsigned position = cast<AffineDimExpr>(expr).getPosition();
70  if ((size_t)position >= seenIds.size() || seenIds[position])
71  return failure();
72  seenIds[position] = true;
73  return success();
74  }
75  default:
76  return failure();
77  }
78 }
79 
80 static FailureOr<llvm::SmallSet<unsigned, 2>>
81 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
82  SmallVector<bool> seenIds(numDims, false);
83  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
84  return failure();
85 
86  llvm::SmallSet<unsigned, 2> positions;
87  for (auto it : llvm::enumerate(seenIds)) {
88  if (it.value())
89  positions.insert((unsigned)it.index());
90  }
91  return positions;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // mesh::getMeshShardingAttr
96 //===----------------------------------------------------------------------===//
97 
98 FailureOr<std::pair<bool, MeshShardingAttr>>
100  Value val = cast<Value>(result);
101  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
102  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
103  if (!shardOp)
104  return false;
105  return !shardOp.getAnnotateForUsers();
106  });
107 
108  if (anyShardedForDef) {
109  // expected to have exact one use if it has a use of `mesh.shard` without
110  // unit attr annotate_for_users
111  if (!val.hasOneUse())
112  return failure();
113  auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
114  return std::make_pair(false, shardOp.getShard());
115  }
116 
117  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
118  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
119  if (!shardOp)
120  return false;
121  return shardOp.getAnnotateForUsers();
122  });
123  if (anyShardedForUsers) {
124  SmallVector<ShardOp> shardOps;
125  for (Operation *user : val.getUsers()) {
126  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
127  if (shardOp)
128  shardOps.push_back(shardOp);
129  }
130  MeshShardingAttr shardForDef = shardOps[0].getShard();
131  for (size_t i = 1; i < shardOps.size(); ++i) {
132  // TODO: Deduce a reasonable mesh sharding attr for def when they are
133  // different
134  assert(shardOps[i].getShard() == shardForDef &&
135  "only support all shard ops have the same mesh sharding attr");
136  }
137  return std::make_pair(true, shardForDef);
138  }
139  return failure();
140 }
141 
142 FailureOr<std::pair<bool, MeshShardingAttr>>
144  Value val = opOperand.get();
145  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
146  return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
147 
148  return failure();
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // ShardingInterface::verifyShardingInterfaceImpl
153 //===----------------------------------------------------------------------===//
154 
155 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
156  Operation *op = getOperation();
157 
158  // check operands and results type
159  for (Type type : op->getOperandTypes())
160  if (!llvm::isa<RankedTensorType>(type))
161  return failure();
162  for (Type type : op->getResultTypes())
163  if (!llvm::isa<RankedTensorType>(type))
164  return failure();
165 
166  // check loop types
167  SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
168  if (loopTypes.empty())
169  return failure();
170 
171  // check maps
172  SmallVector<AffineMap> maps = getIndexingMaps();
173  if (maps.empty())
174  return failure();
175  unsigned numOperands = op->getNumOperands();
176  unsigned numResults = op->getNumResults();
177  if (numOperands + numResults != maps.size())
178  return failure();
179 
180  for (OpResult result : op->getResults()) {
181  auto resultType = dyn_cast<RankedTensorType>(result.getType());
182  if (!resultType)
183  return failure();
184  AffineMap map = maps[numOperands + result.getResultNumber()];
185  if (!map.isProjectedPermutation()) {
186  return failure();
187  }
188  }
189 
190  return success();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // ShardingInterface::printLoopTypesAndIndexingMaps
195 //===----------------------------------------------------------------------===//
196 
197 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
198  os << "print loop types and indexing maps for: \n";
199  getOperation()->print(os);
200  os << "\n";
201  os << "loop types: [";
202  for (utils::IteratorType type : getLoopIteratorTypes()) {
203  os << stringifyEnum(type) << " ";
204  }
205  os << "]\n";
206  os << "indexing maps: \n";
207  for (AffineMap map : getIndexingMaps())
208  os << map << "\n";
209  os << "\n";
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // detail::defaultGetShardingOption
214 //===----------------------------------------------------------------------===//
215 
216 namespace {
217 
218 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
219 static LogicalResult fillShardingOption(Operation *op,
220  ShardingOption &shardingOption,
221  FlatSymbolRefAttr mesh,
222  ArrayRef<MeshAxis> meshAxes,
223  unsigned loopIdx) {
224  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
225  (!shardingOption.shardingArray[loopIdx].empty() &&
226  shardingOption.shardingArray[loopIdx] != meshAxes)) {
227  LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
228  << loopIdx << "\n");
229  return failure();
230  }
231  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
232  if (i == loopIdx)
233  continue;
234 
235  for (MeshAxis axis : meshAxes) {
236  if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
237  LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
238  << axis << " duplicate");
239  return failure();
240  }
241  }
242  }
243  if (mesh)
244  shardingOption.mesh = mesh;
245  if (shardingOption.shardingArray[loopIdx].empty())
246  shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
247  meshAxes.end());
248  return success();
249 }
250 
251 } // namespace
252 
253 FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
254  Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
255  ArrayRef<MeshShardingAttr> resultShardings) {
256  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
257  ShardingOption shardingOption;
258 
259  if (failed(shardingOp.verifyShardingInterfaceImpl()))
260  return op->emitOpError() << "invalid sharding interface implementation";
262  shardingOp.getLoopIteratorTypes();
263  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
264  unsigned numOperands = op->getNumOperands();
265  shardingOption.shardingArray.resize(loopTypes.size());
266  llvm::SmallVector<MeshAxis> partialMeshAxes;
267  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
268  bool anyShardingInResultsOrOperands = false;
269 
270  // 1. Fill sharding option based on op results
271  for (auto shardingIt : llvm::enumerate(resultShardings)) {
272  MeshShardingAttr shardAttr = shardingIt.value();
273  if (!shardAttr)
274  continue;
275  AffineMap map = maps[numOperands + shardingIt.index()];
276  anyShardingInResultsOrOperands = true;
277  // Handle the split axes: calculate the corresponding loop index for each
278  // split axes sub-array, and then store the sub-array to
279  // shardingOption[index]
280  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
281  AffineExpr expr = std::get<0>(it);
282  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
283  auto dim = cast<AffineDimExpr>(expr);
284  unsigned index = dim.getPosition();
285  visitedLoopIndices.insert(index);
286  if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
287  axes, index)))
288  return failure();
289  }
290 
291  // Handle the partial axes: at this stage, the exact loop index/indices
292  // cannot be decided because there could be multiple reduction loops.
293  ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
294  if (!partialAxes.empty()) {
295  if (!partialMeshAxes.empty())
296  return op->emitOpError() << "at most one result with partial axes is "
297  "supported at present";
298  partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
299  // Add all the reduction loop indices to `visitedLoopIndices` if
300  // `partialAxes` is not empty
301  for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
302  if (isReductionLoop(loopTypes[loopIdx]))
303  visitedLoopIndices.insert(loopIdx);
304  }
305  }
306  }
307 
308  // 2. Fill sharding option based on operands
309  for (auto shardingIt : llvm::enumerate(operandShardings)) {
310  MeshShardingAttr shardAttr = shardingIt.value();
311  if (!shardAttr)
312  continue;
313 
314  anyShardingInResultsOrOperands = true;
315  AffineMap map = maps[shardingIt.index()];
316  unsigned numDims = map.getNumDims();
317 
318  // Handle the split axes. Partial axes don't need to be handled because they
319  // only affect the defining op of the operand.
320  //
321  // TODO: Change to process the operands with single loop index first and
322  // then the operands with multiple loop indices.
323  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
324  AffineExpr expr = std::get<0>(it);
325  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
326  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
327  checkOperandAffineExpr(expr, numDims);
328  if (failed(loopIndices))
329  return op->emitOpError()
330  << "operand's affine expression is restricted to const_i * "
331  "dim_i + const_j + dim_j + ...";
332  if (loopIndices->empty())
333  continue;
334  if (loopIndices->size() == 1) {
335  unsigned loopIdx = *loopIndices->begin();
336  visitedLoopIndices.insert(loopIdx);
337  if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
338  axes, loopIdx)))
339  return failure();
340  }
341  // If multiple loop indices correspond to a dimension of an operand, it is
342  // difficult to infer which loop indices are responsible for sharding.
343  // Therefore, the exact loop index must be specified by others.
344  if (loopIndices->size() > 1) {
345  bool seenLoopIndices = false;
346  for (unsigned loopIdx : *loopIndices) {
347  if (visitedLoopIndices.contains(loopIdx)) {
348  seenLoopIndices = true;
349  break;
350  }
351  }
352  if (!seenLoopIndices)
353  return op->emitOpError()
354  << "the operand " << shardingIt.index()
355  << " has multiple loop indices in a dimension, but none of "
356  "them could be found in the exactly specified annotation "
357  "of op results or operands.";
358  }
359  }
360  }
361 
362  // 3. Finalize sharding option
363  if (!partialMeshAxes.empty()) {
364  bool anyNonEmptyReductionLoop = llvm::any_of(
365  llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
366  SmallVector<MeshAxis> &subArray = it.value();
367  int64_t idx = it.index();
368  return isReductionLoop(loopTypes[idx]) && !subArray.empty();
369  });
370  if (!anyNonEmptyReductionLoop) {
371  bool filled = false;
372  for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
373  if (isReductionLoop(loopTypes[idx])) {
374  std::ignore = fillShardingOption(op, shardingOption, nullptr,
375  partialMeshAxes, idx);
376  filled = true;
377  break;
378  }
379  }
380  if (!filled)
381  return op->emitOpError() << "no matched reduction loop found for the "
382  "result's partial type";
383  }
384  }
386  if (!anyShardingInResultsOrOperands)
387  shardingOption.empty = true;
388  return shardingOption;
389 }
390 
391 // Get the sharding attributed for the given result and sharding option.
393 getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
395  ArrayRef<ReductionKind> reductionLoopKinds) {
396  auto resultType = cast<RankedTensorType>(result.getType());
397  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
398  SmallVector<MeshAxis> partialAxes;
399 
400  // process the split axes
401  for (auto it : llvm::enumerate(map.getResults())) {
402  AffineExpr expr = it.value();
403  // `expr` must be an `AffineDimExpr` because `map` is verified by
404  // isProjectedPermutation
405  auto dim = cast<AffineDimExpr>(expr);
406  unsigned loopIdx = dim.getPosition();
407  if (loopIdx < shardingOption.shardingArray.size())
408  splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
409  }
410 
411  // process the partial axes
412  // partialType will be ignored if partialAxes is empty
413  ReductionKind partialType = ReductionKind::Sum;
414  size_t reductionLoopKindsIdx = 0;
415  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
416  utils::IteratorType iType = std::get<0>(it);
417  if (isReductionLoop(iType)) {
418  ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
419  ++reductionLoopKindsIdx;
420  if (!partialAxes.empty())
421  assert(partialType == curPartialType &&
422  "Only one reduction type is supported");
423  partialType = curPartialType;
424  const SmallVector<MeshAxis> &axis = std::get<1>(it);
425  partialAxes.append(axis);
426  }
427  }
428 
429  removeTrailingEmptySubArray(splitAxes);
430  return MeshShardingAttr::get(result.getContext(), shardingOption.mesh,
431  splitAxes, partialAxes, partialType);
432 }
433 
434 static FailureOr<MeshShardingAttr>
435 getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
436  AffineMap map) {
437  Value operandValue = opOperand.get();
438  auto operandType = cast<RankedTensorType>(operandValue.getType());
439  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
440  unsigned numDims = map.getNumDims();
441  for (auto it : llvm::enumerate(map.getResults())) {
442  int64_t idx = it.index();
443  AffineExpr expr = it.value();
444  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
445  checkOperandAffineExpr(expr, numDims);
446  if (failed(loopIndices))
447  return failure();
448  SmallVector<unsigned> shardedLoopIndices;
449  for (unsigned loopIdx : *loopIndices) {
450  if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
451  !shardingOption.shardingArray[loopIdx].empty())
452  shardedLoopIndices.push_back(loopIdx);
453  }
454  // mostly one sharded loop index is accepted
455  if (shardedLoopIndices.size() > 1)
456  return failure();
457  if (shardedLoopIndices.size() == 1) {
458  splitAxes[idx].append(
459  shardingOption.shardingArray[shardedLoopIndices[0]]);
460  }
461  }
462 
463  removeTrailingEmptySubArray(splitAxes);
464  return MeshShardingAttr::get(opOperand.get().getContext(),
465  shardingOption.mesh, splitAxes);
466 }
467 
468 FailureOr<SmallVector<MeshShardingAttr>>
470  Operation *op, const ShardingOption &shardingOption) {
472 
473  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
475  shardingOp.getLoopIteratorTypes();
476  SmallVector<ReductionKind> reductionKinds =
477  shardingOp.getReductionLoopIteratorKinds();
478  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
479  unsigned numOperands = op->getNumOperands();
480 
481  for (OpOperand &opOperand : op->getOpOperands()) {
482  FailureOr<MeshShardingAttr> shardingAttr = getShardingAttribute(
483  opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
484  if (failed(shardingAttr))
485  return failure();
486  res.push_back(*shardingAttr);
487  }
488 
489  for (OpResult result : op->getResults()) {
490  res.push_back(getShardingAttribute(
491  result, shardingOption, maps[numOperands + result.getResultNumber()],
492  loopTypes, reductionKinds));
493  }
494 
495  return res;
496 }
497 
498 //===----------------------------------------------------------------------===//
499 // detail::defaultAddShardingAnnotations
500 //===----------------------------------------------------------------------===//
501 
502 // To add a `mesh.shard` op for the given result, based on the details provided
503 // in `shardingOption`, `map`, and `loopTypes`.
504 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
505  const ShardingOption &shardingOption,
506  AffineMap map,
508  ArrayRef<ReductionKind> reductionLoopKinds) {
510  result, shardingOption, map, loopTypes, reductionLoopKinds);
511  maybeInsertTargetShardingAnnotation(shardAttr, result, b);
512 
513  return success();
514 }
515 
516 // To add a `mesh.shard` op for the given operand, based on the details provided
517 // in `shardingOption`, `map`, and `loopTypes`.
518 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
519  const ShardingOption &shardingOption,
520  AffineMap map) {
521 
522  FailureOr<MeshShardingAttr> shardAttr =
523  getShardingAttribute(opOperand, shardingOption, map);
524  if (failed(shardAttr)) {
525  return failure();
526  }
527  OpBuilder::InsertionGuard guard(b);
528  maybeInsertSourceShardingAnnotation(*shardAttr, opOperand, b);
529 
530  return success();
531 }
532 
534  Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
535  assert(!shardingOption.empty && shardingOption.mesh);
536 
537  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
539  shardingOp.getLoopIteratorTypes();
540  SmallVector<ReductionKind> reductionKinds =
541  shardingOp.getReductionLoopIteratorKinds();
542  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
543  unsigned numOperands = op->getNumOperands();
544 
545  // 1. add mesh.shard ops for all op results
546  for (OpResult result : op->getResults()) {
547  if (failed(addShardOp(b, result, shardingOption,
548  maps[numOperands + result.getResultNumber()],
549  loopTypes, reductionKinds)))
550  return failure();
551  }
552 
553  // 2. add mesh.shard ops for all operands
554  for (OpOperand &opOperand : op->getOpOperands()) {
555  if (failed(addShardOp(b, opOperand, shardingOption,
556  maps[opOperand.getOperandNumber()])))
557  return failure();
558  }
559 
560  return success();
561 }
562 
563 #ifndef NDEBUG
564 static bool
566  MeshShardingAttr sharding) {
567  if (isa<RankedTensorType>(value.getType())) {
568  return sharding && isFullReplication(sharding);
569  }
570 
571  return !sharding;
572 }
573 
574 template <typename ValueRange, typename MeshShardingAttrRage>
576  ValueRange &&values, MeshShardingAttrRage &&shardings) {
577  if (std::size(values) != std::size(shardings)) {
578  return false;
579  }
580  return llvm::all_of(llvm::zip_equal(
581  std::forward<ValueRange>(values),
582  std::forward<MeshShardingAttrRage>(shardings)),
583  [](auto valueAndSharding) {
585  std::get<0>(valueAndSharding),
586  std::get<1>(valueAndSharding));
587  });
588 }
589 #endif // NDEBUG
590 
592  Operation &op, ArrayRef<Value> spmdizedOperands,
593  ArrayRef<MeshShardingAttr> operandShardings,
594  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
595  SymbolTableCollection &symbolTable, OpBuilder &builder) {
596  assert(spmdizedOperands.size() == operandShardings.size());
598  operandShardings));
600  resultShardings));
601  // `clone` will populate the mapping of old to new results.
602  builder.clone(op, spmdizationMap);
603 }
604 
606  ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
607  SmallVector<std::optional<SmallVector<MeshAxis>>>
608  &meshAxesAssignmentForLoopIterators) {
609  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
610  unsigned loopIteratorIdx = affineDimExpr.getPosition();
611  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
612  assert(llvm::equal(meshAxesAssignmentForTensorAxis,
613  *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
614  } else {
615  meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
616  llvm::to_vector(meshAxesAssignmentForTensorAxis);
617  }
618 }
619 
621  ArrayRef<MeshShardingAttr> operandShardings,
622  ArrayRef<MeshShardingAttr> resultShardings,
623  ArrayRef<utils::IteratorType> loopIteratorTypes,
624  ArrayRef<AffineMap> indexingMaps) {
626  meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
627  SmallVector<MeshShardingAttr> operatorAndResultShardings;
628  operatorAndResultShardings.reserve(operandShardings.size() +
629  resultShardings.size());
630  llvm::append_range(operatorAndResultShardings, operandShardings);
631  for (auto [sharding, affineMap] :
632  llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
633  if (!sharding) {
634  continue;
635  }
636  for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
637  llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
639  meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
640  meshAxisAssignmentForLoopIterators);
641  }
642  // Missing trailing split axes means replication on those tensor dimensions.
643  for (unsigned i = sharding.getSplitAxes().size();
644  i < affineMap.getNumResults(); ++i) {
646  {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
647  }
648  }
649 
650  ShardingArray res;
651  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
652  [](std::optional<SmallVector<MeshAxis>> &axes) {
653  if (!axes) {
654  return SmallVector<MeshAxis>();
655  };
656  return std::move(*axes);
657  });
658  return res;
659 }
660 
662  ArrayRef<utils::IteratorType> loopIteratorTypes,
663  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
664  for (auto [loopIteratorType, meshAxisAssignment] :
665  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
666  if (loopIteratorType == utils::IteratorType::reduction &&
667  !meshAxisAssignment.empty()) {
668  return true;
669  }
670  }
671  return false;
672 }
673 
675  ArrayRef<utils::IteratorType> loopIteratorTypes,
676  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
677  SmallVector<MeshAxis> meshAxes;
678  for (auto [loopIteratorType, meshAxisAssignment] :
679  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
680  if (loopIteratorType == utils::IteratorType::reduction) {
681  llvm::append_range(meshAxes, meshAxisAssignment);
682  }
683  }
684  return meshAxes;
685 }
686 
688  Operation &op, ArrayRef<Value> spmdizedOperands,
689  ArrayRef<MeshShardingAttr> operandShardings,
690  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
691  SymbolTableCollection &symbolTable, OpBuilder &builder) {
692  // `clone` will populate the mapping of old to new results.
693  Operation *newOp = builder.clone(op, spmdizationMap);
694  // Set the result types to the sharded counterparts.
695  for (auto [oldResult, newResult, sharding] :
696  llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
697  newResult.setType(shardType(newResult.getType(),
698  getMesh(&op, sharding.getMesh(), symbolTable),
699  sharding));
700  }
701 }
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshShardingAttr sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingAttrRage &&shardings)
#define DBGS()
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshShardingAttr getShardingAttribute(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
unsigned getPosition() const
Definition: AffineExpr.cpp:348
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:595
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:341
operand_type_range getOperandTypes()
Definition: Operation.h:392
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
FailureOr< SmallVector< MeshShardingAttr > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:182
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:67
bool isFullReplication(MeshShardingAttr attr)
Definition: MeshOps.h:53
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:42
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:25
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:172
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:47
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:222
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...