MLIR  19.0.0git
ShardingPropagation.cpp
Go to the documentation of this file.
1 //===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/Pass/Pass.h"
17 #include "llvm/Support/Debug.h"
18 #include <vector>
19 
20 namespace mlir {
21 namespace mesh {
22 #define GEN_PASS_DEF_SHARDINGPROPAGATION
23 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
24 } // namespace mesh
25 } // namespace mlir
26 
27 #define DEBUG_TYPE "sharding-propagation"
28 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29 
30 using namespace mlir;
31 using namespace mlir::mesh;
32 
33 //===----------------------------------------------------------------------===//
34 // Utilities
35 //===----------------------------------------------------------------------===//
36 
37 // This method retrieves all potential sharding attributes, prioritizing
38 // specific shardings. For example, mustShardings = [shard0, None] and
39 // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
40 // [shard0, None]]
43  ArrayRef<MeshShardingAttr> optionalShardings) {
45  SmallVector<MeshShardingAttr> curShardingAttrs;
46 
47  std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
48  if (i == mustShardings.size()) {
49  allShardingAttrs.push_back(
50  SmallVector<MeshShardingAttr>(curShardingAttrs));
51  return;
52  }
53 
54  if (mustShardings[i]) {
55  curShardingAttrs.push_back(mustShardings[i]);
56  dfsCreateShardingAttrs(i + 1);
57  curShardingAttrs.pop_back();
58  return;
59  }
60 
61  if (optionalShardings[i]) {
62  curShardingAttrs.push_back(optionalShardings[i]);
63  dfsCreateShardingAttrs(i + 1);
64  curShardingAttrs.pop_back();
65  curShardingAttrs.push_back(nullptr);
66  dfsCreateShardingAttrs(i + 1);
67  curShardingAttrs.pop_back();
68  return;
69  }
70 
71  curShardingAttrs.push_back(nullptr);
72  dfsCreateShardingAttrs(i + 1);
73  curShardingAttrs.pop_back();
74  };
75 
76  dfsCreateShardingAttrs(0);
77  return allShardingAttrs;
78 }
79 
80 // For each operation that implements the ShardingInterface, infer the sharding
81 // option of the operation from its operands and/or results using the
82 // `getShardingOption` method. If the inferred sharding option is not empty, add
83 // a `mesh.shard` operation for all remaining operands and results that do not
84 // have sharding annotations.
85 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
86  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
87  return success();
88 
89  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
90  if (!shardingOp) {
91  op->emitOpError() << "sharding interface is not implemented.";
92  return failure();
93  }
94 
95  // collect MeshShardingAttr from results
96  SmallVector<MeshShardingAttr> allowConflictsResultShardings;
97  allowConflictsResultShardings.resize(op->getNumResults());
98  SmallVector<MeshShardingAttr> resultMustShardings;
99  resultMustShardings.resize(op->getNumResults());
100  for (OpResult result : op->getResults()) {
102  getMeshShardingAttr(result);
103  if (failed(maybeShardAttr))
104  continue;
105  if (!maybeShardAttr->first)
106  resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
107  else
108  allowConflictsResultShardings[result.getResultNumber()] =
109  maybeShardAttr->second;
110  }
111 
112  // collect MeshShardingAttr from operands
113  SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
114  allowConflictsOperandShardings.resize(op->getNumOperands());
115  SmallVector<MeshShardingAttr> operandMustShardings;
116  operandMustShardings.resize(op->getNumOperands());
117  for (OpOperand &opOperand : op->getOpOperands()) {
119  getMeshShardingAttr(opOperand);
120  if (failed(maybeShardAttr))
121  continue;
122 
123  if (maybeShardAttr->first)
124  operandMustShardings[opOperand.getOperandNumber()] =
125  maybeShardAttr->second;
126  else
127  allowConflictsOperandShardings[opOperand.getOperandNumber()] =
128  maybeShardAttr->second;
129  }
130 
131  // try to get the sharding option
132  SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
133  getOrderedPossibleShardingAttrs(operandMustShardings,
134  allowConflictsOperandShardings);
135  SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
136  getOrderedPossibleShardingAttrs(resultMustShardings,
137  allowConflictsResultShardings);
138  FailureOr<ShardingOption> finalShardingOption = failure();
139  for (ArrayRef<MeshShardingAttr> resultShardings :
140  possibleResultShardingAttrs) {
141  if (succeeded(finalShardingOption))
142  break;
143  for (ArrayRef<MeshShardingAttr> operandShardings :
144  possibleOperandShardingAttrs) {
145  FailureOr<ShardingOption> shardingOption =
146  shardingOp.getShardingOption(operandShardings, resultShardings);
147  if (succeeded(shardingOption)) {
148  finalShardingOption = shardingOption;
149  break;
150  }
151  }
152  }
153 
154  if (failed(finalShardingOption)) {
155  op->emitOpError() << "fail to get sharding option.";
156  return failure();
157  }
158  // sharding info is empty, return immediately
159  if (finalShardingOption->empty)
160  return success();
161 
162  if (failed(
163  shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
164  op->emitOpError() << "fail to set sharding annotations.";
165  return failure();
166  }
167  return success();
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // ShardingPropagation
172 //===----------------------------------------------------------------------===//
174  : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
175  void runOnOperation() override {
176  FunctionOpInterface funcOp = getOperation();
177  MLIRContext *ctx = funcOp.getContext();
178  Region &region = funcOp.getFunctionBody();
179  OpBuilder builder(ctx);
180  if (!region.hasOneBlock()) {
181  funcOp.emitOpError() << "only one block is supported!";
182  signalPassFailure();
183  }
184  Block &block = region.front();
185 
186  LLVM_DEBUG(
187  DBGS() << "print all the ops' iterator types and indexing maps in the "
188  "block.\n";
189  for (Operation &op
190  : block.getOperations()) {
191  if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
192  shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
193  });
194 
195  // 1. propagate in reversed order
196  for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
197  if (failed(visitOp(&op, builder)))
198  return signalPassFailure();
199 
200  LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
201  << funcOp << "\n");
202 
203  // 2. propagate in original order
204  for (Operation &op : llvm::make_early_inc_range(block))
205  if (failed(visitOp(&op, builder)))
206  return signalPassFailure();
207  }
208 };
static SmallVector< SmallVector< MeshShardingAttr > > getOrderedPossibleShardingAttrs(ArrayRef< MeshShardingAttr > mustShardings, ArrayRef< MeshShardingAttr > optionalShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
#define DBGS()
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType & getOperations()
Definition: Block.h:134
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
unsigned getNumOperands()
Definition: Operation.h:341
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void runOnOperation() override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26