MLIR 23.0.0git
ShardOps.cpp
Go to the documentation of this file.
1//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
24#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include <algorithm>
34#include <functional>
35#include <iterator>
36#include <numeric>
37#include <optional>
38#include <utility>
39
40#define DEBUG_TYPE "shard-ops"
41
42using namespace mlir;
43using namespace mlir::shard;
44
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
46
47namespace {
48
49struct DimensionSize {
50 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(int64_t val) : val(val) {}
52 int64_t value() const { return val; }
53 operator int64_t() const { return val; }
54 bool isDynamic() const { return ShapedType::isDynamic(val); }
55
56private:
57 int64_t val;
58};
59
60} // namespace
61
62static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
63 if (lhs.isDynamic() || rhs.isDynamic()) {
64 return DimensionSize::dynamic();
65 }
66 return lhs.value() / rhs.value();
67}
68
69static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
70 if (lhs.isDynamic() || rhs.isDynamic()) {
71 return DimensionSize::dynamic();
72 }
73 return lhs.value() * rhs.value();
74}
75
79 ValueRange dynamics, Type type) {
80 SmallVector<Value> values;
81 auto dyn = dynamics.begin();
82 Type i64 = b.getI64Type();
83 if (!type)
84 type = i64;
85 assert((i64 == type || b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
90 } else {
91 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
93 }
94 }
95 return values;
96}
97
98//===----------------------------------------------------------------------===//
99// Inliner
100//===----------------------------------------------------------------------===//
101
102namespace {
103struct ShardInlinerinterface : public DialectInlinerInterface {
104 using DialectInlinerInterface::DialectInlinerInterface;
105 // Currently no restrictions are encoded for inlining.
106 bool isLegalToInline(Operation *, Operation *, bool) const final {
107 return true;
108 }
109 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
110 return true;
111 }
112 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
113 return true;
114 }
115};
116} // namespace
117
118//===----------------------------------------------------------------------===//
119// Shard dialect
120//===----------------------------------------------------------------------===//
121
122void ShardDialect::initialize() {
123 addOperations<
124#define GET_OP_LIST
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
126 >();
127 addAttributes<
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
130 >();
131 addTypes<
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
134 >();
135 addInterface<ShardInlinerinterface>();
136}
137
138Operation *ShardDialect::materializeConstant(OpBuilder &builder,
139 Attribute value, Type type,
140 Location loc) {
141 return arith::ConstantOp::materialize(builder, value, type, loc);
142}
143
144//===----------------------------------------------------------------------===//
145// Shard utilities
146//===----------------------------------------------------------------------===//
147
148static FailureOr<GridOp> getGridAndVerify(Operation *op,
149 FlatSymbolRefAttr gridSymbol,
150 SymbolTableCollection &symbolTable) {
151 shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
152 if (!grid) {
153 return op->emitError() << "Undefined required grid symbol \""
154 << gridSymbol.getValue() << "\".";
155 }
156
157 return grid;
158}
159
160template <typename It>
161static bool isUnique(It begin, It end) {
162 if (begin == end) {
163 return true;
164 }
165 It next = std::next(begin);
166 if (next == end) {
167 return true;
168 }
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
171 return false;
172 }
173 }
174 return true;
175}
176
177static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
178 GridOp grid) {
179 SmallVector<GridAxis> sorted = llvm::to_vector(axes);
180 llvm::sort(sorted);
181 if (!isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) << "Grid axes contains duplicate elements.";
183 }
184
185 GridAxis rank = grid.getRank();
186 for (auto axis : axes) {
187 if (axis >= rank || axis < 0) {
188 return emitError(loc)
189 << "0-based grid axis index " << axis
190 << " is out of bounds. The referenced grid \"" << grid.getSymName()
191 << "\" is of rank " << rank << ".";
192 }
193 }
194
195 return success();
196}
197
198template <typename Op>
199static FailureOr<GridOp>
201 auto grid =
202 ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
203 if (failed(grid)) {
204 return failure();
205 }
206 if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
207 return failure();
208 }
209 return grid;
210}
211
212template <typename InShape, typename GridShape, typename SplitAxes,
213 typename OutShape>
214static void shardShape(const InShape &inShape, const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
216 ArrayRef<int64_t> shardedDimsOffsets = {},
217 ArrayRef<int64_t> haloSizes = {}) {
218 // 0d tensors cannot be sharded and must get replicated
219 if (inShape.empty()) {
220 assert(outShape.empty());
221 return;
222 }
223
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
226
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
229 uint64_t pos = 1;
230 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
234 if (same) {
235 // Find sharded dims in shardedDimsOffsets with same static size on
236 // all devices. Use kDynamic for dimensions with dynamic or
237 // non-uniform offs in shardedDimsOffsets.
238 uint64_t numShards = 0;
239 for (auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
241 }
242 for (size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
244 sz) {
245 same = false;
246 break;
247 }
248 }
249 pos += numShards + 1;
250 }
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
252 }
253 }
254 } else {
255 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
256 outShape[tensorAxis] = shardDimension(
257 inShape[tensorAxis],
258 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
259 }
260
261 if (!haloSizes.empty()) {
262 // add halo sizes if requested
263 int haloAxis = 0;
264 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
271 ++haloAxis;
272 } else {
273 outShape[tensorAxis] = ShapedType::kDynamic;
274 }
275 }
276 }
277 }
278 }
279}
280
281ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
282 Sharding sharding) {
283 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
284 SmallVector<Dim> resShapeArr(shape.getShape().size());
285 shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
286 resShapeArr, sharding.getStaticShardedDimsOffsets(),
287 sharding.getStaticHaloSizes());
288 return shape.clone(resShapeArr);
289}
290
291Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
294 return shardShapedType(rankedTensorType, grid, sharding);
295 }
296 return type;
297}
298
300 Value &operandValue,
301 Operation *operandOp,
302 OpBuilder &builder,
303 ShardOp &newShardOp) {
304 OpBuilder::InsertionGuard insertionGuard(builder);
305 builder.setInsertionPointAfterValue(operandValue);
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
309 // No need for anything if the correct sharding is already set.
310 if (!newShardOp) {
311 newShardOp = shardOp;
312 }
313 return;
314 }
315
316 if (!newShardOp) {
317 auto shardingOp =
318 ShardingOp::create(builder, operandValue.getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue,
320 shardingOp,
321 /*annotate_for_users*/ false);
322 }
323 operandValue.replaceUsesWithIf(
324 newShardOp, [operandOp, operandValue](OpOperand &use) {
325 return use.getOwner() == operandOp && use.get() == operandValue;
326 });
327
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
329 return;
330 }
331
332 auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp,
333 newShardOp.getSharding(),
334 /*annotate_for_users*/ true);
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
336}
337
340 OpBuilder &builder) {
341 ShardOp newShardOp;
343 for (auto &use : result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
345 }
346 for (auto &[operandValue, operandOp] : uses) {
347 maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
348 builder, newShardOp);
349 }
350}
351
353 OpOperand &operand,
354 OpBuilder &builder) {
355 OpBuilder::InsertionGuard insertionGuard(builder);
356 Value operandValue = operand.get();
357 Operation *operandSrcOp = operandValue.getDefiningOp();
358 bool isBlockArg = !operandSrcOp;
359 {
360 [[maybe_unused]] auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.getType());
362 assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
363 }
364 if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
365 operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
366 return;
367 }
368
369 Operation *operandOp = operand.getOwner();
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
371
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
374 // No need for anything the correct sharding is already set.
375 return;
376 }
377
378 builder.setInsertionPoint(operandOp);
379 auto shardingOp =
380 ShardingOp::create(builder, operand.get().getLoc(), sharding);
381 auto newShardOp =
382 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
383 /*annotate_for_users*/ true);
384 IRRewriter rewriter(builder);
385 rewriter.replaceUsesWithIf(
386 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
387 return use.getOwner() == operandOp && use.get() == operandValue;
388 });
389
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
391 // No need for resharding.
392 return;
393 }
394
395 builder.setInsertionPoint(newShardOp);
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
398 /*annotate_for_users*/ false);
399 rewriter.replaceUsesWithIf(
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
402 });
403}
404
405//===----------------------------------------------------------------------===//
406// shard.grid op
407//===----------------------------------------------------------------------===//
408
409LogicalResult GridOp::verify() {
410 int64_t rank = getRank();
411
412 if (rank <= 0)
413 return emitOpError("rank of grid is expected to be a positive integer");
414
415 for (int64_t dimSize : getShape()) {
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError("dimension size of a grid is expected to be "
418 "non-negative or dynamic");
419 }
420
421 return success();
422}
423
424//===----------------------------------------------------------------------===//
425// shard.grid_shape op
426//===----------------------------------------------------------------------===//
427
428LogicalResult
429GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
430 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
431 if (failed(grid)) {
432 return failure();
433 }
434 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
435 return failure();
436 }
437
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() << "Unexpected number of results " << getResult().size()
442 << ". Expected " << expectedResultsCount << ".";
443 }
444
445 return success();
446}
447
448void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
449 GridOp grid) {
450 build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
451}
452
453void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
454 GridOp grid, ArrayRef<GridAxis> axes) {
455 build(odsBuilder, odsState,
456 SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
457 odsBuilder.getIndexType()),
458 grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
459}
460
461void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
462 StringRef grid, ArrayRef<GridAxis> axes) {
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
465 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
466 GridAxesAttr::get(odsBuilder.getContext(), axes));
467}
468
469void GridShapeOp::getAsmResultNames(
470 function_ref<void(Value, StringRef)> setNameFn) {
471 setNameFn(getResults()[0], "grid_shape");
472}
473
474//===----------------------------------------------------------------------===//
475// shard.sharding
476//===----------------------------------------------------------------------===//
477
478void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480 ArrayRef<int64_t> staticHalos,
481 ArrayRef<int64_t> staticOffsets) {
482 return build(
483 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
484 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
485 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
486}
487
488void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
489 llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
490 ArrayRef<int64_t> staticHalos,
491 ArrayRef<int64_t> staticOffsets) {
492 return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
493 GridAxesArrayAttr::get(b.getContext(), splitAxes),
494 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
495 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
496 {});
497}
498
499void ShardingOp::build(
503 ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
504 mlir::SmallVector<int64_t> staticHalos, staticDims;
505 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
506 dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
507 dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
508 return build(
509 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
510 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
511 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
512}
513
514void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
516
517 build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
518 GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
519 from.getStaticShardedDimsOffsets().empty()
521 : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
523 from.getStaticHaloSizes().empty()
525 : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
526 from.getDynamicHaloSizes());
527}
528
529LogicalResult ShardingOp::verify() {
530 llvm::SmallSet<GridAxis, 4> visitedAxes;
531
532 auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
533 for (GridAxis axis : axesArray) {
534 if (axis < 0)
535 return emitError() << "grid axis is expected to be non-negative";
536 if (!visitedAxes.insert(axis).second)
537 return emitError() << "grid axis duplicated";
538 }
539 return success();
540 };
541
542 for (auto subAxes : getSplitAxes().getAxes()) {
543 ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
544 if (failed(checkGridAxis(subAxesArray)))
545 return failure();
546 }
547
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError("halo sizes and shard offsets are mutually exclusive");
550 }
551
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
556 --numSplitAxes;
557 }
558 }
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() << "halo sizes must be specified for all split axes.";
561 }
562 }
563
564 return success();
565}
566
567void ShardingOp::getAsmResultNames(
568 function_ref<void(Value, StringRef)> setNameFn) {
569 setNameFn(getResult(), "sharding");
570}
571
572LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
573 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
574 if (failed(grid)) {
575 return failure();
576 }
577 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
578 !getStaticShardedDimsOffsets().empty()) {
579 return emitError() << "sharded dims offsets are not allowed for "
580 "device grids with dynamic shape.";
581 }
582
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto gridShape = grid.value().getShape();
586 assert(ShapedType::isStaticShape(gridShape));
587 uint64_t pos = 0;
588 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (auto i : innerSplitAxes.asArrayRef()) {
592 numShards += gridShape[i];
593 }
594 for (int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() << "sharded dims offsets has wrong size.";
597 }
598 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
600 return emitError()
601 << "sharded dims offsets must be non-decreasing.";
602 }
603 off = shardedDimsOffsets[pos + i];
604 }
605 }
606 pos += numShards + 1;
607 }
608 }
609 }
610 return success();
611}
612
613namespace {
614// Sharding annotations "halo sizes" and "sharded dims offsets"
615// are a mix of attributes and dynamic values. This canonicalization moves
616// constant values to the respective attribute lists, minimizing the number
617// of values.
618// It also removes sharded_dims_sizes and halos if they are effectively "empty".
619class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
620public:
621 using OpRewritePattern<ShardingOp>::OpRewritePattern;
622
623 LogicalResult matchAndRewrite(ShardingOp op,
624 PatternRewriter &b) const override {
625 auto mixedHalos =
626 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
627 auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
628 op.getDynamicShardedDimsOffsets(), b);
629
630 // No constant operands were folded, just return;
631 bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
632 succeeded(foldDynamicIndexList(mixedOffs, true));
633
634 auto decomposedHalos = decomposeMixedValues(mixedHalos);
635 auto staticHalos = decomposedHalos.first;
636 auto dynamicHalos = decomposedHalos.second;
637 auto decomposedOffs = decomposeMixedValues(mixedOffs);
638 auto staticOffs = decomposedOffs.first;
639 auto dynamicOffs = decomposedOffs.second;
640
641 if (dynamicHalos.empty() && !staticHalos.empty()) {
642 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
643 staticHalos.clear();
644 modified = true;
645 }
646 }
647
648 // Remove sharded dims offsets if they are effectively the default values,
649 // e.g. if they define equi-distance between all neighboring shards.
650 // Requires static-only offsets. Compares the first distance as the
651 // difference between the first two offsets. Only if all consecutive
652 // distances are the same, the offsets are removed.
653 if (dynamicOffs.empty() && !staticOffs.empty()) {
654 assert(staticOffs.size() >= 2);
655 auto diff = staticOffs[1] - staticOffs[0];
656 bool allSame = staticOffs.size() > 2;
657 for (auto i = 2u; i < staticOffs.size(); ++i) {
658 if (staticOffs[i] - staticOffs[i - 1] != diff) {
659 allSame = false;
660 break;
661 }
662 }
663 if (allSame) {
664 staticOffs.clear();
665 modified = true;
666 }
667 }
668
669 if (!modified) {
670 return failure();
671 }
672
673 b.modifyOpInPlace(op, [&]() {
674 op.setStaticHaloSizes(staticHalos);
675 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
676 op.setStaticShardedDimsOffsets(staticOffs);
677 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
678 });
679 return success();
680 }
681};
682} // namespace
683
684void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
685 mlir::MLIRContext *context) {
686 results.add<NormalizeSharding>(context);
687}
688
689//===----------------------------------------------------------------------===//
690// Sharding
691//===----------------------------------------------------------------------===//
692
694 if (getGrid() != rhs.getGrid()) {
695 return false;
696 }
697
698 auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
699 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
700 getSplitAxes().begin() + minSize),
701 llvm::make_range(rhs.getSplitAxes().begin(),
702 rhs.getSplitAxes().begin() + minSize))) {
703 return false;
704 }
705
706 return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
707 std::mem_fn(&GridAxesAttr::empty)) &&
708 llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
709 std::mem_fn(&GridAxesAttr::empty));
710}
711
715
717 if (rhs.getStaticShardedDimsOffsets().size() !=
719 !llvm::equal(getStaticShardedDimsOffsets(),
720 rhs.getStaticShardedDimsOffsets())) {
721 return false;
722 }
723 if (rhs.getDynamicShardedDimsOffsets().size() !=
725 !llvm::equal(getDynamicShardedDimsOffsets(),
726 rhs.getDynamicShardedDimsOffsets())) {
727 return false;
728 }
729 return true;
730}
731
733 if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
734 !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
735 return false;
736 }
737 if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
738 !llvm::equal(getDynamicHaloSizes(), rhs.getDynamicHaloSizes())) {
739 return false;
740 }
741 return true;
742}
743
747
748bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
749
753
754bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
755
756llvm::raw_ostream &mlir::shard::operator<<(llvm::raw_ostream &os,
757 const Sharding &sharding) {
758 os << "Sharding<grid=" << sharding.getGrid() << ", split_axes=[";
759 llvm::interleaveComma(sharding.getSplitAxes(), os, [&](GridAxesAttr axes) {
760 os << "[";
761 llvm::interleaveComma(axes.asArrayRef(), os);
762 os << "]";
763 });
764 os << "]";
765 if (!sharding.getStaticHaloSizes().empty()) {
766 os << ", halo_sizes=[";
767 llvm::interleaveComma(sharding.getStaticHaloSizes(), os);
768 os << "]";
769 }
770 if (!sharding.getStaticShardedDimsOffsets().empty()) {
771 os << ", sharded_dims_offsets=[";
772 llvm::interleaveComma(sharding.getStaticShardedDimsOffsets(), os);
773 os << "]";
774 }
775 os << ">";
776 return os;
777}
778
780
782 auto shardingOp = rhs.getDefiningOp<ShardingOp>();
783 assert(shardingOp && "expected sharding op");
784 auto splitAxes = shardingOp.getSplitAxes().getAxes();
785 // If splitAxes are empty, use "empty" constructor.
786 if (splitAxes.empty()) {
787 *this = Sharding(shardingOp.getGridAttr());
788 return;
789 }
790 *this =
791 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
792 shardingOp.getStaticShardedDimsOffsets(),
793 SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
794 SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
795}
796
798 ArrayRef<GridAxesAttr> splitAxes,
799 ArrayRef<int64_t> staticHaloSizes,
800 ArrayRef<int64_t> staticShardedDimsOffsets,
801 ArrayRef<Value> dynamicHaloSizes,
802 ArrayRef<Value> dynamicShardedDimsOffsets) {
803 Sharding res(grid);
804 if (splitAxes.empty()) {
805 return res;
806 }
807
808 res.split_axes.resize(splitAxes.size());
809 for (auto [i, axis] : llvm::enumerate(splitAxes)) {
810 res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
811 }
812
813 auto clone = [](const auto src, auto &dst) {
814 dst.resize(src.size());
815 llvm::copy(src, dst.begin());
816 };
817
818 clone(staticHaloSizes, res.static_halo_sizes);
819 clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
820 clone(dynamicHaloSizes, res.dynamic_halo_sizes);
821 clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
822
823 return res;
824}
825
826//===----------------------------------------------------------------------===//
827// shard.shard_shape
828//===----------------------------------------------------------------------===//
829
830void ShardShapeOp::getAsmResultNames(
831 function_ref<void(Value, StringRef)> setNameFn) {
832 setNameFn(getResult()[0], "shard_shape");
833}
834
835void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
836 ::mlir::OperationState &odsState,
838 ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
839 ::mlir::ValueRange device) {
840 SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
841 build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
842 SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
843}
844
845//===----------------------------------------------------------------------===//
846// shard.shard op
847//===----------------------------------------------------------------------===//
848
849void ShardOp::getAsmResultNames(
850 function_ref<void(Value, StringRef)> setNameFn) {
851 setNameFn(getResult(), "sharding_annotated");
852}
853
854namespace {
855// Determine if the given ShardOp is a duplicate of another ShardOp
856// on the same value. This can happen if constant values are sharded.
857class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
858public:
859 using OpRewritePattern<ShardOp>::OpRewritePattern;
860
861 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
862 // Get the use-list of the value being sharded and check if it has more than
863 // one use.
864 Value value = op.getSrc();
865 if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
866 return failure();
867 }
868
869 // Iterate through the uses of the value to find a duplicate ShardOp.
870 for (auto &use : value.getUses()) {
871 if (use.getOwner() != op.getOperation()) {
872 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
873 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
874 return failure();
875 }
876 // Create a Sharding object for the current and the other ShardOp
877 // If the two are equal replace current op with the other op.
878 Sharding currentSharding(op.getSharding());
879 Sharding otherSharding(otherOp.getSharding());
880 if (currentSharding == otherSharding) {
881 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
882 b.eraseOp(op.getOperation());
883 } else {
884 // use the other sharding as input for op
885 b.modifyOpInPlace(
886 op, [&]() { op.getSrcMutable().assign(otherOp.getResult()); });
887 }
888 return success();
889 }
890 }
891
892 return failure();
893 }
894};
895} // namespace
896
897void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
898 mlir::MLIRContext *context) {
899 results.add<FoldDuplicateShardOp>(context);
900}
901
902//===----------------------------------------------------------------------===//
903// shard.process_multi_index op
904//===----------------------------------------------------------------------===//
905
906LogicalResult
907ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
908 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
909 if (failed(grid)) {
910 return failure();
911 }
912 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
913 return failure();
914 }
915
916 size_t expectedResultsCount =
917 getAxes().empty() ? grid->getRank() : getAxes().size();
918 if (getResult().size() != expectedResultsCount) {
919 return emitError() << "Unexpected number of results " << getResult().size()
920 << ". Expected " << expectedResultsCount << ".";
921 }
922
923 return success();
924}
925
926void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
927 GridOp grid) {
928 build(odsBuilder, odsState,
929 SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
930 grid.getSymName(), ArrayRef<GridAxis>());
931}
932
933void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
934 StringRef grid, ArrayRef<GridAxis> axes) {
935 build(odsBuilder, odsState,
936 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
937 GridAxesAttr::get(odsBuilder.getContext(), axes));
938}
939
940void ProcessMultiIndexOp::getAsmResultNames(
941 function_ref<void(Value, StringRef)> setNameFn) {
942 setNameFn(getResults()[0], "proc_linear_idx");
943}
944
945//===----------------------------------------------------------------------===//
946// shard.process_linear_index op
947//===----------------------------------------------------------------------===//
948
949LogicalResult
950ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
951 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
952 if (failed(grid)) {
953 return failure();
954 }
955 return success();
956}
957
958void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
959 OperationState &odsState, GridOp grid) {
960 build(odsBuilder, odsState, grid.getSymName());
961}
962
963void ProcessLinearIndexOp::getAsmResultNames(
964 function_ref<void(Value, StringRef)> setNameFn) {
965 setNameFn(getResult(), "proc_linear_idx");
966}
967
968//===----------------------------------------------------------------------===//
969// shard.neighbors_linear_indices op
970//===----------------------------------------------------------------------===//
971
972LogicalResult
973NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
974 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
975 if (failed(grid)) {
976 return failure();
977 }
978 return success();
979}
980
981void NeighborsLinearIndicesOp::getAsmResultNames(
982 function_ref<void(Value, StringRef)> setNameFn) {
983 setNameFn(getNeighborDown(), "down_linear_idx");
984 setNameFn(getNeighborUp(), "up_linear_idx");
985}
986
987//===----------------------------------------------------------------------===//
988// collective communication ops
989//===----------------------------------------------------------------------===//
990
991namespace {
992
993template <typename Op>
994struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
995 using OpRewritePattern<Op>::OpRewritePattern;
996 LogicalResult matchAndRewrite(Op op,
997 PatternRewriter &rewriter) const override {
998 auto gridAxes = op.getGridAxes();
999 if (!gridAxes.empty()) {
1000 return failure();
1001 }
1002 if (op.getInput().getType() != op.getResult().getType()) {
1003 return failure();
1004 }
1005
1006 rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
1007 rewriter.eraseOp(op.getOperation());
1008 return success();
1009 }
1010};
1011
1012} // namespace
1013
1014static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
1015 ArrayRef<int64_t> device,
1016 Operation::operand_range deviceDynamic,
1017 ArrayRef<GridAxis> gridAxes,
1018 ArrayRef<int64_t> gridShape) {
1019 if (device.size() != gridAxes.size()) {
1020 return emitError(loc) << "In-group device \"" << deviceName
1021 << "\" has unexpected multi-index size "
1022 << device.size() << ". Expected " << gridAxes.size()
1023 << ".";
1024 }
1025
1026 for (size_t i = 0; i < device.size(); ++i) {
1027 if (ShapedType::isStatic(device[i]) &&
1028 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1029 gridShape[gridAxes[i]] <= device[i]) {
1030 return emitError(loc)
1031 << "Out of bounds coordinate " << i << " for in-group device \""
1032 << deviceName << "\"."
1033 << " Got " << device[i] << ", but expected value in the range [0, "
1034 << (gridShape[gridAxes[i]] - 1) << "].";
1035 }
1036 }
1037 return success();
1038}
1039
1041 int64_t expectedDimSize,
1042 int64_t resultDimSize,
1043 int64_t resultAxis) {
1044 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1045 return emitError(loc) << "Dimension size mismatch for result axis "
1046 << resultAxis << ". Expected "
1047 << (ShapedType::isDynamic(expectedDimSize)
1048 ? Twine("dynamic")
1049 : Twine(expectedDimSize))
1050 << ", but got " << resultDimSize << ".";
1051 }
1052
1053 return success();
1054}
1055
1057 Value operand, Value result, int64_t gatherAxis,
1058 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1059 auto resultRank = cast<ShapedType>(result.getType()).getRank();
1060 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1061 return emitError(result.getLoc())
1062 << "Gather axis " << gatherAxis << " is out of bounds [0, "
1063 << resultRank << ").";
1064 }
1065
1066 ShapedType operandType = cast<ShapedType>(operand.getType());
1067 ShapedType resultType = cast<ShapedType>(result.getType());
1068 auto deviceGroupSize =
1069 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1070 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1071 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1072 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1073 auto expectedResultDimSize =
1074 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1076 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1077 return failure();
1078 }
1079 }
1080 return success();
1081}
1082
1084 Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1085 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1086 ShapedType operandType = cast<ShapedType>(operand.getType());
1087 ShapedType resultType = cast<ShapedType>(result.getType());
1088 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1089 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1091 result.getLoc(), operandType.getDimSize(axis),
1092 resultType.getDimSize(axis), axis))) {
1093 return failure();
1094 }
1095 }
1096 }
1097
1098 if (splitAxis == concatAxis) {
1099 return success();
1100 }
1101
1102 auto deviceGroupSize =
1103 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1104 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1105 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1106 DimensionSize expectedResultConcatDimSize =
1107 operandConcatDimSize * deviceGroupSize;
1108 DimensionSize expectedResultSplitDimSize =
1109 operandSplitDimSize / deviceGroupSize;
1110 if (!expectedResultSplitDimSize.isDynamic() &&
1111 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1112 expectedResultSplitDimSize = DimensionSize::dynamic();
1113 }
1115 result.getLoc(), expectedResultConcatDimSize.value(),
1116 resultType.getDimSize(concatAxis), concatAxis))) {
1117 return failure();
1118 }
1120 result.getLoc(), expectedResultSplitDimSize.value(),
1121 resultType.getDimSize(splitAxis), splitAxis))) {
1122 return failure();
1123 }
1124
1125 return success();
1126}
1127
1129 Value operand, Value result, int64_t tensorAxis,
1130 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1131 ShapedType operandType = cast<ShapedType>(operand.getType());
1132 ShapedType resultType = cast<ShapedType>(result.getType());
1133 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1134 if (axis != tensorAxis) {
1136 result.getLoc(), operandType.getDimSize(axis),
1137 resultType.getDimSize(axis), axis))) {
1138 return failure();
1139 }
1140 }
1141 }
1142
1143 auto deviceGroupSize =
1144 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1145 auto operandScatterDimSize =
1146 DimensionSize(operandType.getDimSize(tensorAxis));
1147 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1148 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1149 return emitError(result.getLoc())
1150 << "Operand dimension size " << int64_t(operandScatterDimSize)
1151 << " is not divisible by collective device group size "
1152 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1153 << ".";
1154 }
1155 DimensionSize expectedResultTensorDimSize =
1156 operandScatterDimSize / deviceGroupSize;
1158 result.getLoc(), expectedResultTensorDimSize.value(),
1159 resultType.getDimSize(tensorAxis), tensorAxis))) {
1160 return failure();
1161 }
1162
1163 return success();
1164}
1165
1166static RankedTensorType sliceResultType(Type operandType, GridOp grid,
1167 ArrayRef<GridAxis> gridAxes,
1168 int64_t sliceAxis) {
1169 RankedTensorType operandRankedTensorType =
1170 cast<RankedTensorType>(operandType);
1171 DimensionSize operandSliceAxisSize =
1172 operandRankedTensorType.getShape()[sliceAxis];
1173 SmallVector<int64_t> resultShape =
1174 llvm::to_vector(operandRankedTensorType.getShape());
1175
1176 resultShape[sliceAxis] =
1177 operandSliceAxisSize /
1178 DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
1179 return operandRankedTensorType.clone(resultShape);
1180}
1181
1182//===----------------------------------------------------------------------===//
1183// shard.all_gather op
1184//===----------------------------------------------------------------------===//
1185
1186LogicalResult
1187AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1188 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1189 if (failed(grid)) {
1190 return failure();
1191 }
1192 auto gatherAxis = getGatherAxis().getSExtValue();
1193 return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1194 gatherAxis, getGridAxes(),
1195 grid.value().getShape());
1196}
1197
1198void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1199 MLIRContext *context) {
1200 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1201}
1202
1203void AllGatherOp::getAsmResultNames(
1204 function_ref<void(Value, StringRef)> setNameFn) {
1205 setNameFn(getResult(), "all_gather");
1206}
1207
1208//===----------------------------------------------------------------------===//
1209// shard.all_reduce op
1210//===----------------------------------------------------------------------===//
1211
1212LogicalResult
1213AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1214 return getGridAndVerifyAxes(*this, symbolTable);
1215}
1216
1217void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1218 MLIRContext *context) {
1219 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1220}
1221
1222void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1223 Value input, StringRef grid,
1224 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1225 build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
1226 reduction);
1227}
1228
1229void AllReduceOp::getAsmResultNames(
1230 function_ref<void(Value, StringRef)> setNameFn) {
1231 setNameFn(getResult(), "all_reduce");
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// shard.all_slice op
1236//===----------------------------------------------------------------------===//
1237
1238LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1239 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1240 if (failed(grid)) {
1241 return failure();
1242 }
1244 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1245 grid.value().getShape());
1246}
1247
1248void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1249 MLIRContext *context) {
1250 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1251}
1252
1253void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1254 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1255 int64_t sliceAxis) {
1256 Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
1257 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1258 sliceAxis);
1259}
1260
1261void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1262 Type resultType, Value input, StringRef grid,
1263 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1264 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1265 APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1266}
1267
1268void AllSliceOp::getAsmResultNames(
1269 function_ref<void(Value, StringRef)> setNameFn) {
1270 setNameFn(getResult(), "all_slice");
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// shard.all_to_all op
1275//===----------------------------------------------------------------------===//
1276
1277LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1278 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1279 if (failed(grid)) {
1280 return failure();
1281 }
1282
1284 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1285 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1286}
1287
1288void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1289 MLIRContext *context) {
1290 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1291}
1292
1293void AllToAllOp::getAsmResultNames(
1294 function_ref<void(Value, StringRef)> setNameFn) {
1295 setNameFn(getResult(), "all_to_all");
1296}
1297
1298//===----------------------------------------------------------------------===//
1299// shard.broadcast op
1300//===----------------------------------------------------------------------===//
1301
1302LogicalResult
1303BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1304 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1305 if (failed(grid)) {
1306 return failure();
1307 }
1308 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1309 getRootDynamic(), getGridAxes(),
1310 grid.value().getShape()))) {
1311 return failure();
1312 }
1313
1314 return success();
1315}
1316
1317void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1318 MLIRContext *context) {
1319 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1320}
1321
1322void BroadcastOp::getAsmResultNames(
1323 function_ref<void(Value, StringRef)> setNameFn) {
1324 setNameFn(getResult(), "broadcast");
1325}
1326
1327//===----------------------------------------------------------------------===//
1328// shard.gather op
1329//===----------------------------------------------------------------------===//
1330
1331LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1332 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1333 if (failed(grid)) {
1334 return failure();
1335 }
1336 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1337 getRootDynamic(), getGridAxes(),
1338 grid.value().getShape()))) {
1339 return failure();
1340 }
1341
1342 auto gatherAxis = getGatherAxis().getSExtValue();
1343 return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1344 getGridAxes(),
1345 grid.value().getShape());
1346}
1347
1348void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1349 MLIRContext *context) {
1350 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1351}
1352
1353void GatherOp::getAsmResultNames(
1354 function_ref<void(Value, StringRef)> setNameFn) {
1355 setNameFn(getResult(), "gather");
1356}
1357
1358//===----------------------------------------------------------------------===//
1359// shard.recv op
1360//===----------------------------------------------------------------------===//
1361
1362LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1363 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1364 if (failed(grid)) {
1365 return failure();
1366 }
1367 if (getSource() &&
1368 failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1369 getSource().value(), getSourceDynamic(),
1370 getGridAxes(), grid.value().getShape()))) {
1371 return failure();
1372 }
1373 return success();
1374}
1375
1376void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1377 MLIRContext *context) {
1378 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1379}
1380
1381void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1382 setNameFn(getResult(), "recv");
1383}
1384
1385//===----------------------------------------------------------------------===//
1386// shard.reduce op
1387//===----------------------------------------------------------------------===//
1388
1389LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1390 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1391 if (failed(grid)) {
1392 return failure();
1393 }
1394 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1395 getRootDynamic(), getGridAxes(),
1396 grid.value().getShape()))) {
1397 return failure();
1398 }
1399
1400 return success();
1401}
1402
1403void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1404 MLIRContext *context) {
1405 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1406}
1407
1408void ReduceOp::getAsmResultNames(
1409 function_ref<void(Value, StringRef)> setNameFn) {
1410 setNameFn(getResult(), "reduce");
1411}
1412
1413//===----------------------------------------------------------------------===//
1414// shard.reduce_scatter op
1415//===----------------------------------------------------------------------===//
1416
1417LogicalResult
1418ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1419 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1420 if (failed(grid)) {
1421 return failure();
1422 }
1423
1425 getOperand(), getResult(), getScatterDim().getSExtValue(), getGridAxes(),
1426 grid.value().getShape());
1427}
1428
1429void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1430 MLIRContext *context) {
1431 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1432}
1433
1434void ReduceScatterOp::getAsmResultNames(
1435 function_ref<void(Value, StringRef)> setNameFn) {
1436 setNameFn(getResult(), "reduce_scatter");
1437}
1438
1439//===----------------------------------------------------------------------===//
1440// shard.scatter op
1441//===----------------------------------------------------------------------===//
1442
1443LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1444 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1445 if (failed(grid)) {
1446 return failure();
1447 }
1448 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1449 getRootDynamic(), getGridAxes(),
1450 grid.value().getShape()))) {
1451 return failure();
1452 }
1453
1454 auto scatterDim = getScatterDim().getSExtValue();
1455 return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1456 scatterDim, getGridAxes(),
1457 grid.value().getShape());
1458}
1459
1460void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1461 MLIRContext *context) {
1462 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1463}
1464
1465void ScatterOp::getAsmResultNames(
1466 function_ref<void(Value, StringRef)> setNameFn) {
1467 setNameFn(getResult(), "scatter");
1468}
1469
1470//===----------------------------------------------------------------------===//
1471// shard.send op
1472//===----------------------------------------------------------------------===//
1473
1474LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1475 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1476 if (failed(grid)) {
1477 return failure();
1478 }
1479 if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1480 getDestination(), getDestinationDynamic(),
1481 getGridAxes(), grid.value().getShape()))) {
1482 return failure();
1483 }
1484 return success();
1485}
1486
1487void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1488 MLIRContext *context) {
1489 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1490}
1491
1492void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1493 setNameFn(getResult(), "send");
1494}
1495
1496//===----------------------------------------------------------------------===//
1497// shard.shift op
1498//===----------------------------------------------------------------------===//
1499
1500LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1501 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1502 if (failed(grid)) {
1503 return failure();
1504 }
1505
1506 auto gridAxes = getGridAxes();
1507 auto shiftAxis = getShiftAxis().getZExtValue();
1508 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1509 return emitError() << "Invalid shift axis " << shiftAxis
1510 << ". It must be one of the grouping grid axes.";
1511 }
1512
1513 return success();
1514}
1515
1516void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1517 MLIRContext *context) {
1518 // TODO: remove op when offset is 0 or if it is a rotate with and
1519 // offset % shift_axis_grid_dim_size == 0.
1520}
1521
1522void ShiftOp::getAsmResultNames(
1523 function_ref<void(Value, StringRef)> setNameFn) {
1524 setNameFn(getResult(), "shift");
1525}
1526
1527//===----------------------------------------------------------------------===//
1528// shard.update_halo op
1529//===----------------------------------------------------------------------===//
1530
1531LogicalResult
1532UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1533 auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
1534 if (failed(grid)) {
1535 return failure();
1536 }
1537
1538 return success();
1539}
1540
1541//===----------------------------------------------------------------------===//
1542// TableGen'd op method definitions
1543//===----------------------------------------------------------------------===//
1544
1545#define GET_OP_CLASSES
1546#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1547
1548#define GET_ATTRDEF_CLASSES
1549#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1550
1551#define GET_TYPEDEF_CLASSES
1552#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1553
1554#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition ShardOps.cpp:62
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:148
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
Definition ShardOps.cpp:177
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:200
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
Definition ShardOps.cpp:299
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
OperandRange operand_range
Definition Operation.h:397
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition ShardOps.h:69
bool operator!=(Value rhs) const
Definition ShardOps.cpp:748
bool equalShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:716
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition ShardOps.cpp:779
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:797
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:693
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:64
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:61
::llvm::StringRef getGrid() const
Definition ShardOps.h:62
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:712
bool operator==(Value rhs) const
Definition ShardOps.cpp:744
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:68
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:65
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:63
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:732
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
shard::Sharding Sharding
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const Sharding &sharding)
Definition ShardOps.cpp:756
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition ShardOps.cpp:338
DenseI16ArrayAttr GridAxesAttr
Definition ShardOps.h:28
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:123
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:116
int16_t GridAxis
Definition ShardOps.h:27
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition ShardOps.cpp:352
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:178
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:156
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition AffineExpr.h:252
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.