MLIR 23.0.0git
Partition.cpp
Go to the documentation of this file.
1//===- Partition.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
32#include <array>
33#include <iterator>
34#include <memory>
35#include <optional>
36#include <tuple>
37#include <utility>
38
39namespace mlir::shard {
40
41template <typename SourceAxes, typename TargetAxes>
42static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
43 const TargetAxes &targetAxes) {
44 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
45 return sourceAxes.contains(targetAxis);
46 });
47}
48
49/// Base class for resharding patterns.
50/// Subclasses implement `tryApply` to detect and apply a specific resharding.
52public:
53 virtual ~ReshardingPattern() = default;
54
55 /// Try to apply this resharding pattern. Returns the resharded value and
56 /// resulting sharding on success, or std::nullopt if the pattern doesn't
57 /// match.
58 virtual std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
59 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
60 const Sharding &srcSharding, const Sharding &tgtSharding,
61 ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard) = 0;
62
63protected:
64 /// Returns true if either sharding has non-empty static sharded dims offsets.
65 static bool hasStaticOffsets(const Sharding &srcSharding,
66 const Sharding &tgtSharding) {
67 return !srcSharding.getStaticShardedDimsOffsets().empty() ||
68 !tgtSharding.getStaticShardedDimsOffsets().empty();
69 }
70
71 /// Returns true if either sharding has non-empty static sharded dims offsets
72 /// or non-empty static halo sizes.
73 static bool hasStaticOffsetsOrHalos(const Sharding &srcSharding,
74 const Sharding &tgtSharding) {
75 return hasStaticOffsets(srcSharding, tgtSharding) ||
76 !srcSharding.getStaticHaloSizes().empty() ||
77 !tgtSharding.getStaticHaloSizes().empty();
78 }
79};
80
81/// Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]].
83 static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
84 int64_t splitTensorDim, GridAxis splitGridAxis) {
85 SmallVector<GridAxesAttr> tgtShardingSplitAxes =
86 llvm::to_vector(srcSharding.getSplitAxes());
87 while (static_cast<int64_t>(tgtShardingSplitAxes.size()) <=
88 splitTensorDim) {
89 tgtShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
90 }
91 auto tgtSplitAxes =
92 llvm::to_vector(tgtShardingSplitAxes[splitTensorDim].asArrayRef());
93 tgtSplitAxes.push_back(splitGridAxis);
94 tgtShardingSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
95 return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
96 }
97
98 // Split a replicated tensor along a grid axis.
99 // E.g. [[0, 1]] -> [[0, 1, 2]].
100 // Returns the partitioned target value with its sharding.
101 static std::tuple<TypedValue<ShapedType>, Sharding>
102 apply(ImplicitLocOpBuilder &builder, Sharding srcSharding,
103 TypedValue<ShapedType> srcShard, GridOp grid, int64_t splitTensorDim,
104 GridAxis splitGridAxis) {
105 TypedValue<ShapedType> tgtShard =
106 AllSliceOp::create(builder, srcShard, grid,
107 ArrayRef<GridAxis>(splitGridAxis), splitTensorDim)
108 .getResult();
109 Sharding resultSharding =
110 tgtSharding(builder.getContext(), std::move(srcSharding),
111 splitTensorDim, splitGridAxis);
112 return {tgtShard, resultSharding};
113 }
114
115 // Detect if the resharding is of type e.g.
116 // [[0, 1]] -> [[0, 1, 2]].
117 // If detected, returns the corresponding grid axis.
118 // Does not detect insertions like
119 // [[0, 1]] -> [[0, 2, 1]].
120 static std::optional<GridAxis> detect(const Sharding &srcSharding,
121 const Sharding &tgtSharding,
122 int64_t tensorDim) {
123 if (static_cast<size_t>(tensorDim) >= tgtSharding.getSplitAxes().size())
124 return std::nullopt;
125 auto tgtAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
126 if (srcSharding.getSplitAxes().size() > static_cast<size_t>(tensorDim)) {
127 auto srcAxes = srcSharding.getSplitAxes()[tensorDim].asArrayRef();
128 if (srcAxes.size() + 1 != tgtAxes.size())
129 return std::nullopt;
130 if (!llvm::equal(srcAxes,
131 llvm::make_range(tgtAxes.begin(), tgtAxes.end() - 1)))
132 return std::nullopt;
133 } else {
134 if (tgtAxes.size() != 1)
135 return std::nullopt;
136 }
137 return tgtAxes.back();
138 }
139
140public:
141 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
142 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
143 const Sharding &srcSharding, const Sharding &tgtSharding,
144 ShapedType srcUnshardedType,
145 TypedValue<ShapedType> srcShard) override {
146 if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
147 return std::nullopt;
148 if (auto gridAxis = detect(srcSharding, tgtSharding, tensorDim))
149 return apply(builder, srcSharding, srcShard, grid, tensorDim,
150 gridAxis.value());
151 return std::nullopt;
152 }
153};
154
155/// Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> [].
157 // Detect if the resharding removes trailing split axes along a tensor
158 // dimension, e.g.
159 // [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
160 // If detected, returns the removed trailing split axes (grid axes).
161 static std::optional<SmallVector<GridAxis>>
162 detect(const Sharding &srcSharding, const Sharding &tgtSharding,
163 int64_t tensorDim) {
164 if (static_cast<size_t>(tensorDim) >= srcSharding.getSplitAxes().size())
165 return std::nullopt;
166 size_t dimOff = 0;
167 auto srcSplitAxes = srcSharding.getSplitAxes()[tensorDim].asArrayRef();
168 if (tgtSharding.getSplitAxes().size() > static_cast<size_t>(tensorDim)) {
169 auto tgtSplitAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
170 // No match if the target sharding does not have less split axes than
171 // the source sharding along the current tensor dimension.
172 if (srcSplitAxes.size() <= tgtSplitAxes.size())
173 return std::nullopt;
174 // No match if the split axes of the target sharding are different from
175 // the first split axes of the source sharding.
176 if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(),
177 srcSplitAxes.begin()))
178 return std::nullopt;
179 dimOff = tgtSplitAxes.size();
180 } else {
181 // Here the target dimension is replicated; there is nothing to do if
182 // the source dimension is also replicated.
183 if (srcSplitAxes.size() == 0)
184 return std::nullopt;
185 dimOff = 0;
186 }
187 // This is a match. Return the trailing grid axes of the source sharding
188 // along this dimension.
189 ArrayRef<GridAxis> trailingAxes = srcSplitAxes.drop_front(dimOff);
190 SmallVector<GridAxis> unsplitAxes(trailingAxes.begin(), trailingAxes.end());
191 return unsplitAxes;
192 }
193
194 // Return the resulting Sharding if the unsplit last axes resharding is
195 // applied.
196 static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
197 int64_t splitTensorDim, size_t numUnsplitAxes) {
198 SmallVector<GridAxesAttr> resSplitAxes =
199 llvm::to_vector(srcSharding.getSplitAxes());
200 assert(static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
201 ArrayRef<GridAxis> srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef();
202 assert(srcSplitAxes.size() >= numUnsplitAxes);
203 size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
204 SmallVector<GridAxis> newSplitAxes(srcSplitAxes.begin(),
205 srcSplitAxes.begin() + numSplitAxes);
206 resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes);
207 return Sharding::get(srcSharding.getGridAttr(), resSplitAxes);
208 }
209
210 // Return the resulting Tensor type after applying the unsplit last axes
211 // resharding.
212 static ShapedType allGatherResultType(ShapedType srcType,
213 int64_t splitTensorDim,
214 ArrayRef<int64_t> gridShape,
215 ArrayRef<GridAxis> unsplitAxes) {
216 SmallVector<int64_t> tgtShape = llvm::to_vector(srcType.getShape());
217 for (GridAxis gridAxis : unsplitAxes)
218 tgtShape[splitTensorDim] =
219 gatherDimension(tgtShape[splitTensorDim], gridShape[gridAxis]);
220 return srcType.cloneWith(tgtShape, srcType.getElementType());
221 }
222
223 // Perform the resharding for the unsplit last axes case.
224 // This basically performs an all-gather along the unsplit grid axes.
225 static std::tuple<TypedValue<ShapedType>, Sharding>
226 apply(ImplicitLocOpBuilder &builder, Sharding srcSharding,
227 ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
228 GridOp grid, int64_t splitTensorDim, ArrayRef<GridAxis> unsplitAxes) {
229 MLIRContext *ctx = builder.getContext();
230 builder.setInsertionPointAfterValue(srcShard);
231
232 Sharding resultSharding = tgtSharding(ctx, std::move(srcSharding),
233 splitTensorDim, unsplitAxes.size());
234 ShapedType agResultType = allGatherResultType(
235 srcShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
236 Value allGatherResult = AllGatherOp::create(
237 builder,
238 RankedTensorType::get(agResultType.getShape(),
239 agResultType.getElementType()),
240 grid.getSymName(), unsplitAxes, srcShard, APInt(64, splitTensorDim));
241 ShapedType tgtType =
242 shardShapedType(srcUnshardedType, grid, resultSharding);
243 TypedValue<ShapedType> tgtShard =
244 tensor::CastOp::create(builder, tgtType, allGatherResult).getResult();
245 return {tgtShard, resultSharding};
246 }
247
248public:
249 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
250 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
251 const Sharding &srcSharding, const Sharding &tgtSharding,
252 ShapedType srcUnshardedType,
253 TypedValue<ShapedType> srcShard) override {
254 if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
255 return std::nullopt;
256 if (auto gridAxes = detect(srcSharding, tgtSharding, tensorDim))
257 return apply(builder, srcSharding, srcUnshardedType, srcShard, grid,
258 tensorDim, gridAxes.value());
259 return std::nullopt;
260 }
261};
262
263/// Move a split axis between tensor dimensions:
264/// e.g. [[0], []] -> [[], [0]].
266 // Detect if the resharding moves a single split axis from one tensor
267 // dimension to another tensor dimension. If detected, returns the
268 // corresponding (tgt_tensor_dim, grid_axis) pair.
269 static std::optional<std::tuple<int64_t, GridAxis>>
270 detect(const Sharding &srcSharding, const Sharding &tgtSharding,
271 int64_t srcTensorDim) {
272 if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
273 return std::nullopt;
274 auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
275 if (srcAxes.size() != 1)
276 return std::nullopt;
277 for (size_t tgtTensorDim = 0;
278 tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
279 if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
280 continue;
281 auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
282 if (tgtAxes.size() != 1 || srcAxes.front() != tgtAxes.front())
283 continue;
284 return std::make_tuple(static_cast<int64_t>(tgtTensorDim),
285 srcAxes.front());
286 }
287 return std::nullopt;
288 }
289
290 static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
291 int64_t srcTensorDim, int64_t tgtTensorDim) {
292 SmallVector<GridAxesAttr> tgtShardingSplitAxes =
293 llvm::to_vector(srcSharding.getSplitAxes());
294 while (static_cast<int64_t>(tgtShardingSplitAxes.size()) <= tgtTensorDim) {
295 tgtShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
296 }
297
298 auto srcSplitAxes =
299 llvm::to_vector(tgtShardingSplitAxes[srcTensorDim].asArrayRef());
300 assert(srcSplitAxes.size() == 1);
301 auto gridAxis = srcSplitAxes.back();
302 srcSplitAxes.pop_back();
303 tgtShardingSplitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
304
305 auto tgtSplitAxes =
306 llvm::to_vector(tgtShardingSplitAxes[tgtTensorDim].asArrayRef());
307 tgtSplitAxes.push_back(gridAxis);
308 tgtShardingSplitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
309
310 return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
311 }
312
313 static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
314 int64_t srcTensorDim,
315 int64_t tgtTensorDim) {
316 SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
317 tgtShape[srcTensorDim] =
318 gatherDimension(tgtShape[srcTensorDim], splitCount);
319 tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
320 return srcShape.cloneWith(tgtShape, srcShape.getElementType());
321 }
322
323 static std::tuple<TypedValue<ShapedType>, Sharding>
324 apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding,
325 ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
326 int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis gridAxis) {
327 MLIRContext *ctx = builder.getContext();
328 builder.setInsertionPointAfterValue(srcShard);
329
330 Sharding resultSharding =
331 tgtSharding(ctx, std::move(srcSharding), srcTensorDim, tgtTensorDim);
332 ShapedType a2aResultShape =
333 allToAllResultShape(srcShard.getType(), grid.getShape()[gridAxis],
334 srcTensorDim, tgtTensorDim);
335 Value allToAllResult = AllToAllOp::create(
336 builder,
337 RankedTensorType::get(a2aResultShape.getShape(),
338 a2aResultShape.getElementType()),
339 grid.getSymName(), SmallVector<GridAxis>({gridAxis}), srcShard,
340 APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
341 ShapedType tgtShape =
342 shardShapedType(srcUnshardedType, grid, resultSharding);
343 TypedValue<ShapedType> tgtShard =
344 tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
345 return {tgtShard, resultSharding};
346 }
347
348public:
349 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
350 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
351 const Sharding &srcSharding, const Sharding &tgtSharding,
352 ShapedType srcUnshardedType,
353 TypedValue<ShapedType> srcShard) override {
354 if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
355 return std::nullopt;
356 if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
357 auto [tgtTensorDim, gridAxis] = detectRes.value();
358 return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
359 tensorDim, tgtTensorDim, gridAxis);
360 }
361 return std::nullopt;
362 }
363};
364
365/// Update halo sizes: handles cases where only the halo sizes differ between
366/// source and target sharding. Requires copying the "core" of the source tensor
367/// into the "core" of the destination tensor followed by an update halo op.
369public:
370 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
371 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
372 const Sharding &srcSharding, const Sharding &tgtSharding,
373 ShapedType srcUnshardedType,
374 TypedValue<ShapedType> srcShard) override {
375 // UpdateHaloPattern handles all dimensions at once; only trigger on dim 0.
376 if (tensorDim != 0)
377 return std::nullopt;
378 // Currently handles only cases where halo sizes differ but everything else
379 // stays the same (from source to destination sharding).
380 if (!srcSharding.equalSplitAxes(tgtSharding) ||
381 hasStaticOffsets(srcSharding, tgtSharding) ||
382 srcSharding.equalHaloSizes(tgtSharding)) {
383 return std::nullopt;
384 }
385
386 auto srcHaloSizes = srcSharding.getStaticHaloSizes();
387 auto tgtHaloSizes = tgtSharding.getStaticHaloSizes();
388 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
389 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
390 ShapedType::isStaticShape(tgtHaloSizes) &&
391 srcShard.getType().hasStaticShape()) &&
392 "dynamic shapes/halos are not supported yet for shard-partition");
393 auto rank = srcShard.getType().getRank();
394 auto splitAxes = srcSharding.getSplitAxes();
395 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
396 strides(rank, 1), outShape(srcShard.getType().getShape()),
397 coreShape(srcShard.getType().getShape());
398
399 // Determine "core" of source and destination.
400 // The core is the local part of the shard excluding halo regions.
401 for (auto i = 0u; i < rank; ++i) {
402 if (i < splitAxes.size() && !splitAxes[i].empty()) {
403 if (!srcHaloSizes.empty()) {
404 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
405 srcCoreOffs[i] = srcHaloSizes[i * 2];
406 }
407 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
408 outShape[i] =
409 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
410 }
411 }
412
413 // Extract core from source and copy into destination core.
414 auto noVals = ValueRange{};
415 auto initVal = tensor::EmptyOp::create(builder, srcShard.getLoc(), outShape,
416 srcShard.getType().getElementType());
417 auto core = tensor::ExtractSliceOp::create(
418 builder, srcShard.getLoc(),
419 RankedTensorType::get(coreShape, srcShard.getType().getElementType()),
420 srcShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
421 auto initOprnd = tensor::InsertSliceOp::create(
422 builder, srcShard.getLoc(), core, initVal, noVals, noVals, noVals,
423 tgtCoreOffs, coreShape, strides);
424
425 // Finally update the halo.
426 auto updateHaloResult =
427 UpdateHaloOp::create(builder, srcShard.getLoc(),
428 RankedTensorType::get(
429 outShape, srcShard.getType().getElementType()),
430 initOprnd, grid.getSymName(),
431 GridAxesArrayAttr::get(builder.getContext(),
432 srcSharding.getSplitAxes()),
433 tgtSharding.getDynamicHaloSizes(),
434 tgtSharding.getStaticHaloSizes())
435 .getResult();
436 return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
437 tgtSharding);
438 }
439};
440
441// In most cases the sharded tensor axes must be exactly divisible by the single
442// grid axis size. Only halo size changes can deal with non-divisible cases.
444 GridOp grid, const Sharding &srcSharding,
445 const Sharding &tgtSharding,
446 TypedValue<ShapedType> unshardedSrc,
447 TypedValue<ShapedType> shardedSrc) {
448 // If source and destination sharding are the same, no need to do anything.
449 if (srcSharding == tgtSharding ||
450 (isFullReplication(srcSharding) && isFullReplication(tgtSharding))) {
451 return shardedSrc;
452 }
453
454 assert(shardedSrc.getType() ==
455 shardShapedType(unshardedSrc.getType(), grid, srcSharding));
456 [[maybe_unused]] ShapedType tgtShardType =
457 shardShapedType(unshardedSrc.getType(), grid, tgtSharding);
458 assert(shardedSrc.getType().getRank() == tgtShardType.getRank());
459 assert(unshardedSrc.getType().getRank() == tgtShardType.getRank());
460
461 // Each pattern's tryApply checks its own applicability preconditions.
462 static UpdateHaloPattern updateHaloPattern;
463 static MoveSplitAxisPattern moveSplitAxisPattern;
464 static SplitLastAxisPattern splitLastAxisPattern;
465 static UnsplitLastAxesPattern unsplitLastAxesPattern;
466 static ReshardingPattern *patterns[] = {
467 &updateHaloPattern, &moveSplitAxisPattern, &splitLastAxisPattern,
468 &unsplitLastAxesPattern};
469 TypedValue<ShapedType> currentShard = shardedSrc;
470 Sharding currentSharding = srcSharding;
471 for (int64_t dim = 0;
472 dim < tgtShardType.getRank() && currentSharding != tgtSharding; ++dim) {
473 for (auto &pattern : patterns) {
474 if (auto tryRes = pattern->tryApply(builder, grid, dim, currentSharding,
475 tgtSharding, unshardedSrc.getType(),
476 currentShard)) {
477 std::tie(currentShard, currentSharding) = tryRes.value();
478 break;
479 }
480 }
481 }
482
483 if (currentSharding != tgtSharding ||
484 currentShard.getType() != tgtShardType) {
485 builder.emitError()
486 << "Failed to reshard; probably hitting an unknown resharding pattern:"
487 << " got " << currentSharding << " expected " << tgtSharding
488 << " got type " << currentShard.getType() << " expected "
489 << tgtShardType;
490 return TypedValue<ShapedType>();
491 }
492 return currentShard;
493}
494
496 ShardOp srcShardOp, ShardOp tgtShardOp,
497 TypedValue<ShapedType> shardedSrc) {
498 assert(srcShardOp.getResult() == tgtShardOp.getSrc());
499 auto srcSharding = srcShardOp.getSharding();
500 auto tgtSharding = tgtShardOp.getSharding();
501 ImplicitLocOpBuilder implicitLocOpBuilder(tgtShardOp->getLoc(), builder);
502 return reshard(implicitLocOpBuilder, grid, srcSharding, tgtSharding,
503 srcShardOp.getSrc(), shardedSrc);
504}
505
506TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp srcShardOp,
507 ShardOp tgtShardOp,
508 TypedValue<ShapedType> shardedSrc,
509 SymbolTableCollection &symbolTableCollection) {
510 GridOp srcGrid = getGrid(srcShardOp, symbolTableCollection);
511 assert(srcGrid && srcGrid == getGrid(tgtShardOp, symbolTableCollection));
512 return reshard(builder, srcGrid, srcShardOp, tgtShardOp, shardedSrc);
513}
514
516 registry.insert<shard::ShardDialect, tensor::TensorDialect>();
517}
518
519#define GEN_PASS_DEF_PARTITION
520#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
521
523
524// Get the types of block arguments for an partitioned block.
525// Reads the sharding annotations of the arguments to deduce the sharded types.
526// Types that are not ranked tensors are left unchanged.
529 SymbolTableCollection &symbolTableCollection) {
531 llvm::transform(
532 block.getArguments(), std::back_inserter(res),
533 [&symbolTableCollection](BlockArgument arg) {
534 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
535 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
536 rankedTensorArg.use_empty()) {
537 return arg.getType();
538 }
539
540 assert(rankedTensorArg.hasOneUse());
541 Operation *useOp = *rankedTensorArg.getUsers().begin();
542 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
543 assert(shardOp);
544 GridOp grid = getGrid(shardOp, symbolTableCollection);
545 return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
546 shardOp.getSharding()));
547 });
548 return res;
549}
550
551static LogicalResult
553 ArrayRef<Sharding> operandShardings,
554 ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
555 SymbolTableCollection &symbolTableCollection,
556 OpBuilder &builder) {
557 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
558 if (!shardingInterface) {
559 // If there is no sharding interface we are conservative and assume that
560 // the op should be fully replicated no all devices.
561 partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
562 resultShardings, partitionMap,
563 symbolTableCollection, builder);
564 } else {
565 if (failed(shardingInterface.partition(
566 partitionedOperands, operandShardings, resultShardings,
567 partitionMap, symbolTableCollection, builder))) {
568 return failure();
569 }
570 }
571
572 assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
573 return partitionMap.contains(result);
574 }));
575
576 return success();
577}
578
579// Retrieve the sharding annotations for the operands of the given operation.
580// If the type is not a ranked tensor it is not require to have an annotation.
581static std::vector<Sharding> getOperandShardings(Operation &op) {
582 std::vector<Sharding> res;
583 res.reserve(op.getNumOperands());
584 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
585 TypedValue<RankedTensorType> rankedTensor =
586 dyn_cast<TypedValue<RankedTensorType>>(operand);
587 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
588 return Sharding();
589 }
590
591 Operation *definingOp = operand.getDefiningOp();
592 assert(definingOp);
593 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
594 return Sharding(shardOp.getSharding());
595 });
596 return res;
597}
598
599// Retrieve the sharding annotations for the results of the given operation.
600// If the type is not a ranked tensor it is not require to have an annotation.
601static std::vector<Sharding> getResultShardings(Operation &op) {
602 std::vector<Sharding> res;
603 res.reserve(op.getNumResults());
604 llvm::transform(
605 op.getResults(), std::back_inserter(res), [&op](OpResult result) {
606 if (!result.hasOneUse() || result.use_empty()) {
607 return Sharding();
608 }
609 TypedValue<RankedTensorType> rankedTensor =
611 if (!rankedTensor) {
612 return Sharding();
613 }
614 Operation *userOp = *result.getUsers().begin();
615 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
616 if (shardOp) {
617 return Sharding(shardOp.getSharding());
618 }
619 if (rankedTensor.getType().getRank() == 0) {
620 // This is a 0d tensor result without explicit sharding.
621 // Find grid symbol from operands, if any.
622 // Shardings without grid are not always fully supported yet.
623 for (auto operand : op.getOperands()) {
624 if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
625 return Sharding(sharding.getGridAttr());
626 }
627 }
628 }
629 return Sharding();
630 });
631 return res;
632}
633
634static LogicalResult
635partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
636 SymbolTableCollection &symbolTableCollection,
637 OpBuilder &builder) {
638 Value tgtPartitionValue;
639
640 // Check if 2 shard ops are chained. If not there is no need for resharding
641 // as the source and target shared the same sharding.
642 ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
643 if (!srcShardOp) {
644 tgtPartitionValue = partitionMap.lookup(shardOp.getSrc());
645 } else {
646 // Insert resharding.
647 TypedValue<ShapedType> shardedSrc =
648 cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
649 tgtPartitionValue = reshard(builder, srcShardOp, shardOp, shardedSrc,
650 symbolTableCollection);
651 if (!tgtPartitionValue) {
652 return shardOp.emitError()
653 << "Failed to reshard from " << srcShardOp.getSharding() << " to "
654 << shardOp.getSharding();
655 }
656 }
657
658 assert(!partitionMap.contains(shardOp.getResult()));
659 partitionMap.map(shardOp.getResult(), tgtPartitionValue);
660 return success();
661}
662
663// Check if the block args are correctly annotated with sharding information:
664// - non-tensor, 0d-tensor and unused args are ignored
665// - each tensor arg must have exactly one use, which must be a shard.shard
666// operation
667static LogicalResult checkFullyAnnotated(Block &block) {
668 for (const BlockArgument &arg : block.getArguments()) {
669 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
670 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
671 rankedTensorArg.use_empty())
672 continue;
673
674 if (!rankedTensorArg.hasOneUse())
675 return emitError(block.getParent()->getLoc())
676 << "Cannot partition: expected a single use for block argument "
677 << arg.getArgNumber() << " in block "
678 << block.computeBlockNumber();
679
680 Operation *useOp = *rankedTensorArg.getUsers().begin();
681 auto shardOp = dyn_cast<ShardOp>(useOp);
682 if (!shardOp)
683 return emitError(block.getParent()->getLoc())
684 << "Cannot partition: expected a shard.shard op for block "
685 << "argument " << arg.getArgNumber() << " in block "
686 << block.computeBlockNumber();
687 }
688 return success();
689}
690
691// Check if the operation is correctly and fully annotated with sharding
692// information:
693// - Operation results must have exactly one use (e.g. the shard operation).
694// - All operands and all results must be annotated, e.g. they must be
695// produced by/consumed by a shard.shard operation.
696// - Result annotations must not include the 'annotate_for_users' attribute.
697// - Operand annotations must include the 'annotate_for_users' attribute.
698// raises an error if the operation is not correctly and fully annotated.
699static LogicalResult checkFullyAnnotated(Operation *op) {
700 // constant ops do not need to have sharding annotations
702 return success();
703
704 for (OpOperand &operand : op->getOpOperands()) {
705 // non-tensor and 0d-tensor operands are ignored
706 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
707 if (!rankedTT || rankedTT.getRank() == 0)
708 continue;
709
710 auto shard = operand.get().getDefiningOp<ShardOp>();
711 if (!shard)
712 return op->emitError() << "Cannot partition: tensor operand "
713 << operand.getOperandNumber()
714 << " must be defined by a shard.shard operation.";
715 if (!shard.getAnnotateForUsers())
716 return op->emitError()
717 << "Cannot partition: shard.shard for operand "
718 << operand.getOperandNumber() << " must set 'annotate_for_users'.";
719 }
720 for (const OpResult &result : op->getResults()) {
721 if (!result.hasOneUse())
722 return op->emitError()
723 << "Cannot partition: result " << result.getResultNumber()
724 << " must have exactly one use.";
725 auto shard = dyn_cast<ShardOp>(*result.user_begin());
726 if (!shard)
727 return op->emitError()
728 << "Cannot partition: user of result " << result.getResultNumber()
729 << " must be shard.shard operation.";
730 if (shard.getAnnotateForUsers())
731 return op->emitError() << "Cannot partition: shard.shard for result "
732 << result.getResultNumber()
733 << " must not set 'annotate_for_users'.";
734 }
735 return success();
736}
737
738static LogicalResult
740 SymbolTableCollection &symbolTableCollection,
741 OpBuilder &builder) {
742 if (isa<ShardingOp>(op)) {
743 return success();
744 }
745
746 if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
747 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
748 if (!shardOp) {
749 return op.emitError("expected a shard op as source of get_sharding");
750 }
751 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
752 partitionMap.map(op.getResult(0), newSharding->getResult(0));
753 return success();
754 }
755
756 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
757 if (shardOp) {
758 return partitionOperation(shardOp, partitionMap, symbolTableCollection,
759 builder);
760 }
761
762 // Check if operation is correctly and fully annotated.
763 if (failed(checkFullyAnnotated(&op)))
764 return failure();
765
766 SmallVector<Value> partitionedOperands;
767 llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
768 [&partitionMap](Value operand) {
769 assert(partitionMap.contains(operand));
770 return partitionMap.lookup(operand);
771 });
772 return partitionOperation(op, partitionedOperands, getOperandShardings(op),
773 getResultShardings(op), partitionMap,
774 symbolTableCollection, builder);
775}
776
777static LogicalResult
778partitionBlock(Block &block, IRMapping &partitionMap,
779 SymbolTableCollection &symbolTableCollection,
780 OpBuilder &builder) {
781
782 if (failed(checkFullyAnnotated(block)))
783 return failure();
784
785 SmallVector<Location> argLocations;
786 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
787 [](BlockArgument arg) { return arg.getLoc(); });
788 Block *newBlock = builder.createBlock(
789 block.getParent(), {},
790 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
791 for (auto [unshardedBlockArg, partitionedBlockArg] :
792 llvm::zip(block.getArguments(), newBlock->getArguments())) {
793 partitionMap.map(unshardedBlockArg, partitionedBlockArg);
794 }
795
796 OpBuilder::InsertionGuard insertionGuard(builder);
797 builder.setInsertionPointToEnd(newBlock);
798 for (Operation &op : block.getOperations()) {
799 if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
800 builder))) {
801 return failure();
802 }
803 }
804
805 return success();
806}
807
808static LogicalResult
809partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
810 SymbolTableCollection &symbolTableCollection) {
811 OpBuilder builder(op.getFunctionBody());
812
813 // Snapshot the original blocks to not mess up the iteration when adding new
814 // blocks.
815 SmallVector<Block *> originalBlocks;
816 for (Block &b : op.getBlocks()) {
817 if (llvm::any_of(b.getOperations(),
818 [](Operation &op) { return isa<ShardOp>(op); })) {
819 originalBlocks.push_back(&b);
820 }
821 }
822
823 for (Block *block : originalBlocks) {
824 if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
825 builder))) {
826 return failure();
827 }
828 }
829
830 for (Block *block : originalBlocks) {
831 block->erase();
832 }
833
834 // Find a return op and change the function results signature to its operands
835 // signature.
836 Operation *returnOp = nullptr;
837 for (Block &block : op.getFunctionBody()) {
838 if (block.empty()) {
839 continue;
840 }
841
842 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
843 returnOp = &block.back();
844 break;
845 }
846 }
847 if (returnOp) {
848 op.setType(FunctionType::get(
849 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
850 returnOp->getOperandTypes()));
851 }
852
853 return success();
854}
855
856namespace {
857
858struct Partition : public impl::PartitionBase<Partition> {
859 void runOnOperation() override {
860 IRMapping partitionMap;
861 SymbolTableCollection symbolTableCollection;
862 if (failed(partitionFuncOp(getOperation(), partitionMap,
863 symbolTableCollection))) {
864 return signalPassFailure();
865 }
866 }
867
868 void getDependentDialects(DialectRegistry &registry) const override {
870 registry.insert<shard::ShardDialect>();
871 }
872};
873
874} // namespace
875
876} // namespace mlir::shard
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
BlockArgListType getArguments()
Definition Block.h:97
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
Definition Block.cpp:144
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
mlir::InFlightDiagnostic emitError(const llvm::Twine &message=llvm::Twine())
This builder can also be used to emit diagnostics to the current location.
Definition Builders.h:703
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:412
unsigned getNumOperands()
Definition Operation.h:375
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:426
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Move a split axis between tensor dimensions: e.g.
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Base class for resharding patterns.
Definition Partition.cpp:51
virtual std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard)=0
Try to apply this resharding pattern.
static bool hasStaticOffsetsOrHalos(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets or non-empty static halo si...
Definition Partition.cpp:73
static bool hasStaticOffsets(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets.
Definition Partition.cpp:65
virtual ~ReshardingPattern()=default
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:792
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:688
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:64
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:61
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:68
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:65
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:63
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:727
Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]].
Definition Partition.cpp:82
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> [].
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Update halo sizes: handles cases where only the halo sizes differ between source and target sharding.
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static LogicalResult checkFullyAnnotated(Block &block)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition Partition.cpp:42
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:116
int16_t GridAxis
Definition ShardOps.h:27
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Sharding > getOperandShardings(Operation &op)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:178
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:131
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:187
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
This trait indicates that a terminator operation is "return-like".