18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
28 #define GEN_PASS_DEF_SHARDINGPROPAGATION
29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
33 #define DEBUG_TYPE "sharding-propagation"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
48 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
50 template <
typename... Ts>
51 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
52 const std::tuple<Ts...> &t);
53 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
56 template <
typename Stream,
typename Range>
57 static Stream &printRange(Stream &stream,
Range &&range) {
59 llvm::for_each(range, [&stream](
auto &v) {
67 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
69 return printRange(stream, vec);
72 [[maybe_unused]]
static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
74 return stream <<
"{empty = " << v.
empty <<
", mesh" << v.
mesh
78 template <
typename Stream,
typename... Ts,
size_t... Is>
79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80 std::index_sequence<Is...>) {
81 static_assert(
sizeof...(Is) ==
sizeof...(Ts),
82 "Indices must have same number of elements as tuple types!");
83 static_assert(
sizeof...(Ts) > 0,
"Cannot insert empty tuple into stream.");
86 ((stream << std::get<Is>(tuple) <<
", "), ...);
90 template <
typename... Ts>
91 static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream,
92 const std::tuple<Ts...> &t) {
93 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
96 [[maybe_unused]]
static llvm::raw_ostream &
98 return stream << static_cast<int>(v);
115 std::vector<MeshSharding> curShardingAttrs;
117 std::function<void(
size_t)> dfsCreateShardingAttrs = [&](
size_t i) {
118 if (i == mustShardings.size()) {
119 allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
123 if (mustShardings[i]) {
124 curShardingAttrs.push_back(mustShardings[i]);
125 dfsCreateShardingAttrs(i + 1);
126 curShardingAttrs.pop_back();
130 if (optionalShardings[i]) {
131 curShardingAttrs.push_back(optionalShardings[i]);
132 dfsCreateShardingAttrs(i + 1);
133 curShardingAttrs.pop_back();
134 curShardingAttrs.push_back({});
135 dfsCreateShardingAttrs(i + 1);
136 curShardingAttrs.pop_back();
140 curShardingAttrs.push_back({});
141 dfsCreateShardingAttrs(i + 1);
142 curShardingAttrs.pop_back();
145 dfsCreateShardingAttrs(0);
146 return allShardingAttrs;
160 Operation *op,
const std::vector<MeshSharding> &operandAndResultShardings) {
164 auto operandShardings =
165 llvm::make_range(operandAndResultShardings.begin(),
166 operandAndResultShardings.begin() + operandsCount);
167 auto resultShardings =
168 llvm::make_range(operandAndResultShardings.begin() + operandsCount,
169 operandAndResultShardings.end());
171 for (
auto [operand, sharding] :
172 llvm::zip_equal(op->
getOperands(), operandShardings)) {
173 ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
177 bool needsResharding = sharding != shardOp.getSharding();
178 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
179 if (needsResharding) {
180 if (isExplicitAnnotationForThisOp) {
188 for (
auto [result, sharding] :
189 llvm::zip_equal(op->
getResults(), resultShardings)) {
190 for (
auto user : result.getUsers()) {
191 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
195 bool needsResharding = sharding != shardOp.getSharding();
196 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
197 if (needsResharding) {
198 if (isExplicitAnnotationForThisOp) {
218 ShardingInterface shardingOp,
219 ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
220 ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
222 shardingOptionsAndReshardingRequirements;
226 possibleOperandShardingAttrs) {
227 FailureOr<ShardingOption> shardingOption =
228 shardingOp.getShardingOption(operandShardings, resultShardings);
229 if (failed(shardingOption) || shardingOption->empty) {
237 FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
238 shardingOp.getShardingAnnotations(*shardingOption);
239 if (failed(operandAndResultShardings)) {
250 return *shardingOption;
253 shardingOptionsAndReshardingRequirements.emplace_back(
254 std::move(*shardingOption), reshardingRquirement);
258 if (shardingOptionsAndReshardingRequirements.empty()) {
263 shardingOptionsAndReshardingRequirements.begin(),
264 shardingOptionsAndReshardingRequirements.begin() + 1,
265 shardingOptionsAndReshardingRequirements.end(),
266 [](
const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
267 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
268 return std::get<ReshardingRquirementKind>(a) <
269 std::get<ReshardingRquirementKind>(b);
272 LLVM_DEBUG(
DBGS() <<
"shardingOptionsAndReshardingRequirements = "
273 << shardingOptionsAndReshardingRequirements <<
"\n";);
275 return std::get<ShardingOption>(
276 shardingOptionsAndReshardingRequirements.front());
286 llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
289 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
291 op->
emitOpError() <<
"sharding interface is not implemented.";
296 std::vector<MeshSharding> allowConflictsResultShardings;
298 std::vector<MeshSharding> resultMustShardings;
301 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
303 if (failed(maybeShardAttr))
305 if (!maybeShardAttr->first)
306 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
308 allowConflictsResultShardings[result.getResultNumber()] =
309 maybeShardAttr->second;
313 std::vector<MeshSharding> allowConflictsOperandShardings;
315 std::vector<MeshSharding> operandMustShardings;
318 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
320 if (failed(maybeShardAttr))
323 if (maybeShardAttr->first)
324 operandMustShardings[opOperand.getOperandNumber()] =
325 maybeShardAttr->second;
327 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
328 maybeShardAttr->second;
334 allowConflictsOperandShardings);
337 allowConflictsResultShardings);
339 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
341 if (failed(shardingOption)) {
342 op->
emitOpError() <<
"fail to get sharding option.";
346 LLVM_DEBUG(
DBGS() <<
"Selected sharding option: " << *shardingOption <<
"\n");
349 if (shardingOption->empty)
352 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
353 op->
emitOpError() <<
"fail to set sharding annotations.";
363 :
public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
365 FunctionOpInterface funcOp = getOperation();
367 Region ®ion = funcOp.getFunctionBody();
370 funcOp.emitOpError() <<
"only one block is supported!";
376 DBGS() <<
"print all the ops' iterator types and indexing maps in the "
380 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
381 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
385 for (
Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
386 if (failed(
visitOp(&op, builder)))
387 return signalPassFailure();
389 LLVM_DEBUG(
DBGS() <<
"After reversed order propagation:\n"
394 for (
Operation &op : llvm::make_early_inc_range(block))
395 if (failed(
visitOp(&op, builder)))
396 return signalPassFailure();
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const std::vector< MeshSharding > &operandAndResultShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< std::vector< MeshSharding >> possibleOperandShardingAttrs, ArrayRef< std::vector< MeshSharding >> possibleResultShardingAttrs)
@ RESHARDING_FOR_EXPLICIT_ANNOTATIONS
@ NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS
static SmallVector< std::vector< MeshSharding > > getOrderedPossibleShardingAttrs(ArrayRef< MeshSharding > mustShardings, ArrayRef< MeshSharding > optionalShardings)
Block represents an ordered list of Operations.
OpListType & getOperations()
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
void runOnOperation() override
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
ShardingArray shardingArray
static ShardingOption makeEmpty()