MLIR  20.0.0git
MeshShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/SymbolTable.h"
31 #include "mlir/IR/Value.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include <iterator>
38 #include <numeric>
39 #include <optional>
40 #include <utility>
41 
42 namespace mlir::linalg {
43 
49 
50 // Returns the corresponding mesh reduction kind for the given arith op.
53  // Floating-point operations.
54  .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55  .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56  // TODO: handle maxnumf and minnumf.
57  .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58  .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59  // Integer operations.
60  .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61  .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62  .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63  .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64  // TODO: handle signless, signed and unsigned types properly.
65  // It is assumed that the element type of the collective operands and
66  // result drive the meaning of the reduction kind, whether it is signed
67  // or unsigned.
68  // The reduction op inside the linalg op may have different result type
69  // from the element type of the linalg op's result.
70  // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71  // or signless operands.
72  // Maybe expand the reduction kinds.
73  .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74  .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75  .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76  .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77  .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78  .Default([](Operation *op) { return ReductionKind::Generic; });
79 }
80 
81 static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82  SmallVector<Operation *> combinerOps;
83  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84  if (!reducedValue || combinerOps.size() != 1) {
85  return std::nullopt;
86  }
87 
88  return combinerOps[0];
89 }
90 
92  std::optional<Operation *> reductionOp = getCombinerOp(op);
93  if (!reductionOp) {
94  return ReductionKind::Generic;
95  }
96  [[maybe_unused]] Type resultElementType =
97  llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98  // TODO: handle case when result type of the reduction op does not match the
99  // element type of the result tensor.
100  // Would it makes sense at all?
101  assert(resultElementType == reductionOp.value()->getResult(0).getType());
102  return getReductionKind(reductionOp.value());
103 }
104 
105 static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106  ArrayRef<MeshSharding> resultShardings,
107  SymbolTableCollection &symbolTable) {
108  for (const MeshSharding &sharding : operandShardings) {
109  if (sharding) {
110  return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111  }
112  }
113 
114  for (const MeshSharding &sharding : resultShardings) {
115  if (sharding) {
116  return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117  }
118  }
119 
120  assert(false);
121  return nullptr;
122 }
123 
124 // Choose the operand based on the current process index along the reduction
125 // mesh axes.
126 // We need to use the initial value only once to avoid including it in the
127 // reduction multiple times.
128 // In each process group only the leading process with linear index 0 would use
129 // the original operand.
130 // The other processes would use the reduction operation neutral tensor.
132  LinalgOp op, int operandNumber, Value spmdizedOperand,
133  ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134  ImplicitLocOpBuilder &builder) {
135  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
136  meshOp.getSymName(), reductionMeshAxes, builder);
137  Value zero = builder.create<arith::ConstantIndexOp>(0);
138  Value isLeadProcess = builder.create<arith::CmpIOp>(
139  builder.getI1Type(), arith::CmpIPredicate::eq,
140  processLinearIndexInReductionGroup, zero);
141  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
142  isLeadProcess, true, true);
143  // Then block.
144  {
145  OpBuilder::InsertionGuard insertionGuard(builder);
146  builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
147  builder.create<scf::YieldOp>(spmdizedOperand);
148  }
149 
150  // Else block.
151  {
152  OpBuilder::InsertionGuard insertionGuard(builder);
153  builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
155  tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156 
157  SmallVector<Operation *> combinerOps;
158  matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159  assert(combinerOps.size() == 1);
160  std::optional<TypedAttr> neutralEl =
161  arith::getNeutralElement(combinerOps[0]);
162 
163  Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164  neutralEl.value().getType());
165  Value constant =
166  builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167  Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168  .getResult(0);
169 
170  builder.create<scf::YieldOp>(fill);
171  }
172  return ifOp.getResult(0);
173 }
174 
175 // Create the DPS init operands for the spmdized Linalg op.
176 // Return all the new spmdized operands.
178  LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
179  ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
180  ImplicitLocOpBuilder &builder) {
181  // TODO: add support for multiple destination passing style initial value
182  // operands.
183  assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
184  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
185  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
186  Value spmdizedInitOperand =
187  spmdizationMap.lookup(op->getOperands()[operandIdx]);
188  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
189  op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
190  return newOperands;
191 }
192 
194  Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
195  MeshSharding resultSharding, ReductionKind reductionKind,
196  IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
197  SmallVector<MeshAxis> allReduceMeshAxes;
198  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
199  [&resultSharding](MeshAxis axis) {
200  return !llvm::is_contained(resultSharding.getPartialAxes(),
201  axis);
202  });
203  if (allReduceMeshAxes.empty()) {
204  return;
205  }
206 
207  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
208  Value reducedValue = builder.create<mesh::AllReduceOp>(
209  spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
210  reductionKind);
211  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
212 }
213 
215  LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
216  ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
217  ImplicitLocOpBuilder &builder) {
218  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
219  for (auto [unshardedLinalgOpResult, resultSharding] :
220  llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
222  unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
223  reductionKind, spmdizationMap, builder);
224  }
225 }
226 
228  LinalgOp op, ArrayRef<Value> spmdizedOperands,
229  ArrayRef<MeshSharding> operandShardings,
230  ArrayRef<MeshSharding> resultShardings,
231  ArrayRef<utils::IteratorType> loopIteratorTypes,
232  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
233  IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
234  ImplicitLocOpBuilder &builder) {
235  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
237  loopIteratorTypes, meshAxisAssignmentForLoopIterators);
238  SmallVector<Value> spmdizedLinalgOpOperands =
239  createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
240  reductionMeshAxes,
241  spmdizationMap, builder);
242  // We must not change the operand mappings of the original spmdizationMap as
243  // they are the mappings for the whole spmdization blob and may be used by
244  // others.
245  IRMapping internalSpmdizationMap;
246  for (auto [unshardedOperand, spmdizedOperand] :
247  llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
248  internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
249  }
251  *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
252  internalSpmdizationMap, symbolTable, builder);
253  for (Value result : op->getResults()) {
254  spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
255  }
256 
257  // Handle partial shardings.
259  op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
260 }
261 
262 namespace {
263 
264 // ShardingInterface for ops that implement LinalgStructuredInterface.
265 // The supported ops are only those where the indexing maps are projected
266 // permutations.
267 template <typename Op>
268 struct StructuredOpShardingInterface
269  : public mesh::ShardingInterface::ExternalModel<
270  StructuredOpShardingInterface<Op>, Op> {
271  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
272  return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
273  }
274 
275  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
276  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
277  SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
278 
279  // Results must have the same indexing as destination passing style initial
280  // operands.
281  for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
282  res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
283  }
284 
285  return res;
286  }
287 
288  SmallVector<ReductionKind>
289  getReductionLoopIteratorKinds(Operation *op) const {
290  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
291  SmallVector<utils::IteratorType> iteratorTypes =
292  linalgOp.getIteratorTypesArray();
293  unsigned reductionItersCount = std::accumulate(
294  iteratorTypes.begin(), iteratorTypes.end(), 0,
295  [](unsigned count, utils::IteratorType iter) {
296  return count + (iter == utils::IteratorType::reduction);
297  });
298  mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
299  return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
300  }
301 
302  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
303  ArrayRef<MeshSharding> operandShardings,
304  ArrayRef<MeshSharding> resultShardings,
305  IRMapping &spmdizationMap,
306  SymbolTableCollection &symbolTable,
307  OpBuilder &builder) const {
308  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
309 
310  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
311  bool allIndexingMapsAreProjectedPermutation =
312  llvm::all_of(indexingMaps, [](AffineMap map) {
313  return map.isProjectedPermutation();
314  });
315  if (!allIndexingMapsAreProjectedPermutation) {
316  // TODO: handle non-projected permutations.
317  return op->emitOpError()
318  << "supports indexing maps that are only projected permutation.";
319  }
320 
321  SmallVector<utils::IteratorType> loopIteratorTypes =
322  linalgOp.getIteratorTypesArray();
323  ShardingArray meshAxisAssignmentForLoopIterators =
324  getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
325  loopIteratorTypes, indexingMaps);
327  loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
328  ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
330  linalgOp, spmdizedOperands, operandShardings, resultShardings,
331  loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
332  symbolTable, implicitLocBuilder);
333  } else {
334  spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
335  operandShardings, resultShardings,
336  spmdizationMap, symbolTable, builder);
337  }
338 
339  return success();
340  }
341 };
342 
343 } // namespace
344 
345 template <typename OpType>
346 static void registerOne(MLIRContext *ctx) {
347  OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
348 }
349 
350 /// Variadic helper function.
351 template <typename... OpTypes>
352 static void registerAll(MLIRContext *ctx) {
353  (registerOne<OpTypes>(ctx), ...);
354 }
355 
357  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
358  DialectRegistry registry;
359  registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
360  tensor::TensorDialect>();
361  ctx->appendDialectRegistry(registry);
362  for (StringRef name : registry.getDialectNames())
363  ctx->getOrLoadDialect(name);
364 
365  registerOne<linalg::GenericOp>(ctx);
366  registerAll<
367 #define GET_OP_LIST
368 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
369  >(ctx);
370  });
371 }
372 
373 } // namespace mlir::linalg
IntegerType getI1Type()
Definition: Builders.cpp:97
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
::llvm::StringRef getMesh() const
Definition: MeshOps.h:65
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2550
static MeshOp getMesh(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, SymbolTableCollection &symbolTable)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, int operandNumber, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
mesh::ReductionKind ReductionKind
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshSharding resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::MeshSharding MeshSharding
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:210
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:26
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...