MLIR  19.0.0git
MeshShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/SymbolTable.h"
31 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include <iterator>
39 #include <optional>
40 #include <utility>
41 
42 namespace mlir::linalg {
43 
49 
50 // Returns the corresponding mesh reduction kind for the given arith op.
53  // Floating-point operations.
54  .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55  .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56  // TODO: handle maxnumf and minnumf.
57  .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58  .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59  // Integer operations.
60  .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61  .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62  .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63  .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64  // TODO: handle signless, signed and unsigned types properly.
65  // It is assumed that the element type of the collective operands and
66  // result drive the meaning of the reduction kind, whether it is signed
67  // or unsigned.
68  // The reduction op inside the linalg op may have different result type
69  // from the element type of the linalg op's result.
70  // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71  // or signless operands.
72  // Maybe expand the reduction kinds.
73  .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74  .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75  .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76  .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77  .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78  .Default([](Operation *op) { return ReductionKind::Generic; });
79 }
80 
81 static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82  SmallVector<Operation *> combinerOps;
83  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84  if (!reducedValue || combinerOps.size() != 1) {
85  return std::nullopt;
86  }
87 
88  return combinerOps[0];
89 }
90 
92  std::optional<Operation *> reductionOp = getCombinerOp(op);
93  if (!reductionOp) {
94  return ReductionKind::Generic;
95  }
96  [[maybe_unused]] Type resultElementType =
97  llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98  // TODO: handle case when result type of the reduction op does not match the
99  // element type of the result tensor.
100  // Would it makes sense at all?
101  assert(resultElementType == reductionOp.value()->getResult(0).getType());
102  return getReductionKind(reductionOp.value());
103 }
104 
106  ArrayRef<MeshShardingAttr> operandShardings,
107  ArrayRef<MeshShardingAttr> resultShardings,
108  SymbolTableCollection &symbolTable) {
109  for (MeshShardingAttr sharding : operandShardings) {
110  if (sharding) {
111  return mesh::getMesh(op, sharding.getMesh(), symbolTable);
112  }
113  }
114 
115  for (MeshShardingAttr sharding : resultShardings) {
116  if (sharding) {
117  return mesh::getMesh(op, sharding.getMesh(), symbolTable);
118  }
119  }
120 
121  assert(false);
122  return nullptr;
123 }
124 
125 // Choose the operand based on the current process index along the reduction
126 // mesh axes.
127 // We need to use the initial value only once to avoid including it in the
128 // reduction multiple times.
129 // In each process group only the leading process with linear index 0 would use
130 // the original operand.
131 // The other processes would use the reduction operation neutral tensor.
133  LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
134  MeshOp meshOp, ImplicitLocOpBuilder &builder) {
135  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
136  meshOp.getSymName(), reductionMeshAxes, builder);
137  Value zero = builder.create<arith::ConstantIndexOp>(0);
138  Value isLeadProcess = builder.create<arith::CmpIOp>(
139  builder.getI1Type(), arith::CmpIPredicate::eq,
140  processLinearIndexInReductionGroup, zero);
141  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
142  isLeadProcess, true, true);
143  // Then block.
144  {
145  OpBuilder::InsertionGuard insertionGuard(builder);
146  builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
147  builder.create<scf::YieldOp>(spmdizedOperand);
148  }
149 
150  // Else block.
151  {
152  OpBuilder::InsertionGuard insertionGuard(builder);
153  builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
155  tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156  PartialReductionOpInterface partialReductionIface =
157  llvm::cast<PartialReductionOpInterface>(op.getOperation());
158  FailureOr<Operation *> reductionNeutralTensorOp =
159  partialReductionIface.generateInitialTensorForPartialReduction(
160  builder, builder.getLoc(), shape, {});
161  assert(succeeded(reductionNeutralTensorOp));
162  builder.create<scf::YieldOp>(
163  reductionNeutralTensorOp.value()->getResult(0));
164  }
165  return ifOp.getResult(0);
166 }
167 
168 // Create the DPS init operands for the spmdized Linalg op.
169 // Return all the new spmdized operands.
171  LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
172  ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
173  ImplicitLocOpBuilder &builder) {
174  // TODO: add support for multiple destination passing style initial value
175  // operands.
176  // PartialReductionOpInterface::generateInitialTensorForPartialReduction
177  // needs to also support multiple DPS initial operands.
178  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
179  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
180  Value spmdizedInitOperand =
181  spmdizationMap.lookup(op->getOperands()[operandIdx]);
182  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
183  op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
184  return newOperands;
185 }
186 
188  Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
189  MeshShardingAttr resultSharding, ReductionKind reductionKind,
190  IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
191  SmallVector<MeshAxis> allReduceMeshAxes;
192  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
193  [&resultSharding](MeshAxis axis) {
194  return !llvm::is_contained(resultSharding.getPartialAxes(),
195  axis);
196  });
197  if (allReduceMeshAxes.empty()) {
198  return;
199  }
200 
201  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
202  Value reducedValue = builder.create<mesh::AllReduceOp>(
203  spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
204  allReduceMeshAxes, reductionKind);
205  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
206 }
207 
209  LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
210  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
211  ImplicitLocOpBuilder &builder) {
212  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
213  for (auto [unshardedLinalgOpResult, resultSharding] :
214  llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
216  unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
217  reductionKind, spmdizationMap, builder);
218  }
219 }
220 
222  LinalgOp op, ArrayRef<Value> spmdizedOperands,
223  ArrayRef<MeshShardingAttr> operandShardings,
224  ArrayRef<MeshShardingAttr> resultShardings,
225  ArrayRef<utils::IteratorType> loopIteratorTypes,
226  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
227  IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
228  ImplicitLocOpBuilder &builder) {
229  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
231  loopIteratorTypes, meshAxisAssignmentForLoopIterators);
232  SmallVector<Value> spmdizedLinalgOpOperands =
233  createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
234  reductionMeshAxes,
235  spmdizationMap, builder);
236  // We must not change the operand mappings of the original spmdizationMap as
237  // they are the mappings for the whole spmdization blob and may be used by
238  // others.
239  IRMapping internalSpmdizationMap;
240  for (auto [unshardedOperand, spmdizedOperand] :
241  llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
242  internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
243  }
245  *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
246  internalSpmdizationMap, symbolTable, builder);
247  for (Value result : op->getResults()) {
248  spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
249  }
250 
251  // Handle partial shardings.
253  op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
254 }
255 
256 namespace {
257 
258 // ShardingInterface for ops that implement LinalgStructuredInterface.
259 // The supported ops are only those where the indexing maps are projected
260 // permutations.
261 template <typename Op>
262 struct StructuredOpShardingInterface
263  : public mesh::ShardingInterface::ExternalModel<
264  StructuredOpShardingInterface<Op>, Op> {
265  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
266  return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
267  }
268 
269  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
270  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
271  SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
272 
273  // Results must have the same indexing as destination passing style initial
274  // operands.
275  for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
276  res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
277  }
278 
279  return res;
280  }
281 
282  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
283  ArrayRef<MeshShardingAttr> operandShardings,
284  ArrayRef<MeshShardingAttr> resultShardings,
285  IRMapping &spmdizationMap,
286  SymbolTableCollection &symbolTable,
287  OpBuilder &builder) const {
288  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
289 
290  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
291  bool allIndexingMapsAreProjectedPermutation =
292  llvm::all_of(indexingMaps, [](AffineMap map) {
293  return map.isProjectedPermutation();
294  });
295  if (!allIndexingMapsAreProjectedPermutation) {
296  // TODO: handle non-projected permutations.
297  return op->emitOpError()
298  << "supports indexing maps that are only projected permutation.";
299  }
300 
301  SmallVector<utils::IteratorType> loopIteratorTypes =
302  linalgOp.getIteratorTypesArray();
303  ShardingArray meshAxisAssignmentForLoopIterators =
304  getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
305  loopIteratorTypes, indexingMaps);
307  loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
308  ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
310  linalgOp, spmdizedOperands, operandShardings, resultShardings,
311  loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
312  symbolTable, implicitLocBuilder);
313  } else {
314  spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
315  operandShardings, resultShardings,
316  spmdizationMap, symbolTable, builder);
317  }
318 
319  return success();
320  }
321 };
322 
323 } // namespace
324 
325 template <typename OpType>
326 static void registerOne(MLIRContext *ctx) {
327  OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
328 }
329 
330 /// Variadic helper function.
331 template <typename... OpTypes>
332 static void registerAll(MLIRContext *ctx) {
333  (registerOne<OpTypes>(ctx), ...);
334 }
335 
337  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
338  DialectRegistry registry;
339  registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
340  tensor::TensorDialect>();
341  ctx->appendDialectRegistry(registry);
342  for (StringRef name : registry.getDialectNames())
343  ctx->getOrLoadDialect(name);
344 
345  registerOne<linalg::GenericOp>(ctx);
346  registerAll<
347 #define GET_OP_LIST
348 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
349  >(ctx);
350  });
351 }
352 
353 } // namespace mlir::linalg
IntegerType getI1Type()
Definition: Builders.cpp:73
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
static MeshOp getMesh(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, SymbolTableCollection &symbolTable)
mesh::ReductionKind ReductionKind
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
mesh::MeshShardingAttr MeshShardingAttr
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshShardingAttr resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:210
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:25
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...