MLIR  22.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/IR/MLIRContext.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/SymbolTable.h"
30 #include "mlir/IR/Value.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include <numeric>
35 #include <optional>
36 
37 namespace mlir::linalg {
38 
44 
45 // Returns the corresponding grid reduction kind for the given arith op.
48  // Floating-point operations.
49  .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
50  .Case([](arith::MulFOp op) { return ReductionKind::Product; })
51  // TODO: handle maxnumf and minnumf.
52  .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
53  .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
54  // Integer operations.
55  .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
56  .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
57  .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
58  .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
59  // TODO: handle signless, signed and unsigned types properly.
60  // It is assumed that the element type of the collective operands and
61  // result drive the meaning of the reduction kind, whether it is signed
62  // or unsigned.
63  // The reduction op inside the linalg op may have different result type
64  // from the element type of the linalg op's result.
65  // Also signed and unsigned Arith dialect ops may accept signed, unsigned
66  // or signless operands.
67  // Maybe expand the reduction kinds.
68  .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
69  .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
70  .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
71  .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
72  .Case([](arith::MulIOp op) { return ReductionKind::Product; })
73  .Default([](Operation *op) { return ReductionKind::Generic; });
74 }
75 
76 static std::optional<Operation *> getCombinerOp(LinalgOp op) {
77  SmallVector<Operation *> combinerOps;
78  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
79  if (!reducedValue || combinerOps.size() != 1) {
80  return std::nullopt;
81  }
82 
83  return combinerOps[0];
84 }
85 
87  std::optional<Operation *> reductionOp = getCombinerOp(op);
88  if (!reductionOp) {
89  return ReductionKind::Generic;
90  }
91  [[maybe_unused]] Type resultElementType =
92  llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
93  // TODO: handle case when result type of the reduction op does not match the
94  // element type of the result tensor.
95  // Would it makes sense at all?
96  assert(resultElementType == reductionOp.value()->getResult(0).getType());
97  return getReductionKind(reductionOp.value());
98 }
99 
100 static GridOp getGrid(Operation *op, ArrayRef<Sharding> operandShardings,
101  ArrayRef<Sharding> resultShardings,
102  SymbolTableCollection &symbolTable) {
103  for (const Sharding &sharding : operandShardings) {
104  if (sharding) {
105  return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
106  }
107  }
108 
109  for (const Sharding &sharding : resultShardings) {
110  if (sharding) {
111  return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
112  }
113  }
114 
115  assert(false);
116  return nullptr;
117 }
118 
119 // Choose the operand based on the current process index along the reduction
120 // grid axes.
121 // We need to use the initial value only once to avoid including it in the
122 // reduction multiple times.
123 // In each process group only the leading process with linear index 0 would use
124 // the original operand.
125 // The other processes would use the reduction operation neutral tensor.
127  LinalgOp op, int operandNumber, Value partitionedOperand,
128  ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
129  ImplicitLocOpBuilder &builder) {
130  Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
131  gridOp.getSymName(), reductionGridAxes, builder);
132  Value zero = arith::ConstantIndexOp::create(builder, 0);
133  Value isLeadProcess = arith::CmpIOp::create(
134  builder, builder.getI1Type(), arith::CmpIPredicate::eq,
135  processLinearIndexInReductionGroup, zero);
136  scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.getType(),
137  isLeadProcess, true, true);
138  // Then block.
139  {
140  OpBuilder::InsertionGuard insertionGuard(builder);
141  builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
142  scf::YieldOp::create(builder, partitionedOperand);
143  }
144 
145  // Else block.
146  {
147  OpBuilder::InsertionGuard insertionGuard(builder);
148  builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
150  tensor::getMixedSizes(builder, builder.getLoc(), partitionedOperand);
151 
152  SmallVector<Operation *> combinerOps;
153  matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
154  assert(combinerOps.size() == 1);
155  std::optional<TypedAttr> neutralEl =
156  arith::getNeutralElement(combinerOps[0]);
157 
158  Value init = tensor::EmptyOp::create(builder, op.getLoc(), shape,
159  neutralEl.value().getType());
160  Value constant =
161  arith::ConstantOp::create(builder, op.getLoc(), neutralEl.value());
162  Value fill = linalg::FillOp::create(builder, op.getLoc(), constant, init)
163  .getResult(0);
164 
165  scf::YieldOp::create(builder, fill);
166  }
167  return ifOp.getResult(0);
168 }
169 
170 // Create the DPS init operands for the partitioned Linalg op.
171 // Return all the new partitioned operands.
173  LinalgOp op, GridOp gridOp, ArrayRef<Value> partitionedOperands,
174  ArrayRef<GridAxis> reductionGridAxes, IRMapping &partitionMap,
175  ImplicitLocOpBuilder &builder) {
176  // TODO: add support for multiple destination passing style initial value
177  // operands.
178  assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
179  SmallVector<Value> newOperands = llvm::to_vector(partitionedOperands);
180  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
181  Value partitionedInitOperand =
182  partitionMap.lookup(op->getOperands()[operandIdx]);
183  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
184  op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder);
185  return newOperands;
186 }
187 
189  LinalgOp unshardedOp, ArrayRef<GridAxis> opReductionGridAxes,
190  ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
191  ImplicitLocOpBuilder &builder) {
192  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
193  for (auto [unshardedLinalgOpResult, resultSharding] :
194  llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
195  Value partitionedLinalgOpResult =
196  partitionMap.lookup(unshardedLinalgOpResult);
197  Value reducedValue = shard::AllReduceOp::create(
198  builder, partitionedLinalgOpResult, resultSharding.getGrid(),
199  opReductionGridAxes, reductionKind);
200  partitionMap.map(unshardedLinalgOpResult, reducedValue);
201  }
202 }
203 
205  LinalgOp op, ArrayRef<Value> partitionedOperands,
206  ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
207  ArrayRef<utils::IteratorType> loopIteratorTypes,
208  ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators,
209  IRMapping &partitionMap, SymbolTableCollection &symbolTable,
210  ImplicitLocOpBuilder &builder) {
211  GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable);
213  loopIteratorTypes, gridAxisAssignmentForLoopIterators);
214  SmallVector<Value> partitionedLinalgOpOperands =
215  createDestinationPassingStyleInitOperands(op, grid, partitionedOperands,
216  reductionGridAxes, partitionMap,
217  builder);
218  // We must not change the operand mappings of the original partitionMap as
219  // they are the mappings for the whole partition blob and may be used by
220  // others.
221  IRMapping internalPartitionMap;
222  for (auto [unshardedOperand, partitionedOperand] :
223  llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) {
224  internalPartitionMap.map(unshardedOperand, partitionedOperand);
225  }
227  *op, partitionedLinalgOpOperands, operandShardings, resultShardings,
228  internalPartitionMap, symbolTable, builder);
229  for (Value result : op->getResults()) {
230  partitionMap.map(result, internalPartitionMap.lookup(result));
231  }
232 
233  // Handle partial shardings.
235  op, reductionGridAxes, resultShardings, partitionMap, builder);
236 }
237 
238 namespace {
239 
240 // ShardingInterface for ops that implement LinalgStructuredInterface.
241 // The supported ops are only those where the indexing maps are projected
242 // permutations.
243 template <typename Op>
244 struct StructuredOpShardingInterface
245  : public shard::ShardingInterface::ExternalModel<
246  StructuredOpShardingInterface<Op>, Op> {
247  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
248  return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
249  }
250 
251  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
252  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
253  SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
254 
255  // Results must have the same indexing as destination passing style initial
256  // operands.
257  for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
258  res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
259  }
260 
261  return res;
262  }
263 
264  SmallVector<ReductionKind>
265  getReductionLoopIteratorKinds(Operation *op) const {
266  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
267  SmallVector<utils::IteratorType> iteratorTypes =
268  linalgOp.getIteratorTypesArray();
269  unsigned reductionItersCount = std::accumulate(
270  iteratorTypes.begin(), iteratorTypes.end(), 0,
271  [](unsigned count, utils::IteratorType iter) {
272  return count + (iter == utils::IteratorType::reduction);
273  });
274  shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
275  return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
276  }
277 
278  LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
279  ArrayRef<Sharding> operandShardings,
280  ArrayRef<Sharding> resultShardings,
281  IRMapping &partitionMap,
282  SymbolTableCollection &symbolTable,
283  OpBuilder &builder) const {
284  LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
285 
286  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
287  bool allIndexingMapsAreProjectedPermutation =
288  llvm::all_of(indexingMaps, [](AffineMap map) {
289  return map.isProjectedPermutation();
290  });
291  if (!allIndexingMapsAreProjectedPermutation) {
292  // TODO: handle non-projected permutations.
293  return op->emitOpError()
294  << "supports indexing maps that are only projected permutation.";
295  }
296 
297  SmallVector<utils::IteratorType> loopIteratorTypes =
298  linalgOp.getIteratorTypesArray();
299  ShardingArray gridAxisAssignmentForLoopIterators =
300  getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings,
301  loopIteratorTypes, indexingMaps);
303  loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
304  ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
306  linalgOp, partitionedOperands, operandShardings, resultShardings,
307  loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
308  symbolTable, implicitLocBuilder);
309  } else {
310  partitionTriviallyShardableOperation(*op, partitionedOperands,
311  operandShardings, resultShardings,
312  partitionMap, symbolTable, builder);
313  }
314 
315  return success();
316  }
317 };
318 
319 } // namespace
320 
321 template <typename OpType>
322 static void registerOne(MLIRContext *ctx) {
323  OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
324 }
325 
326 /// Variadic helper function.
327 template <typename... OpTypes>
328 static void registerAll(MLIRContext *ctx) {
329  (registerOne<OpTypes>(ctx), ...);
330 }
331 
333  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
334  DialectRegistry registry;
335  registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
336  tensor::TensorDialect>();
337  ctx->appendDialectRegistry(registry);
338  for (StringRef name : registry.getDialectNames())
339  ctx->getOrLoadDialect(name);
340 
341  registerOne<linalg::GenericOp>(ctx);
342  registerAll<
343 #define GET_OP_LIST
344 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
345  >(ctx);
346  });
347 }
348 
349 } // namespace mlir::linalg
IntegerType getI1Type()
Definition: Builders.cpp:52
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:656
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:100
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2727
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, int operandNumber, Value partitionedOperand, ArrayRef< GridAxis > reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder)
shard::Sharding Sharding
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< GridAxis > opReductionGridAxes, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
shard::ShardingArray ShardingArray
shard::ReductionKind ReductionKind
shard::GridOp GridOp
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, GridOp gridOp, ArrayRef< Value > partitionedOperands, ArrayRef< GridAxis > reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
static void partitionLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators, IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
shard::GridAxis GridAxis
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static GridOp getGrid(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, SymbolTableCollection &symbolTable)
int16_t GridAxis
Definition: ShardOps.h:26
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
SmallVector< SmallVector< GridAxis > > ShardingArray
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:228
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:121
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...