MLIR 22.0.0git
ShardingInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ShardingInterfaceImpl.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
23#include "mlir/IR/AffineExpr.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
29#include "mlir/IR/SymbolTable.h"
30#include "mlir/IR/Value.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include <numeric>
35#include <optional>
36
37namespace mlir::linalg {
38
40using ReductionKind = shard::ReductionKind;
43using GridOp = shard::GridOp;
44
45// Returns the corresponding grid reduction kind for the given arith op.
48 // Floating-point operations.
49 .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
50 .Case([](arith::MulFOp op) { return ReductionKind::Product; })
51 // TODO: handle maxnumf and minnumf.
52 .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
53 .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
54 // Integer operations.
55 .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
56 .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
57 .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
58 .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
59 // TODO: handle signless, signed and unsigned types properly.
60 // It is assumed that the element type of the collective operands and
61 // result drive the meaning of the reduction kind, whether it is signed
62 // or unsigned.
63 // The reduction op inside the linalg op may have different result type
64 // from the element type of the linalg op's result.
65 // Also signed and unsigned Arith dialect ops may accept signed, unsigned
66 // or signless operands.
67 // Maybe expand the reduction kinds.
68 .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
69 .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
70 .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
71 .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
72 .Case([](arith::MulIOp op) { return ReductionKind::Product; })
73 .Default([](Operation *op) { return ReductionKind::Generic; });
74}
75
76static std::optional<Operation *> getCombinerOp(LinalgOp op) {
77 SmallVector<Operation *> combinerOps;
78 Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
79 if (!reducedValue || combinerOps.size() != 1) {
80 return std::nullopt;
81 }
82
83 return combinerOps[0];
84}
85
87 std::optional<Operation *> reductionOp = getCombinerOp(op);
88 if (!reductionOp) {
89 return ReductionKind::Generic;
90 }
91 [[maybe_unused]] Type resultElementType =
92 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
93 // TODO: handle case when result type of the reduction op does not match the
94 // element type of the result tensor.
95 // Would it makes sense at all?
96 assert(resultElementType == reductionOp.value()->getResult(0).getType());
97 return getReductionKind(reductionOp.value());
98}
99
100static GridOp getGrid(Operation *op, ArrayRef<Sharding> operandShardings,
101 ArrayRef<Sharding> resultShardings,
102 SymbolTableCollection &symbolTable) {
103 for (const Sharding &sharding : operandShardings) {
104 if (sharding) {
105 return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
106 }
107 }
108
109 for (const Sharding &sharding : resultShardings) {
110 if (sharding) {
111 return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
112 }
113 }
114
115 assert(false);
116 return nullptr;
117}
118
119// Choose the operand based on the current process index along the reduction
120// grid axes.
121// We need to use the initial value only once to avoid including it in the
122// reduction multiple times.
123// In each process group only the leading process with linear index 0 would use
124// the original operand.
125// The other processes would use the reduction operation neutral tensor.
127 LinalgOp op, int operandNumber, Value partitionedOperand,
128 ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
129 ImplicitLocOpBuilder &builder) {
130 Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
131 gridOp.getSymName(), reductionGridAxes, builder);
132 Value zero = arith::ConstantIndexOp::create(builder, 0);
133 Value isLeadProcess = arith::CmpIOp::create(
134 builder, builder.getI1Type(), arith::CmpIPredicate::eq,
135 processLinearIndexInReductionGroup, zero);
136 scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.getType(),
137 isLeadProcess, true, true);
138 // Then block.
139 {
140 OpBuilder::InsertionGuard insertionGuard(builder);
141 builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
142 scf::YieldOp::create(builder, partitionedOperand);
143 }
144
145 // Else block.
146 {
147 OpBuilder::InsertionGuard insertionGuard(builder);
148 builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
150 tensor::getMixedSizes(builder, builder.getLoc(), partitionedOperand);
151
152 SmallVector<Operation *> combinerOps;
153 matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
154 assert(combinerOps.size() == 1);
155 std::optional<TypedAttr> neutralEl =
156 arith::getNeutralElement(combinerOps[0]);
157
158 Value init = tensor::EmptyOp::create(builder, op.getLoc(), shape,
159 neutralEl.value().getType());
160 Value constant =
161 arith::ConstantOp::create(builder, op.getLoc(), neutralEl.value());
162 Value fill = linalg::FillOp::create(builder, op.getLoc(), constant, init)
163 .getResult(0);
164
165 scf::YieldOp::create(builder, fill);
166 }
167 return ifOp.getResult(0);
168}
169
170// Create the DPS init operands for the partitioned Linalg op.
171// Return all the new partitioned operands.
173 LinalgOp op, GridOp gridOp, ArrayRef<Value> partitionedOperands,
174 ArrayRef<GridAxis> reductionGridAxes, IRMapping &partitionMap,
175 ImplicitLocOpBuilder &builder) {
176 // TODO: add support for multiple destination passing style initial value
177 // operands.
178 assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
179 SmallVector<Value> newOperands = llvm::to_vector(partitionedOperands);
180 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
181 Value partitionedInitOperand =
182 partitionMap.lookup(op->getOperands()[operandIdx]);
183 newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
184 op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder);
185 return newOperands;
186}
187
189 LinalgOp unshardedOp, ArrayRef<GridAxis> opReductionGridAxes,
190 ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
191 ImplicitLocOpBuilder &builder) {
192 ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
193 for (auto [unshardedLinalgOpResult, resultSharding] :
194 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
195 Value partitionedLinalgOpResult =
196 partitionMap.lookup(unshardedLinalgOpResult);
197 Value reducedValue = shard::AllReduceOp::create(
198 builder, partitionedLinalgOpResult, resultSharding.getGrid(),
199 opReductionGridAxes, reductionKind);
200 partitionMap.map(unshardedLinalgOpResult, reducedValue);
201 }
202}
203
205 LinalgOp op, ArrayRef<Value> partitionedOperands,
206 ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
207 ArrayRef<utils::IteratorType> loopIteratorTypes,
208 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators,
209 IRMapping &partitionMap, SymbolTableCollection &symbolTable,
210 ImplicitLocOpBuilder &builder) {
211 GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable);
213 loopIteratorTypes, gridAxisAssignmentForLoopIterators);
214 SmallVector<Value> partitionedLinalgOpOperands =
215 createDestinationPassingStyleInitOperands(op, grid, partitionedOperands,
216 reductionGridAxes, partitionMap,
217 builder);
218 // We must not change the operand mappings of the original partitionMap as
219 // they are the mappings for the whole partition blob and may be used by
220 // others.
221 IRMapping internalPartitionMap;
222 for (auto [unshardedOperand, partitionedOperand] :
223 llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) {
224 internalPartitionMap.map(unshardedOperand, partitionedOperand);
225 }
226 partitionTriviallyShardableOperation(
227 *op, partitionedLinalgOpOperands, operandShardings, resultShardings,
228 internalPartitionMap, symbolTable, builder);
229 for (Value result : op->getResults()) {
230 partitionMap.map(result, internalPartitionMap.lookup(result));
231 }
232
233 // Handle partial shardings.
235 op, reductionGridAxes, resultShardings, partitionMap, builder);
236}
237
238namespace {
239
240// ShardingInterface for ops that implement LinalgStructuredInterface.
241// The supported ops are only those where the indexing maps are projected
242// permutations.
243template <typename Op>
244struct StructuredOpShardingInterface
245 : public shard::ShardingInterface::ExternalModel<
246 StructuredOpShardingInterface<Op>, Op> {
247 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
248 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
249 }
250
251 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
252 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
253 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
254
255 // Results must have the same indexing as destination passing style initial
256 // operands.
257 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
258 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
259 }
260
261 return res;
262 }
263
264 SmallVector<ReductionKind>
265 getReductionLoopIteratorKinds(Operation *op) const {
266 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
267 SmallVector<utils::IteratorType> iteratorTypes =
268 linalgOp.getIteratorTypesArray();
269 unsigned reductionItersCount = llvm::accumulate(
270 iteratorTypes, 0u, [](unsigned count, utils::IteratorType iter) {
271 return count + (iter == utils::IteratorType::reduction);
272 });
273 shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
274 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
275 }
276
277 LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
278 ArrayRef<Sharding> operandShardings,
279 ArrayRef<Sharding> resultShardings,
280 IRMapping &partitionMap,
281 SymbolTableCollection &symbolTable,
282 OpBuilder &builder) const {
283 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
284
285 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
286 bool allIndexingMapsAreProjectedPermutation =
287 llvm::all_of(indexingMaps, [](AffineMap map) {
288 return map.isProjectedPermutation();
289 });
290 if (!allIndexingMapsAreProjectedPermutation) {
291 // TODO: handle non-projected permutations.
292 return op->emitOpError()
293 << "supports indexing maps that are only projected permutation.";
294 }
295
296 SmallVector<utils::IteratorType> loopIteratorTypes =
297 linalgOp.getIteratorTypesArray();
298 ShardingArray gridAxisAssignmentForLoopIterators =
299 getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings,
300 loopIteratorTypes, indexingMaps);
302 loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
303 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
305 linalgOp, partitionedOperands, operandShardings, resultShardings,
306 loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
307 symbolTable, implicitLocBuilder);
308 } else {
309 partitionTriviallyShardableOperation(*op, partitionedOperands,
310 operandShardings, resultShardings,
311 partitionMap, symbolTable, builder);
312 }
313
314 return success();
315 }
316};
317
318} // namespace
319
320template <typename OpType>
321static void registerOne(MLIRContext *ctx) {
322 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
323}
324
325/// Variadic helper function.
326template <typename... OpTypes>
327static void registerAll(MLIRContext *ctx) {
328 (registerOne<OpTypes>(ctx), ...);
329}
330
332 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
333 DialectRegistry registry;
334 registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
335 tensor::TensorDialect>();
336 ctx->appendDialectRegistry(registry);
337 for (StringRef name : registry.getDialectNames())
338 ctx->getOrLoadDialect(name);
339
342#define GET_OP_LIST
343#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
344 >(ctx);
345 });
346}
347
348} // namespace mlir::linalg
return success()
IntegerType getI1Type()
Definition Builders.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
shard::Sharding Sharding
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void partitionLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators, IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKind(Operation *op)
shard::ShardingArray ShardingArray
static Value createDestinationPassingStyleInitOperand(LinalgOp op, int operandNumber, Value partitionedOperand, ArrayRef< GridAxis > reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder)
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< GridAxis > opReductionGridAxes, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
shard::ReductionKind ReductionKind
shard::GridAxis GridAxis
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, GridOp gridOp, ArrayRef< Value > partitionedOperands, ArrayRef< GridAxis > reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry &registry)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static GridOp getGrid(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, SymbolTableCollection &symbolTable)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
int16_t GridAxis
Definition ShardOps.h:26
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
SmallVector< SmallVector< GridAxis > > ShardingArray
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...