MLIR 22.0.0git
ShardOps.cpp
Go to the documentation of this file.
1//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
24#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include <algorithm>
34#include <functional>
35#include <iterator>
36#include <numeric>
37#include <optional>
38#include <utility>
39
40#define DEBUG_TYPE "shard-ops"
41
42using namespace mlir;
43using namespace mlir::shard;
44
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
46
47namespace {
48
49struct DimensionSize {
50 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(int64_t val) : val(val) {}
52 int64_t value() const { return val; }
53 operator int64_t() const { return val; }
54 bool isDynamic() const { return ShapedType::isDynamic(val); }
55
56private:
57 int64_t val;
58};
59
60} // namespace
61
62static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
63 if (lhs.isDynamic() || rhs.isDynamic()) {
64 return DimensionSize::dynamic();
65 }
66 return lhs.value() / rhs.value();
67}
68
69static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
70 if (lhs.isDynamic() || rhs.isDynamic()) {
71 return DimensionSize::dynamic();
72 }
73 return lhs.value() * rhs.value();
74}
75
79 ValueRange dynamics, Type type) {
80 SmallVector<Value> values;
81 auto dyn = dynamics.begin();
82 Type i64 = b.getI64Type();
83 if (!type)
84 type = i64;
85 assert((i64 == type || b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
90 } else {
91 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
93 }
94 }
95 return values;
96}
97
98//===----------------------------------------------------------------------===//
99// Inliner
100//===----------------------------------------------------------------------===//
101
102namespace {
103struct ShardInlinerinterface : public DialectInlinerInterface {
105 // Currently no restrictions are encoded for inlining.
106 bool isLegalToInline(Operation *, Operation *, bool) const final {
107 return true;
108 }
109 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
110 return true;
111 }
112 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
113 return true;
114 }
115};
116} // namespace
117
118//===----------------------------------------------------------------------===//
119// Shard dialect
120//===----------------------------------------------------------------------===//
121
122void ShardDialect::initialize() {
123 addOperations<
124#define GET_OP_LIST
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
126 >();
127 addAttributes<
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
130 >();
131 addTypes<
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
134 >();
135 addInterface<ShardInlinerinterface>();
136}
137
138Operation *ShardDialect::materializeConstant(OpBuilder &builder,
139 Attribute value, Type type,
140 Location loc) {
141 return arith::ConstantOp::materialize(builder, value, type, loc);
142}
143
144//===----------------------------------------------------------------------===//
145// Shard utilities
146//===----------------------------------------------------------------------===//
147
148static FailureOr<GridOp> getGridAndVerify(Operation *op,
149 FlatSymbolRefAttr gridSymbol,
150 SymbolTableCollection &symbolTable) {
151 shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
152 if (!grid) {
153 return op->emitError() << "Undefined required grid symbol \""
154 << gridSymbol.getValue() << "\".";
155 }
156
157 return grid;
158}
159
160template <typename It>
161static bool isUnique(It begin, It end) {
162 if (begin == end) {
163 return true;
164 }
165 It next = std::next(begin);
166 if (next == end) {
167 return true;
168 }
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
171 return false;
172 }
173 }
174 return true;
175}
176
177static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
178 GridOp grid) {
179 SmallVector<GridAxis> sorted = llvm::to_vector(axes);
180 llvm::sort(sorted);
181 if (!isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) << "Grid axes contains duplicate elements.";
183 }
184
185 GridAxis rank = grid.getRank();
186 for (auto axis : axes) {
187 if (axis >= rank || axis < 0) {
188 return emitError(loc)
189 << "0-based grid axis index " << axis
190 << " is out of bounds. The referenced grid \"" << grid.getSymName()
191 << "\" is of rank " << rank << ".";
192 }
193 }
194
195 return success();
196}
197
198template <typename Op>
199static FailureOr<GridOp>
201 auto grid =
202 ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
203 if (failed(grid)) {
204 return failure();
205 }
206 if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
207 return failure();
208 }
209 return grid;
210}
211
212template <typename InShape, typename GridShape, typename SplitAxes,
213 typename OutShape>
214static void shardShape(const InShape &inShape, const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
216 ArrayRef<int64_t> shardedDimsOffsets = {},
217 ArrayRef<int64_t> haloSizes = {}) {
218 // 0d tensors cannot be sharded and must get replicated
219 if (inShape.empty()) {
220 assert(outShape.empty());
221 return;
222 }
223
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
226
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
229 uint64_t pos = 1;
230 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
234 if (same) {
235 // Find sharded dims in shardedDimsOffsets with same static size on
236 // all devices. Use kDynamic for dimensions with dynamic or
237 // non-uniform offs in shardedDimsOffsets.
238 uint64_t numShards = 0;
239 for (auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
241 }
242 for (size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
244 sz) {
245 same = false;
246 break;
247 }
248 }
249 pos += numShards + 1;
250 }
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
252 }
253 }
254 } else {
255 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
256 outShape[tensorAxis] = shardDimension(
257 inShape[tensorAxis],
258 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
259 }
260
261 if (!haloSizes.empty()) {
262 // add halo sizes if requested
263 int haloAxis = 0;
264 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
271 ++haloAxis;
272 } else {
273 outShape[tensorAxis] = ShapedType::kDynamic;
274 }
275 }
276 }
277 }
278 }
279}
280
281ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
282 Sharding sharding) {
283 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
284 SmallVector<Dim> resShapeArr(shape.getShape().size());
285 shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
286 resShapeArr, sharding.getStaticShardedDimsOffsets(),
287 sharding.getStaticHaloSizes());
288 return shape.clone(resShapeArr);
289}
290
291Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
294 return shardShapedType(rankedTensorType, grid, sharding);
295 }
296 return type;
297}
298
300 Value &operandValue,
301 Operation *operandOp,
302 OpBuilder &builder,
303 ShardOp &newShardOp) {
304 OpBuilder::InsertionGuard insertionGuard(builder);
305 builder.setInsertionPointAfterValue(operandValue);
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
309 // No need for anything if the correct sharding is already set.
310 if (!newShardOp) {
311 newShardOp = shardOp;
312 }
313 return;
314 }
315
316 if (!newShardOp) {
317 auto shardingOp =
318 ShardingOp::create(builder, operandValue.getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue,
320 shardingOp,
321 /*annotate_for_users*/ false);
322 }
323 operandValue.replaceUsesWithIf(
324 newShardOp, [operandOp, operandValue](OpOperand &use) {
325 return use.getOwner() == operandOp && use.get() == operandValue;
326 });
327
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
329 return;
330 }
331
332 auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp,
333 newShardOp.getSharding(),
334 /*annotate_for_users*/ true);
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
336}
337
340 OpBuilder &builder) {
341 ShardOp newShardOp;
343 for (auto &use : result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
345 }
346 for (auto &[operandValue, operandOp] : uses) {
347 maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
348 builder, newShardOp);
349 }
350}
351
353 OpOperand &operand,
354 OpBuilder &builder) {
355 OpBuilder::InsertionGuard insertionGuard(builder);
356 Value operandValue = operand.get();
357 Operation *operandSrcOp = operandValue.getDefiningOp();
358 bool isBlockArg = !operandSrcOp;
359 {
360 [[maybe_unused]] auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.getType());
362 assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
363 }
364 if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
365 operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
366 return;
367 }
368
369 Operation *operandOp = operand.getOwner();
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
371
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
374 // No need for anything the correct sharding is already set.
375 return;
376 }
377
378 builder.setInsertionPoint(operandOp);
379 auto shardingOp =
380 ShardingOp::create(builder, operand.get().getLoc(), sharding);
381 auto newShardOp =
382 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
383 /*annotate_for_users*/ true);
384 IRRewriter rewriter(builder);
385 rewriter.replaceUsesWithIf(
386 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
387 return use.getOwner() == operandOp && use.get() == operandValue;
388 });
389
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
391 // No need for resharding.
392 return;
393 }
394
395 builder.setInsertionPoint(newShardOp);
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
398 /*annotate_for_users*/ false);
399 rewriter.replaceUsesWithIf(
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
402 });
403}
404
405//===----------------------------------------------------------------------===//
406// shard.grid op
407//===----------------------------------------------------------------------===//
408
409LogicalResult GridOp::verify() {
410 int64_t rank = getRank();
411
412 if (rank <= 0)
413 return emitOpError("rank of grid is expected to be a positive integer");
414
415 for (int64_t dimSize : getShape()) {
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError("dimension size of a grid is expected to be "
418 "non-negative or dynamic");
419 }
420
421 return success();
422}
423
424//===----------------------------------------------------------------------===//
425// shard.grid_shape op
426//===----------------------------------------------------------------------===//
427
428LogicalResult
429GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
430 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
431 if (failed(grid)) {
432 return failure();
433 }
434 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
435 return failure();
436 }
437
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() << "Unexpected number of results " << getResult().size()
442 << ". Expected " << expectedResultsCount << ".";
443 }
444
445 return success();
446}
447
448void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
449 GridOp grid) {
450 build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
451}
452
453void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
454 GridOp grid, ArrayRef<GridAxis> axes) {
455 build(odsBuilder, odsState,
456 SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
457 odsBuilder.getIndexType()),
458 grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
459}
460
461void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
462 StringRef grid, ArrayRef<GridAxis> axes) {
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
465 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
466 GridAxesAttr::get(odsBuilder.getContext(), axes));
467}
468
469void GridShapeOp::getAsmResultNames(
470 function_ref<void(Value, StringRef)> setNameFn) {
471 setNameFn(getResults()[0], "grid_shape");
472}
473
474//===----------------------------------------------------------------------===//
475// shard.sharding
476//===----------------------------------------------------------------------===//
477
478void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480 ArrayRef<GridAxesAttr> split_axes,
481 ArrayRef<int64_t> static_halos,
482 ArrayRef<int64_t> static_offsets) {
483 return build(
484 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
485 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
486 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
487}
488
489void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
490 llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
491 ArrayRef<int64_t> static_halos,
492 ArrayRef<int64_t> static_offsets) {
493 return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
494 GridAxesArrayAttr::get(b.getContext(), split_axes),
495 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
496 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
497 {});
498}
499
500void ShardingOp::build(
504 ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
505 mlir::SmallVector<int64_t> staticHalos, staticDims;
506 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
507 dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
508 dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
509 return build(
510 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
511 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
512 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
513}
514
515void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
517
518 build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
519 GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
520 from.getStaticShardedDimsOffsets().empty()
522 : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
524 from.getStaticHaloSizes().empty()
526 : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
527 from.getDynamicHaloSizes());
528}
529
530LogicalResult ShardingOp::verify() {
531 llvm::SmallSet<GridAxis, 4> visitedAxes;
532
533 auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
534 for (GridAxis axis : axesArray) {
535 if (axis < 0)
536 return emitError() << "grid axis is expected to be non-negative";
537 if (!visitedAxes.insert(axis).second)
538 return emitError() << "grid axis duplicated";
539 }
540 return success();
541 };
542
543 for (auto subAxes : getSplitAxes().getAxes()) {
544 ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
545 if (failed(checkGridAxis(subAxesArray)))
546 return failure();
547 }
548
549 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550 return emitOpError("halo sizes and shard offsets are mutually exclusive");
551 }
552
553 if (!getStaticHaloSizes().empty()) {
554 auto numSplitAxes = getSplitAxes().getAxes().size();
555 for (auto splitAxis : getSplitAxes().getAxes()) {
556 if (splitAxis.empty()) {
557 --numSplitAxes;
558 }
559 }
560 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561 return emitError() << "halo sizes must be specified for all split axes.";
562 }
563 }
564
565 return success();
566}
567
568void ShardingOp::getAsmResultNames(
569 function_ref<void(Value, StringRef)> setNameFn) {
570 setNameFn(getResult(), "sharding");
571}
572
573LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
574 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
575 if (failed(grid)) {
576 return failure();
577 }
578 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
579 !getStaticShardedDimsOffsets().empty()) {
580 return emitError() << "sharded dims offsets are not allowed for "
581 "device grids with dynamic shape.";
582 }
583
584 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585 if (!shardedDimsOffsets.empty()) {
586 auto gridShape = grid.value().getShape();
587 assert(ShapedType::isStaticShape(gridShape));
588 uint64_t pos = 0;
589 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
590 if (!innerSplitAxes.empty()) {
591 int64_t numShards = 0, off = 0;
592 for (auto i : innerSplitAxes.asArrayRef()) {
593 numShards += gridShape[i];
594 }
595 for (int64_t i = 0; i <= numShards; ++i) {
596 if (shardedDimsOffsets.size() <= pos + i) {
597 return emitError() << "sharded dims offsets has wrong size.";
598 }
599 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
600 if (shardedDimsOffsets[pos + i] < off) {
601 return emitError()
602 << "sharded dims offsets must be non-decreasing.";
603 }
604 off = shardedDimsOffsets[pos + i];
605 }
606 }
607 pos += numShards + 1;
608 }
609 }
610 }
611 return success();
612}
613
614namespace {
615// Sharding annotations "halo sizes" and "sharded dims offsets"
616// are a mix of attributes and dynamic values. This canonicalization moves
617// constant values to the respective attribute lists, minimizing the number
618// of values.
619// It also removes sharded_dims_sizes and halos if they are effectively "empty".
620class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
621public:
622 using OpRewritePattern<ShardingOp>::OpRewritePattern;
623
624 LogicalResult matchAndRewrite(ShardingOp op,
625 PatternRewriter &b) const override {
626 auto mixedHalos =
627 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
628 auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
629 op.getDynamicShardedDimsOffsets(), b);
630
631 // No constant operands were folded, just return;
632 bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
633 succeeded(foldDynamicIndexList(mixedOffs, true));
634
635 auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
636 auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
637
638 if (dynamicHalos.empty() && !staticHalos.empty()) {
639 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
640 staticHalos.clear();
641 modified = true;
642 }
643 }
644
645 // Remove sharded dims offsets if they are effectively the default values,
646 // e.g. if they define equi-distance between all neighboring shards.
647 // Requires static-only offsets. Compares the first distance as the
648 // difference between the first two offsets. Only if all consecutive
649 // distances are the same, the offsets are removed.
650 if (dynamicOffs.empty() && !staticOffs.empty()) {
651 assert(staticOffs.size() >= 2);
652 auto diff = staticOffs[1] - staticOffs[0];
653 bool all_same = staticOffs.size() > 2;
654 for (auto i = 2u; i < staticOffs.size(); ++i) {
655 if (staticOffs[i] - staticOffs[i - 1] != diff) {
656 all_same = false;
657 break;
658 }
659 }
660 if (all_same) {
661 staticOffs.clear();
662 modified = true;
663 }
664 }
665
666 if (!modified) {
667 return failure();
668 }
669
670 op.setStaticHaloSizes(staticHalos);
671 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672 op.setStaticShardedDimsOffsets(staticOffs);
673 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
674
675 return success();
676 }
677};
678} // namespace
679
680void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
681 mlir::MLIRContext *context) {
682 results.add<NormalizeSharding>(context);
683}
684
685//===----------------------------------------------------------------------===//
686// Sharding
687//===----------------------------------------------------------------------===//
688
690 if (getGrid() != rhs.getGrid()) {
691 return false;
692 }
693
694 auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
695 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
696 getSplitAxes().begin() + minSize),
697 llvm::make_range(rhs.getSplitAxes().begin(),
698 rhs.getSplitAxes().begin() + minSize))) {
699 return false;
700 }
701
702 return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
703 std::mem_fn(&GridAxesAttr::empty)) &&
704 llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
705 std::mem_fn(&GridAxesAttr::empty));
706}
707
711
713 if (rhs.getStaticShardedDimsOffsets().size() !=
715 !llvm::equal(getStaticShardedDimsOffsets(),
716 rhs.getStaticShardedDimsOffsets())) {
717 return false;
718 }
719 if (rhs.getDynamicShardedDimsOffsets().size() !=
721 !llvm::equal(getDynamicShardedDimsOffsets(),
722 rhs.getDynamicShardedDimsOffsets())) {
723 return false;
724 }
725 return true;
726}
727
729 if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
730 !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
731 return false;
732 }
733 if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
734 !llvm::equal(getDynamicHaloSizes(), rhs.getDynamicHaloSizes())) {
735 return false;
736 }
737 return true;
738}
739
743
744bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
745
749
750bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
751
753
755 auto shardingOp = rhs.getDefiningOp<ShardingOp>();
756 assert(shardingOp && "expected sharding op");
757 auto splitAxes = shardingOp.getSplitAxes().getAxes();
758 // If splitAxes are empty, use "empty" constructor.
759 if (splitAxes.empty()) {
760 *this = Sharding(shardingOp.getGridAttr());
761 return;
762 }
763 *this =
764 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
765 shardingOp.getStaticShardedDimsOffsets(),
766 SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
767 SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
768}
769
771 ArrayRef<GridAxesAttr> split_axes_,
772 ArrayRef<int64_t> static_halo_sizes_,
773 ArrayRef<int64_t> static_sharded_dims_offsets_,
774 ArrayRef<Value> dynamic_halo_sizes_,
775 ArrayRef<Value> dynamic_sharded_dims_offsets_) {
776 Sharding res(grid_);
777 if (split_axes_.empty()) {
778 return res;
779 }
780
781 res.split_axes.resize(split_axes_.size());
782 for (auto [i, axis] : llvm::enumerate(split_axes_)) {
783 res.split_axes[i] =
784 GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
785 }
786
787 auto clone = [](const auto src, auto &dst) {
788 dst.resize(src.size());
789 llvm::copy(src, dst.begin());
790 };
791
792 clone(static_halo_sizes_, res.static_halo_sizes);
793 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
794 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
795 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
796
797 return res;
798}
799
800//===----------------------------------------------------------------------===//
801// shard.shard_shape
802//===----------------------------------------------------------------------===//
803
804void ShardShapeOp::getAsmResultNames(
805 function_ref<void(Value, StringRef)> setNameFn) {
806 setNameFn(getResult()[0], "shard_shape");
807}
808
809void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
810 ::mlir::OperationState &odsState,
812 ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
813 ::mlir::ValueRange device) {
814 SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
815 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
816 SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
817}
818
819//===----------------------------------------------------------------------===//
820// shard.shard op
821//===----------------------------------------------------------------------===//
822
823void ShardOp::getAsmResultNames(
824 function_ref<void(Value, StringRef)> setNameFn) {
825 setNameFn(getResult(), "sharding_annotated");
826}
827
828namespace {
829// Determine if the given ShardOp is a duplicate of another ShardOp
830// on the same value. This can happen if constant values are sharded.
831class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
832public:
833 using OpRewritePattern<ShardOp>::OpRewritePattern;
834
835 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
836 // Get the use-list of the value being sharded and check if it has more than
837 // one use.
838 Value value = op.getSrc();
839 if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
840 return failure();
841 }
842
843 // Iterate through the uses of the value to find a duplicate ShardOp.
844 for (auto &use : value.getUses()) {
845 if (use.getOwner() != op.getOperation()) {
846 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
847 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
848 return failure();
849 }
850 // Create a Sharding object for the current and the other ShardOp
851 // If the two are equal replace current op with the other op.
852 Sharding currentSharding(op.getSharding());
853 Sharding otherSharding(otherOp.getSharding());
854 if (currentSharding == otherSharding) {
855 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
856 b.eraseOp(op.getOperation());
857 } else {
858 // use the other sharding as input for op
859 op.getSrcMutable().assign(otherOp.getResult());
860 }
861 return success();
862 }
863 }
864
865 return failure();
866 }
867};
868} // namespace
869
870void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
871 mlir::MLIRContext *context) {
872 results.add<FoldDuplicateShardOp>(context);
873}
874
875//===----------------------------------------------------------------------===//
876// shard.process_multi_index op
877//===----------------------------------------------------------------------===//
878
879LogicalResult
880ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
881 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
882 if (failed(grid)) {
883 return failure();
884 }
885 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
886 return failure();
887 }
888
889 size_t expectedResultsCount =
890 getAxes().empty() ? grid->getRank() : getAxes().size();
891 if (getResult().size() != expectedResultsCount) {
892 return emitError() << "Unexpected number of results " << getResult().size()
893 << ". Expected " << expectedResultsCount << ".";
894 }
895
896 return success();
897}
898
899void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
900 GridOp grid) {
901 build(odsBuilder, odsState,
902 SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
903 grid.getSymName(), ArrayRef<GridAxis>());
904}
905
906void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
907 StringRef grid, ArrayRef<GridAxis> axes) {
908 build(odsBuilder, odsState,
909 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
910 GridAxesAttr::get(odsBuilder.getContext(), axes));
911}
912
913void ProcessMultiIndexOp::getAsmResultNames(
914 function_ref<void(Value, StringRef)> setNameFn) {
915 setNameFn(getResults()[0], "proc_linear_idx");
916}
917
918//===----------------------------------------------------------------------===//
919// shard.process_linear_index op
920//===----------------------------------------------------------------------===//
921
922LogicalResult
923ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
924 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
925 if (failed(grid)) {
926 return failure();
927 }
928 return success();
929}
930
931void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
932 OperationState &odsState, GridOp grid) {
933 build(odsBuilder, odsState, grid.getSymName());
934}
935
936void ProcessLinearIndexOp::getAsmResultNames(
937 function_ref<void(Value, StringRef)> setNameFn) {
938 setNameFn(getResult(), "proc_linear_idx");
939}
940
941//===----------------------------------------------------------------------===//
942// shard.neighbors_linear_indices op
943//===----------------------------------------------------------------------===//
944
945LogicalResult
946NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
947 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
948 if (failed(grid)) {
949 return failure();
950 }
951 return success();
952}
953
954void NeighborsLinearIndicesOp::getAsmResultNames(
955 function_ref<void(Value, StringRef)> setNameFn) {
956 setNameFn(getNeighborDown(), "down_linear_idx");
957 setNameFn(getNeighborUp(), "up_linear_idx");
958}
959
960//===----------------------------------------------------------------------===//
961// collective communication ops
962//===----------------------------------------------------------------------===//
963
964namespace {
965
966template <typename Op>
967struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
968 using OpRewritePattern<Op>::OpRewritePattern;
969 LogicalResult matchAndRewrite(Op op,
970 PatternRewriter &rewriter) const override {
971 auto gridAxes = op.getGridAxes();
972 if (!gridAxes.empty()) {
973 return failure();
974 }
975 if (op.getInput().getType() != op.getResult().getType()) {
976 return failure();
977 }
978
979 rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
980 rewriter.eraseOp(op.getOperation());
981 return success();
982 }
983};
984
985} // namespace
986
987static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
988 ArrayRef<int64_t> device,
989 Operation::operand_range deviceDynamic,
990 ArrayRef<GridAxis> gridAxes,
991 ArrayRef<int64_t> gridShape) {
992 if (device.size() != gridAxes.size()) {
993 return emitError(loc) << "In-group device \"" << deviceName
994 << "\" has unexpected multi-index size "
995 << device.size() << ". Expected " << gridAxes.size()
996 << ".";
997 }
998
999 for (size_t i = 0; i < device.size(); ++i) {
1000 if (ShapedType::isStatic(device[i]) &&
1001 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1002 gridShape[gridAxes[i]] <= device[i]) {
1003 return emitError(loc)
1004 << "Out of bounds coordinate " << i << " for in-group device \""
1005 << deviceName << "\"."
1006 << " Got " << device[i] << ", but expected value in the range [0, "
1007 << (gridShape[gridAxes[i]] - 1) << "].";
1008 }
1009 }
1010 return success();
1011}
1012
1014 int64_t expectedDimSize,
1015 int64_t resultDimSize,
1016 int64_t resultAxis) {
1017 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1018 return emitError(loc) << "Dimension size mismatch for result axis "
1019 << resultAxis << ". Expected "
1020 << (ShapedType::isDynamic(expectedDimSize)
1021 ? Twine("dynamic")
1022 : Twine(expectedDimSize))
1023 << ", but got " << resultDimSize << ".";
1024 }
1025
1026 return success();
1027}
1028
1030 Value operand, Value result, int64_t gatherAxis,
1031 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1032 auto resultRank = cast<ShapedType>(result.getType()).getRank();
1033 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1034 return emitError(result.getLoc())
1035 << "Gather axis " << gatherAxis << " is out of bounds [0, "
1036 << resultRank << ").";
1037 }
1038
1039 ShapedType operandType = cast<ShapedType>(operand.getType());
1040 ShapedType resultType = cast<ShapedType>(result.getType());
1041 auto deviceGroupSize =
1042 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1043 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1044 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1045 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1046 auto expectedResultDimSize =
1047 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1049 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1050 return failure();
1051 }
1052 }
1053 return success();
1054}
1055
1057 Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1058 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1059 ShapedType operandType = cast<ShapedType>(operand.getType());
1060 ShapedType resultType = cast<ShapedType>(result.getType());
1061 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1062 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1064 result.getLoc(), operandType.getDimSize(axis),
1065 resultType.getDimSize(axis), axis))) {
1066 return failure();
1067 }
1068 }
1069 }
1070
1071 if (splitAxis == concatAxis) {
1072 return success();
1073 }
1074
1075 auto deviceGroupSize =
1076 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1077 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1078 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1079 DimensionSize expectedResultConcatDimSize =
1080 operandConcatDimSize * deviceGroupSize;
1081 DimensionSize expectedResultSplitDimSize =
1082 operandSplitDimSize / deviceGroupSize;
1083 if (!expectedResultSplitDimSize.isDynamic() &&
1084 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1085 expectedResultSplitDimSize = DimensionSize::dynamic();
1086 }
1088 result.getLoc(), expectedResultConcatDimSize.value(),
1089 resultType.getDimSize(concatAxis), concatAxis))) {
1090 return failure();
1091 }
1093 result.getLoc(), expectedResultSplitDimSize.value(),
1094 resultType.getDimSize(splitAxis), splitAxis))) {
1095 return failure();
1096 }
1097
1098 return success();
1099}
1100
1102 Value operand, Value result, int64_t tensorAxis,
1103 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1104 ShapedType operandType = cast<ShapedType>(operand.getType());
1105 ShapedType resultType = cast<ShapedType>(result.getType());
1106 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1107 if (axis != tensorAxis) {
1109 result.getLoc(), operandType.getDimSize(axis),
1110 resultType.getDimSize(axis), axis))) {
1111 return failure();
1112 }
1113 }
1114 }
1115
1116 auto deviceGroupSize =
1117 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1118 auto operandScatterDimSize =
1119 DimensionSize(operandType.getDimSize(tensorAxis));
1120 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1121 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1122 return emitError(result.getLoc())
1123 << "Operand dimension size " << int64_t(operandScatterDimSize)
1124 << " is not divisible by collective device group size "
1125 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1126 << ".";
1127 }
1128 DimensionSize expectedResultTensorDimSize =
1129 operandScatterDimSize / deviceGroupSize;
1131 result.getLoc(), expectedResultTensorDimSize.value(),
1132 resultType.getDimSize(tensorAxis), tensorAxis))) {
1133 return failure();
1134 }
1135
1136 return success();
1137}
1138
1139static RankedTensorType sliceResultType(Type operandType, GridOp grid,
1140 ArrayRef<GridAxis> gridAxes,
1141 int64_t sliceAxis) {
1142 RankedTensorType operandRankedTensorType =
1143 cast<RankedTensorType>(operandType);
1144 DimensionSize operandSliceAxisSize =
1145 operandRankedTensorType.getShape()[sliceAxis];
1146 SmallVector<int64_t> resultShape =
1147 llvm::to_vector(operandRankedTensorType.getShape());
1148
1149 resultShape[sliceAxis] =
1150 operandSliceAxisSize /
1151 DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
1152 return operandRankedTensorType.clone(resultShape);
1153}
1154
1155//===----------------------------------------------------------------------===//
1156// shard.all_gather op
1157//===----------------------------------------------------------------------===//
1158
1159LogicalResult
1160AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1161 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1162 if (failed(grid)) {
1163 return failure();
1164 }
1165 auto gatherAxis = getGatherAxis().getSExtValue();
1166 return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1167 gatherAxis, getGridAxes(),
1168 grid.value().getShape());
1169}
1170
1171void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1172 MLIRContext *context) {
1173 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1174}
1175
1176void AllGatherOp::getAsmResultNames(
1177 function_ref<void(Value, StringRef)> setNameFn) {
1178 setNameFn(getResult(), "all_gather");
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// shard.all_reduce op
1183//===----------------------------------------------------------------------===//
1184
1185LogicalResult
1186AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1187 return getGridAndVerifyAxes(*this, symbolTable);
1188}
1189
1190void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1191 MLIRContext *context) {
1192 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1193}
1194
1195void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1196 Value input, StringRef grid,
1197 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1198 build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
1199 reduction);
1200}
1201
1202void AllReduceOp::getAsmResultNames(
1203 function_ref<void(Value, StringRef)> setNameFn) {
1204 setNameFn(getResult(), "all_reduce");
1205}
1206
1207//===----------------------------------------------------------------------===//
1208// shard.all_slice op
1209//===----------------------------------------------------------------------===//
1210
1211LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1212 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1213 if (failed(grid)) {
1214 return failure();
1215 }
1217 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1218 grid.value().getShape());
1219}
1220
1221void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1222 MLIRContext *context) {
1223 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1224}
1225
1226void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1227 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1228 int64_t sliceAxis) {
1229 Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
1230 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1231 sliceAxis);
1232}
1233
1234void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1235 Type resultType, Value input, StringRef grid,
1236 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1237 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1238 APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1239}
1240
1241void AllSliceOp::getAsmResultNames(
1242 function_ref<void(Value, StringRef)> setNameFn) {
1243 setNameFn(getResult(), "all_slice");
1244}
1245
1246//===----------------------------------------------------------------------===//
1247// shard.all_to_all op
1248//===----------------------------------------------------------------------===//
1249
1250LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1251 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1252 if (failed(grid)) {
1253 return failure();
1254 }
1255
1257 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1258 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1259}
1260
1261void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1262 MLIRContext *context) {
1263 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1264}
1265
1266void AllToAllOp::getAsmResultNames(
1267 function_ref<void(Value, StringRef)> setNameFn) {
1268 setNameFn(getResult(), "all_to_all");
1269}
1270
1271//===----------------------------------------------------------------------===//
1272// shard.broadcast op
1273//===----------------------------------------------------------------------===//
1274
1275LogicalResult
1276BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1277 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1278 if (failed(grid)) {
1279 return failure();
1280 }
1281 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1282 getRootDynamic(), getGridAxes(),
1283 grid.value().getShape()))) {
1284 return failure();
1285 }
1286
1287 return success();
1288}
1289
1290void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1291 MLIRContext *context) {
1292 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1293}
1294
1295void BroadcastOp::getAsmResultNames(
1296 function_ref<void(Value, StringRef)> setNameFn) {
1297 setNameFn(getResult(), "broadcast");
1298}
1299
1300//===----------------------------------------------------------------------===//
1301// shard.gather op
1302//===----------------------------------------------------------------------===//
1303
1304LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1305 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1306 if (failed(grid)) {
1307 return failure();
1308 }
1309 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1310 getRootDynamic(), getGridAxes(),
1311 grid.value().getShape()))) {
1312 return failure();
1313 }
1314
1315 auto gatherAxis = getGatherAxis().getSExtValue();
1316 return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1317 getGridAxes(),
1318 grid.value().getShape());
1319}
1320
1321void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1322 MLIRContext *context) {
1323 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1324}
1325
1326void GatherOp::getAsmResultNames(
1327 function_ref<void(Value, StringRef)> setNameFn) {
1328 setNameFn(getResult(), "gather");
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// shard.recv op
1333//===----------------------------------------------------------------------===//
1334
1335LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1336 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1337 if (failed(grid)) {
1338 return failure();
1339 }
1340 if (getSource() &&
1341 failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1342 getSource().value(), getSourceDynamic(),
1343 getGridAxes(), grid.value().getShape()))) {
1344 return failure();
1345 }
1346 return success();
1347}
1348
1349void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1350 MLIRContext *context) {
1351 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1352}
1353
1354void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1355 setNameFn(getResult(), "recv");
1356}
1357
1358//===----------------------------------------------------------------------===//
1359// shard.reduce op
1360//===----------------------------------------------------------------------===//
1361
1362LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1363 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1364 if (failed(grid)) {
1365 return failure();
1366 }
1367 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1368 getRootDynamic(), getGridAxes(),
1369 grid.value().getShape()))) {
1370 return failure();
1371 }
1372
1373 return success();
1374}
1375
1376void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1377 MLIRContext *context) {
1378 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1379}
1380
1381void ReduceOp::getAsmResultNames(
1382 function_ref<void(Value, StringRef)> setNameFn) {
1383 setNameFn(getResult(), "reduce");
1384}
1385
1386//===----------------------------------------------------------------------===//
1387// shard.reduce_scatter op
1388//===----------------------------------------------------------------------===//
1389
1390LogicalResult
1391ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1392 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1393 if (failed(grid)) {
1394 return failure();
1395 }
1396
1398 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1399 grid.value().getShape());
1400}
1401
1402void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1403 MLIRContext *context) {
1404 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1405}
1406
1407void ReduceScatterOp::getAsmResultNames(
1408 function_ref<void(Value, StringRef)> setNameFn) {
1409 setNameFn(getResult(), "reduce_scatter");
1410}
1411
1412//===----------------------------------------------------------------------===//
1413// shard.scatter op
1414//===----------------------------------------------------------------------===//
1415
1416LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1417 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1418 if (failed(grid)) {
1419 return failure();
1420 }
1421 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1422 getRootDynamic(), getGridAxes(),
1423 grid.value().getShape()))) {
1424 return failure();
1425 }
1426
1427 auto scatterAxis = getScatterAxis().getSExtValue();
1428 return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1429 scatterAxis, getGridAxes(),
1430 grid.value().getShape());
1431}
1432
1433void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1434 MLIRContext *context) {
1435 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1436}
1437
1438void ScatterOp::getAsmResultNames(
1439 function_ref<void(Value, StringRef)> setNameFn) {
1440 setNameFn(getResult(), "scatter");
1441}
1442
1443//===----------------------------------------------------------------------===//
1444// shard.send op
1445//===----------------------------------------------------------------------===//
1446
1447LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1448 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1449 if (failed(grid)) {
1450 return failure();
1451 }
1452 if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1453 getDestination(), getDestinationDynamic(),
1454 getGridAxes(), grid.value().getShape()))) {
1455 return failure();
1456 }
1457 return success();
1458}
1459
1460void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1461 MLIRContext *context) {
1462 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1463}
1464
1465void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1466 setNameFn(getResult(), "send");
1467}
1468
1469//===----------------------------------------------------------------------===//
1470// shard.shift op
1471//===----------------------------------------------------------------------===//
1472
1473LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1474 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1475 if (failed(grid)) {
1476 return failure();
1477 }
1478
1479 auto gridAxes = getGridAxes();
1480 auto shiftAxis = getShiftAxis().getZExtValue();
1481 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1482 return emitError() << "Invalid shift axis " << shiftAxis
1483 << ". It must be one of the grouping grid axes.";
1484 }
1485
1486 return success();
1487}
1488
1489void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1490 MLIRContext *context) {
1491 // TODO: remove op when offset is 0 or if it is a rotate with and
1492 // offset % shift_axis_grid_dim_size == 0.
1493}
1494
1495void ShiftOp::getAsmResultNames(
1496 function_ref<void(Value, StringRef)> setNameFn) {
1497 setNameFn(getResult(), "shift");
1498}
1499
1500//===----------------------------------------------------------------------===//
1501// shard.update_halo op
1502//===----------------------------------------------------------------------===//
1503
1504LogicalResult
1505UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1506 auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
1507 if (failed(grid)) {
1508 return failure();
1509 }
1510
1511 return success();
1512}
1513
1514//===----------------------------------------------------------------------===//
1515// TableGen'd op method definitions
1516//===----------------------------------------------------------------------===//
1517
1518#define GET_OP_CLASSES
1519#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1520
1521#define GET_ATTRDEF_CLASSES
1522#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1523
1524#define GET_TYPEDEF_CLASSES
1525#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1526
1527#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition ShardOps.cpp:62
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:148
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
Definition ShardOps.cpp:177
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:200
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
Definition ShardOps.cpp:987
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
Definition ShardOps.cpp:299
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition ShardOps.h:68
bool operator!=(Value rhs) const
Definition ShardOps.cpp:744
bool equalShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:712
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition ShardOps.cpp:752
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:770
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:689
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:63
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:60
::llvm::StringRef getGrid() const
Definition ShardOps.h:61
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:708
bool operator==(Value rhs) const
Definition ShardOps.cpp:740
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:67
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:64
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:728
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
shard::Sharding Sharding
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition ShardOps.cpp:338
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:113
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:106
int16_t GridAxis
Definition ShardOps.h:26
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition ShardOps.cpp:352
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:168
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:146
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition AffineExpr.h:252
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.