MLIR  19.0.0git
ShardingPropagation.cpp
Go to the documentation of this file.
1 //===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/Verifier.h"
17 #include "mlir/Pass/Pass.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include <algorithm>
25 #include <vector>
26 
27 namespace mlir {
28 namespace mesh {
29 #define GEN_PASS_DEF_SHARDINGPROPAGATION
30 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
31 } // namespace mesh
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "sharding-propagation"
35 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
36 
37 using namespace mlir;
38 using namespace mlir::mesh;
39 
41  NO_RESHARDING = 0,
44 };
45 
46 #ifdef LLVM_DEBUG
47 
48 template <typename T>
49 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
50  const SmallVector<T> &vec);
51 template <typename... Ts>
52 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
53  const std::tuple<Ts...> &t);
54 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
56 
57 template <typename Stream, typename Range>
58 static Stream &printRange(Stream &stream, Range &&range) {
59  stream << "[";
60  llvm::for_each(range, [&stream](auto &v) {
61  stream << v;
62  stream << ", ";
63  });
64  return stream << "]";
65 }
66 
67 template <typename T>
68 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
69  const SmallVector<T> &vec) {
70  return printRange(stream, vec);
71 }
72 
73 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
74  const ShardingOption &v) {
75  return stream << "{empty = " << v.empty << ", mesh" << v.mesh
76  << ", shardingArray = " << v.shardingArray << "}";
77 }
78 
79 template <typename Stream, typename... Ts, size_t... Is>
80 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
81  std::index_sequence<Is...>) {
82  static_assert(sizeof...(Is) == sizeof...(Ts),
83  "Indices must have same number of elements as tuple types!");
84  static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
85 
86  stream << "{";
87  ((stream << std::get<Is>(tuple) << ", "), ...);
88  return stream << "}";
89 }
90 
91 template <typename... Ts>
92 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
93  const std::tuple<Ts...> &t) {
94  return printTuple(stream, t, std::index_sequence_for<Ts...>{});
95 }
96 
97 [[maybe_unused]] static llvm::raw_ostream &
98 operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
99  return stream << static_cast<int>(v);
100 }
101 
102 #endif // LLVM_DEBUG
103 
104 //===----------------------------------------------------------------------===//
105 // Utilities
106 //===----------------------------------------------------------------------===//
107 
108 // This method retrieves all potential sharding attributes, prioritizing
109 // specific shardings. For example, mustShardings = [shard0, None] and
110 // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
111 // [shard0, None]]
114  ArrayRef<MeshShardingAttr> optionalShardings) {
116  SmallVector<MeshShardingAttr> curShardingAttrs;
117 
118  std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
119  if (i == mustShardings.size()) {
120  allShardingAttrs.push_back(
121  SmallVector<MeshShardingAttr>(curShardingAttrs));
122  return;
123  }
124 
125  if (mustShardings[i]) {
126  curShardingAttrs.push_back(mustShardings[i]);
127  dfsCreateShardingAttrs(i + 1);
128  curShardingAttrs.pop_back();
129  return;
130  }
131 
132  if (optionalShardings[i]) {
133  curShardingAttrs.push_back(optionalShardings[i]);
134  dfsCreateShardingAttrs(i + 1);
135  curShardingAttrs.pop_back();
136  curShardingAttrs.push_back(nullptr);
137  dfsCreateShardingAttrs(i + 1);
138  curShardingAttrs.pop_back();
139  return;
140  }
141 
142  curShardingAttrs.push_back(nullptr);
143  dfsCreateShardingAttrs(i + 1);
144  curShardingAttrs.pop_back();
145  };
146 
147  dfsCreateShardingAttrs(0);
148  return allShardingAttrs;
149 }
150 
151 // The order of preference is form highest to lowest:
152 // 1. No resharding is required (all existing annotations are compatible).
153 // 2. No resharding for operands/results that have annotation specifically
154 // targeting this operation. This means
155 // * operands that are the result of `mesh.shard` ops marked with
156 // `annotate_for_users`.
157 // * results that are annotated with `mesh.shard` ops without
158 // `annotate_for_users`.
159 // 3. All other cases. Resharding is required for operands/results with
160 // annotation targeting explicitly this operation.
162  Operation *op,
163  const SmallVector<MeshShardingAttr> &operandAndResultShardings) {
165 
166  size_t operandsCount = op->getOperands().size();
167  auto operandShardings =
168  llvm::make_range(operandAndResultShardings.begin(),
169  operandAndResultShardings.begin() + operandsCount);
170  auto resultShardings =
171  llvm::make_range(operandAndResultShardings.begin() + operandsCount,
172  operandAndResultShardings.end());
173 
174  for (auto [operand, sharding] :
175  llvm::zip_equal(op->getOperands(), operandShardings)) {
176  ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
177  if (!shardOp) {
178  continue;
179  }
180  bool needsResharding = shardOp.getShardAttr() != sharding;
181  bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
182  if (needsResharding) {
183  if (isExplicitAnnotationForThisOp) {
184  // This is the worst case. No need to continue.
186  }
188  }
189  }
190 
191  for (auto [result, sharding] :
192  llvm::zip_equal(op->getResults(), resultShardings)) {
193  for (auto user : result.getUsers()) {
194  ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
195  if (!shardOp) {
196  continue;
197  }
198  bool needsResharding = shardOp.getShardAttr() != sharding;
199  bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
200  if (needsResharding) {
201  if (isExplicitAnnotationForThisOp) {
202  // This is the worst case. No need to continue.
204  }
206  }
207  }
208  }
209 
210  return res;
211 }
212 
213 // From all the operand and result sharding combinations,
214 // return the one that is most desirable.
215 // The order of preference is:
216 // 1. No resharding with respect to existing sharding annotations.
217 // 2. Resharding for values that have already annotations that do not target
218 // this op.
219 // 3. Resharding of existing explicit sharding annotations for this op.
221  ShardingInterface shardingOp,
222  ArrayRef<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs,
223  ArrayRef<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs) {
225  shardingOptionsAndReshardingRequirements;
226 
227  for (ArrayRef<MeshShardingAttr> resultShardings :
228  possibleResultShardingAttrs) {
229  for (ArrayRef<MeshShardingAttr> operandShardings :
230  possibleOperandShardingAttrs) {
231  FailureOr<ShardingOption> shardingOption =
232  shardingOp.getShardingOption(operandShardings, resultShardings);
233  if (failed(shardingOption) || shardingOption->empty) {
234  continue;
235  }
236  // These shardings may not be the same as those in operandShardings and
237  // resultShardings.
238  // They may be missing some annotations.
239  // Whatever is returned by getShardingAnnotations is exactly what the op
240  // needs.
241  FailureOr<SmallVector<MeshShardingAttr>> operandAndResultShardings =
242  shardingOp.getShardingAnnotations(*shardingOption);
243  if (failed(operandAndResultShardings)) {
244  return failure();
245  }
246 
247  LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
248  << *operandAndResultShardings << "\n";);
249 
250  ReshardingRquirementKind reshardingRquirement =
251  getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
252  if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
253  // This is the best case. No need to go on.
254  return *shardingOption;
255  }
256 
257  shardingOptionsAndReshardingRequirements.emplace_back(
258  std::move(*shardingOption), reshardingRquirement);
259  }
260  }
261 
262  if (shardingOptionsAndReshardingRequirements.empty()) {
263  return ShardingOption::makeEmpty();
264  }
265 
266  std::partial_sort(
267  shardingOptionsAndReshardingRequirements.begin(),
268  shardingOptionsAndReshardingRequirements.begin() + 1,
269  shardingOptionsAndReshardingRequirements.end(),
270  [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
271  const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
272  return std::get<ReshardingRquirementKind>(a) <
273  std::get<ReshardingRquirementKind>(b);
274  });
275 
276  LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
277  << shardingOptionsAndReshardingRequirements << "\n";);
278 
279  return std::get<ShardingOption>(
280  shardingOptionsAndReshardingRequirements.front());
281 }
282 
283 // For each operation that implements the ShardingInterface, infer the sharding
284 // option of the operation from its operands and/or results using the
285 // `getShardingOption` method. If the inferred sharding option is not empty, add
286 // a `mesh.shard` operation for all remaining operands and results that do not
287 // have sharding annotations.
288 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
289  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
290  return success();
291 
292  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
293  if (!shardingOp) {
294  op->emitOpError() << "sharding interface is not implemented.";
295  return failure();
296  }
297 
298  // collect MeshShardingAttr from results
299  SmallVector<MeshShardingAttr> allowConflictsResultShardings;
300  allowConflictsResultShardings.resize(op->getNumResults());
301  SmallVector<MeshShardingAttr> resultMustShardings;
302  resultMustShardings.resize(op->getNumResults());
303  for (OpResult result : op->getResults()) {
305  getMeshShardingAttr(result);
306  if (failed(maybeShardAttr))
307  continue;
308  if (!maybeShardAttr->first)
309  resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
310  else
311  allowConflictsResultShardings[result.getResultNumber()] =
312  maybeShardAttr->second;
313  }
314 
315  // collect MeshShardingAttr from operands
316  SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
317  allowConflictsOperandShardings.resize(op->getNumOperands());
318  SmallVector<MeshShardingAttr> operandMustShardings;
319  operandMustShardings.resize(op->getNumOperands());
320  for (OpOperand &opOperand : op->getOpOperands()) {
322  getMeshShardingAttr(opOperand);
323  if (failed(maybeShardAttr))
324  continue;
325 
326  if (maybeShardAttr->first)
327  operandMustShardings[opOperand.getOperandNumber()] =
328  maybeShardAttr->second;
329  else
330  allowConflictsOperandShardings[opOperand.getOperandNumber()] =
331  maybeShardAttr->second;
332  }
333 
334  // try to get the sharding option
335  SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
336  getOrderedPossibleShardingAttrs(operandMustShardings,
337  allowConflictsOperandShardings);
338  SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
339  getOrderedPossibleShardingAttrs(resultMustShardings,
340  allowConflictsResultShardings);
342  shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
343 
344  if (failed(shardingOption)) {
345  op->emitOpError() << "fail to get sharding option.";
346  return failure();
347  }
348 
349  LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
350 
351  // sharding info is empty, return immediately
352  if (shardingOption->empty)
353  return success();
354 
355  if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
356  op->emitOpError() << "fail to set sharding annotations.";
357  return failure();
358  }
359  return success();
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // ShardingPropagation
364 //===----------------------------------------------------------------------===//
366  : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
367  void runOnOperation() override {
368  FunctionOpInterface funcOp = getOperation();
369  MLIRContext *ctx = funcOp.getContext();
370  Region &region = funcOp.getFunctionBody();
371  OpBuilder builder(ctx);
372  if (!region.hasOneBlock()) {
373  funcOp.emitOpError() << "only one block is supported!";
374  signalPassFailure();
375  }
376  Block &block = region.front();
377 
378  LLVM_DEBUG(
379  DBGS() << "print all the ops' iterator types and indexing maps in the "
380  "block.\n";
381  for (Operation &op
382  : block.getOperations()) {
383  if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
384  shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
385  });
386 
387  // 1. propagate in reversed order
388  for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
389  if (failed(visitOp(&op, builder)))
390  return signalPassFailure();
391 
392  LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
393  << funcOp << "\n");
394  LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
395 
396  // 2. propagate in original order
397  for (Operation &op : llvm::make_early_inc_range(block))
398  if (failed(visitOp(&op, builder)))
399  return signalPassFailure();
400  }
401 };
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< SmallVector< MeshShardingAttr >> possibleOperandShardingAttrs, ArrayRef< SmallVector< MeshShardingAttr >> possibleResultShardingAttrs)
static SmallVector< SmallVector< MeshShardingAttr > > getOrderedPossibleShardingAttrs(ArrayRef< MeshShardingAttr > mustShardings, ArrayRef< MeshShardingAttr > optionalShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
ReshardingRquirementKind
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const SmallVector< MeshShardingAttr > &operandAndResultShardings)
#define DBGS()
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType & getOperations()
Definition: Block.h:135
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
unsigned getNumOperands()
Definition: Operation.h:341
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
void runOnOperation() override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static ShardingOption makeEmpty()