17 #include "llvm/Support/Debug.h"
22 #define GEN_PASS_DEF_SHARDINGPROPAGATION
23 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
27 #define DEBUG_TYPE "sharding-propagation"
28 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 std::function<void(
size_t)> dfsCreateShardingAttrs = [&](
size_t i) {
48 if (i == mustShardings.size()) {
49 allShardingAttrs.push_back(
54 if (mustShardings[i]) {
55 curShardingAttrs.push_back(mustShardings[i]);
56 dfsCreateShardingAttrs(i + 1);
57 curShardingAttrs.pop_back();
61 if (optionalShardings[i]) {
62 curShardingAttrs.push_back(optionalShardings[i]);
63 dfsCreateShardingAttrs(i + 1);
64 curShardingAttrs.pop_back();
65 curShardingAttrs.push_back(
nullptr);
66 dfsCreateShardingAttrs(i + 1);
67 curShardingAttrs.pop_back();
71 curShardingAttrs.push_back(
nullptr);
72 dfsCreateShardingAttrs(i + 1);
73 curShardingAttrs.pop_back();
76 dfsCreateShardingAttrs(0);
77 return allShardingAttrs;
89 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
91 op->
emitOpError() <<
"sharding interface is not implemented.";
103 if (
failed(maybeShardAttr))
105 if (!maybeShardAttr->first)
106 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
108 allowConflictsResultShardings[result.getResultNumber()] =
109 maybeShardAttr->second;
120 if (
failed(maybeShardAttr))
123 if (maybeShardAttr->first)
124 operandMustShardings[opOperand.getOperandNumber()] =
125 maybeShardAttr->second;
127 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
128 maybeShardAttr->second;
134 allowConflictsOperandShardings);
137 allowConflictsResultShardings);
140 possibleResultShardingAttrs) {
144 possibleOperandShardingAttrs) {
146 shardingOp.getShardingOption(operandShardings, resultShardings);
148 finalShardingOption = shardingOption;
154 if (
failed(finalShardingOption)) {
155 op->
emitOpError() <<
"fail to get sharding option.";
159 if (finalShardingOption->empty)
163 shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
164 op->
emitOpError() <<
"fail to set sharding annotations.";
174 :
public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
176 FunctionOpInterface funcOp = getOperation();
178 Region ®ion = funcOp.getFunctionBody();
181 funcOp.emitOpError() <<
"only one block is supported!";
187 DBGS() <<
"print all the ops' iterator types and indexing maps in the "
191 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
192 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
196 for (
Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
198 return signalPassFailure();
200 LLVM_DEBUG(
DBGS() <<
"After reversed order propagation:\n"
204 for (
Operation &op : llvm::make_early_inc_range(block))
206 return signalPassFailure();
static SmallVector< SmallVector< MeshShardingAttr > > getOrderedPossibleShardingAttrs(ArrayRef< MeshShardingAttr > mustShardings, ArrayRef< MeshShardingAttr > optionalShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
Block represents an ordered list of Operations.
OpListType & getOperations()
This class provides support for representing a failure result, or a valid value of type T.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void runOnOperation() override
This class represents an efficient way to signal success or failure.