MLIR 22.0.0git
ShardingPropagation.cpp
Go to the documentation of this file.
1//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/IR/Verifier.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/iterator_range.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/raw_ostream.h"
20#include <algorithm>
21#include <vector>
22
23namespace mlir {
24namespace shard {
25#define GEN_PASS_DEF_SHARDINGPROPAGATION
26#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
27} // namespace shard
28} // namespace mlir
29
30#define DEBUG_TYPE "sharding-propagation"
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
32
33using namespace mlir;
34using namespace mlir::shard;
35
41
42#ifdef LLVM_DEBUG
43
44template <typename T>
45static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
46 const SmallVector<T> &vec);
47template <typename... Ts>
48static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
49 const std::tuple<Ts...> &t);
50static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
52
53template <typename Stream, typename Range>
54static Stream &printRange(Stream &stream, Range &&range) {
55 stream << "[";
56 for (auto &v : range) {
57 stream << v;
58 stream << ", ";
59 }
60 return stream << "]";
61}
62
63template <typename T>
64static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
65 const SmallVector<T> &vec) {
66 return printRange(stream, vec);
67}
68
69[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
70 const ShardingOption &v) {
71 return stream << "{empty = " << v.empty << ", grid" << v.grid
72 << ", shardingArray = " << v.shardingArray << "}";
73}
74
75template <typename Stream, typename... Ts, size_t... Is>
76static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
77 std::index_sequence<Is...>) {
78 static_assert(sizeof...(Is) == sizeof...(Ts),
79 "Indices must have same number of elements as tuple types!");
80 static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
81
82 stream << "{";
83 ((stream << std::get<Is>(tuple) << ", "), ...);
84 return stream << "}";
85}
86
87template <typename... Ts>
88static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
89 const std::tuple<Ts...> &t) {
90 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
91}
92
93[[maybe_unused]] static llvm::raw_ostream &
94operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
95 return stream << static_cast<int>(v);
96}
97
98#endif // LLVM_DEBUG
99
100//===----------------------------------------------------------------------===//
101// Utilities
102//===----------------------------------------------------------------------===//
103
104// This method retrieves all potential sharding attributes, prioritizing
105// specific shardings. For example, mustShardings = [shard0, None] and
106// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
107// [shard0, None]]
110 ArrayRef<Sharding> optionalShardings) {
112 std::vector<Sharding> curShardingAttrs;
114 std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
115 if (i == mustShardings.size()) {
116 allShardingAttrs.push_back(std::vector<Sharding>(curShardingAttrs));
117 return;
119
120 if (mustShardings[i]) {
121 curShardingAttrs.push_back(mustShardings[i]);
122 dfsCreateShardingAttrs(i + 1);
123 curShardingAttrs.pop_back();
124 return;
125 }
127 if (optionalShardings[i]) {
128 curShardingAttrs.push_back(optionalShardings[i]);
129 dfsCreateShardingAttrs(i + 1);
130 curShardingAttrs.pop_back();
131 curShardingAttrs.emplace_back();
132 dfsCreateShardingAttrs(i + 1);
133 curShardingAttrs.pop_back();
134 return;
136
137 curShardingAttrs.emplace_back();
138 dfsCreateShardingAttrs(i + 1);
139 curShardingAttrs.pop_back();
140 };
141
142 dfsCreateShardingAttrs(0);
143 return allShardingAttrs;
144}
146// The order of preference is form highest to lowest:
147// 1. No resharding is required (all existing annotations are compatible).
148// 2. No resharding for operands/results that have annotation specifically
149// targeting this operation. This means
150// * operands that are the result of `shard.shard` ops marked with
151// `annotate_for_users`.
152// * results that are annotated with `shard.shard` ops without
153// `annotate_for_users`.
154// 3. All other cases. Resharding is required for operands/results with
155// annotation targeting explicitly this operation.
157 Operation *op, const std::vector<Sharding> &operandAndResultShardings) {
159
160 size_t operandsCount = op->getOperands().size();
161 auto operandShardings =
162 llvm::make_range(operandAndResultShardings.begin(),
163 operandAndResultShardings.begin() + operandsCount);
164 auto resultShardings =
165 llvm::make_range(operandAndResultShardings.begin() + operandsCount,
166 operandAndResultShardings.end());
167
168 for (auto [operand, sharding] :
169 llvm::zip_equal(op->getOperands(), operandShardings)) {
170 ShardOp shardOp = operand.getDefiningOp<ShardOp>();
171 if (!shardOp) {
172 continue;
173 }
174 bool needsResharding = sharding != shardOp.getSharding();
175 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
176 if (needsResharding) {
177 if (isExplicitAnnotationForThisOp) {
178 // This is the worst case. No need to continue.
183 }
185 for (auto [result, sharding] :
186 llvm::zip_equal(op->getResults(), resultShardings)) {
187 for (auto user : result.getUsers()) {
188 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
189 if (!shardOp) {
190 continue;
191 }
192 bool needsResharding = sharding != shardOp.getSharding();
193 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
194 if (needsResharding) {
195 if (isExplicitAnnotationForThisOp) {
196 // This is the worst case. No need to continue.
198 }
200 }
201 }
202 }
203
204 return res;
205}
206
207// From all the operand and result sharding combinations,
208// return the one that is most desirable.
209// The order of preference is:
210// 1. No resharding with respect to existing sharding annotations.
211// 2. Resharding for values that have already annotations that do not target
212// this op.
213// 3. Resharding of existing explicit sharding annotations for this op.
214static FailureOr<ShardingOption> selectShardingOption(
215 ShardingInterface shardingOp,
216 ArrayRef<std::vector<Sharding>> possibleOperandShardingAttrs,
217 ArrayRef<std::vector<Sharding>> possibleResultShardingAttrs) {
219 shardingOptionsAndReshardingRequirements;
220
221 for (ArrayRef<Sharding> resultShardings : possibleResultShardingAttrs) {
222 for (ArrayRef<Sharding> operandShardings : possibleOperandShardingAttrs) {
223 FailureOr<ShardingOption> shardingOption =
224 shardingOp.getShardingOption(operandShardings, resultShardings);
225 if (failed(shardingOption) || shardingOption->empty) {
226 continue;
227 }
228 // These shardings may not be the same as those in operandShardings and
229 // resultShardings.
230 // They may be missing some annotations.
231 // Whatever is returned by getShardingAnnotations is exactly what the op
232 // needs.
233 FailureOr<std::vector<Sharding>> operandAndResultShardings =
234 shardingOp.getShardingAnnotations(*shardingOption);
235 if (failed(operandAndResultShardings)) {
236 return failure();
237 }
238
239 // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
240 // << *operandAndResultShardings << "\n";);
241
242 ReshardingRquirementKind reshardingRquirement =
243 getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
244 if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
245 // This is the best case. No need to go on.
246 return *shardingOption;
247 }
248
249 shardingOptionsAndReshardingRequirements.emplace_back(
250 std::move(*shardingOption), reshardingRquirement);
251 }
252 }
253
254 if (shardingOptionsAndReshardingRequirements.empty()) {
256 }
257
258 std::partial_sort(
259 shardingOptionsAndReshardingRequirements.begin(),
260 shardingOptionsAndReshardingRequirements.begin() + 1,
261 shardingOptionsAndReshardingRequirements.end(),
262 [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
263 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
264 return std::get<ReshardingRquirementKind>(a) <
265 std::get<ReshardingRquirementKind>(b);
266 });
267
268 LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
269 << shardingOptionsAndReshardingRequirements << "\n";);
270
271 return std::get<ShardingOption>(
272 shardingOptionsAndReshardingRequirements.front());
273}
274
275// For each operation that implements the ShardingInterface, infer the sharding
276// option of the operation from its operands and/or results using the
277// `getShardingOption` method. If the inferred sharding option is not empty, add
278// a `shard.shard` operation for all remaining operands and results that do not
279// have sharding annotations.
280static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
281 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
282 if (op->hasTrait<OpTrait::IsTerminator>() ||
283 (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
284 llvm::isa<shard::ShardOp, shard::ShardingOp, shard::GetShardingOp>(op))
285 return success();
286
287 if (!shardingOp) {
288 op->emitOpError() << "sharding interface is not implemented.";
289 return failure();
290 }
291
292 // collect Sharding from results
293 std::vector<Sharding> allowConflictsResultShardings;
294 allowConflictsResultShardings.resize(op->getNumResults());
295 std::vector<Sharding> resultMustShardings;
296 resultMustShardings.resize(op->getNumResults());
297 for (OpResult result : op->getResults()) {
298 FailureOr<std::pair<bool, Sharding>> maybeShardAttr = getSharding(result);
299 if (failed(maybeShardAttr))
300 continue;
301 if (!maybeShardAttr->first)
302 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
303 else
304 allowConflictsResultShardings[result.getResultNumber()] =
305 maybeShardAttr->second;
306 }
307
308 // collect Sharding from operands
309 std::vector<Sharding> allowConflictsOperandShardings;
310 allowConflictsOperandShardings.resize(op->getNumOperands());
311 std::vector<Sharding> operandMustShardings;
312 operandMustShardings.resize(op->getNumOperands());
313 for (OpOperand &opOperand : op->getOpOperands()) {
314 FailureOr<std::pair<bool, Sharding>> maybeShardAttr =
315 getSharding(opOperand);
316 if (failed(maybeShardAttr))
317 continue;
318
319 if (maybeShardAttr->first)
320 operandMustShardings[opOperand.getOperandNumber()] =
321 maybeShardAttr->second;
322 else
323 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
324 maybeShardAttr->second;
325 }
326
327 // try to get the sharding option
328 SmallVector<std::vector<Sharding>> possibleOperandShardingAttrs =
329 getOrderedPossibleShardingAttrs(operandMustShardings,
330 allowConflictsOperandShardings);
331 SmallVector<std::vector<Sharding>> possibleResultShardingAttrs =
332 getOrderedPossibleShardingAttrs(resultMustShardings,
333 allowConflictsResultShardings);
334 FailureOr<ShardingOption> shardingOption = selectShardingOption(
335 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
336
337 if (failed(shardingOption)) {
338 op->emitOpError() << "fail to get sharding option.";
339 return failure();
340 }
341
342 LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
343
344 // sharding info is empty, return immediately
345 if (shardingOption->empty)
346 return success();
347
348 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
349 op->emitOpError() << "fail to set sharding annotations.";
350 return failure();
351 }
352 return success();
353}
354
355//===----------------------------------------------------------------------===//
356// ShardingPropagation
357//===----------------------------------------------------------------------===//
359 : public shard::impl::ShardingPropagationBase<ShardingPropagation> {
360
362
363 void runOnOperation() override {
364 FunctionOpInterface funcOp = getOperation();
365 MLIRContext *ctx = funcOp.getContext();
366 Region &region = funcOp.getFunctionBody();
367 OpBuilder builder(ctx);
368 if (!region.hasOneBlock()) {
369 funcOp.emitOpError() << "only one block is supported!";
370 return signalPassFailure();
371 }
372 Block &block = region.front();
373
374 LLVM_DEBUG(
375 DBGS() << "print all the ops' iterator types and indexing maps in the "
376 "block.\n";
377 for (Operation &op : block.getOperations()) {
378 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
379 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
380 });
381
382 auto traverse = [&](auto &&range, OpBuilder &builder,
383 const char *order) -> bool {
384 for (Operation &op : range) {
385 if (failed(visitOp(&op, builder))) {
387 return true;
388 }
389 }
390 LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
391 << funcOp << "\n");
392 LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
393 return false;
394 };
395
396 // 1. Propagate in reversed order.
399 traverse(llvm::reverse(block), builder, "backward");
400
401 // 2. Propagate in original order.
403 traverse(block, builder, "forward");
404
405 // 3. Propagate in backward order if needed.
407 traverse(llvm::reverse(block), builder, "backward");
408 }
409};
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const std::vector< Sharding > &operandAndResultShardings)
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< std::vector< Sharding > > possibleOperandShardingAttrs, ArrayRef< std::vector< Sharding > > possibleResultShardingAttrs)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
ReshardingRquirementKind
static SmallVector< std::vector< Sharding > > getOrderedPossibleShardingAttrs(ArrayRef< Sharding > mustShardings, ArrayRef< Sharding > optionalShardings)
#define DBGS()
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:137
mlir::FunctionOpInterface getOperation()
Definition Pass.h:444
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:218
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
::mlir::Pass::Option< mlir::shard::TraversalOrder > traversal
@ BackwardForward
Backward then forward traversal.
Definition Passes.h:31
@ Backward
Backward traversal.
Definition Passes.h:27
@ ForwardBackward
Forward then backward traversal.
Definition Passes.h:29
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static ShardingOption makeEmpty()