18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
28 #define GEN_PASS_DEF_SHARDINGPROPAGATION
29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
33 #define DEBUG_TYPE "sharding-propagation"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
48 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
50 template <
typename... Ts>
51 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
52 const std::tuple<Ts...> &t);
53 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
56 template <
typename Stream,
typename Range>
57 static Stream &printRange(Stream &stream,
Range &&range) {
59 llvm::for_each(range, [&stream](
auto &v) {
67 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
69 return printRange(stream, vec);
72 [[maybe_unused]]
static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
74 return stream <<
"{empty = " << v.
empty <<
", mesh" << v.
mesh
78 template <
typename Stream,
typename... Ts,
size_t... Is>
79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80 std::index_sequence<Is...>) {
81 static_assert(
sizeof...(Is) ==
sizeof...(Ts),
82 "Indices must have same number of elements as tuple types!");
83 static_assert(
sizeof...(Ts) > 0,
"Cannot insert empty tuple into stream.");
86 ((stream << std::get<Is>(tuple) <<
", "), ...);
90 template <
typename... Ts>
91 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
92 const std::tuple<Ts...> &t) {
93 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
96 [[maybe_unused]]
static llvm::raw_ostream &
98 return stream << static_cast<int>(v);
117 std::function<void(
size_t)> dfsCreateShardingAttrs = [&](
size_t i) {
118 if (i == mustShardings.size()) {
119 allShardingAttrs.push_back(
124 if (mustShardings[i]) {
125 curShardingAttrs.push_back(mustShardings[i]);
126 dfsCreateShardingAttrs(i + 1);
127 curShardingAttrs.pop_back();
131 if (optionalShardings[i]) {
132 curShardingAttrs.push_back(optionalShardings[i]);
133 dfsCreateShardingAttrs(i + 1);
134 curShardingAttrs.pop_back();
135 curShardingAttrs.push_back(
nullptr);
136 dfsCreateShardingAttrs(i + 1);
137 curShardingAttrs.pop_back();
141 curShardingAttrs.push_back(
nullptr);
142 dfsCreateShardingAttrs(i + 1);
143 curShardingAttrs.pop_back();
146 dfsCreateShardingAttrs(0);
147 return allShardingAttrs;
166 auto operandShardings =
167 llvm::make_range(operandAndResultShardings.begin(),
168 operandAndResultShardings.begin() + operandsCount);
169 auto resultShardings =
170 llvm::make_range(operandAndResultShardings.begin() + operandsCount,
171 operandAndResultShardings.end());
173 for (
auto [operand, sharding] :
174 llvm::zip_equal(op->
getOperands(), operandShardings)) {
175 ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
179 bool needsResharding = shardOp.getShardAttr() != sharding;
180 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
181 if (needsResharding) {
182 if (isExplicitAnnotationForThisOp) {
190 for (
auto [result, sharding] :
191 llvm::zip_equal(op->
getResults(), resultShardings)) {
192 for (
auto user : result.getUsers()) {
193 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
197 bool needsResharding = shardOp.getShardAttr() != sharding;
198 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
199 if (needsResharding) {
200 if (isExplicitAnnotationForThisOp) {
220 ShardingInterface shardingOp,
224 shardingOptionsAndReshardingRequirements;
227 possibleResultShardingAttrs) {
229 possibleOperandShardingAttrs) {
230 FailureOr<ShardingOption> shardingOption =
231 shardingOp.getShardingOption(operandShardings, resultShardings);
232 if (failed(shardingOption) || shardingOption->empty) {
240 FailureOr<SmallVector<MeshShardingAttr>> operandAndResultShardings =
241 shardingOp.getShardingAnnotations(*shardingOption);
242 if (failed(operandAndResultShardings)) {
246 LLVM_DEBUG(
DBGS() <<
"operandAndResultShardings = "
247 << *operandAndResultShardings <<
"\n";);
253 return *shardingOption;
256 shardingOptionsAndReshardingRequirements.emplace_back(
257 std::move(*shardingOption), reshardingRquirement);
261 if (shardingOptionsAndReshardingRequirements.empty()) {
266 shardingOptionsAndReshardingRequirements.begin(),
267 shardingOptionsAndReshardingRequirements.begin() + 1,
268 shardingOptionsAndReshardingRequirements.end(),
269 [](
const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
270 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
271 return std::get<ReshardingRquirementKind>(a) <
272 std::get<ReshardingRquirementKind>(b);
275 LLVM_DEBUG(
DBGS() <<
"shardingOptionsAndReshardingRequirements = "
276 << shardingOptionsAndReshardingRequirements <<
"\n";);
278 return std::get<ShardingOption>(
279 shardingOptionsAndReshardingRequirements.front());
291 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
293 op->
emitOpError() <<
"sharding interface is not implemented.";
303 FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
305 if (failed(maybeShardAttr))
307 if (!maybeShardAttr->first)
308 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
310 allowConflictsResultShardings[result.getResultNumber()] =
311 maybeShardAttr->second;
320 FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
322 if (failed(maybeShardAttr))
325 if (maybeShardAttr->first)
326 operandMustShardings[opOperand.getOperandNumber()] =
327 maybeShardAttr->second;
329 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
330 maybeShardAttr->second;
336 allowConflictsOperandShardings);
339 allowConflictsResultShardings);
341 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
343 if (failed(shardingOption)) {
344 op->
emitOpError() <<
"fail to get sharding option.";
348 LLVM_DEBUG(
DBGS() <<
"Selected sharding option: " << *shardingOption <<
"\n");
351 if (shardingOption->empty)
354 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
355 op->
emitOpError() <<
"fail to set sharding annotations.";
365 :
public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
367 FunctionOpInterface funcOp = getOperation();
369 Region ®ion = funcOp.getFunctionBody();
372 funcOp.emitOpError() <<
"only one block is supported!";
378 DBGS() <<
"print all the ops' iterator types and indexing maps in the "
382 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
383 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
387 for (
Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
388 if (failed(
visitOp(&op, builder)))
389 return signalPassFailure();
391 LLVM_DEBUG(
DBGS() <<
"After reversed order propagation:\n"
396 for (
Operation &op : llvm::make_early_inc_range(block))
397 if (failed(
visitOp(&op, builder)))
398 return signalPassFailure();
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< SmallVector< MeshShardingAttr >> possibleOperandShardingAttrs, ArrayRef< SmallVector< MeshShardingAttr >> possibleResultShardingAttrs)
static SmallVector< SmallVector< MeshShardingAttr > > getOrderedPossibleShardingAttrs(ArrayRef< MeshShardingAttr > mustShardings, ArrayRef< MeshShardingAttr > optionalShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
@ RESHARDING_FOR_EXPLICIT_ANNOTATIONS
@ NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const SmallVector< MeshShardingAttr > &operandAndResultShardings)
Block represents an ordered list of Operations.
OpListType & getOperations()
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
void runOnOperation() override
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
ShardingArray shardingArray
static ShardingOption makeEmpty()