MLIR  19.0.0git
ShardingInterface.cpp
Go to the documentation of this file.
1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <utility>
22 
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 
26 using namespace mlir;
27 using namespace mlir::mesh;
28 
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // common util functions
33 //===----------------------------------------------------------------------===//
34 
35 static LogicalResult
37  SmallVectorImpl<bool> &seenIds) {
38  switch (expr.getKind()) {
39  case AffineExprKind::Add: {
40  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41  AffineExpr lhs = binOpExpr.getLHS();
42  AffineExpr rhs = binOpExpr.getRHS();
43  if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44  return failure();
45  if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46  return failure();
47  return success();
48  }
49  case AffineExprKind::Mul: {
50  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51  AffineExpr lhs = binOpExpr.getLHS();
52  AffineExpr rhs = binOpExpr.getRHS();
53  AffineExpr dimExpr;
54  if (lhs.getKind() == AffineExprKind::DimId &&
56  dimExpr = lhs;
57  } else if (rhs.getKind() == AffineExprKind::DimId &&
59  dimExpr = rhs;
60  } else
61  return failure();
62  unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63  if ((size_t)position >= seenIds.size() || seenIds[position])
64  return failure();
65  seenIds[position] = true;
66  return success();
67  }
68  case AffineExprKind::DimId: {
69  unsigned position = cast<AffineDimExpr>(expr).getPosition();
70  if ((size_t)position >= seenIds.size() || seenIds[position])
71  return failure();
72  seenIds[position] = true;
73  return success();
74  }
75  default:
76  return failure();
77  }
78 }
79 
81 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
82  SmallVector<bool> seenIds(numDims, false);
83  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
84  return failure();
85 
86  llvm::SmallSet<unsigned, 2> positions;
87  for (auto it : llvm::enumerate(seenIds)) {
88  if (it.value())
89  positions.insert((unsigned)it.index());
90  }
91  return positions;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // mesh::getMeshShardingAttr
96 //===----------------------------------------------------------------------===//
97 
100  Value val = cast<Value>(result);
101  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
102  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
103  if (!shardOp)
104  return false;
105  return !shardOp.getAnnotateForUsers();
106  });
107 
108  if (anyShardedForDef) {
109  // expected to have exact one use if it has a use of `mesh.shard` without
110  // unit attr annotate_for_users
111  if (!val.hasOneUse())
112  return failure();
113  auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
114  return std::make_pair(false, shardOp.getShard());
115  }
116 
117  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
118  auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
119  if (!shardOp)
120  return false;
121  return shardOp.getAnnotateForUsers();
122  });
123  if (anyShardedForUsers) {
124  SmallVector<ShardOp> shardOps;
125  for (Operation *user : val.getUsers()) {
126  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
127  if (shardOp)
128  shardOps.push_back(shardOp);
129  }
130  MeshShardingAttr shardForDef = shardOps[0].getShard();
131  for (size_t i = 1; i < shardOps.size(); ++i) {
132  // TODO: Deduce a reasonable mesh sharding attr for def when they are
133  // different
134  assert(shardOps[i].getShard() == shardForDef &&
135  "only support all shard ops have the same mesh sharding attr");
136  }
137  return std::make_pair(true, shardForDef);
138  }
139  return failure();
140 }
141 
144  Value val = opOperand.get();
145  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
146  return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
147 
148  return failure();
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // ShardingInterface::verifyShardingInterfaceImpl
153 //===----------------------------------------------------------------------===//
154 
155 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
156  Operation *op = getOperation();
157 
158  // check operands and results type
159  for (Type type : op->getOperandTypes())
160  if (!llvm::isa<RankedTensorType>(type))
161  return failure();
162  for (Type type : op->getResultTypes())
163  if (!llvm::isa<RankedTensorType>(type))
164  return failure();
165 
166  // check loop types
167  SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
168  if (loopTypes.size() == 0)
169  return failure();
170 
171  // check maps
172  SmallVector<AffineMap> maps = getIndexingMaps();
173  if (maps.size() == 0)
174  return failure();
175  unsigned numOperands = op->getNumOperands();
176  unsigned numResults = op->getNumResults();
177  if (numOperands + numResults != maps.size())
178  return failure();
179 
180  for (OpResult result : op->getResults()) {
181  auto resultType = dyn_cast<RankedTensorType>(result.getType());
182  if (!resultType)
183  return failure();
184  AffineMap map = maps[numOperands + result.getResultNumber()];
185  if (!map.isProjectedPermutation()) {
186  return failure();
187  }
188  }
189 
190  return success();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // ShardingInterface::printLoopTypesAndIndexingMaps
195 //===----------------------------------------------------------------------===//
196 
197 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
198  os << "print loop types and indexing maps for: \n";
199  getOperation()->print(os);
200  os << "\n";
201  os << "loop types: [";
202  for (utils::IteratorType type : getLoopIteratorTypes()) {
203  os << stringifyEnum(type) << " ";
204  }
205  os << "]\n";
206  os << "indexing maps: \n";
207  for (AffineMap map : getIndexingMaps())
208  os << map << "\n";
209  os << "\n";
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // detail::defaultGetShardingOption
214 //===----------------------------------------------------------------------===//
215 
216 namespace {
217 
218 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
219 static LogicalResult fillShardingOption(Operation *op,
220  ShardingOption &shardingOption,
221  FlatSymbolRefAttr mesh,
222  ArrayRef<MeshAxis> meshAxes,
223  unsigned loopIdx) {
224  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
225  (!shardingOption.shardingArray[loopIdx].empty() &&
226  shardingOption.shardingArray[loopIdx] != meshAxes)) {
227  LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
228  << loopIdx << "\n");
229  return failure();
230  }
231  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
232  if (i == loopIdx)
233  continue;
234 
235  for (MeshAxis axis : meshAxes) {
236  if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
237  LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
238  << axis << " duplicate");
239  return failure();
240  }
241  }
242  }
243  if (mesh)
244  shardingOption.mesh = mesh;
245  if (shardingOption.shardingArray[loopIdx].empty())
246  shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
247  meshAxes.end());
248  return success();
249 }
250 
251 } // namespace
252 
254  Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
255  ArrayRef<MeshShardingAttr> resultShardings) {
256  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
257  ShardingOption shardingOption;
258 
259  if (failed(shardingOp.verifyShardingInterfaceImpl()))
260  return op->emitOpError() << "invalid sharding interface implementation";
262  shardingOp.getLoopIteratorTypes();
263  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
264  unsigned numOperands = op->getNumOperands();
265  shardingOption.shardingArray.resize(loopTypes.size());
266  llvm::SmallVector<MeshAxis> partialMeshAxes;
267  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
268  bool anyShardingInResultsOrOperands = false;
269 
270  // 1. Fill sharding option based on op results
271  for (auto shardingIt : llvm::enumerate(resultShardings)) {
272  MeshShardingAttr shardAttr = shardingIt.value();
273  if (!shardAttr)
274  continue;
275  AffineMap map = maps[numOperands + shardingIt.index()];
276  anyShardingInResultsOrOperands = true;
277  // Handle the split axes: calculate the corresponding loop index for each
278  // split axes sub-array, and then store the sub-array to
279  // shardingOption[index]
280  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
281  AffineExpr expr = std::get<0>(it);
282  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
283  auto dim = cast<AffineDimExpr>(expr);
284  unsigned index = dim.getPosition();
285  visitedLoopIndices.insert(index);
286  if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
287  axes, index)))
288  return failure();
289  }
290 
291  // Handle the partial axes: at this stage, the exact loop index/indices
292  // cannot be decided because there could be multiple reduction loops.
293  ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
294  if (!partialAxes.empty()) {
295  if (!partialMeshAxes.empty())
296  return op->emitOpError() << "at most one result with partial axes is "
297  "supported at present";
298  partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
299  // Add all the reduction loop indices to `visitedLoopIndices` if
300  // `partialAxes` is not empty
301  for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
302  if (isReductionLoop(loopTypes[loopIdx]))
303  visitedLoopIndices.insert(loopIdx);
304  }
305  }
306  }
307 
308  // 2. Fill sharding option based on operands
309  for (auto shardingIt : llvm::enumerate(operandShardings)) {
310  MeshShardingAttr shardAttr = shardingIt.value();
311  if (!shardAttr)
312  continue;
313 
314  anyShardingInResultsOrOperands = true;
315  AffineMap map = maps[shardingIt.index()];
316  unsigned numDims = map.getNumDims();
317 
318  // Handle the split axes. Partial axes don't need to be handled because they
319  // only affect the defining op of the operand.
320  //
321  // TODO: Change to process the operands with single loop index first and
322  // then the operands with multiple loop indices.
323  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
324  AffineExpr expr = std::get<0>(it);
325  ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
327  checkOperandAffineExpr(expr, numDims);
328  if (failed(loopIndices))
329  return op->emitOpError()
330  << "operand's affine expression is restricted to const_i * "
331  "dim_i + const_j + dim_j + ...";
332  if (loopIndices->empty())
333  continue;
334  if (loopIndices->size() == 1) {
335  unsigned loopIdx = *loopIndices->begin();
336  visitedLoopIndices.insert(loopIdx);
337  if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
338  axes, loopIdx)))
339  return failure();
340  }
341  // If multiple loop indices correspond to a dimension of an operand, it is
342  // difficult to infer which loop indices are responsible for sharding.
343  // Therefore, the exact loop index must be specified by others.
344  if (loopIndices->size() > 1) {
345  bool seenLoopIndices = false;
346  for (unsigned loopIdx : *loopIndices) {
347  if (visitedLoopIndices.contains(loopIdx)) {
348  seenLoopIndices = true;
349  break;
350  }
351  }
352  if (!seenLoopIndices)
353  return op->emitOpError()
354  << "the operand " << shardingIt.index()
355  << " has multiple loop indices in a dimension, but none of "
356  "them could be found in the exactly specified annotation "
357  "of op results or operands.";
358  }
359  }
360  }
361 
362  // 3. Finalize sharding option
363  if (!partialMeshAxes.empty()) {
364  bool anyNonEmptyReductionLoop = llvm::any_of(
365  llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
366  SmallVector<MeshAxis> &subArray = it.value();
367  int64_t idx = it.index();
368  return isReductionLoop(loopTypes[idx]) && !subArray.empty();
369  });
370  if (!anyNonEmptyReductionLoop) {
371  bool filled = false;
372  for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
373  if (isReductionLoop(loopTypes[idx])) {
374  std::ignore = fillShardingOption(op, shardingOption, nullptr,
375  partialMeshAxes, idx);
376  filled = true;
377  break;
378  }
379  }
380  if (!filled)
381  return op->emitOpError() << "no matched reduction loop found for the "
382  "result's partial type";
383  }
384  }
386  if (!anyShardingInResultsOrOperands)
387  shardingOption.empty = true;
388  return shardingOption;
389 }
390 
391 //===----------------------------------------------------------------------===//
392 // detail::defaultAddShardingAnnotations
393 //===----------------------------------------------------------------------===//
394 
395 // To add a `mesh.shard` op for the given result, based on the details provided
396 // in `shardingOption`, `map`, and `loopTypes`.
398  const ShardingOption &shardingOption,
399  AffineMap map,
401  ArrayRef<ReductionKind> reductionLoopKinds) {
403  getMeshShardingAttr(result);
404  if (succeeded(maybeSharding) && !maybeSharding->first)
405  return success();
406 
407  auto resultType = cast<RankedTensorType>(result.getType());
408  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
409  SmallVector<MeshAxis> partialAxes;
410 
411  // process the split axes
412  for (auto it : llvm::enumerate(map.getResults())) {
413  AffineExpr expr = it.value();
414  // `expr` must be an `AffineDimExpr` because `map` is verified by
415  // isProjectedPermutation
416  auto dim = cast<AffineDimExpr>(expr);
417  unsigned loopIdx = dim.getPosition();
418  if (loopIdx < shardingOption.shardingArray.size())
419  splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
420  }
421 
422  // process the partial axes
423  // partialType will be ignored if partialAxes is empty
424  ReductionKind partialType = ReductionKind::Sum;
425  size_t reductionLoopKindsIdx = 0;
426  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
427  utils::IteratorType iType = std::get<0>(it);
428  if (isReductionLoop(iType)) {
429  ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430  ++reductionLoopKindsIdx;
431  if (!partialAxes.empty())
432  assert(partialType == curPartialType &&
433  "Only one reduction type is supported");
434  partialType = curPartialType;
435  const SmallVector<MeshAxis> &axis = std::get<1>(it);
436  partialAxes.append(axis);
437  }
438  }
439 
440  removeTrailingEmptySubArray(splitAxes);
442  b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
443  OpBuilder::InsertionGuard guard(b);
444  b.setInsertionPointAfterValue(result);
445  auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
446  shardAttr, /*annotate_for_users*/ false);
447  result.replaceAllUsesExcept(shardOp, shardOp);
448  return success();
449 }
450 
451 // To add a `mesh.shard` op for the given operand, based on the details provided
452 // in `shardingOption`, `map`, and `loopTypes`.
454  const ShardingOption &shardingOption,
455  AffineMap map) {
456  auto maybeShardingAttr = getMeshShardingAttr(opOperand);
457  if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
458  return success();
459  Value operand = opOperand.get();
460  auto operandType = cast<RankedTensorType>(operand.getType());
461  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
462  unsigned numDims = map.getNumDims();
463  for (auto it : llvm::enumerate(map.getResults())) {
464  int64_t idx = it.index();
465  AffineExpr expr = it.value();
467  checkOperandAffineExpr(expr, numDims);
468  if (failed(loopIndices))
469  return failure();
470  SmallVector<unsigned> shardedLoopIndices;
471  for (unsigned loopIdx : *loopIndices) {
472  if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
473  !shardingOption.shardingArray[loopIdx].empty())
474  shardedLoopIndices.push_back(loopIdx);
475  }
476  // mostly one sharded loop index is accepted
477  if (shardedLoopIndices.size() > 1)
478  return failure();
479  if (shardedLoopIndices.size() == 1) {
480  splitAxes[idx].append(
481  shardingOption.shardingArray[shardedLoopIndices[0]]);
482  }
483  }
484 
485  removeTrailingEmptySubArray(splitAxes);
486  MeshShardingAttr shardAttr =
487  MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
488  OpBuilder::InsertionGuard guard(b);
489  b.setInsertionPoint(opOperand.getOwner());
490  auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
491  shardAttr, true);
492  opOperand.set(shardOp);
493 
494  return success();
495 }
496 
498  Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
499  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
501  shardingOp.getLoopIteratorTypes();
502  SmallVector<ReductionKind> reductionKinds =
503  shardingOp.getReductionLoopIteratorKinds();
504  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
505  unsigned numOperands = op->getNumOperands();
506 
507  // 1. add mesh.shard ops for all op results
508  for (OpResult result : op->getResults()) {
509  if (failed(addShardOp(b, result, shardingOption,
510  maps[numOperands + result.getResultNumber()],
511  loopTypes, reductionKinds)))
512  return failure();
513  }
514 
515  // 2. add mesh.shard ops for all operands
516  for (OpOperand &opOperand : op->getOpOperands()) {
517  if (failed(addShardOp(b, opOperand, shardingOption,
518  maps[opOperand.getOperandNumber()])))
519  return failure();
520  }
521 
522  return success();
523 }
524 
525 #ifndef NDEBUG
526 static bool
528  MeshShardingAttr sharding) {
529  if (isa<RankedTensorType>(value.getType())) {
530  return sharding && isFullReplication(sharding);
531  }
532 
533  return !sharding;
534 }
535 
536 template <typename ValueRange, typename MeshShardingAttrRage>
538  ValueRange &&values, MeshShardingAttrRage &&shardings) {
539  if (std::size(values) != std::size(shardings)) {
540  return false;
541  }
542  return llvm::all_of(llvm::zip_equal(
543  std::forward<ValueRange>(values),
544  std::forward<MeshShardingAttrRage>(shardings)),
545  [](auto valueAndSharding) {
547  std::get<0>(valueAndSharding),
548  std::get<1>(valueAndSharding));
549  });
550 }
551 #endif // NDEBUG
552 
554  Operation &op, ArrayRef<Value> spmdizedOperands,
555  ArrayRef<MeshShardingAttr> operandShardings,
556  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
557  SymbolTableCollection &symbolTable, OpBuilder &builder) {
558  assert(spmdizedOperands.size() == operandShardings.size());
560  operandShardings));
562  resultShardings));
563  // `clone` will populate the mapping of old to new results.
564  builder.clone(op, spmdizationMap);
565 }
566 
568  ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
569  SmallVector<std::optional<SmallVector<MeshAxis>>>
570  &meshAxesAssignmentForLoopIterators) {
571  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
572  unsigned loopIteratorIdx = affineDimExpr.getPosition();
573  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
574  assert(llvm::equal(meshAxesAssignmentForTensorAxis,
575  *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
576  } else {
577  meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
578  llvm::to_vector(meshAxesAssignmentForTensorAxis);
579  }
580 }
581 
583  ArrayRef<MeshShardingAttr> operandShardings,
584  ArrayRef<MeshShardingAttr> resultShardings,
585  ArrayRef<utils::IteratorType> loopIteratorTypes,
586  ArrayRef<AffineMap> indexingMaps) {
588  meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
589  SmallVector<MeshShardingAttr> operatorAndResultShardings;
590  operatorAndResultShardings.reserve(operandShardings.size() +
591  resultShardings.size());
592  llvm::append_range(operatorAndResultShardings, operandShardings);
593  for (auto [sharding, affineMap] :
594  llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
595  if (!sharding) {
596  continue;
597  }
598  for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
599  llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
601  meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
602  meshAxisAssignmentForLoopIterators);
603  }
604  // Missing trailing split axes means replication on those tensor dimensions.
605  for (unsigned i = sharding.getSplitAxes().size();
606  i < affineMap.getNumResults(); ++i) {
608  {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
609  }
610  }
611 
612  ShardingArray res;
613  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
614  [](std::optional<SmallVector<MeshAxis>> &axes) {
615  if (!axes) {
616  return SmallVector<MeshAxis>();
617  };
618  return std::move(*axes);
619  });
620  return res;
621 }
622 
624  ArrayRef<utils::IteratorType> loopIteratorTypes,
625  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
626  for (auto [loopIteratorType, meshAxisAssignment] :
627  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
628  if (loopIteratorType == utils::IteratorType::reduction &&
629  !meshAxisAssignment.empty()) {
630  return true;
631  }
632  }
633  return false;
634 }
635 
637  ArrayRef<utils::IteratorType> loopIteratorTypes,
638  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
639  SmallVector<MeshAxis> meshAxes;
640  for (auto [loopIteratorType, meshAxisAssignment] :
641  llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
642  if (loopIteratorType == utils::IteratorType::reduction) {
643  llvm::append_range(meshAxes, meshAxisAssignment);
644  }
645  }
646  return meshAxes;
647 }
648 
650  Operation &op, ArrayRef<Value> spmdizedOperands,
651  ArrayRef<MeshShardingAttr> operandShardings,
652  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
653  SymbolTableCollection &symbolTable, OpBuilder &builder) {
654  // `clone` will populate the mapping of old to new results.
655  Operation *newOp = builder.clone(op, spmdizationMap);
656  // Set the result types to the sharded counterparts.
657  for (auto [oldResult, newResult, sharding] :
658  llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
659  newResult.setType(shardType(newResult.getType(),
660  getMesh(&op, sharding.getMesh(), symbolTable),
661  sharding));
662  }
663 }
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshShardingAttr sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingAttrRage &&shardings)
#define DBGS()
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:237
unsigned getPosition() const
Definition: AffineExpr.cpp:340
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:579
unsigned getNumDims() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:341
operand_type_range getOperandTypes()
Definition: Operation.h:392
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl< Operation * > &exceptions)
Replace all uses of 'this' value with 'newValue', updating anything in the IR that uses 'this' to use...
Definition: Value.cpp:61
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
bool isFullReplication(MeshShardingAttr attr)
Definition: MeshOps.h:53
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:42
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:25
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:171
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:47
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26