MLIR  22.0.0git
ShardingInterface.cpp
Go to the documentation of this file.
1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <utility>
22 
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 
26 using namespace mlir;
27 using namespace mlir::shard;
28 
29 #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // common util functions
33 //===----------------------------------------------------------------------===//
34 
35 static LogicalResult
37  SmallVectorImpl<bool> &seenIds) {
38  switch (expr.getKind()) {
39  case AffineExprKind::Add: {
40  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41  AffineExpr lhs = binOpExpr.getLHS();
42  AffineExpr rhs = binOpExpr.getRHS();
43  if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44  return failure();
45  if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46  return failure();
47  return success();
48  }
49  case AffineExprKind::Mul: {
50  auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51  AffineExpr lhs = binOpExpr.getLHS();
52  AffineExpr rhs = binOpExpr.getRHS();
53  AffineExpr dimExpr;
54  if (lhs.getKind() == AffineExprKind::DimId &&
56  dimExpr = lhs;
57  } else if (rhs.getKind() == AffineExprKind::DimId &&
59  dimExpr = rhs;
60  } else {
61  return failure();
62  }
63  unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
64  if ((size_t)position >= seenIds.size() || seenIds[position])
65  return failure();
66  seenIds[position] = true;
67  return success();
68  }
69  case AffineExprKind::DimId: {
70  unsigned position = cast<AffineDimExpr>(expr).getPosition();
71  if ((size_t)position >= seenIds.size() || seenIds[position])
72  return failure();
73  seenIds[position] = true;
74  return success();
75  }
76  default:
77  return failure();
78  }
79 }
80 
81 static FailureOr<llvm::SmallSet<unsigned, 2>>
82 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
83  SmallVector<bool> seenIds(numDims, false);
84  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
85  return failure();
86 
87  llvm::SmallSet<unsigned, 2> positions;
88  for (auto it : llvm::enumerate(seenIds)) {
89  if (it.value())
90  positions.insert((unsigned)it.index());
91  }
92  return positions;
93 }
94 
95 template <typename T>
99  for (const auto &v : vec) {
100  res.emplace_back(GridAxesAttr::get(ctxt, v));
101  }
102  return res;
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // shard::getSharding
107 //===----------------------------------------------------------------------===//
108 
109 FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpResult result) {
110  Value val = cast<Value>(result);
111  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
112  auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
113  if (!shardOp)
114  return false;
115  return !shardOp.getAnnotateForUsers();
116  });
117 
118  if (anyShardedForDef) {
119  // expected to have exact one use if it has a use of `shard.shard` without
120  // unit attr annotate_for_users
121  if (!val.hasOneUse())
122  return failure();
123  auto shardOp = llvm::cast<shard::ShardOp>(*val.getUsers().begin());
124  return std::make_pair(false, Sharding(shardOp.getSharding()));
125  }
126 
127  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
128  auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
129  if (!shardOp)
130  return false;
131  return shardOp.getAnnotateForUsers();
132  });
133  if (anyShardedForUsers) {
134  SmallVector<ShardOp> shardOps;
135  for (Operation *user : val.getUsers()) {
136  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
137  if (shardOp)
138  shardOps.push_back(shardOp);
139  }
140  Sharding shardForDef = shardOps[0].getSharding();
141  for (size_t i = 1; i < shardOps.size(); ++i) {
142  // TODO: Deduce a reasonable grid sharding attr for def when they are
143  // different
144  assert(shardForDef == shardOps[i].getSharding() &&
145  "only support all shard ops have the same grid sharding attr");
146  }
147  return std::make_pair(true, shardForDef);
148  }
149  return failure();
150 }
151 
152 FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpOperand &opOperand) {
153  Value val = opOperand.get();
154  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
155  return std::make_pair(shardOp.getAnnotateForUsers(),
156  Sharding(shardOp.getSharding()));
157 
158  return failure();
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // ShardingInterface::verifyShardingInterfaceImpl
163 //===----------------------------------------------------------------------===//
164 
165 LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() {
166  Operation *op = getOperation();
167 
168  // check operands and results type
169  for (Type type : op->getOperandTypes())
170  if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
171  return failure();
172  for (Type type : op->getResultTypes())
173  if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
174  return failure();
175 
176  // check maps
177  SmallVector<AffineMap> maps = getIndexingMaps();
178  if (maps.empty())
179  return failure();
180  unsigned numOperands = op->getNumOperands();
181  unsigned numResults = op->getNumResults();
182  if (numOperands + numResults != maps.size())
183  return failure();
184 
185  for (OpResult result : op->getResults()) {
186  auto resultType = dyn_cast<RankedTensorType>(result.getType());
187  if (!resultType)
188  return failure();
189  AffineMap map = maps[numOperands + result.getResultNumber()];
190  if (!map.isProjectedPermutation()) {
191  return failure();
192  }
193  }
194 
195  return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // ShardingInterface::printLoopTypesAndIndexingMaps
200 //===----------------------------------------------------------------------===//
201 
202 void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
203  os << "print loop types and indexing maps for: \n";
204  getOperation()->print(os);
205  os << "\n";
206  os << "loop types: [";
207  for (utils::IteratorType type : getLoopIteratorTypes()) {
208  os << stringifyEnum(type) << " ";
209  }
210  os << "]\n";
211  os << "indexing maps: \n";
212  for (AffineMap map : getIndexingMaps())
213  os << map << "\n";
214  os << "\n";
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // detail::defaultGetShardingOption
219 //===----------------------------------------------------------------------===//
220 
221 namespace {
222 
223 // Update the given `shardingOption` according to `gridAxes` and `loopIdx`
224 static LogicalResult fillShardingOption(Operation *op,
225  ShardingOption &shardingOption,
226  FlatSymbolRefAttr grid,
227  ArrayRef<GridAxis> gridAxes,
228  unsigned loopIdx) {
229  if ((shardingOption.grid && grid && shardingOption.grid != grid) ||
230  (!shardingOption.shardingArray[loopIdx].empty() &&
231  shardingOption.shardingArray[loopIdx] != gridAxes)) {
232  LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
233  << loopIdx << "\n");
234  return failure();
235  }
236  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
237  if (i == loopIdx)
238  continue;
239 
240  for (GridAxis axis : gridAxes) {
241  if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
242  LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes "
243  << axis << " duplicate");
244  return failure();
245  }
246  }
247  }
248  if (grid)
249  shardingOption.grid = grid;
250  if (shardingOption.shardingArray[loopIdx].empty())
251  shardingOption.shardingArray[loopIdx].append(gridAxes.begin(),
252  gridAxes.end());
253  return success();
254 }
255 
256 } // namespace
257 
258 FailureOr<ShardingOption>
260  ArrayRef<Sharding> operandShardings,
261  ArrayRef<Sharding> resultShardings) {
262  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
263  ShardingOption shardingOption;
264 
265  if (failed(shardingOp.verifyShardingInterfaceImpl()))
266  return op->emitOpError() << "invalid sharding interface implementation";
268  shardingOp.getLoopIteratorTypes();
269  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
270  unsigned numOperands = op->getNumOperands();
271  shardingOption.shardingArray.resize(loopTypes.size());
272  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
273  bool anyShardingInResultsOrOperands = false;
274 
275  // 1. Fill sharding option based on op results
276  for (auto shardingIt : llvm::enumerate(resultShardings)) {
277  Sharding shardAttr = shardingIt.value();
278  if (!shardAttr)
279  continue;
280  AffineMap map = maps[numOperands + shardingIt.index()];
281  anyShardingInResultsOrOperands = true;
282  if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
283  shardingOption.grid = shardAttr.getGridAttr();
284  } else {
285  // Handle the split axes: calculate the corresponding loop index for each
286  // split axes sub-array, and then store the sub-array to
287  // shardingOption[index]
288  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
289  AffineExpr expr = std::get<0>(it);
290  ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
291  auto dim = cast<AffineDimExpr>(expr);
292  unsigned index = dim.getPosition();
293  visitedLoopIndices.insert(index);
294  if (failed(fillShardingOption(op, shardingOption,
295  shardAttr.getGridAttr(), axes, index)))
296  return failure();
297  }
298  }
299  }
300 
301  // 2. Fill sharding option based on operands
302  for (auto shardingIt : llvm::enumerate(operandShardings)) {
303  Sharding shardAttr = shardingIt.value();
304  if (!shardAttr)
305  continue;
306 
307  anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
308  AffineMap map = maps[shardingIt.index()];
309  unsigned numDims = map.getNumDims();
310 
311  // Handle the split axes.
312  //
313  // TODO: Change to process the operands with single loop index first and
314  // then the operands with multiple loop indices.
315  for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
316  AffineExpr expr = std::get<0>(it);
317  ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
318  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
319  checkOperandAffineExpr(expr, numDims);
320  if (failed(loopIndices))
321  return op->emitOpError()
322  << "operand's affine expression is restricted to const_i * "
323  "dim_i + const_j + dim_j + ...";
324  if (loopIndices->empty())
325  continue;
326  if (loopIndices->size() == 1) {
327  unsigned loopIdx = *loopIndices->begin();
328  visitedLoopIndices.insert(loopIdx);
329  if (failed(fillShardingOption(op, shardingOption,
330  shardAttr.getGridAttr(), axes, loopIdx)))
331  return failure();
332  }
333  // If multiple loop indices correspond to a dimension of an operand, it is
334  // difficult to infer which loop indices are responsible for sharding.
335  // Therefore, the exact loop index must be specified by others.
336  if (loopIndices->size() > 1) {
337  bool seenLoopIndices = false;
338  for (unsigned loopIdx : *loopIndices) {
339  if (visitedLoopIndices.contains(loopIdx)) {
340  seenLoopIndices = true;
341  break;
342  }
343  }
344  if (!seenLoopIndices)
345  return op->emitOpError()
346  << "the operand " << shardingIt.index()
347  << " has multiple loop indices in a dimension, but none of "
348  "them could be found in the exactly specified annotation "
349  "of op results or operands.";
350  }
351  }
352  }
353 
354  // 3. Finalize sharding option
356  if (!anyShardingInResultsOrOperands)
357  shardingOption.empty = true;
358  return shardingOption;
359 }
360 
361 // Get the sharding attributed for the given result and sharding option.
363  const ShardingOption &shardingOption, AffineMap map,
364  ArrayRef<utils::IteratorType> loopTypes) {
365  auto resultType = cast<RankedTensorType>(result.getType());
366  SmallVector<SmallVector<GridAxis>> splitAxes(resultType.getRank());
367 
368  // process the split axes
369  for (auto it : llvm::enumerate(map.getResults())) {
370  AffineExpr expr = it.value();
371  // `expr` must be an `AffineDimExpr` because `map` is verified by
372  // isProjectedPermutation
373  auto dim = cast<AffineDimExpr>(expr);
374  unsigned loopIdx = dim.getPosition();
375  if (loopIdx < shardingOption.shardingArray.size())
376  splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
377  }
378 
379  removeTrailingEmptySubArray(splitAxes);
380  return Sharding::get(shardingOption.grid,
381  fromArrayOfVector(result.getContext(), splitAxes));
382 }
383 
384 static FailureOr<Sharding> getSharding(OpOperand &opOperand,
385  const ShardingOption &shardingOption,
386  AffineMap map) {
387  Value operandValue = opOperand.get();
388  auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
389  if (!operandType) {
390  if (operandValue.getType().isIntOrIndexOrFloat())
391  return Sharding();
392  return failure();
393  }
394  // 0d tensors cannot be sharded and must get replicated
395  if (operandType.getRank() == 0) {
396  return Sharding(shardingOption.grid);
397  }
398  SmallVector<SmallVector<GridAxis>> splitAxes(operandType.getRank());
399  unsigned numDims = map.getNumDims();
400  for (auto it : llvm::enumerate(map.getResults())) {
401  int64_t idx = it.index();
402  AffineExpr expr = it.value();
403  FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
404  checkOperandAffineExpr(expr, numDims);
405  if (failed(loopIndices))
406  return failure();
407  SmallVector<unsigned> shardedLoopIndices;
408  for (unsigned loopIdx : *loopIndices) {
409  if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
410  !shardingOption.shardingArray[loopIdx].empty())
411  shardedLoopIndices.push_back(loopIdx);
412  }
413  // mostly one sharded loop index is accepted
414  if (shardedLoopIndices.size() > 1)
415  return failure();
416  if (shardedLoopIndices.size() == 1) {
417  splitAxes[idx].append(
418  shardingOption.shardingArray[shardedLoopIndices[0]]);
419  }
420  }
421 
422  removeTrailingEmptySubArray(splitAxes);
423  return Sharding::get(
424  shardingOption.grid,
425  fromArrayOfVector(opOperand.get().getContext(), splitAxes));
426 }
427 
428 FailureOr<std::vector<Sharding>> shard::detail::defaultGetShardingAnnotations(
429  Operation *op, const ShardingOption &shardingOption) {
430  std::vector<Sharding> res;
431 
432  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
434  shardingOp.getLoopIteratorTypes();
435  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
436  unsigned numOperands = op->getNumOperands();
437 
438  for (OpOperand &opOperand : op->getOpOperands()) {
439  FailureOr<Sharding> shardingAttr = ::getSharding(
440  opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
441  if (failed(shardingAttr))
442  return failure();
443  res.push_back(*shardingAttr);
444  }
445 
446  for (OpResult result : op->getResults()) {
447  res.push_back(::getSharding(result, shardingOption,
448  maps[numOperands + result.getResultNumber()],
449  loopTypes));
450  }
451 
452  return res;
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // detail::defaultAddShardingAnnotations
457 //===----------------------------------------------------------------------===//
458 
459 // To add a `shard.shard` op for the given result, based on the details provided
460 // in `shardingOption`, `map`, and `loopTypes`.
461 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
462  const ShardingOption &shardingOption,
463  AffineMap map,
464  ArrayRef<utils::IteratorType> loopTypes) {
465  Sharding sharding = getSharding(result, shardingOption, map, loopTypes);
466  maybeInsertTargetShardingAnnotation(sharding, result, b);
467 
468  return success();
469 }
470 
471 // To add a `shard.shard` op for the given operand, based on the details
472 // provided in `shardingOption`, `map`, and `loopTypes`.
473 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
474  const ShardingOption &shardingOption,
475  AffineMap map) {
476 
477  FailureOr<Sharding> sharding = getSharding(opOperand, shardingOption, map);
478  if (failed(sharding)) {
479  return failure();
480  }
481  OpBuilder::InsertionGuard guard(b);
482  maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
483 
484  return success();
485 }
486 
488  Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
489  assert(!shardingOption.empty && shardingOption.grid);
490 
491  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
493  shardingOp.getLoopIteratorTypes();
494  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
495  unsigned numOperands = op->getNumOperands();
496 
497  // 1. add shard.shard ops for all op results
498  for (OpResult result : op->getResults()) {
499  if (failed(addShardOp(b, result, shardingOption,
500  maps[numOperands + result.getResultNumber()],
501  loopTypes)))
502  return failure();
503  }
504 
505  // 2. add shard.shard ops for all operands
506  for (OpOperand &opOperand : op->getOpOperands()) {
507  if (failed(addShardOp(b, opOperand, shardingOption,
508  maps[opOperand.getOperandNumber()])))
509  return failure();
510  }
511 
512  return success();
513 }
514 
515 #ifndef NDEBUG
516 static bool
518  const Sharding &sharding) {
519  if (isa<RankedTensorType>(value.getType())) {
520  return isFullReplication(sharding);
521  }
522 
523  return !sharding;
524 }
525 
526 template <typename ValueRange, typename ShardingRage>
527 static bool
529  ShardingRage &&shardings) {
530  if (std::size(values) != std::size(shardings)) {
531  return false;
532  }
533  return llvm::all_of(llvm::zip_equal(std::forward<ValueRange>(values),
534  std::forward<ShardingRage>(shardings)),
535  [](auto valueAndSharding) {
537  std::get<0>(valueAndSharding),
538  std::get<1>(valueAndSharding));
539  });
540 }
541 #endif // NDEBUG
542 
544  Operation &op, ArrayRef<Value> partitionedOperands,
545  ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
546  IRMapping &partitionMap, SymbolTableCollection &symbolTable,
547  OpBuilder &builder) {
548  assert(partitionedOperands.size() == operandShardings.size());
550  operandShardings));
552  resultShardings));
553  // `clone` will populate the mapping of old to new results.
554  builder.clone(op, partitionMap);
555 }
556 
558  ArrayRef<GridAxis> gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
559  SmallVector<std::optional<SmallVector<GridAxis>>>
560  &gridAxesAssignmentForLoopIterators) {
561  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
562  unsigned loopIteratorIdx = affineDimExpr.getPosition();
563  if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) {
564  assert(llvm::equal(gridAxesAssignmentForTensorAxis,
565  *gridAxesAssignmentForLoopIterators[loopIteratorIdx]));
566  } else {
567  gridAxesAssignmentForLoopIterators[loopIteratorIdx] =
568  llvm::to_vector(gridAxesAssignmentForTensorAxis);
569  }
570 }
571 
573  ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
574  ArrayRef<utils::IteratorType> loopIteratorTypes,
575  ArrayRef<AffineMap> indexingMaps) {
577  gridAxisAssignmentForLoopIterators(loopIteratorTypes.size());
578  std::vector<Sharding> operatorAndResultShardings;
579  operatorAndResultShardings.reserve(operandShardings.size() +
580  resultShardings.size());
581  llvm::append_range(operatorAndResultShardings, operandShardings);
582  for (auto [sharding, affineMap] :
583  llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
584  if (!sharding) {
585  continue;
586  }
587  for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] :
588  llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
590  gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
591  gridAxisAssignmentForLoopIterators);
592  }
593  // Missing trailing split axes means replication on those tensor dimensions.
594  for (unsigned i = sharding.getSplitAxes().size();
595  i < affineMap.getNumResults(); ++i) {
597  {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators);
598  }
599  }
600 
601  ShardingArray res;
602  llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res),
603  [](std::optional<SmallVector<GridAxis>> &axes) {
604  if (!axes) {
605  return SmallVector<GridAxis>();
606  };
607  return std::move(*axes);
608  });
609  return res;
610 }
611 
613  ArrayRef<utils::IteratorType> loopIteratorTypes,
614  ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
615  for (auto [loopIteratorType, gridAxisAssignment] :
616  llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
617  if (loopIteratorType == utils::IteratorType::reduction &&
618  !gridAxisAssignment.empty()) {
619  return true;
620  }
621  }
622  return false;
623 }
624 
626  ArrayRef<utils::IteratorType> loopIteratorTypes,
627  ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
628  SmallVector<GridAxis> gridAxes;
629  for (auto [loopIteratorType, gridAxisAssignment] :
630  llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
631  if (loopIteratorType == utils::IteratorType::reduction) {
632  llvm::append_range(gridAxes, gridAxisAssignment);
633  }
634  }
635  return gridAxes;
636 }
637 
639  Operation &op, ArrayRef<Value> partitionedOperands,
640  ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
641  IRMapping &partitionMap, SymbolTableCollection &symbolTable,
642  OpBuilder &builder) {
643  // `clone` will populate the mapping of old to new results.
644  Operation *newOp = builder.clone(op, partitionMap);
645  // Set the result types to the sharded counterparts.
646  for (auto [oldResult, newResult, sharding] :
647  llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
648  newResult.setType(shardType(
649  newResult.getType(),
650  getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding));
651  }
652 }
static bool isValueCompatibleWithFullReplicationSharding(Value value, const Sharding &sharding)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, ShardingRage &&shardings)
#define DBGS()
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
static void updateGridAxisAssignmentForLoopIterators(ArrayRef< GridAxis > gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< GridAxis >>> &gridAxesAssignmentForLoopIterators)
SmallVector< GridAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:346
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumOperands()
Definition: Operation.h:346
operand_type_range getOperandTypes()
Definition: Operation.h:397
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: ShardOps.cpp:770
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition: ShardOps.h:60
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition: ShardOps.h:62
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
shard::Sharding Sharding
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< std::vector< Sharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings)
int16_t GridAxis
Definition: ShardOps.h:26
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition: ShardOps.cpp:338
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:113
bool isFullReplication(Sharding sharding)
Definition: ShardOps.h:106
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: ShardOps.cpp:352
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: ShardOps.h:100
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition: ShardOps.cpp:291
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.