MLIR  20.0.0git
ShardingPropagation.cpp
Go to the documentation of this file.
1 //===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/Verifier.h"
17 #include "mlir/Pass/Pass.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include <algorithm>
24 #include <vector>
25 
26 namespace mlir {
27 namespace mesh {
28 #define GEN_PASS_DEF_SHARDINGPROPAGATION
29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
30 } // namespace mesh
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "sharding-propagation"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35 
36 using namespace mlir;
37 using namespace mlir::mesh;
38 
40  NO_RESHARDING = 0,
43 };
44 
45 #ifdef LLVM_DEBUG
46 
47 template <typename T>
48 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
49  const SmallVector<T> &vec);
50 template <typename... Ts>
51 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
52  const std::tuple<Ts...> &t);
53 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
55 
56 template <typename Stream, typename Range>
57 static Stream &printRange(Stream &stream, Range &&range) {
58  stream << "[";
59  llvm::for_each(range, [&stream](auto &v) {
60  stream << v;
61  stream << ", ";
62  });
63  return stream << "]";
64 }
65 
66 template <typename T>
67 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
68  const SmallVector<T> &vec) {
69  return printRange(stream, vec);
70 }
71 
72 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
73  const ShardingOption &v) {
74  return stream << "{empty = " << v.empty << ", mesh" << v.mesh
75  << ", shardingArray = " << v.shardingArray << "}";
76 }
77 
78 template <typename Stream, typename... Ts, size_t... Is>
79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80  std::index_sequence<Is...>) {
81  static_assert(sizeof...(Is) == sizeof...(Ts),
82  "Indices must have same number of elements as tuple types!");
83  static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
84 
85  stream << "{";
86  ((stream << std::get<Is>(tuple) << ", "), ...);
87  return stream << "}";
88 }
89 
90 template <typename... Ts>
91 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
92  const std::tuple<Ts...> &t) {
93  return printTuple(stream, t, std::index_sequence_for<Ts...>{});
94 }
95 
96 [[maybe_unused]] static llvm::raw_ostream &
97 operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
98  return stream << static_cast<int>(v);
99 }
100 
101 #endif // LLVM_DEBUG
102 
103 //===----------------------------------------------------------------------===//
104 // Utilities
105 //===----------------------------------------------------------------------===//
106 
107 // This method retrieves all potential sharding attributes, prioritizing
108 // specific shardings. For example, mustShardings = [shard0, None] and
109 // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
110 // [shard0, None]]
113  ArrayRef<MeshShardingAttr> optionalShardings) {
115  SmallVector<MeshShardingAttr> curShardingAttrs;
116 
117  std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
118  if (i == mustShardings.size()) {
119  allShardingAttrs.push_back(
120  SmallVector<MeshShardingAttr>(curShardingAttrs));
121  return;
122  }
123 
124  if (mustShardings[i]) {
125  curShardingAttrs.push_back(mustShardings[i]);
126  dfsCreateShardingAttrs(i + 1);
127  curShardingAttrs.pop_back();
128  return;
129  }
130 
131  if (optionalShardings[i]) {
132  curShardingAttrs.push_back(optionalShardings[i]);
133  dfsCreateShardingAttrs(i + 1);
134  curShardingAttrs.pop_back();
135  curShardingAttrs.push_back(nullptr);
136  dfsCreateShardingAttrs(i + 1);
137  curShardingAttrs.pop_back();
138  return;
139  }
140 
141  curShardingAttrs.push_back(nullptr);
142  dfsCreateShardingAttrs(i + 1);
143  curShardingAttrs.pop_back();
144  };
145 
146  dfsCreateShardingAttrs(0);
147  return allShardingAttrs;
148 }
149 
150 // The order of preference is form highest to lowest:
151 // 1. No resharding is required (all existing annotations are compatible).
152 // 2. No resharding for operands/results that have annotation specifically
153 // targeting this operation. This means
154 // * operands that are the result of `mesh.shard` ops marked with
155 // `annotate_for_users`.
156 // * results that are annotated with `mesh.shard` ops without
157 // `annotate_for_users`.
158 // 3. All other cases. Resharding is required for operands/results with
159 // annotation targeting explicitly this operation.
161  Operation *op,
162  const SmallVector<MeshShardingAttr> &operandAndResultShardings) {
164 
165  size_t operandsCount = op->getOperands().size();
166  auto operandShardings =
167  llvm::make_range(operandAndResultShardings.begin(),
168  operandAndResultShardings.begin() + operandsCount);
169  auto resultShardings =
170  llvm::make_range(operandAndResultShardings.begin() + operandsCount,
171  operandAndResultShardings.end());
172 
173  for (auto [operand, sharding] :
174  llvm::zip_equal(op->getOperands(), operandShardings)) {
175  ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
176  if (!shardOp) {
177  continue;
178  }
179  bool needsResharding = shardOp.getShardAttr() != sharding;
180  bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
181  if (needsResharding) {
182  if (isExplicitAnnotationForThisOp) {
183  // This is the worst case. No need to continue.
185  }
187  }
188  }
189 
190  for (auto [result, sharding] :
191  llvm::zip_equal(op->getResults(), resultShardings)) {
192  for (auto user : result.getUsers()) {
193  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
194  if (!shardOp) {
195  continue;
196  }
197  bool needsResharding = shardOp.getShardAttr() != sharding;
198  bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
199  if (needsResharding) {
200  if (isExplicitAnnotationForThisOp) {
201  // This is the worst case. No need to continue.
203  }
205  }
206  }
207  }
208 
209  return res;
210 }
211 
212 // From all the operand and result sharding combinations,
213 // return the one that is most desirable.
214 // The order of preference is:
215 // 1. No resharding with respect to existing sharding annotations.
216 // 2. Resharding for values that have already annotations that do not target
217 // this op.
218 // 3. Resharding of existing explicit sharding annotations for this op.
219 static FailureOr<ShardingOption> selectShardingOption(
220  ShardingInterface shardingOp,
221  ArrayRef<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs,
222  ArrayRef<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs) {
224  shardingOptionsAndReshardingRequirements;
225 
226  for (ArrayRef<MeshShardingAttr> resultShardings :
227  possibleResultShardingAttrs) {
228  for (ArrayRef<MeshShardingAttr> operandShardings :
229  possibleOperandShardingAttrs) {
230  FailureOr<ShardingOption> shardingOption =
231  shardingOp.getShardingOption(operandShardings, resultShardings);
232  if (failed(shardingOption) || shardingOption->empty) {
233  continue;
234  }
235  // These shardings may not be the same as those in operandShardings and
236  // resultShardings.
237  // They may be missing some annotations.
238  // Whatever is returned by getShardingAnnotations is exactly what the op
239  // needs.
240  FailureOr<SmallVector<MeshShardingAttr>> operandAndResultShardings =
241  shardingOp.getShardingAnnotations(*shardingOption);
242  if (failed(operandAndResultShardings)) {
243  return failure();
244  }
245 
246  LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
247  << *operandAndResultShardings << "\n";);
248 
249  ReshardingRquirementKind reshardingRquirement =
250  getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
251  if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
252  // This is the best case. No need to go on.
253  return *shardingOption;
254  }
255 
256  shardingOptionsAndReshardingRequirements.emplace_back(
257  std::move(*shardingOption), reshardingRquirement);
258  }
259  }
260 
261  if (shardingOptionsAndReshardingRequirements.empty()) {
262  return ShardingOption::makeEmpty();
263  }
264 
265  std::partial_sort(
266  shardingOptionsAndReshardingRequirements.begin(),
267  shardingOptionsAndReshardingRequirements.begin() + 1,
268  shardingOptionsAndReshardingRequirements.end(),
269  [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
270  const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
271  return std::get<ReshardingRquirementKind>(a) <
272  std::get<ReshardingRquirementKind>(b);
273  });
274 
275  LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
276  << shardingOptionsAndReshardingRequirements << "\n";);
277 
278  return std::get<ShardingOption>(
279  shardingOptionsAndReshardingRequirements.front());
280 }
281 
282 // For each operation that implements the ShardingInterface, infer the sharding
283 // option of the operation from its operands and/or results using the
284 // `getShardingOption` method. If the inferred sharding option is not empty, add
285 // a `mesh.shard` operation for all remaining operands and results that do not
286 // have sharding annotations.
287 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
288  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
289  return success();
290 
291  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
292  if (!shardingOp) {
293  op->emitOpError() << "sharding interface is not implemented.";
294  return failure();
295  }
296 
297  // collect MeshShardingAttr from results
298  SmallVector<MeshShardingAttr> allowConflictsResultShardings;
299  allowConflictsResultShardings.resize(op->getNumResults());
300  SmallVector<MeshShardingAttr> resultMustShardings;
301  resultMustShardings.resize(op->getNumResults());
302  for (OpResult result : op->getResults()) {
303  FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
304  getMeshShardingAttr(result);
305  if (failed(maybeShardAttr))
306  continue;
307  if (!maybeShardAttr->first)
308  resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
309  else
310  allowConflictsResultShardings[result.getResultNumber()] =
311  maybeShardAttr->second;
312  }
313 
314  // collect MeshShardingAttr from operands
315  SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
316  allowConflictsOperandShardings.resize(op->getNumOperands());
317  SmallVector<MeshShardingAttr> operandMustShardings;
318  operandMustShardings.resize(op->getNumOperands());
319  for (OpOperand &opOperand : op->getOpOperands()) {
320  FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
321  getMeshShardingAttr(opOperand);
322  if (failed(maybeShardAttr))
323  continue;
324 
325  if (maybeShardAttr->first)
326  operandMustShardings[opOperand.getOperandNumber()] =
327  maybeShardAttr->second;
328  else
329  allowConflictsOperandShardings[opOperand.getOperandNumber()] =
330  maybeShardAttr->second;
331  }
332 
333  // try to get the sharding option
334  SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
335  getOrderedPossibleShardingAttrs(operandMustShardings,
336  allowConflictsOperandShardings);
337  SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
338  getOrderedPossibleShardingAttrs(resultMustShardings,
339  allowConflictsResultShardings);
340  FailureOr<ShardingOption> shardingOption = selectShardingOption(
341  shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
342 
343  if (failed(shardingOption)) {
344  op->emitOpError() << "fail to get sharding option.";
345  return failure();
346  }
347 
348  LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
349 
350  // sharding info is empty, return immediately
351  if (shardingOption->empty)
352  return success();
353 
354  if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
355  op->emitOpError() << "fail to set sharding annotations.";
356  return failure();
357  }
358  return success();
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // ShardingPropagation
363 //===----------------------------------------------------------------------===//
365  : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
366  void runOnOperation() override {
367  FunctionOpInterface funcOp = getOperation();
368  MLIRContext *ctx = funcOp.getContext();
369  Region &region = funcOp.getFunctionBody();
370  OpBuilder builder(ctx);
371  if (!region.hasOneBlock()) {
372  funcOp.emitOpError() << "only one block is supported!";
373  signalPassFailure();
374  }
375  Block &block = region.front();
376 
377  LLVM_DEBUG(
378  DBGS() << "print all the ops' iterator types and indexing maps in the "
379  "block.\n";
380  for (Operation &op
381  : block.getOperations()) {
382  if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
383  shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
384  });
385 
386  // 1. propagate in reversed order
387  for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
388  if (failed(visitOp(&op, builder)))
389  return signalPassFailure();
390 
391  LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
392  << funcOp << "\n");
393  LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
394 
395  // 2. propagate in original order
396  for (Operation &op : llvm::make_early_inc_range(block))
397  if (failed(visitOp(&op, builder)))
398  return signalPassFailure();
399  }
400 };
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< SmallVector< MeshShardingAttr >> possibleOperandShardingAttrs, ArrayRef< SmallVector< MeshShardingAttr >> possibleResultShardingAttrs)
static SmallVector< SmallVector< MeshShardingAttr > > getOrderedPossibleShardingAttrs(ArrayRef< MeshShardingAttr > mustShardings, ArrayRef< MeshShardingAttr > optionalShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
ReshardingRquirementKind
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const SmallVector< MeshShardingAttr > &operandAndResultShardings)
#define DBGS()
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType & getOperations()
Definition: Block.h:135
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:210
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
unsigned getNumOperands()
Definition: Operation.h:341
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
void runOnOperation() override
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static ShardingOption makeEmpty()