16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/iterator_range.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/raw_ostream.h"
25 #define GEN_PASS_DEF_SHARDINGPROPAGATION
26 #include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "sharding-propagation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
45 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
47 template <
typename... Ts>
48 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
49 const std::tuple<Ts...> &t);
50 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
53 template <
typename Stream,
typename Range>
54 static Stream &printRange(Stream &stream,
Range &&range) {
56 for (
auto &v : range) {
64 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
66 return printRange(stream, vec);
69 [[maybe_unused]]
static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
71 return stream <<
"{empty = " << v.
empty <<
", grid" << v.
grid
75 template <
typename Stream,
typename... Ts,
size_t... Is>
76 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
77 std::index_sequence<Is...>) {
78 static_assert(
sizeof...(Is) ==
sizeof...(Ts),
79 "Indices must have same number of elements as tuple types!");
80 static_assert(
sizeof...(Ts) > 0,
"Cannot insert empty tuple into stream.");
83 ((stream << std::get<Is>(tuple) <<
", "), ...);
87 template <
typename... Ts>
88 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
89 const std::tuple<Ts...> &t) {
90 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
93 [[maybe_unused]]
static llvm::raw_ostream &
95 return stream << static_cast<int>(v);
112 std::vector<Sharding> curShardingAttrs;
114 std::function<void(
size_t)> dfsCreateShardingAttrs = [&](
size_t i) {
115 if (i == mustShardings.size()) {
116 allShardingAttrs.push_back(std::vector<Sharding>(curShardingAttrs));
120 if (mustShardings[i]) {
121 curShardingAttrs.push_back(mustShardings[i]);
122 dfsCreateShardingAttrs(i + 1);
123 curShardingAttrs.pop_back();
127 if (optionalShardings[i]) {
128 curShardingAttrs.push_back(optionalShardings[i]);
129 dfsCreateShardingAttrs(i + 1);
130 curShardingAttrs.pop_back();
131 curShardingAttrs.push_back({});
132 dfsCreateShardingAttrs(i + 1);
133 curShardingAttrs.pop_back();
137 curShardingAttrs.push_back({});
138 dfsCreateShardingAttrs(i + 1);
139 curShardingAttrs.pop_back();
142 dfsCreateShardingAttrs(0);
143 return allShardingAttrs;
157 Operation *op,
const std::vector<Sharding> &operandAndResultShardings) {
161 auto operandShardings =
162 llvm::make_range(operandAndResultShardings.begin(),
163 operandAndResultShardings.begin() + operandsCount);
164 auto resultShardings =
165 llvm::make_range(operandAndResultShardings.begin() + operandsCount,
166 operandAndResultShardings.end());
168 for (
auto [operand, sharding] :
169 llvm::zip_equal(op->
getOperands(), operandShardings)) {
170 ShardOp shardOp = operand.getDefiningOp<ShardOp>();
174 bool needsResharding = sharding != shardOp.getSharding();
175 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
176 if (needsResharding) {
177 if (isExplicitAnnotationForThisOp) {
185 for (
auto [result, sharding] :
186 llvm::zip_equal(op->
getResults(), resultShardings)) {
187 for (
auto user : result.getUsers()) {
188 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
192 bool needsResharding = sharding != shardOp.getSharding();
193 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
194 if (needsResharding) {
195 if (isExplicitAnnotationForThisOp) {
215 ShardingInterface shardingOp,
216 ArrayRef<std::vector<Sharding>> possibleOperandShardingAttrs,
217 ArrayRef<std::vector<Sharding>> possibleResultShardingAttrs) {
219 shardingOptionsAndReshardingRequirements;
223 FailureOr<ShardingOption> shardingOption =
224 shardingOp.getShardingOption(operandShardings, resultShardings);
225 if (
failed(shardingOption) || shardingOption->empty) {
233 FailureOr<std::vector<Sharding>> operandAndResultShardings =
234 shardingOp.getShardingAnnotations(*shardingOption);
235 if (
failed(operandAndResultShardings)) {
246 return *shardingOption;
249 shardingOptionsAndReshardingRequirements.emplace_back(
250 std::move(*shardingOption), reshardingRquirement);
254 if (shardingOptionsAndReshardingRequirements.empty()) {
259 shardingOptionsAndReshardingRequirements.begin(),
260 shardingOptionsAndReshardingRequirements.begin() + 1,
261 shardingOptionsAndReshardingRequirements.end(),
262 [](
const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
263 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
264 return std::get<ReshardingRquirementKind>(a) <
265 std::get<ReshardingRquirementKind>(b);
268 LLVM_DEBUG(
DBGS() <<
"shardingOptionsAndReshardingRequirements = "
269 << shardingOptionsAndReshardingRequirements <<
"\n";);
271 return std::get<ShardingOption>(
272 shardingOptionsAndReshardingRequirements.front());
281 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
284 llvm::isa<shard::ShardOp, shard::ShardingOp, shard::GetShardingOp>(op))
288 op->
emitOpError() <<
"sharding interface is not implemented.";
293 std::vector<Sharding> allowConflictsResultShardings;
295 std::vector<Sharding> resultMustShardings;
298 FailureOr<std::pair<bool, Sharding>> maybeShardAttr =
getSharding(result);
299 if (
failed(maybeShardAttr))
301 if (!maybeShardAttr->first)
302 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
304 allowConflictsResultShardings[result.getResultNumber()] =
305 maybeShardAttr->second;
309 std::vector<Sharding> allowConflictsOperandShardings;
311 std::vector<Sharding> operandMustShardings;
314 FailureOr<std::pair<bool, Sharding>> maybeShardAttr =
316 if (
failed(maybeShardAttr))
319 if (maybeShardAttr->first)
320 operandMustShardings[opOperand.getOperandNumber()] =
321 maybeShardAttr->second;
323 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
324 maybeShardAttr->second;
330 allowConflictsOperandShardings);
333 allowConflictsResultShardings);
335 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
337 if (
failed(shardingOption)) {
338 op->
emitOpError() <<
"fail to get sharding option.";
342 LLVM_DEBUG(
DBGS() <<
"Selected sharding option: " << *shardingOption <<
"\n");
345 if (shardingOption->empty)
348 if (
failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
349 op->
emitOpError() <<
"fail to set sharding annotations.";
359 :
public shard::impl::ShardingPropagationBase<ShardingPropagation> {
361 using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
364 FunctionOpInterface funcOp = getOperation();
366 Region ®ion = funcOp.getFunctionBody();
369 funcOp.emitOpError() <<
"only one block is supported!";
370 return signalPassFailure();
375 DBGS() <<
"print all the ops' iterator types and indexing maps in the "
378 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
379 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
382 auto traverse = [&](
auto &&range,
OpBuilder &builder,
383 const char *order) ->
bool {
390 LLVM_DEBUG(
DBGS() <<
"After " << order <<
" order propagation:\n"
399 traverse(llvm::reverse(block), builder,
"backward");
403 traverse(block, builder,
"forward");
407 traverse(llvm::reverse(block), builder,
"backward");
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const std::vector< Sharding > &operandAndResultShardings)
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< std::vector< Sharding >> possibleOperandShardingAttrs, ArrayRef< std::vector< Sharding >> possibleResultShardingAttrs)
static SmallVector< std::vector< Sharding > > getOrderedPossibleShardingAttrs(ArrayRef< Sharding > mustShardings, ArrayRef< Sharding > optionalShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
@ RESHARDING_FOR_EXPLICIT_ANNOTATIONS
@ NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS
Block represents an ordered list of Operations.
OpListType & getOperations()
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
@ BackwardForward
Backward then forward traversal.
@ Backward
Backward traversal.
@ ForwardBackward
Forward then backward traversal.
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
void runOnOperation() override
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
ShardingArray shardingArray
static ShardingOption makeEmpty()