MLIR  22.0.0git
ShardingPropagation.cpp
Go to the documentation of this file.
1 //===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Verifier.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/iterator_range.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include <algorithm>
21 #include <vector>
22 
23 namespace mlir {
24 namespace shard {
25 #define GEN_PASS_DEF_SHARDINGPROPAGATION
26 #include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
27 } // namespace shard
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "sharding-propagation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
32 
33 using namespace mlir;
34 using namespace mlir::shard;
35 
37  NO_RESHARDING = 0,
40 };
41 
42 #ifdef LLVM_DEBUG
43 
44 template <typename T>
45 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
46  const SmallVector<T> &vec);
47 template <typename... Ts>
48 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
49  const std::tuple<Ts...> &t);
50 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
52 
53 template <typename Stream, typename Range>
54 static Stream &printRange(Stream &stream, Range &&range) {
55  stream << "[";
56  for (auto &v : range) {
57  stream << v;
58  stream << ", ";
59  }
60  return stream << "]";
61 }
62 
63 template <typename T>
64 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
65  const SmallVector<T> &vec) {
66  return printRange(stream, vec);
67 }
68 
69 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
70  const ShardingOption &v) {
71  return stream << "{empty = " << v.empty << ", grid" << v.grid
72  << ", shardingArray = " << v.shardingArray << "}";
73 }
74 
75 template <typename Stream, typename... Ts, size_t... Is>
76 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
77  std::index_sequence<Is...>) {
78  static_assert(sizeof...(Is) == sizeof...(Ts),
79  "Indices must have same number of elements as tuple types!");
80  static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
81 
82  stream << "{";
83  ((stream << std::get<Is>(tuple) << ", "), ...);
84  return stream << "}";
85 }
86 
87 template <typename... Ts>
88 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
89  const std::tuple<Ts...> &t) {
90  return printTuple(stream, t, std::index_sequence_for<Ts...>{});
91 }
92 
93 [[maybe_unused]] static llvm::raw_ostream &
94 operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
95  return stream << static_cast<int>(v);
96 }
97 
98 #endif // LLVM_DEBUG
99 
100 //===----------------------------------------------------------------------===//
101 // Utilities
102 //===----------------------------------------------------------------------===//
103 
104 // This method retrieves all potential sharding attributes, prioritizing
105 // specific shardings. For example, mustShardings = [shard0, None] and
106 // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
107 // [shard0, None]]
110  ArrayRef<Sharding> optionalShardings) {
111  SmallVector<std::vector<Sharding>> allShardingAttrs;
112  std::vector<Sharding> curShardingAttrs;
113 
114  std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
115  if (i == mustShardings.size()) {
116  allShardingAttrs.push_back(std::vector<Sharding>(curShardingAttrs));
117  return;
118  }
119 
120  if (mustShardings[i]) {
121  curShardingAttrs.push_back(mustShardings[i]);
122  dfsCreateShardingAttrs(i + 1);
123  curShardingAttrs.pop_back();
124  return;
125  }
126 
127  if (optionalShardings[i]) {
128  curShardingAttrs.push_back(optionalShardings[i]);
129  dfsCreateShardingAttrs(i + 1);
130  curShardingAttrs.pop_back();
131  curShardingAttrs.push_back({});
132  dfsCreateShardingAttrs(i + 1);
133  curShardingAttrs.pop_back();
134  return;
135  }
136 
137  curShardingAttrs.push_back({});
138  dfsCreateShardingAttrs(i + 1);
139  curShardingAttrs.pop_back();
140  };
141 
142  dfsCreateShardingAttrs(0);
143  return allShardingAttrs;
144 }
145 
146 // The order of preference is form highest to lowest:
147 // 1. No resharding is required (all existing annotations are compatible).
148 // 2. No resharding for operands/results that have annotation specifically
149 // targeting this operation. This means
150 // * operands that are the result of `shard.shard` ops marked with
151 // `annotate_for_users`.
152 // * results that are annotated with `shard.shard` ops without
153 // `annotate_for_users`.
154 // 3. All other cases. Resharding is required for operands/results with
155 // annotation targeting explicitly this operation.
157  Operation *op, const std::vector<Sharding> &operandAndResultShardings) {
159 
160  size_t operandsCount = op->getOperands().size();
161  auto operandShardings =
162  llvm::make_range(operandAndResultShardings.begin(),
163  operandAndResultShardings.begin() + operandsCount);
164  auto resultShardings =
165  llvm::make_range(operandAndResultShardings.begin() + operandsCount,
166  operandAndResultShardings.end());
167 
168  for (auto [operand, sharding] :
169  llvm::zip_equal(op->getOperands(), operandShardings)) {
170  ShardOp shardOp = operand.getDefiningOp<ShardOp>();
171  if (!shardOp) {
172  continue;
173  }
174  bool needsResharding = sharding != shardOp.getSharding();
175  bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
176  if (needsResharding) {
177  if (isExplicitAnnotationForThisOp) {
178  // This is the worst case. No need to continue.
180  }
182  }
183  }
184 
185  for (auto [result, sharding] :
186  llvm::zip_equal(op->getResults(), resultShardings)) {
187  for (auto user : result.getUsers()) {
188  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
189  if (!shardOp) {
190  continue;
191  }
192  bool needsResharding = sharding != shardOp.getSharding();
193  bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
194  if (needsResharding) {
195  if (isExplicitAnnotationForThisOp) {
196  // This is the worst case. No need to continue.
198  }
200  }
201  }
202  }
203 
204  return res;
205 }
206 
207 // From all the operand and result sharding combinations,
208 // return the one that is most desirable.
209 // The order of preference is:
210 // 1. No resharding with respect to existing sharding annotations.
211 // 2. Resharding for values that have already annotations that do not target
212 // this op.
213 // 3. Resharding of existing explicit sharding annotations for this op.
214 static FailureOr<ShardingOption> selectShardingOption(
215  ShardingInterface shardingOp,
216  ArrayRef<std::vector<Sharding>> possibleOperandShardingAttrs,
217  ArrayRef<std::vector<Sharding>> possibleResultShardingAttrs) {
219  shardingOptionsAndReshardingRequirements;
220 
221  for (ArrayRef<Sharding> resultShardings : possibleResultShardingAttrs) {
222  for (ArrayRef<Sharding> operandShardings : possibleOperandShardingAttrs) {
223  FailureOr<ShardingOption> shardingOption =
224  shardingOp.getShardingOption(operandShardings, resultShardings);
225  if (failed(shardingOption) || shardingOption->empty) {
226  continue;
227  }
228  // These shardings may not be the same as those in operandShardings and
229  // resultShardings.
230  // They may be missing some annotations.
231  // Whatever is returned by getShardingAnnotations is exactly what the op
232  // needs.
233  FailureOr<std::vector<Sharding>> operandAndResultShardings =
234  shardingOp.getShardingAnnotations(*shardingOption);
235  if (failed(operandAndResultShardings)) {
236  return failure();
237  }
238 
239  // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
240  // << *operandAndResultShardings << "\n";);
241 
242  ReshardingRquirementKind reshardingRquirement =
243  getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
244  if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
245  // This is the best case. No need to go on.
246  return *shardingOption;
247  }
248 
249  shardingOptionsAndReshardingRequirements.emplace_back(
250  std::move(*shardingOption), reshardingRquirement);
251  }
252  }
253 
254  if (shardingOptionsAndReshardingRequirements.empty()) {
255  return ShardingOption::makeEmpty();
256  }
257 
258  std::partial_sort(
259  shardingOptionsAndReshardingRequirements.begin(),
260  shardingOptionsAndReshardingRequirements.begin() + 1,
261  shardingOptionsAndReshardingRequirements.end(),
262  [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
263  const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
264  return std::get<ReshardingRquirementKind>(a) <
265  std::get<ReshardingRquirementKind>(b);
266  });
267 
268  LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
269  << shardingOptionsAndReshardingRequirements << "\n";);
270 
271  return std::get<ShardingOption>(
272  shardingOptionsAndReshardingRequirements.front());
273 }
274 
275 // For each operation that implements the ShardingInterface, infer the sharding
276 // option of the operation from its operands and/or results using the
277 // `getShardingOption` method. If the inferred sharding option is not empty, add
278 // a `shard.shard` operation for all remaining operands and results that do not
279 // have sharding annotations.
280 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
281  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
282  if (op->hasTrait<OpTrait::IsTerminator>() ||
283  (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
284  llvm::isa<shard::ShardOp, shard::ShardingOp, shard::GetShardingOp>(op))
285  return success();
286 
287  if (!shardingOp) {
288  op->emitOpError() << "sharding interface is not implemented.";
289  return failure();
290  }
291 
292  // collect Sharding from results
293  std::vector<Sharding> allowConflictsResultShardings;
294  allowConflictsResultShardings.resize(op->getNumResults());
295  std::vector<Sharding> resultMustShardings;
296  resultMustShardings.resize(op->getNumResults());
297  for (OpResult result : op->getResults()) {
298  FailureOr<std::pair<bool, Sharding>> maybeShardAttr = getSharding(result);
299  if (failed(maybeShardAttr))
300  continue;
301  if (!maybeShardAttr->first)
302  resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
303  else
304  allowConflictsResultShardings[result.getResultNumber()] =
305  maybeShardAttr->second;
306  }
307 
308  // collect Sharding from operands
309  std::vector<Sharding> allowConflictsOperandShardings;
310  allowConflictsOperandShardings.resize(op->getNumOperands());
311  std::vector<Sharding> operandMustShardings;
312  operandMustShardings.resize(op->getNumOperands());
313  for (OpOperand &opOperand : op->getOpOperands()) {
314  FailureOr<std::pair<bool, Sharding>> maybeShardAttr =
315  getSharding(opOperand);
316  if (failed(maybeShardAttr))
317  continue;
318 
319  if (maybeShardAttr->first)
320  operandMustShardings[opOperand.getOperandNumber()] =
321  maybeShardAttr->second;
322  else
323  allowConflictsOperandShardings[opOperand.getOperandNumber()] =
324  maybeShardAttr->second;
325  }
326 
327  // try to get the sharding option
328  SmallVector<std::vector<Sharding>> possibleOperandShardingAttrs =
329  getOrderedPossibleShardingAttrs(operandMustShardings,
330  allowConflictsOperandShardings);
331  SmallVector<std::vector<Sharding>> possibleResultShardingAttrs =
332  getOrderedPossibleShardingAttrs(resultMustShardings,
333  allowConflictsResultShardings);
334  FailureOr<ShardingOption> shardingOption = selectShardingOption(
335  shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
336 
337  if (failed(shardingOption)) {
338  op->emitOpError() << "fail to get sharding option.";
339  return failure();
340  }
341 
342  LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
343 
344  // sharding info is empty, return immediately
345  if (shardingOption->empty)
346  return success();
347 
348  if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
349  op->emitOpError() << "fail to set sharding annotations.";
350  return failure();
351  }
352  return success();
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // ShardingPropagation
357 //===----------------------------------------------------------------------===//
359  : public shard::impl::ShardingPropagationBase<ShardingPropagation> {
360 
361  using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
362 
363  void runOnOperation() override {
364  FunctionOpInterface funcOp = getOperation();
365  MLIRContext *ctx = funcOp.getContext();
366  Region &region = funcOp.getFunctionBody();
367  OpBuilder builder(ctx);
368  if (!region.hasOneBlock()) {
369  funcOp.emitOpError() << "only one block is supported!";
370  return signalPassFailure();
371  }
372  Block &block = region.front();
373 
374  LLVM_DEBUG(
375  DBGS() << "print all the ops' iterator types and indexing maps in the "
376  "block.\n";
377  for (Operation &op : block.getOperations()) {
378  if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
379  shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
380  });
381 
382  auto traverse = [&](auto &&range, OpBuilder &builder,
383  const char *order) -> bool {
384  for (Operation &op : range) {
385  if (failed(visitOp(&op, builder))) {
386  signalPassFailure();
387  return true;
388  }
389  }
390  LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
391  << funcOp << "\n");
392  LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
393  return false;
394  };
395 
396  // 1. Propagate in reversed order.
397  if (traversal == TraversalOrder::Backward ||
398  traversal == TraversalOrder::BackwardForward)
399  traverse(llvm::reverse(block), builder, "backward");
400 
401  // 2. Propagate in original order.
402  if (traversal != TraversalOrder::Backward)
403  traverse(block, builder, "forward");
404 
405  // 3. Propagate in backward order if needed.
406  if (traversal == TraversalOrder::ForwardBackward)
407  traverse(llvm::reverse(block), builder, "backward");
408  }
409 };
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const std::vector< Sharding > &operandAndResultShardings)
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< std::vector< Sharding >> possibleOperandShardingAttrs, ArrayRef< std::vector< Sharding >> possibleResultShardingAttrs)
static SmallVector< std::vector< Sharding > > getOrderedPossibleShardingAttrs(ArrayRef< Sharding > mustShardings, ArrayRef< Sharding > optionalShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
ReshardingRquirementKind
#define DBGS()
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
@ BackwardForward
Backward then forward traversal.
@ Backward
Backward traversal.
@ ForwardBackward
Forward then backward traversal.
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
void runOnOperation() override
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static ShardingOption makeEmpty()