MLIR 23.0.0git
Partition.cpp
Go to the documentation of this file.
1//===- Partition.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
32#include <array>
33#include <iterator>
34#include <memory>
35#include <optional>
36#include <tuple>
37#include <utility>
38
39namespace mlir::shard {
40
41template <typename SourceAxes, typename TargetAxes>
42static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
43 const TargetAxes &targetAxes) {
44 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
45 return sourceAxes.contains(targetAxis);
46 });
47}
48
49/// Base class for resharding patterns.
50/// Subclasses implement `tryApply` to detect and apply a specific resharding.
52public:
53 virtual ~ReshardingPattern() = default;
54
55 /// Try to apply this resharding pattern. Returns the resharded value and
56 /// resulting sharding on success, or std::nullopt if the pattern doesn't
57 /// match.
58 virtual std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
59 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
60 const Sharding &srcSharding, const Sharding &tgtSharding,
61 ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard) = 0;
62
63protected:
64 /// Returns true if either sharding has non-empty static sharded dims offsets.
65 static bool hasStaticOffsets(const Sharding &srcSharding,
66 const Sharding &tgtSharding) {
67 return !srcSharding.getStaticShardedDimsOffsets().empty() ||
68 !tgtSharding.getStaticShardedDimsOffsets().empty();
69 }
70
71 /// Returns true if either sharding has non-empty static sharded dims offsets
72 /// or non-empty static halo sizes.
73 static bool hasStaticOffsetsOrHalos(const Sharding &srcSharding,
74 const Sharding &tgtSharding) {
75 return hasStaticOffsets(srcSharding, tgtSharding) ||
76 !srcSharding.getStaticHaloSizes().empty() ||
77 !tgtSharding.getStaticHaloSizes().empty();
78 }
79};
80
81/// Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]].
83 static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
84 int64_t splitTensorDim, GridAxis splitGridAxis) {
85 SmallVector<GridAxesAttr> tgtShardingSplitAxes =
86 llvm::to_vector(srcSharding.getSplitAxes());
87 while (static_cast<int64_t>(tgtShardingSplitAxes.size()) <=
88 splitTensorDim) {
89 tgtShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
90 }
91 auto tgtSplitAxes =
92 llvm::to_vector(tgtShardingSplitAxes[splitTensorDim].asArrayRef());
93 tgtSplitAxes.push_back(splitGridAxis);
94 tgtShardingSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
95 return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
96 }
97
98 // Split a replicated tensor along a grid axis.
99 // E.g. [[0, 1]] -> [[0, 1, 2]].
100 // Returns the partitioned target value with its sharding.
101 static std::tuple<TypedValue<ShapedType>, Sharding>
102 apply(ImplicitLocOpBuilder &builder, Sharding srcSharding,
103 TypedValue<ShapedType> srcShard, GridOp grid, int64_t splitTensorDim,
104 GridAxis splitGridAxis) {
105 TypedValue<ShapedType> tgtShard =
106 AllSliceOp::create(builder, srcShard, grid,
107 ArrayRef<GridAxis>(splitGridAxis), splitTensorDim)
108 .getResult();
109 Sharding resultSharding =
110 tgtSharding(builder.getContext(), std::move(srcSharding),
111 splitTensorDim, splitGridAxis);
112 return {tgtShard, resultSharding};
113 }
114
115 // Detect if the resharding is of type e.g.
116 // [[0, 1]] -> [[0, 1, 2]].
117 // If detected, returns the corresponding grid axis.
118 // Does not detect insertions like
119 // [[0, 1]] -> [[0, 2, 1]].
120 static std::optional<GridAxis> detect(const Sharding &srcSharding,
121 const Sharding &tgtSharding,
122 int64_t tensorDim) {
123 if (static_cast<size_t>(tensorDim) >= tgtSharding.getSplitAxes().size())
124 return std::nullopt;
125 auto tgtAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
126 if (srcSharding.getSplitAxes().size() > static_cast<size_t>(tensorDim)) {
127 auto srcAxes = srcSharding.getSplitAxes()[tensorDim].asArrayRef();
128 if (srcAxes.size() + 1 != tgtAxes.size())
129 return std::nullopt;
130 if (!llvm::equal(srcAxes,
131 llvm::make_range(tgtAxes.begin(), tgtAxes.end() - 1)))
132 return std::nullopt;
133 } else {
134 if (tgtAxes.size() != 1)
135 return std::nullopt;
136 }
137 return tgtAxes.back();
138 }
139
140public:
141 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
142 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
143 const Sharding &srcSharding, const Sharding &tgtSharding,
144 ShapedType srcUnshardedType,
145 TypedValue<ShapedType> srcShard) override {
146 if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
147 return std::nullopt;
148 if (auto gridAxis = detect(srcSharding, tgtSharding, tensorDim))
149 return apply(builder, srcSharding, srcShard, grid, tensorDim,
150 gridAxis.value());
151 return std::nullopt;
152 }
153};
154
155/// Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> [].
157 // Detect if the resharding removes trailing split axes along a tensor
158 // dimension, e.g.
159 // [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
160 // If detected, returns the removed trailing split axes (grid axes).
161 static std::optional<SmallVector<GridAxis>>
162 detect(const Sharding &srcSharding, const Sharding &tgtSharding,
163 int64_t tensorDim) {
164 if (static_cast<size_t>(tensorDim) >= srcSharding.getSplitAxes().size())
165 return std::nullopt;
166 size_t dimOff = 0;
167 auto srcSplitAxes = srcSharding.getSplitAxes()[tensorDim].asArrayRef();
168 if (tgtSharding.getSplitAxes().size() > static_cast<size_t>(tensorDim)) {
169 auto tgtSplitAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
170 // No match if the target sharding does not have less split axes than
171 // the source sharding along the current tensor dimension.
172 if (srcSplitAxes.size() <= tgtSplitAxes.size())
173 return std::nullopt;
174 // No match if the split axes of the target sharding are different from
175 // the first split axes of the source sharding.
176 if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(),
177 srcSplitAxes.begin()))
178 return std::nullopt;
179 dimOff = tgtSplitAxes.size();
180 } else {
181 // Here the target dimension is replicated; there is nothing to do if
182 // the source dimension is also replicated.
183 if (srcSplitAxes.size() == 0)
184 return std::nullopt;
185 dimOff = 0;
186 }
187 // This is a match. Return the trailing grid axes of the source sharding
188 // along this dimension.
189 ArrayRef<GridAxis> trailingAxes = srcSplitAxes.drop_front(dimOff);
190 SmallVector<GridAxis> unsplitAxes(trailingAxes.begin(), trailingAxes.end());
191 return unsplitAxes;
192 }
193
194 // Return the resulting Sharding if the unsplit last axes resharding is
195 // applied.
196 static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
197 int64_t splitTensorDim, size_t numUnsplitAxes) {
198 SmallVector<GridAxesAttr> resSplitAxes =
199 llvm::to_vector(srcSharding.getSplitAxes());
200 assert(static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
201 ArrayRef<GridAxis> srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef();
202 assert(srcSplitAxes.size() >= numUnsplitAxes);
203 size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
204 SmallVector<GridAxis> newSplitAxes(srcSplitAxes.begin(),
205 srcSplitAxes.begin() + numSplitAxes);
206 resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes);
207 return Sharding::get(srcSharding.getGridAttr(), resSplitAxes);
208 }
209
210 // Return the resulting Tensor type after applying the unsplit last axes
211 // resharding.
212 static ShapedType allGatherResultType(ShapedType srcType,
213 int64_t splitTensorDim,
214 ArrayRef<int64_t> gridShape,
215 ArrayRef<GridAxis> unsplitAxes) {
216 SmallVector<int64_t> tgtShape = llvm::to_vector(srcType.getShape());
217 for (GridAxis gridAxis : unsplitAxes)
218 tgtShape[splitTensorDim] =
219 gatherDimension(tgtShape[splitTensorDim], gridShape[gridAxis]);
220 return srcType.cloneWith(tgtShape, srcType.getElementType());
221 }
222
223 // Perform the resharding for the unsplit last axes case.
224 // This basically performs an all-gather along the unsplit grid axes.
225 static std::tuple<TypedValue<ShapedType>, Sharding>
226 apply(ImplicitLocOpBuilder &builder, Sharding srcSharding,
227 ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
228 GridOp grid, int64_t splitTensorDim, ArrayRef<GridAxis> unsplitAxes) {
229 MLIRContext *ctx = builder.getContext();
230 builder.setInsertionPointAfterValue(srcShard);
231
232 Sharding resultSharding = tgtSharding(ctx, std::move(srcSharding),
233 splitTensorDim, unsplitAxes.size());
234 ShapedType agResultType = allGatherResultType(
235 srcShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
236 Value allGatherResult = AllGatherOp::create(
237 builder,
238 RankedTensorType::get(agResultType.getShape(),
239 agResultType.getElementType()),
240 grid.getSymName(), unsplitAxes, srcShard, APInt(64, splitTensorDim));
241 ShapedType tgtType =
242 shardShapedType(srcUnshardedType, grid, resultSharding);
243 TypedValue<ShapedType> tgtShard =
244 tensor::CastOp::create(builder, tgtType, allGatherResult).getResult();
245 return {tgtShard, resultSharding};
246 }
247
248public:
249 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
250 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
251 const Sharding &srcSharding, const Sharding &tgtSharding,
252 ShapedType srcUnshardedType,
253 TypedValue<ShapedType> srcShard) override {
254 if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
255 return std::nullopt;
256 if (auto gridAxes = detect(srcSharding, tgtSharding, tensorDim))
257 return apply(builder, srcSharding, srcUnshardedType, srcShard, grid,
258 tensorDim, gridAxes.value());
259 return std::nullopt;
260 }
261};
262
263// Compute the result shape of an all-to-all that gathers along srcTensorDim
264// and scatters along tgtTensorDim with the given split count.
265static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
266 int64_t srcTensorDim,
267 int64_t tgtTensorDim) {
268 SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
269 tgtShape[srcTensorDim] = gatherDimension(tgtShape[srcTensorDim], splitCount);
270 tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
271 return srcShape.cloneWith(tgtShape, srcShape.getElementType());
272}
273
274/// Move the last split axis of one tensor dimension to the front of another
275/// tensor dimension's split axes, e.g. [[0], []] -> [[], [0]] or
276/// [[0, 1], [2]] -> [[0], [1, 2]].
278 // Detect if the resharding moves the last grid axis of srcTensorDim to the
279 // front of another tensor dimension's split axes. If detected, returns
280 // (tgtTensorDim, movedGridAxis).
281 //
282 // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an] (n >= 1)
283 // tgt[srcTensorDim] = [a1,...,a(n-1)]
284 // src[tgtTensorDim] = [b1,...,bm] (m >= 0)
285 // tgt[tgtTensorDim] = [an, b1,...,bm]
286 static std::optional<std::tuple<int64_t, GridAxis>>
287 detect(const Sharding &srcSharding, const Sharding &tgtSharding,
288 int64_t srcTensorDim) {
289 if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
290 return std::nullopt;
291 auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
292 // Need at least 1 axis to move.
293 if (srcAxes.empty())
294 return std::nullopt;
295
296 // After the move the source tensor dim should lose its last axis.
297 if (static_cast<size_t>(srcTensorDim) >= tgtSharding.getSplitAxes().size())
298 return std::nullopt;
299 auto tgtSrcAxes = tgtSharding.getSplitAxes()[srcTensorDim].asArrayRef();
300 if (tgtSrcAxes.size() + 1 != srcAxes.size())
301 return std::nullopt;
302 // The remaining axes at srcTensorDim must be the same (prefix of source).
303 if (!llvm::equal(tgtSrcAxes,
304 llvm::make_range(srcAxes.begin(), srcAxes.end() - 1)))
305 return std::nullopt;
306
307 GridAxis movedAxis = srcAxes.back();
308
309 // Find a target tensor dimension whose split axes start with movedAxis
310 // and whose remaining axes match the source sharding at that dimension.
311 for (size_t tgtTensorDim = 0;
312 tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
313 if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
314 continue;
315 auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
316 // The target dimension must start with the moved axis.
317 if (tgtAxes.empty() || tgtAxes.front() != movedAxis)
318 continue;
319 // The remainder of tgtAxes must equal the source sharding at
320 // tgtTensorDim.
321 ArrayRef<GridAxis> srcTgtAxes =
322 static_cast<size_t>(tgtTensorDim) < srcSharding.getSplitAxes().size()
323 ? srcSharding.getSplitAxes()[tgtTensorDim].asArrayRef()
325 if (!llvm::equal(srcTgtAxes,
326 llvm::make_range(tgtAxes.begin() + 1, tgtAxes.end())))
327 continue;
328 return std::make_tuple(static_cast<int64_t>(tgtTensorDim), movedAxis);
329 }
330 return std::nullopt;
331 }
332
333 // Compute the result sharding after moving movedAxis from srcTensorDim
334 // to the front of tgtTensorDim.
335 static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
336 int64_t srcTensorDim, int64_t tgtTensorDim,
337 GridAxis movedAxis) {
338 SmallVector<GridAxesAttr> splitAxes =
339 llvm::to_vector(srcSharding.getSplitAxes());
340 while (static_cast<int64_t>(splitAxes.size()) <= tgtTensorDim)
341 splitAxes.push_back(GridAxesAttr::get(ctx, {}));
342
343 // Remove last axis from srcTensorDim.
344 auto srcSplitAxes = llvm::to_vector(splitAxes[srcTensorDim].asArrayRef());
345 assert(!srcSplitAxes.empty() && srcSplitAxes.back() == movedAxis);
346 srcSplitAxes.pop_back();
347 splitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
348
349 // Prepend movedAxis to tgtTensorDim.
350 auto tgtSplitAxes = llvm::to_vector(splitAxes[tgtTensorDim].asArrayRef());
351 tgtSplitAxes.insert(tgtSplitAxes.begin(), movedAxis);
352 splitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
353
354 return Sharding::get(srcSharding.getGridAttr(), splitAxes);
355 }
356
357 static std::tuple<TypedValue<ShapedType>, Sharding>
358 apply(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding,
359 ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
360 int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis movedAxis) {
361 MLIRContext *ctx = builder.getContext();
362 builder.setInsertionPointAfterValue(srcShard);
363
364 Sharding resultSharding =
365 tgtSharding(ctx, srcSharding, srcTensorDim, tgtTensorDim, movedAxis);
366 ShapedType a2aResultShape =
367 allToAllResultShape(srcShard.getType(), grid.getShape()[movedAxis],
368 srcTensorDim, tgtTensorDim);
369 Value allToAllResult = AllToAllOp::create(
370 builder,
371 RankedTensorType::get(a2aResultShape.getShape(),
372 a2aResultShape.getElementType()),
373 grid.getSymName(), SmallVector<GridAxis>({movedAxis}), srcShard,
374 APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
375 ShapedType tgtShape =
376 shardShapedType(srcUnshardedType, grid, resultSharding);
377 TypedValue<ShapedType> tgtShard =
378 tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
379 return {tgtShard, resultSharding};
380 }
381
382public:
383 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
384 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
385 const Sharding &srcSharding, const Sharding &tgtSharding,
386 ShapedType srcUnshardedType,
387 TypedValue<ShapedType> srcShard) override {
388 if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
389 return std::nullopt;
390 if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
391 auto [tgtTensorDim, movedAxis] = detectRes.value();
392 return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
393 tensorDim, tgtTensorDim, movedAxis);
394 }
395 return std::nullopt;
396 }
397};
398
399/// Update halo sizes: handles cases where only the halo sizes differ between
400/// source and target sharding. Requires copying the "core" of the source tensor
401/// into the "core" of the destination tensor followed by an update halo op.
403public:
404 std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
405 tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
406 const Sharding &srcSharding, const Sharding &tgtSharding,
407 ShapedType srcUnshardedType,
408 TypedValue<ShapedType> srcShard) override {
409 // UpdateHaloPattern handles all dimensions at once; only trigger on dim 0.
410 if (tensorDim != 0)
411 return std::nullopt;
412 // Currently handles only cases where halo sizes differ but everything else
413 // stays the same (from source to destination sharding).
414 if (!srcSharding.equalSplitAxes(tgtSharding) ||
415 hasStaticOffsets(srcSharding, tgtSharding) ||
416 srcSharding.equalHaloSizes(tgtSharding)) {
417 return std::nullopt;
418 }
419
420 auto srcHaloSizes = srcSharding.getStaticHaloSizes();
421 auto tgtHaloSizes = tgtSharding.getStaticHaloSizes();
422 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
423 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
424 ShapedType::isStaticShape(tgtHaloSizes) &&
425 srcShard.getType().hasStaticShape()) &&
426 "dynamic shapes/halos are not supported yet for shard-partition");
427 auto rank = srcShard.getType().getRank();
428 auto splitAxes = srcSharding.getSplitAxes();
429 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
430 strides(rank, 1), outShape(srcShard.getType().getShape()),
431 coreShape(srcShard.getType().getShape());
432
433 // Determine "core" of source and destination.
434 // The core is the local part of the shard excluding halo regions.
435 for (auto i = 0u; i < rank; ++i) {
436 if (i < splitAxes.size() && !splitAxes[i].empty()) {
437 if (!srcHaloSizes.empty()) {
438 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
439 srcCoreOffs[i] = srcHaloSizes[i * 2];
440 }
441 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
442 outShape[i] =
443 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
444 }
445 }
446
447 // Extract core from source and copy into destination core.
448 auto noVals = ValueRange{};
449 auto initVal = tensor::EmptyOp::create(builder, srcShard.getLoc(), outShape,
450 srcShard.getType().getElementType());
451 auto core = tensor::ExtractSliceOp::create(
452 builder, srcShard.getLoc(),
453 RankedTensorType::get(coreShape, srcShard.getType().getElementType()),
454 srcShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
455 auto initOprnd = tensor::InsertSliceOp::create(
456 builder, srcShard.getLoc(), core, initVal, noVals, noVals, noVals,
457 tgtCoreOffs, coreShape, strides);
458
459 // Finally update the halo.
460 auto updateHaloResult =
461 UpdateHaloOp::create(builder, srcShard.getLoc(),
462 RankedTensorType::get(
463 outShape, srcShard.getType().getElementType()),
464 initOprnd, grid.getSymName(),
465 GridAxesArrayAttr::get(builder.getContext(),
466 srcSharding.getSplitAxes()),
467 tgtSharding.getDynamicHaloSizes(),
468 tgtSharding.getStaticHaloSizes())
469 .getResult();
470 return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
471 tgtSharding);
472 }
473};
474
475// In most cases the sharded tensor axes must be exactly divisible by the single
476// grid axis size. Only halo size changes can deal with non-divisible cases.
478 GridOp grid, const Sharding &srcSharding,
479 const Sharding &tgtSharding,
480 TypedValue<ShapedType> unshardedSrc,
481 TypedValue<ShapedType> shardedSrc) {
482 // If source and destination sharding are the same, no need to do anything.
483 if (srcSharding == tgtSharding ||
484 (isFullReplication(srcSharding) && isFullReplication(tgtSharding))) {
485 return shardedSrc;
486 }
487
488 assert(shardedSrc.getType() ==
489 shardShapedType(unshardedSrc.getType(), grid, srcSharding));
490 [[maybe_unused]] ShapedType tgtShardType =
491 shardShapedType(unshardedSrc.getType(), grid, tgtSharding);
492 assert(shardedSrc.getType().getRank() == tgtShardType.getRank());
493 assert(unshardedSrc.getType().getRank() == tgtShardType.getRank());
494
495 // Each pattern's tryApply checks its own applicability preconditions.
496 static UpdateHaloPattern updateHaloPattern;
497 static MoveLastSplitAxisPattern moveLastSplitAxisPattern;
498 static SplitLastAxisPattern splitLastAxisPattern;
499 static UnsplitLastAxesPattern unsplitLastAxesPattern;
500 static ReshardingPattern *patterns[] = {
501 &updateHaloPattern, &moveLastSplitAxisPattern, &splitLastAxisPattern,
502 &unsplitLastAxesPattern};
503 TypedValue<ShapedType> currentShard = shardedSrc;
504 Sharding currentSharding = srcSharding;
505 for (int64_t dim = 0;
506 dim < tgtShardType.getRank() && currentSharding != tgtSharding; ++dim) {
507 for (auto &pattern : patterns) {
508 if (auto tryRes = pattern->tryApply(builder, grid, dim, currentSharding,
509 tgtSharding, unshardedSrc.getType(),
510 currentShard)) {
511 std::tie(currentShard, currentSharding) = tryRes.value();
512 break;
513 }
514 }
515 }
516
517 if (currentSharding != tgtSharding ||
518 currentShard.getType() != tgtShardType) {
519 builder.emitError()
520 << "Failed to reshard; probably hitting an unknown resharding pattern:"
521 << " got " << currentSharding << " expected " << tgtSharding
522 << " got type " << currentShard.getType() << " expected "
523 << tgtShardType;
524 return TypedValue<ShapedType>();
525 }
526 return currentShard;
527}
528
530 ShardOp srcShardOp, ShardOp tgtShardOp,
531 TypedValue<ShapedType> shardedSrc) {
532 assert(srcShardOp.getResult() == tgtShardOp.getSrc());
533 auto srcSharding = srcShardOp.getSharding();
534 auto tgtSharding = tgtShardOp.getSharding();
535 ImplicitLocOpBuilder implicitLocOpBuilder(tgtShardOp->getLoc(), builder);
536 return reshard(implicitLocOpBuilder, grid, srcSharding, tgtSharding,
537 srcShardOp.getSrc(), shardedSrc);
538}
539
540TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp srcShardOp,
541 ShardOp tgtShardOp,
542 TypedValue<ShapedType> shardedSrc,
543 SymbolTableCollection &symbolTableCollection) {
544 GridOp srcGrid = getGrid(srcShardOp, symbolTableCollection);
545 assert(srcGrid && srcGrid == getGrid(tgtShardOp, symbolTableCollection));
546 return reshard(builder, srcGrid, srcShardOp, tgtShardOp, shardedSrc);
547}
548
550 registry.insert<shard::ShardDialect, tensor::TensorDialect>();
551}
552
553#define GEN_PASS_DEF_PARTITION
554#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
555
557
558// Get the types of block arguments for an partitioned block.
559// Reads the sharding annotations of the arguments to deduce the sharded types.
560// Types that are not ranked tensors are left unchanged.
563 SymbolTableCollection &symbolTableCollection) {
565 llvm::transform(
566 block.getArguments(), std::back_inserter(res),
567 [&symbolTableCollection](BlockArgument arg) {
568 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
569 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
570 rankedTensorArg.use_empty()) {
571 return arg.getType();
572 }
573
574 assert(rankedTensorArg.hasOneUse());
575 Operation *useOp = *rankedTensorArg.getUsers().begin();
576 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
577 assert(shardOp);
578 GridOp grid = getGrid(shardOp, symbolTableCollection);
579 return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
580 shardOp.getSharding()));
581 });
582 return res;
583}
584
585static LogicalResult
587 ArrayRef<Sharding> operandShardings,
588 ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
589 SymbolTableCollection &symbolTableCollection,
590 OpBuilder &builder) {
591 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
592 if (!shardingInterface) {
593 // If there is no sharding interface we are conservative and assume that
594 // the op should be fully replicated no all devices.
595 partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
596 resultShardings, partitionMap,
597 symbolTableCollection, builder);
598 } else {
599 if (failed(shardingInterface.partition(
600 partitionedOperands, operandShardings, resultShardings,
601 partitionMap, symbolTableCollection, builder))) {
602 return failure();
603 }
604 }
605
606 assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
607 return partitionMap.contains(result);
608 }));
609
610 return success();
611}
612
613// Retrieve the sharding annotations for the operands of the given operation.
614// If the type is not a ranked tensor it is not require to have an annotation.
615static std::vector<Sharding> getOperandShardings(Operation &op) {
616 std::vector<Sharding> res;
617 res.reserve(op.getNumOperands());
618 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
619 TypedValue<RankedTensorType> rankedTensor =
620 dyn_cast<TypedValue<RankedTensorType>>(operand);
621 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
622 return Sharding();
623 }
624
625 Operation *definingOp = operand.getDefiningOp();
626 assert(definingOp);
627 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
628 return Sharding(shardOp.getSharding());
629 });
630 return res;
631}
632
633// Retrieve the sharding annotations for the results of the given operation.
634// If the type is not a ranked tensor it is not require to have an annotation.
635static std::vector<Sharding> getResultShardings(Operation &op) {
636 std::vector<Sharding> res;
637 res.reserve(op.getNumResults());
638 llvm::transform(
639 op.getResults(), std::back_inserter(res), [&op](OpResult result) {
640 if (!result.hasOneUse() || result.use_empty()) {
641 return Sharding();
642 }
643 TypedValue<RankedTensorType> rankedTensor =
645 if (!rankedTensor) {
646 return Sharding();
647 }
648 Operation *userOp = *result.getUsers().begin();
649 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
650 if (shardOp) {
651 return Sharding(shardOp.getSharding());
652 }
653 if (rankedTensor.getType().getRank() == 0) {
654 // This is a 0d tensor result without explicit sharding.
655 // Find grid symbol from operands, if any.
656 // Shardings without grid are not always fully supported yet.
657 for (auto operand : op.getOperands()) {
658 if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
659 return Sharding(sharding.getGridAttr());
660 }
661 }
662 }
663 return Sharding();
664 });
665 return res;
666}
667
668static LogicalResult
669partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
670 SymbolTableCollection &symbolTableCollection,
671 OpBuilder &builder) {
672 Value tgtPartitionValue;
673
674 // Check if 2 shard ops are chained. If not there is no need for resharding
675 // as the source and target shared the same sharding.
676 ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
677 if (!srcShardOp) {
678 tgtPartitionValue = partitionMap.lookup(shardOp.getSrc());
679 } else {
680 // Insert resharding.
681 TypedValue<ShapedType> shardedSrc =
682 cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
683 tgtPartitionValue = reshard(builder, srcShardOp, shardOp, shardedSrc,
684 symbolTableCollection);
685 if (!tgtPartitionValue) {
686 return shardOp.emitError()
687 << "Failed to reshard from " << srcShardOp.getSharding() << " to "
688 << shardOp.getSharding();
689 }
690 }
691
692 assert(!partitionMap.contains(shardOp.getResult()));
693 partitionMap.map(shardOp.getResult(), tgtPartitionValue);
694 return success();
695}
696
697// Check if the block args are correctly annotated with sharding information:
698// - non-tensor, 0d-tensor and unused args are ignored
699// - each tensor arg must have exactly one use, which must be a shard.shard
700// operation
701static LogicalResult checkFullyAnnotated(Block &block) {
702 for (const BlockArgument &arg : block.getArguments()) {
703 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
704 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
705 rankedTensorArg.use_empty())
706 continue;
707
708 if (!rankedTensorArg.hasOneUse())
709 return emitError(block.getParent()->getLoc())
710 << "Cannot partition: expected a single use for block argument "
711 << arg.getArgNumber() << " in block "
712 << block.computeBlockNumber();
713
714 Operation *useOp = *rankedTensorArg.getUsers().begin();
715 auto shardOp = dyn_cast<ShardOp>(useOp);
716 if (!shardOp)
717 return emitError(block.getParent()->getLoc())
718 << "Cannot partition: expected a shard.shard op for block "
719 << "argument " << arg.getArgNumber() << " in block "
720 << block.computeBlockNumber();
721 }
722 return success();
723}
724
725// Check if the operation is correctly and fully annotated with sharding
726// information:
727// - Operation results must have exactly one use (e.g. the shard operation).
728// - All operands and all results must be annotated, e.g. they must be
729// produced by/consumed by a shard.shard operation.
730// - Result annotations must not include the 'annotate_for_users' attribute.
731// - Operand annotations must include the 'annotate_for_users' attribute.
732// raises an error if the operation is not correctly and fully annotated.
733static LogicalResult checkFullyAnnotated(Operation *op) {
734 // constant ops do not need to have sharding annotations
736 return success();
737
738 for (OpOperand &operand : op->getOpOperands()) {
739 // non-tensor and 0d-tensor operands are ignored
740 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
741 if (!rankedTT || rankedTT.getRank() == 0)
742 continue;
743
744 auto shard = operand.get().getDefiningOp<ShardOp>();
745 if (!shard)
746 return op->emitError() << "Cannot partition: tensor operand "
747 << operand.getOperandNumber()
748 << " must be defined by a shard.shard operation.";
749 if (!shard.getAnnotateForUsers())
750 return op->emitError()
751 << "Cannot partition: shard.shard for operand "
752 << operand.getOperandNumber() << " must set 'annotate_for_users'.";
753 }
754 for (const OpResult &result : op->getResults()) {
755 if (!result.hasOneUse())
756 return op->emitError()
757 << "Cannot partition: result " << result.getResultNumber()
758 << " must have exactly one use.";
759 auto shard = dyn_cast<ShardOp>(*result.user_begin());
760 if (!shard)
761 return op->emitError()
762 << "Cannot partition: user of result " << result.getResultNumber()
763 << " must be shard.shard operation.";
764 if (shard.getAnnotateForUsers())
765 return op->emitError() << "Cannot partition: shard.shard for result "
766 << result.getResultNumber()
767 << " must not set 'annotate_for_users'.";
768 }
769 return success();
770}
771
772static LogicalResult
774 SymbolTableCollection &symbolTableCollection,
775 OpBuilder &builder) {
776 if (isa<ShardingOp>(op)) {
777 return success();
778 }
779
780 if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
781 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
782 if (!shardOp) {
783 return op.emitError("expected a shard op as source of get_sharding");
784 }
785 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
786 partitionMap.map(op.getResult(0), newSharding->getResult(0));
787 return success();
788 }
789
790 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
791 if (shardOp) {
792 return partitionOperation(shardOp, partitionMap, symbolTableCollection,
793 builder);
794 }
795
796 // Check if operation is correctly and fully annotated.
797 if (failed(checkFullyAnnotated(&op)))
798 return failure();
799
800 SmallVector<Value> partitionedOperands;
801 llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
802 [&partitionMap](Value operand) {
803 assert(partitionMap.contains(operand));
804 return partitionMap.lookup(operand);
805 });
806 return partitionOperation(op, partitionedOperands, getOperandShardings(op),
807 getResultShardings(op), partitionMap,
808 symbolTableCollection, builder);
809}
810
811static LogicalResult
812partitionBlock(Block &block, IRMapping &partitionMap,
813 SymbolTableCollection &symbolTableCollection,
814 OpBuilder &builder) {
815
816 if (failed(checkFullyAnnotated(block)))
817 return failure();
818
819 SmallVector<Location> argLocations;
820 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
821 [](BlockArgument arg) { return arg.getLoc(); });
822 Block *newBlock = builder.createBlock(
823 block.getParent(), {},
824 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
825 for (auto [unshardedBlockArg, partitionedBlockArg] :
826 llvm::zip(block.getArguments(), newBlock->getArguments())) {
827 partitionMap.map(unshardedBlockArg, partitionedBlockArg);
828 }
829
830 OpBuilder::InsertionGuard insertionGuard(builder);
831 builder.setInsertionPointToEnd(newBlock);
832 for (Operation &op : block.getOperations()) {
833 if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
834 builder))) {
835 return failure();
836 }
837 }
838
839 return success();
840}
841
842static LogicalResult
843partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
844 SymbolTableCollection &symbolTableCollection) {
845 OpBuilder builder(op.getFunctionBody());
846
847 // Snapshot the original blocks to not mess up the iteration when adding new
848 // blocks.
849 SmallVector<Block *> originalBlocks;
850 for (Block &b : op.getBlocks()) {
851 if (llvm::any_of(b.getOperations(),
852 [](Operation &op) { return isa<ShardOp>(op); })) {
853 originalBlocks.push_back(&b);
854 }
855 }
856
857 for (Block *block : originalBlocks) {
858 if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
859 builder))) {
860 return failure();
861 }
862 }
863
864 for (Block *block : originalBlocks) {
865 block->erase();
866 }
867
868 // Find a return op and change the function results signature to its operands
869 // signature.
870 Operation *returnOp = nullptr;
871 for (Block &block : op.getFunctionBody()) {
872 if (block.empty()) {
873 continue;
874 }
875
876 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
877 returnOp = &block.back();
878 break;
879 }
880 }
881 if (returnOp) {
882 op.setType(FunctionType::get(
883 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
884 returnOp->getOperandTypes()));
885 }
886
887 return success();
888}
889
890namespace {
891
892struct Partition : public impl::PartitionBase<Partition> {
893 void runOnOperation() override {
894 IRMapping partitionMap;
895 SymbolTableCollection symbolTableCollection;
896 if (failed(partitionFuncOp(getOperation(), partitionMap,
897 symbolTableCollection))) {
898 return signalPassFailure();
899 }
900 }
901
902 void getDependentDialects(DialectRegistry &registry) const override {
904 registry.insert<shard::ShardDialect>();
905 }
906};
907
908} // namespace
909
910} // namespace mlir::shard
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
BlockArgListType getArguments()
Definition Block.h:97
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
Definition Block.cpp:144
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
mlir::InFlightDiagnostic emitError(const llvm::Twine &message=llvm::Twine())
This builder can also be used to emit diagnostics to the current location.
Definition Builders.h:703
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
unsigned getNumOperands()
Definition Operation.h:372
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Move the last split axis of one tensor dimension to the front of another tensor dimension's split axe...
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Base class for resharding patterns.
Definition Partition.cpp:51
virtual std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard)=0
Try to apply this resharding pattern.
static bool hasStaticOffsetsOrHalos(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets or non-empty static halo si...
Definition Partition.cpp:73
static bool hasStaticOffsets(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets.
Definition Partition.cpp:65
virtual ~ReshardingPattern()=default
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:797
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:693
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:64
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:61
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:68
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:65
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:63
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:732
Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]].
Definition Partition.cpp:82
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> [].
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Update halo sizes: handles cases where only the halo sizes differ between source and target sharding.
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount, int64_t srcTensorDim, int64_t tgtTensorDim)
static LogicalResult checkFullyAnnotated(Block &block)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition Partition.cpp:42
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:116
int16_t GridAxis
Definition ShardOps.h:27
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Sharding > getOperandShardings(Operation &op)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:178
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:131
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:187
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
This trait indicates that a terminator operation is "return-like".