MLIR  20.0.0git
MeshShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/SymbolTable.h"
31 #include "mlir/IR/Value.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include <iterator>
38 #include <numeric>
39 #include <optional>
40 #include <utility>
41 
42 namespace mlir::linalg {
43 
49 
50 // Returns the corresponding mesh reduction kind for the given arith op.
53  // Floating-point operations.
54  .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55  .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56  // TODO: handle maxnumf and minnumf.
57  .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58  .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59  // Integer operations.
60  .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61  .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62  .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63  .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64  // TODO: handle signless, signed and unsigned types properly.
65  // It is assumed that the element type of the collective operands and
66  // result drive the meaning of the reduction kind, whether it is signed
67  // or unsigned.
68  // The reduction op inside the linalg op may have different result type
69  // from the element type of the linalg op's result.
70  // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71  // or signless operands.
72  // Maybe expand the reduction kinds.
73  .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74  .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75  .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76  .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77  .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78  .Default([](Operation *op) { return ReductionKind::Generic; });
79 }
80 
81 static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82  SmallVector<Operation *> combinerOps;
83  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84  if (!reducedValue || combinerOps.size() != 1) {
85  return std::nullopt;
86  }
87 
88  return combinerOps[0];
89 }
90 
92  std::optional<Operation *> reductionOp = getCombinerOp(op);
93  if (!reductionOp) {
94  return ReductionKind::Generic;
95  }
96  [[maybe_unused]] Type resultElementType =
97  llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98  // TODO: handle case when result type of the reduction op does not match the
99  // element type of the result tensor.
100  // Would it makes sense at all?
101  assert(resultElementType == reductionOp.value()->getResult(0).getType());
102  return getReductionKind(reductionOp.value());
103 }
104 
105 static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106  ArrayRef<MeshSharding> resultShardings,
107  SymbolTableCollection &symbolTable) {
108  for (const MeshSharding& sharding : operandShardings) {
109  if (sharding) {
110  return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111  }
112  }
113 
114  for (const MeshSharding& sharding : resultShardings) {
115  if (sharding) {
116  return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117  }
118  }
119 
120  assert(false);
121  return nullptr;
122 }
123 
124 // Choose the operand based on the current process index along the reduction
125 // mesh axes.
126 // We need to use the initial value only once to avoid including it in the
127 // reduction multiple times.
128 // In each process group only the leading process with linear index 0 would use
129 // the original operand.
130 // The other processes would use the reduction operation neutral tensor.
132  LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
133  MeshOp meshOp, ImplicitLocOpBuilder &builder) {
134  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
135  meshOp.getSymName(), reductionMeshAxes, builder);
136  Value zero = builder.create<arith::ConstantIndexOp>(0);
137  Value isLeadProcess = builder.create<arith::CmpIOp>(
138  builder.getI1Type(), arith::CmpIPredicate::eq,
139  processLinearIndexInReductionGroup, zero);
140  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
141  isLeadProcess, true, true);
142  // Then block.
143  {
144  OpBuilder::InsertionGuard insertionGuard(builder);
145  builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
146  builder.create<scf::YieldOp>(spmdizedOperand);
147  }
148 
149  // Else block.
150  {
151  OpBuilder::InsertionGuard insertionGuard(builder);
152  builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
154  tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
155  PartialReductionOpInterface partialReductionIface =
156  llvm::cast<PartialReductionOpInterface>(op.getOperation());
157  assert(op->getNumResults() == 1 && "Multiple results not supported.");
158  FailureOr<SmallVector<Value>> reductionNeutralTensor =
159  partialReductionIface.generateInitialTensorForPartialReduction(
160  builder, builder.getLoc(), shape, {});
161  assert(succeeded(reductionNeutralTensor));
162  builder.create<scf::YieldOp>(reductionNeutralTensor.value());
163  }
164  return ifOp.getResult(0);
165 }
166 
167 // Create the DPS init operands for the spmdized Linalg op.
168 // Return all the new spmdized operands.
170  LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
171  ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
172  ImplicitLocOpBuilder &builder) {
173  // TODO: add support for multiple destination passing style initial value
174  // operands.
175  assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
176  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
177  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
178  Value spmdizedInitOperand =
179  spmdizationMap.lookup(op->getOperands()[operandIdx]);
180  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
181  op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
182  return newOperands;
183 }
184 
186  Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
187  MeshSharding resultSharding, ReductionKind reductionKind,
188  IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
189  SmallVector<MeshAxis> allReduceMeshAxes;
190  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
191  [&resultSharding](MeshAxis axis) {
192  return !llvm::is_contained(resultSharding.getPartialAxes(),
193  axis);
194  });
195  if (allReduceMeshAxes.empty()) {
196  return;
197  }
198 
199  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
200  Value reducedValue = builder.create<mesh::AllReduceOp>(
201  spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
202  reductionKind);
203  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
204 }
205 
207  LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
208  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
209  ImplicitLocOpBuilder &builder) {
210  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
211  for (auto [unshardedLinalgOpResult, resultSharding] :
212  llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
214  unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
215  reductionKind, spmdizationMap, builder);
216  }
217 }
218 
220  LinalgOp op, ArrayRef<Value> spmdizedOperands,
221  ArrayRef<MeshSharding> operandShardings,
222  ArrayRef<MeshSharding> resultShardings,
223  ArrayRef<utils::IteratorType> loopIteratorTypes,
224  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
225  IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
226  ImplicitLocOpBuilder &builder) {
227  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
229  loopIteratorTypes, meshAxisAssignmentForLoopIterators);
230  SmallVector<Value> spmdizedLinalgOpOperands =
231  createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
232  reductionMeshAxes,
233  spmdizationMap, builder);
234  // We must not change the operand mappings of the original spmdizationMap as
235  // they are the mappings for the whole spmdization blob and may be used by
236  // others.
237  IRMapping internalSpmdizationMap;
238  for (auto [unshardedOperand, spmdizedOperand] :
239  llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
240  internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
241  }
243  *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
244  internalSpmdizationMap, symbolTable, builder);
245  for (Value result : op->getResults()) {
246  spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
247  }
248 
249  // Handle partial shardings.
251  op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
252 }
253 
254 namespace {
255 
256 // ShardingInterface for ops that implement LinalgStructuredInterface.
257 // The supported ops are only those where the indexing maps are projected
258 // permutations.
259 template <typename Op>
260 struct StructuredOpShardingInterface
261  : public mesh::ShardingInterface::ExternalModel<
262  StructuredOpShardingInterface<Op>, Op> {
263  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
264  return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
265  }
266 
267  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
268  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
269  SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
270 
271  // Results must have the same indexing as destination passing style initial
272  // operands.
273  for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
274  res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
275  }
276 
277  return res;
278  }
279 
280  SmallVector<ReductionKind>
281  getReductionLoopIteratorKinds(Operation *op) const {
282  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
283  SmallVector<utils::IteratorType> iteratorTypes =
284  linalgOp.getIteratorTypesArray();
285  unsigned reductionItersCount = std::accumulate(
286  iteratorTypes.begin(), iteratorTypes.end(), 0,
287  [](unsigned count, utils::IteratorType iter) {
288  return count + (iter == utils::IteratorType::reduction);
289  });
290  mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
291  return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
292  }
293 
294  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
295  ArrayRef<MeshSharding> operandShardings,
296  ArrayRef<MeshSharding> resultShardings,
297  IRMapping &spmdizationMap,
298  SymbolTableCollection &symbolTable,
299  OpBuilder &builder) const {
300  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
301 
302  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
303  bool allIndexingMapsAreProjectedPermutation =
304  llvm::all_of(indexingMaps, [](AffineMap map) {
305  return map.isProjectedPermutation();
306  });
307  if (!allIndexingMapsAreProjectedPermutation) {
308  // TODO: handle non-projected permutations.
309  return op->emitOpError()
310  << "supports indexing maps that are only projected permutation.";
311  }
312 
313  SmallVector<utils::IteratorType> loopIteratorTypes =
314  linalgOp.getIteratorTypesArray();
315  ShardingArray meshAxisAssignmentForLoopIterators =
316  getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
317  loopIteratorTypes, indexingMaps);
319  loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
320  ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
322  linalgOp, spmdizedOperands, operandShardings, resultShardings,
323  loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
324  symbolTable, implicitLocBuilder);
325  } else {
326  spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
327  operandShardings, resultShardings,
328  spmdizationMap, symbolTable, builder);
329  }
330 
331  return success();
332  }
333 };
334 
335 } // namespace
336 
337 template <typename OpType>
338 static void registerOne(MLIRContext *ctx) {
339  OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
340 }
341 
342 /// Variadic helper function.
343 template <typename... OpTypes>
344 static void registerAll(MLIRContext *ctx) {
345  (registerOne<OpTypes>(ctx), ...);
346 }
347 
349  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
350  DialectRegistry registry;
351  registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
352  tensor::TensorDialect>();
353  ctx->appendDialectRegistry(registry);
354  for (StringRef name : registry.getDialectNames())
355  ctx->getOrLoadDialect(name);
356 
357  registerOne<linalg::GenericOp>(ctx);
358  registerAll<
359 #define GET_OP_LIST
360 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
361  >(ctx);
362  });
363 }
364 
365 } // namespace mlir::linalg
IntegerType getI1Type()
Definition: Builders.cpp:97
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
::llvm::StringRef getMesh() const
Definition: MeshOps.h:65
static MeshOp getMesh(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, SymbolTableCollection &symbolTable)
mesh::ReductionKind ReductionKind
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshSharding resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::MeshSharding MeshSharding
static Value createDestinationPassingStyleInitOperand(LinalgOp op, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:210
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:26
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...