MLIR  20.0.0git
MeshShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/SymbolTable.h"
31 #include "mlir/IR/Value.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include <iterator>
38 #include <numeric>
39 #include <optional>
40 #include <utility>
41 
42 namespace mlir::linalg {
43 
49 
50 // Returns the corresponding mesh reduction kind for the given arith op.
53  // Floating-point operations.
54  .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55  .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56  // TODO: handle maxnumf and minnumf.
57  .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58  .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59  // Integer operations.
60  .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61  .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62  .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63  .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64  // TODO: handle signless, signed and unsigned types properly.
65  // It is assumed that the element type of the collective operands and
66  // result drive the meaning of the reduction kind, whether it is signed
67  // or unsigned.
68  // The reduction op inside the linalg op may have different result type
69  // from the element type of the linalg op's result.
70  // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71  // or signless operands.
72  // Maybe expand the reduction kinds.
73  .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74  .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75  .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76  .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77  .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78  .Default([](Operation *op) { return ReductionKind::Generic; });
79 }
80 
81 static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82  SmallVector<Operation *> combinerOps;
83  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84  if (!reducedValue || combinerOps.size() != 1) {
85  return std::nullopt;
86  }
87 
88  return combinerOps[0];
89 }
90 
92  std::optional<Operation *> reductionOp = getCombinerOp(op);
93  if (!reductionOp) {
94  return ReductionKind::Generic;
95  }
96  [[maybe_unused]] Type resultElementType =
97  llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98  // TODO: handle case when result type of the reduction op does not match the
99  // element type of the result tensor.
100  // Would it makes sense at all?
101  assert(resultElementType == reductionOp.value()->getResult(0).getType());
102  return getReductionKind(reductionOp.value());
103 }
104 
106  ArrayRef<MeshShardingAttr> operandShardings,
107  ArrayRef<MeshShardingAttr> resultShardings,
108  SymbolTableCollection &symbolTable) {
109  for (MeshShardingAttr sharding : operandShardings) {
110  if (sharding) {
111  return mesh::getMesh(op, sharding.getMesh(), symbolTable);
112  }
113  }
114 
115  for (MeshShardingAttr sharding : resultShardings) {
116  if (sharding) {
117  return mesh::getMesh(op, sharding.getMesh(), symbolTable);
118  }
119  }
120 
121  assert(false);
122  return nullptr;
123 }
124 
125 // Choose the operand based on the current process index along the reduction
126 // mesh axes.
127 // We need to use the initial value only once to avoid including it in the
128 // reduction multiple times.
129 // In each process group only the leading process with linear index 0 would use
130 // the original operand.
131 // The other processes would use the reduction operation neutral tensor.
133  LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
134  MeshOp meshOp, ImplicitLocOpBuilder &builder) {
135  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
136  meshOp.getSymName(), reductionMeshAxes, builder);
137  Value zero = builder.create<arith::ConstantIndexOp>(0);
138  Value isLeadProcess = builder.create<arith::CmpIOp>(
139  builder.getI1Type(), arith::CmpIPredicate::eq,
140  processLinearIndexInReductionGroup, zero);
141  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
142  isLeadProcess, true, true);
143  // Then block.
144  {
145  OpBuilder::InsertionGuard insertionGuard(builder);
146  builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
147  builder.create<scf::YieldOp>(spmdizedOperand);
148  }
149 
150  // Else block.
151  {
152  OpBuilder::InsertionGuard insertionGuard(builder);
153  builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
155  tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156  PartialReductionOpInterface partialReductionIface =
157  llvm::cast<PartialReductionOpInterface>(op.getOperation());
158  assert(op->getNumResults() == 1 && "Multiple results not supported.");
159  FailureOr<SmallVector<Value>> reductionNeutralTensor =
160  partialReductionIface.generateInitialTensorForPartialReduction(
161  builder, builder.getLoc(), shape, {});
162  assert(succeeded(reductionNeutralTensor));
163  builder.create<scf::YieldOp>(reductionNeutralTensor.value());
164  }
165  return ifOp.getResult(0);
166 }
167 
168 // Create the DPS init operands for the spmdized Linalg op.
169 // Return all the new spmdized operands.
171  LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
172  ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
173  ImplicitLocOpBuilder &builder) {
174  // TODO: add support for multiple destination passing style initial value
175  // operands.
176  assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
177  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
178  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
179  Value spmdizedInitOperand =
180  spmdizationMap.lookup(op->getOperands()[operandIdx]);
181  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
182  op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
183  return newOperands;
184 }
185 
187  Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
188  MeshShardingAttr resultSharding, ReductionKind reductionKind,
189  IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
190  SmallVector<MeshAxis> allReduceMeshAxes;
191  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
192  [&resultSharding](MeshAxis axis) {
193  return !llvm::is_contained(resultSharding.getPartialAxes(),
194  axis);
195  });
196  if (allReduceMeshAxes.empty()) {
197  return;
198  }
199 
200  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
201  Value reducedValue = builder.create<mesh::AllReduceOp>(
202  spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
203  allReduceMeshAxes, reductionKind);
204  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
205 }
206 
208  LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
209  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
210  ImplicitLocOpBuilder &builder) {
211  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
212  for (auto [unshardedLinalgOpResult, resultSharding] :
213  llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
215  unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
216  reductionKind, spmdizationMap, builder);
217  }
218 }
219 
221  LinalgOp op, ArrayRef<Value> spmdizedOperands,
222  ArrayRef<MeshShardingAttr> operandShardings,
223  ArrayRef<MeshShardingAttr> resultShardings,
224  ArrayRef<utils::IteratorType> loopIteratorTypes,
225  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
226  IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
227  ImplicitLocOpBuilder &builder) {
228  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
230  loopIteratorTypes, meshAxisAssignmentForLoopIterators);
231  SmallVector<Value> spmdizedLinalgOpOperands =
232  createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
233  reductionMeshAxes,
234  spmdizationMap, builder);
235  // We must not change the operand mappings of the original spmdizationMap as
236  // they are the mappings for the whole spmdization blob and may be used by
237  // others.
238  IRMapping internalSpmdizationMap;
239  for (auto [unshardedOperand, spmdizedOperand] :
240  llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
241  internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
242  }
244  *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
245  internalSpmdizationMap, symbolTable, builder);
246  for (Value result : op->getResults()) {
247  spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
248  }
249 
250  // Handle partial shardings.
252  op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
253 }
254 
255 namespace {
256 
257 // ShardingInterface for ops that implement LinalgStructuredInterface.
258 // The supported ops are only those where the indexing maps are projected
259 // permutations.
260 template <typename Op>
261 struct StructuredOpShardingInterface
262  : public mesh::ShardingInterface::ExternalModel<
263  StructuredOpShardingInterface<Op>, Op> {
264  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
265  return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
266  }
267 
268  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
269  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
270  SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
271 
272  // Results must have the same indexing as destination passing style initial
273  // operands.
274  for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
275  res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
276  }
277 
278  return res;
279  }
280 
281  SmallVector<ReductionKind>
282  getReductionLoopIteratorKinds(Operation *op) const {
283  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
284  SmallVector<utils::IteratorType> iteratorTypes =
285  linalgOp.getIteratorTypesArray();
286  unsigned reductionItersCount = std::accumulate(
287  iteratorTypes.begin(), iteratorTypes.end(), 0,
288  [](unsigned count, utils::IteratorType iter) {
289  return count + (iter == utils::IteratorType::reduction);
290  });
291  mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
292  return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
293  }
294 
295  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
296  ArrayRef<MeshShardingAttr> operandShardings,
297  ArrayRef<MeshShardingAttr> resultShardings,
298  IRMapping &spmdizationMap,
299  SymbolTableCollection &symbolTable,
300  OpBuilder &builder) const {
301  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
302 
303  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
304  bool allIndexingMapsAreProjectedPermutation =
305  llvm::all_of(indexingMaps, [](AffineMap map) {
306  return map.isProjectedPermutation();
307  });
308  if (!allIndexingMapsAreProjectedPermutation) {
309  // TODO: handle non-projected permutations.
310  return op->emitOpError()
311  << "supports indexing maps that are only projected permutation.";
312  }
313 
314  SmallVector<utils::IteratorType> loopIteratorTypes =
315  linalgOp.getIteratorTypesArray();
316  ShardingArray meshAxisAssignmentForLoopIterators =
317  getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
318  loopIteratorTypes, indexingMaps);
320  loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
321  ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
323  linalgOp, spmdizedOperands, operandShardings, resultShardings,
324  loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
325  symbolTable, implicitLocBuilder);
326  } else {
327  spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
328  operandShardings, resultShardings,
329  spmdizationMap, symbolTable, builder);
330  }
331 
332  return success();
333  }
334 };
335 
336 } // namespace
337 
338 template <typename OpType>
339 static void registerOne(MLIRContext *ctx) {
340  OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
341 }
342 
343 /// Variadic helper function.
344 template <typename... OpTypes>
345 static void registerAll(MLIRContext *ctx) {
346  (registerOne<OpTypes>(ctx), ...);
347 }
348 
350  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
351  DialectRegistry registry;
352  registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
353  tensor::TensorDialect>();
354  ctx->appendDialectRegistry(registry);
355  for (StringRef name : registry.getDialectNames())
356  ctx->getOrLoadDialect(name);
357 
358  registerOne<linalg::GenericOp>(ctx);
359  registerAll<
360 #define GET_OP_LIST
361 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
362  >(ctx);
363  });
364 }
365 
366 } // namespace mlir::linalg
IntegerType getI1Type()
Definition: Builders.cpp:77
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:439
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
static MeshOp getMesh(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, SymbolTableCollection &symbolTable)
mesh::ReductionKind ReductionKind
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
mesh::MeshShardingAttr MeshShardingAttr
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshShardingAttr resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:210
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:67
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
int16_t MeshAxis
Definition: MeshOps.h:25
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...