18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
28 #define GEN_PASS_DEF_SHARDINGPROPAGATION
29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
33 #define DEBUG_TYPE "sharding-propagation"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
48 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
50 template <
typename... Ts>
51 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
52 const std::tuple<Ts...> &t);
53 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
56 template <
typename Stream,
typename Range>
57 static Stream &printRange(Stream &stream,
Range &&range) {
59 llvm::for_each(range, [&stream](
auto &v) {
67 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
69 return printRange(stream, vec);
72 [[maybe_unused]]
static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
74 return stream <<
"{empty = " << v.
empty <<
", mesh" << v.
mesh
78 template <
typename Stream,
typename... Ts,
size_t... Is>
79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80 std::index_sequence<Is...>) {
81 static_assert(
sizeof...(Is) ==
sizeof...(Ts),
82 "Indices must have same number of elements as tuple types!");
83 static_assert(
sizeof...(Ts) > 0,
"Cannot insert empty tuple into stream.");
86 ((stream << std::get<Is>(tuple) <<
", "), ...);
90 template <
typename... Ts>
91 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
92 const std::tuple<Ts...> &t) {
93 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
96 [[maybe_unused]]
static llvm::raw_ostream &
98 return stream << static_cast<int>(v);
115 std::vector<MeshSharding> curShardingAttrs;
117 std::function<void(
size_t)> dfsCreateShardingAttrs = [&](
size_t i) {
118 if (i == mustShardings.size()) {
119 allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
123 if (mustShardings[i]) {
124 curShardingAttrs.push_back(mustShardings[i]);
125 dfsCreateShardingAttrs(i + 1);
126 curShardingAttrs.pop_back();
130 if (optionalShardings[i]) {
131 curShardingAttrs.push_back(optionalShardings[i]);
132 dfsCreateShardingAttrs(i + 1);
133 curShardingAttrs.pop_back();
134 curShardingAttrs.push_back({});
135 dfsCreateShardingAttrs(i + 1);
136 curShardingAttrs.pop_back();
140 curShardingAttrs.push_back({});
141 dfsCreateShardingAttrs(i + 1);
142 curShardingAttrs.pop_back();
145 dfsCreateShardingAttrs(0);
146 return allShardingAttrs;
160 Operation *op,
const std::vector<MeshSharding> &operandAndResultShardings) {
164 auto operandShardings =
165 llvm::make_range(operandAndResultShardings.begin(),
166 operandAndResultShardings.begin() + operandsCount);
167 auto resultShardings =
168 llvm::make_range(operandAndResultShardings.begin() + operandsCount,
169 operandAndResultShardings.end());
171 for (
auto [operand, sharding] :
172 llvm::zip_equal(op->
getOperands(), operandShardings)) {
173 ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
177 bool needsResharding = sharding != shardOp.getSharding();
178 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
179 if (needsResharding) {
180 if (isExplicitAnnotationForThisOp) {
188 for (
auto [result, sharding] :
189 llvm::zip_equal(op->
getResults(), resultShardings)) {
190 for (
auto user : result.getUsers()) {
191 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
195 bool needsResharding = sharding != shardOp.getSharding();
196 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
197 if (needsResharding) {
198 if (isExplicitAnnotationForThisOp) {
218 ShardingInterface shardingOp,
219 ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
220 ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
222 shardingOptionsAndReshardingRequirements;
226 possibleOperandShardingAttrs) {
227 FailureOr<ShardingOption> shardingOption =
228 shardingOp.getShardingOption(operandShardings, resultShardings);
229 if (failed(shardingOption) || shardingOption->empty) {
237 FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
238 shardingOp.getShardingAnnotations(*shardingOption);
239 if (failed(operandAndResultShardings)) {
250 return *shardingOption;
253 shardingOptionsAndReshardingRequirements.emplace_back(
254 std::move(*shardingOption), reshardingRquirement);
258 if (shardingOptionsAndReshardingRequirements.empty()) {
263 shardingOptionsAndReshardingRequirements.begin(),
264 shardingOptionsAndReshardingRequirements.begin() + 1,
265 shardingOptionsAndReshardingRequirements.end(),
266 [](
const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
267 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
268 return std::get<ReshardingRquirementKind>(a) <
269 std::get<ReshardingRquirementKind>(b);
272 LLVM_DEBUG(
DBGS() <<
"shardingOptionsAndReshardingRequirements = "
273 << shardingOptionsAndReshardingRequirements <<
"\n";);
275 return std::get<ShardingOption>(
276 shardingOptionsAndReshardingRequirements.front());
285 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
288 llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
292 op->
emitOpError() <<
"sharding interface is not implemented.";
297 std::vector<MeshSharding> allowConflictsResultShardings;
299 std::vector<MeshSharding> resultMustShardings;
302 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
304 if (failed(maybeShardAttr))
306 if (!maybeShardAttr->first)
307 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
309 allowConflictsResultShardings[result.getResultNumber()] =
310 maybeShardAttr->second;
314 std::vector<MeshSharding> allowConflictsOperandShardings;
316 std::vector<MeshSharding> operandMustShardings;
319 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
321 if (failed(maybeShardAttr))
324 if (maybeShardAttr->first)
325 operandMustShardings[opOperand.getOperandNumber()] =
326 maybeShardAttr->second;
328 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
329 maybeShardAttr->second;
335 allowConflictsOperandShardings);
338 allowConflictsResultShardings);
340 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
342 if (failed(shardingOption)) {
343 op->
emitOpError() <<
"fail to get sharding option.";
347 LLVM_DEBUG(
DBGS() <<
"Selected sharding option: " << *shardingOption <<
"\n");
350 if (shardingOption->empty)
353 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
354 op->
emitOpError() <<
"fail to set sharding annotations.";
364 :
public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
366 FunctionOpInterface funcOp = getOperation();
368 Region ®ion = funcOp.getFunctionBody();
371 funcOp.emitOpError() <<
"only one block is supported!";
372 return signalPassFailure();
377 DBGS() <<
"print all the ops' iterator types and indexing maps in the "
381 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
382 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
386 for (
Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
387 if (failed(
visitOp(&op, builder)))
388 return signalPassFailure();
390 LLVM_DEBUG(
DBGS() <<
"After reversed order propagation:\n"
395 for (
Operation &op : llvm::make_early_inc_range(block))
396 if (failed(
visitOp(&op, builder)))
397 return signalPassFailure();
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const std::vector< MeshSharding > &operandAndResultShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< std::vector< MeshSharding >> possibleOperandShardingAttrs, ArrayRef< std::vector< MeshSharding >> possibleResultShardingAttrs)
@ RESHARDING_FOR_EXPLICIT_ANNOTATIONS
@ NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS
static SmallVector< std::vector< MeshSharding > > getOrderedPossibleShardingAttrs(ArrayRef< MeshSharding > mustShardings, ArrayRef< MeshSharding > optionalShardings)
Block represents an ordered list of Operations.
OpListType & getOperations()
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
void runOnOperation() override
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
ShardingArray shardingArray
static ShardingOption makeEmpty()