MLIR 23.0.0git
ShardOps.cpp
Go to the documentation of this file.
1//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
24#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include <algorithm>
34#include <functional>
35#include <iterator>
36#include <numeric>
37#include <optional>
38#include <utility>
39
40#define DEBUG_TYPE "shard-ops"
41
42using namespace mlir;
43using namespace mlir::shard;
44
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
46
47namespace {
48
49struct DimensionSize {
50 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(int64_t val) : val(val) {}
52 int64_t value() const { return val; }
53 operator int64_t() const { return val; }
54 bool isDynamic() const { return ShapedType::isDynamic(val); }
55
56private:
57 int64_t val;
58};
59
60} // namespace
61
62static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
63 if (lhs.isDynamic() || rhs.isDynamic()) {
64 return DimensionSize::dynamic();
65 }
66 return lhs.value() / rhs.value();
67}
68
69static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
70 if (lhs.isDynamic() || rhs.isDynamic()) {
71 return DimensionSize::dynamic();
72 }
73 return lhs.value() * rhs.value();
74}
75
79 ValueRange dynamics, Type type) {
80 SmallVector<Value> values;
81 auto dyn = dynamics.begin();
82 Type i64 = b.getI64Type();
83 if (!type)
84 type = i64;
85 assert((i64 == type || b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
90 } else {
91 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
93 }
94 }
95 return values;
96}
97
98//===----------------------------------------------------------------------===//
99// Inliner
100//===----------------------------------------------------------------------===//
101
102namespace {
103struct ShardInlinerinterface : public DialectInlinerInterface {
104 using DialectInlinerInterface::DialectInlinerInterface;
105 // Currently no restrictions are encoded for inlining.
106 bool isLegalToInline(Operation *, Operation *, bool) const final {
107 return true;
108 }
109 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
110 return true;
111 }
112 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
113 return true;
114 }
115};
116} // namespace
117
118//===----------------------------------------------------------------------===//
119// Shard dialect
120//===----------------------------------------------------------------------===//
121
122void ShardDialect::initialize() {
123 addOperations<
124#define GET_OP_LIST
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
126 >();
127 addAttributes<
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
130 >();
131 addTypes<
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
134 >();
135 addInterface<ShardInlinerinterface>();
136}
137
138Operation *ShardDialect::materializeConstant(OpBuilder &builder,
139 Attribute value, Type type,
140 Location loc) {
141 return arith::ConstantOp::materialize(builder, value, type, loc);
142}
143
144//===----------------------------------------------------------------------===//
145// Shard utilities
146//===----------------------------------------------------------------------===//
147
148static FailureOr<GridOp> getGridAndVerify(Operation *op,
149 FlatSymbolRefAttr gridSymbol,
150 SymbolTableCollection &symbolTable) {
151 shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
152 if (!grid) {
153 return op->emitError() << "Undefined required grid symbol \""
154 << gridSymbol.getValue() << "\".";
155 }
156
157 return grid;
158}
159
160template <typename It>
161static bool isUnique(It begin, It end) {
162 if (begin == end) {
163 return true;
164 }
165 It next = std::next(begin);
166 if (next == end) {
167 return true;
168 }
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
171 return false;
172 }
173 }
174 return true;
175}
176
177static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
178 GridOp grid) {
179 SmallVector<GridAxis> sorted = llvm::to_vector(axes);
180 llvm::sort(sorted);
181 if (!isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) << "Grid axes contains duplicate elements.";
183 }
184
185 GridAxis rank = grid.getRank();
186 for (auto axis : axes) {
187 if (axis >= rank || axis < 0) {
188 return emitError(loc)
189 << "0-based grid axis index " << axis
190 << " is out of bounds. The referenced grid \"" << grid.getSymName()
191 << "\" is of rank " << rank << ".";
192 }
193 }
194
195 return success();
196}
197
198template <typename Op>
199static FailureOr<GridOp>
201 auto grid =
202 ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
203 if (failed(grid)) {
204 return failure();
205 }
206 if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
207 return failure();
208 }
209 return grid;
210}
211
212template <typename InShape, typename GridShape, typename SplitAxes,
213 typename OutShape>
214static void shardShape(const InShape &inShape, const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
216 ArrayRef<int64_t> shardedDimsOffsets = {},
217 ArrayRef<int64_t> haloSizes = {}) {
218 // 0d tensors cannot be sharded and must get replicated
219 if (inShape.empty()) {
220 assert(outShape.empty());
221 return;
222 }
223
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
226
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
229 uint64_t pos = 1;
230 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
234 if (same) {
235 // Find sharded dims in shardedDimsOffsets with same static size on
236 // all devices. Use kDynamic for dimensions with dynamic or
237 // non-uniform offs in shardedDimsOffsets.
238 uint64_t numShards = 0;
239 for (auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
241 }
242 for (size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
244 sz) {
245 same = false;
246 break;
247 }
248 }
249 pos += numShards + 1;
250 }
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
252 }
253 }
254 } else {
255 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
256 outShape[tensorAxis] = shardDimension(
257 inShape[tensorAxis],
258 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
259 }
260
261 if (!haloSizes.empty()) {
262 // add halo sizes if requested
263 int haloAxis = 0;
264 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
271 ++haloAxis;
272 } else {
273 outShape[tensorAxis] = ShapedType::kDynamic;
274 }
275 }
276 }
277 }
278 }
279}
280
281ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
282 Sharding sharding) {
283 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
284 SmallVector<Dim> resShapeArr(shape.getShape().size());
285 shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
286 resShapeArr, sharding.getStaticShardedDimsOffsets(),
287 sharding.getStaticHaloSizes());
288 return shape.clone(resShapeArr);
289}
290
291Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
294 return shardShapedType(rankedTensorType, grid, sharding);
295 }
296 return type;
297}
298
300 Value &operandValue,
301 Operation *operandOp,
302 OpBuilder &builder,
303 ShardOp &newShardOp) {
304 OpBuilder::InsertionGuard insertionGuard(builder);
305 builder.setInsertionPointAfterValue(operandValue);
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
309 // No need for anything if the correct sharding is already set.
310 if (!newShardOp) {
311 newShardOp = shardOp;
312 }
313 return;
314 }
315
316 if (!newShardOp) {
317 auto shardingOp =
318 ShardingOp::create(builder, operandValue.getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue,
320 shardingOp,
321 /*annotate_for_users*/ false);
322 }
323 operandValue.replaceUsesWithIf(
324 newShardOp, [operandOp, operandValue](OpOperand &use) {
325 return use.getOwner() == operandOp && use.get() == operandValue;
326 });
327
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
329 return;
330 }
331
332 auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp,
333 newShardOp.getSharding(),
334 /*annotate_for_users*/ true);
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
336}
337
340 OpBuilder &builder) {
341 ShardOp newShardOp;
343 for (auto &use : result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
345 }
346 for (auto &[operandValue, operandOp] : uses) {
347 maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
348 builder, newShardOp);
349 }
350}
351
353 OpOperand &operand,
354 OpBuilder &builder) {
355 OpBuilder::InsertionGuard insertionGuard(builder);
356 Value operandValue = operand.get();
357 Operation *operandSrcOp = operandValue.getDefiningOp();
358 bool isBlockArg = !operandSrcOp;
359 {
360 [[maybe_unused]] auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.getType());
362 assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
363 }
364 if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
365 operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
366 return;
367 }
368
369 Operation *operandOp = operand.getOwner();
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
371
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
374 // No need for anything the correct sharding is already set.
375 return;
376 }
377
378 builder.setInsertionPoint(operandOp);
379 auto shardingOp =
380 ShardingOp::create(builder, operand.get().getLoc(), sharding);
381 auto newShardOp =
382 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
383 /*annotate_for_users*/ true);
384 IRRewriter rewriter(builder);
385 rewriter.replaceUsesWithIf(
386 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
387 return use.getOwner() == operandOp && use.get() == operandValue;
388 });
389
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
391 // No need for resharding.
392 return;
393 }
394
395 builder.setInsertionPoint(newShardOp);
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
398 /*annotate_for_users*/ false);
399 rewriter.replaceUsesWithIf(
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
402 });
403}
404
405//===----------------------------------------------------------------------===//
406// shard.grid op
407//===----------------------------------------------------------------------===//
408
409LogicalResult GridOp::verify() {
410 int64_t rank = getRank();
411
412 if (rank <= 0)
413 return emitOpError("rank of grid is expected to be a positive integer");
414
415 for (int64_t dimSize : getShape()) {
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError("dimension size of a grid is expected to be "
418 "non-negative or dynamic");
419 }
420
421 return success();
422}
423
424//===----------------------------------------------------------------------===//
425// shard.grid_shape op
426//===----------------------------------------------------------------------===//
427
428LogicalResult
429GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
430 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
431 if (failed(grid)) {
432 return failure();
433 }
434 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
435 return failure();
436 }
437
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() << "Unexpected number of results " << getResult().size()
442 << ". Expected " << expectedResultsCount << ".";
443 }
444
445 return success();
446}
447
448void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
449 GridOp grid) {
450 build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
451}
452
453void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
454 GridOp grid, ArrayRef<GridAxis> axes) {
455 build(odsBuilder, odsState,
456 SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
457 odsBuilder.getIndexType()),
458 grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
459}
460
461void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
462 StringRef grid, ArrayRef<GridAxis> axes) {
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
465 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
466 GridAxesAttr::get(odsBuilder.getContext(), axes));
467}
468
469void GridShapeOp::getAsmResultNames(
470 function_ref<void(Value, StringRef)> setNameFn) {
471 setNameFn(getResults()[0], "grid_shape");
472}
473
474//===----------------------------------------------------------------------===//
475// shard.sharding
476//===----------------------------------------------------------------------===//
477
478void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480 ArrayRef<int64_t> staticHalos,
481 ArrayRef<int64_t> staticOffsets) {
482 return build(
483 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
484 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
485 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
486}
487
488void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
489 llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
490 ArrayRef<int64_t> staticHalos,
491 ArrayRef<int64_t> staticOffsets) {
492 return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
493 GridAxesArrayAttr::get(b.getContext(), splitAxes),
494 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
495 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
496 {});
497}
498
499void ShardingOp::build(
503 ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
504 mlir::SmallVector<int64_t> staticHalos, staticDims;
505 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
506 dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
507 dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
508 return build(
509 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
510 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
511 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
512}
513
514void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
516
517 build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
518 GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
519 from.getStaticShardedDimsOffsets().empty()
521 : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
523 from.getStaticHaloSizes().empty()
525 : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
526 from.getDynamicHaloSizes());
527}
528
529LogicalResult ShardingOp::verify() {
530 llvm::SmallSet<GridAxis, 4> visitedAxes;
531
532 auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
533 for (GridAxis axis : axesArray) {
534 if (axis < 0)
535 return emitError() << "grid axis is expected to be non-negative";
536 if (!visitedAxes.insert(axis).second)
537 return emitError() << "grid axis duplicated";
538 }
539 return success();
540 };
541
542 for (auto subAxes : getSplitAxes().getAxes()) {
543 ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
544 if (failed(checkGridAxis(subAxesArray)))
545 return failure();
546 }
547
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError("halo sizes and shard offsets are mutually exclusive");
550 }
551
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
556 --numSplitAxes;
557 }
558 }
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() << "halo sizes must be specified for all split axes.";
561 }
562 }
563
564 return success();
565}
566
567void ShardingOp::getAsmResultNames(
568 function_ref<void(Value, StringRef)> setNameFn) {
569 setNameFn(getResult(), "sharding");
570}
571
572LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
573 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
574 if (failed(grid)) {
575 return failure();
576 }
577 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
578 !getStaticShardedDimsOffsets().empty()) {
579 return emitError() << "sharded dims offsets are not allowed for "
580 "device grids with dynamic shape.";
581 }
582
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto gridShape = grid.value().getShape();
586 assert(ShapedType::isStaticShape(gridShape));
587 uint64_t pos = 0;
588 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (auto i : innerSplitAxes.asArrayRef()) {
592 numShards += gridShape[i];
593 }
594 for (int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() << "sharded dims offsets has wrong size.";
597 }
598 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
600 return emitError()
601 << "sharded dims offsets must be non-decreasing.";
602 }
603 off = shardedDimsOffsets[pos + i];
604 }
605 }
606 pos += numShards + 1;
607 }
608 }
609 }
610 return success();
611}
612
613namespace {
614// Sharding annotations "halo sizes" and "sharded dims offsets"
615// are a mix of attributes and dynamic values. This canonicalization moves
616// constant values to the respective attribute lists, minimizing the number
617// of values.
618// It also removes sharded_dims_sizes and halos if they are effectively "empty".
619class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
620public:
621 using OpRewritePattern<ShardingOp>::OpRewritePattern;
622
623 LogicalResult matchAndRewrite(ShardingOp op,
624 PatternRewriter &b) const override {
625 auto mixedHalos =
626 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
627 auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
628 op.getDynamicShardedDimsOffsets(), b);
629
630 // No constant operands were folded, just return;
631 bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
632 succeeded(foldDynamicIndexList(mixedOffs, true));
633
634 auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
635 auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
636
637 if (dynamicHalos.empty() && !staticHalos.empty()) {
638 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
639 staticHalos.clear();
640 modified = true;
641 }
642 }
643
644 // Remove sharded dims offsets if they are effectively the default values,
645 // e.g. if they define equi-distance between all neighboring shards.
646 // Requires static-only offsets. Compares the first distance as the
647 // difference between the first two offsets. Only if all consecutive
648 // distances are the same, the offsets are removed.
649 if (dynamicOffs.empty() && !staticOffs.empty()) {
650 assert(staticOffs.size() >= 2);
651 auto diff = staticOffs[1] - staticOffs[0];
652 bool allSame = staticOffs.size() > 2;
653 for (auto i = 2u; i < staticOffs.size(); ++i) {
654 if (staticOffs[i] - staticOffs[i - 1] != diff) {
655 allSame = false;
656 break;
657 }
658 }
659 if (allSame) {
660 staticOffs.clear();
661 modified = true;
662 }
663 }
664
665 if (!modified) {
666 return failure();
667 }
668
669 op.setStaticHaloSizes(staticHalos);
670 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
671 op.setStaticShardedDimsOffsets(staticOffs);
672 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
673
674 return success();
675 }
676};
677} // namespace
678
679void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
680 mlir::MLIRContext *context) {
681 results.add<NormalizeSharding>(context);
682}
683
684//===----------------------------------------------------------------------===//
685// Sharding
686//===----------------------------------------------------------------------===//
687
689 if (getGrid() != rhs.getGrid()) {
690 return false;
691 }
692
693 auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
694 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
695 getSplitAxes().begin() + minSize),
696 llvm::make_range(rhs.getSplitAxes().begin(),
697 rhs.getSplitAxes().begin() + minSize))) {
698 return false;
699 }
700
701 return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
702 std::mem_fn(&GridAxesAttr::empty)) &&
703 llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
704 std::mem_fn(&GridAxesAttr::empty));
705}
706
710
712 if (rhs.getStaticShardedDimsOffsets().size() !=
714 !llvm::equal(getStaticShardedDimsOffsets(),
715 rhs.getStaticShardedDimsOffsets())) {
716 return false;
717 }
718 if (rhs.getDynamicShardedDimsOffsets().size() !=
720 !llvm::equal(getDynamicShardedDimsOffsets(),
721 rhs.getDynamicShardedDimsOffsets())) {
722 return false;
723 }
724 return true;
725}
726
728 if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
729 !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
730 return false;
731 }
732 if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
733 !llvm::equal(getDynamicHaloSizes(), rhs.getDynamicHaloSizes())) {
734 return false;
735 }
736 return true;
737}
738
742
743bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
744
748
749bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
750
751llvm::raw_ostream &mlir::shard::operator<<(llvm::raw_ostream &os,
752 const Sharding &sharding) {
753 os << "Sharding<grid=" << sharding.getGrid() << ", split_axes=[";
754 llvm::interleaveComma(sharding.getSplitAxes(), os, [&](GridAxesAttr axes) {
755 os << "[";
756 llvm::interleaveComma(axes.asArrayRef(), os);
757 os << "]";
758 });
759 os << "]";
760 if (!sharding.getStaticHaloSizes().empty()) {
761 os << ", halo_sizes=[";
762 llvm::interleaveComma(sharding.getStaticHaloSizes(), os);
763 os << "]";
764 }
765 if (!sharding.getStaticShardedDimsOffsets().empty()) {
766 os << ", sharded_dims_offsets=[";
767 llvm::interleaveComma(sharding.getStaticShardedDimsOffsets(), os);
768 os << "]";
769 }
770 os << ">";
771 return os;
772}
773
775
777 auto shardingOp = rhs.getDefiningOp<ShardingOp>();
778 assert(shardingOp && "expected sharding op");
779 auto splitAxes = shardingOp.getSplitAxes().getAxes();
780 // If splitAxes are empty, use "empty" constructor.
781 if (splitAxes.empty()) {
782 *this = Sharding(shardingOp.getGridAttr());
783 return;
784 }
785 *this =
786 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
787 shardingOp.getStaticShardedDimsOffsets(),
788 SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
789 SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
790}
791
793 ArrayRef<GridAxesAttr> splitAxes,
794 ArrayRef<int64_t> staticHaloSizes,
795 ArrayRef<int64_t> staticShardedDimsOffsets,
796 ArrayRef<Value> dynamicHaloSizes,
797 ArrayRef<Value> dynamicShardedDimsOffsets) {
798 Sharding res(grid);
799 if (splitAxes.empty()) {
800 return res;
801 }
802
803 res.split_axes.resize(splitAxes.size());
804 for (auto [i, axis] : llvm::enumerate(splitAxes)) {
805 res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
806 }
807
808 auto clone = [](const auto src, auto &dst) {
809 dst.resize(src.size());
810 llvm::copy(src, dst.begin());
811 };
812
813 clone(staticHaloSizes, res.static_halo_sizes);
814 clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
815 clone(dynamicHaloSizes, res.dynamic_halo_sizes);
816 clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
817
818 return res;
819}
820
821//===----------------------------------------------------------------------===//
822// shard.shard_shape
823//===----------------------------------------------------------------------===//
824
825void ShardShapeOp::getAsmResultNames(
826 function_ref<void(Value, StringRef)> setNameFn) {
827 setNameFn(getResult()[0], "shard_shape");
828}
829
830void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
831 ::mlir::OperationState &odsState,
833 ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
834 ::mlir::ValueRange device) {
835 SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
836 build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
837 SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
838}
839
840//===----------------------------------------------------------------------===//
841// shard.shard op
842//===----------------------------------------------------------------------===//
843
844void ShardOp::getAsmResultNames(
845 function_ref<void(Value, StringRef)> setNameFn) {
846 setNameFn(getResult(), "sharding_annotated");
847}
848
849namespace {
850// Determine if the given ShardOp is a duplicate of another ShardOp
851// on the same value. This can happen if constant values are sharded.
852class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
853public:
854 using OpRewritePattern<ShardOp>::OpRewritePattern;
855
856 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
857 // Get the use-list of the value being sharded and check if it has more than
858 // one use.
859 Value value = op.getSrc();
860 if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
861 return failure();
862 }
863
864 // Iterate through the uses of the value to find a duplicate ShardOp.
865 for (auto &use : value.getUses()) {
866 if (use.getOwner() != op.getOperation()) {
867 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
868 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
869 return failure();
870 }
871 // Create a Sharding object for the current and the other ShardOp
872 // If the two are equal replace current op with the other op.
873 Sharding currentSharding(op.getSharding());
874 Sharding otherSharding(otherOp.getSharding());
875 if (currentSharding == otherSharding) {
876 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
877 b.eraseOp(op.getOperation());
878 } else {
879 // use the other sharding as input for op
880 op.getSrcMutable().assign(otherOp.getResult());
881 }
882 return success();
883 }
884 }
885
886 return failure();
887 }
888};
889} // namespace
890
891void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
892 mlir::MLIRContext *context) {
893 results.add<FoldDuplicateShardOp>(context);
894}
895
896//===----------------------------------------------------------------------===//
897// shard.process_multi_index op
898//===----------------------------------------------------------------------===//
899
900LogicalResult
901ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
902 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
903 if (failed(grid)) {
904 return failure();
905 }
906 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
907 return failure();
908 }
909
910 size_t expectedResultsCount =
911 getAxes().empty() ? grid->getRank() : getAxes().size();
912 if (getResult().size() != expectedResultsCount) {
913 return emitError() << "Unexpected number of results " << getResult().size()
914 << ". Expected " << expectedResultsCount << ".";
915 }
916
917 return success();
918}
919
920void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
921 GridOp grid) {
922 build(odsBuilder, odsState,
923 SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
924 grid.getSymName(), ArrayRef<GridAxis>());
925}
926
927void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
928 StringRef grid, ArrayRef<GridAxis> axes) {
929 build(odsBuilder, odsState,
930 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
931 GridAxesAttr::get(odsBuilder.getContext(), axes));
932}
933
934void ProcessMultiIndexOp::getAsmResultNames(
935 function_ref<void(Value, StringRef)> setNameFn) {
936 setNameFn(getResults()[0], "proc_linear_idx");
937}
938
939//===----------------------------------------------------------------------===//
940// shard.process_linear_index op
941//===----------------------------------------------------------------------===//
942
943LogicalResult
944ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
945 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
946 if (failed(grid)) {
947 return failure();
948 }
949 return success();
950}
951
952void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
953 OperationState &odsState, GridOp grid) {
954 build(odsBuilder, odsState, grid.getSymName());
955}
956
957void ProcessLinearIndexOp::getAsmResultNames(
958 function_ref<void(Value, StringRef)> setNameFn) {
959 setNameFn(getResult(), "proc_linear_idx");
960}
961
962//===----------------------------------------------------------------------===//
963// shard.neighbors_linear_indices op
964//===----------------------------------------------------------------------===//
965
966LogicalResult
967NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
968 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
969 if (failed(grid)) {
970 return failure();
971 }
972 return success();
973}
974
975void NeighborsLinearIndicesOp::getAsmResultNames(
976 function_ref<void(Value, StringRef)> setNameFn) {
977 setNameFn(getNeighborDown(), "down_linear_idx");
978 setNameFn(getNeighborUp(), "up_linear_idx");
979}
980
981//===----------------------------------------------------------------------===//
982// collective communication ops
983//===----------------------------------------------------------------------===//
984
985namespace {
986
987template <typename Op>
988struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
989 using OpRewritePattern<Op>::OpRewritePattern;
990 LogicalResult matchAndRewrite(Op op,
991 PatternRewriter &rewriter) const override {
992 auto gridAxes = op.getGridAxes();
993 if (!gridAxes.empty()) {
994 return failure();
995 }
996 if (op.getInput().getType() != op.getResult().getType()) {
997 return failure();
998 }
999
1000 rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
1001 rewriter.eraseOp(op.getOperation());
1002 return success();
1003 }
1004};
1005
1006} // namespace
1007
1008static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
1009 ArrayRef<int64_t> device,
1010 Operation::operand_range deviceDynamic,
1011 ArrayRef<GridAxis> gridAxes,
1012 ArrayRef<int64_t> gridShape) {
1013 if (device.size() != gridAxes.size()) {
1014 return emitError(loc) << "In-group device \"" << deviceName
1015 << "\" has unexpected multi-index size "
1016 << device.size() << ". Expected " << gridAxes.size()
1017 << ".";
1018 }
1019
1020 for (size_t i = 0; i < device.size(); ++i) {
1021 if (ShapedType::isStatic(device[i]) &&
1022 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1023 gridShape[gridAxes[i]] <= device[i]) {
1024 return emitError(loc)
1025 << "Out of bounds coordinate " << i << " for in-group device \""
1026 << deviceName << "\"."
1027 << " Got " << device[i] << ", but expected value in the range [0, "
1028 << (gridShape[gridAxes[i]] - 1) << "].";
1029 }
1030 }
1031 return success();
1032}
1033
1035 int64_t expectedDimSize,
1036 int64_t resultDimSize,
1037 int64_t resultAxis) {
1038 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1039 return emitError(loc) << "Dimension size mismatch for result axis "
1040 << resultAxis << ". Expected "
1041 << (ShapedType::isDynamic(expectedDimSize)
1042 ? Twine("dynamic")
1043 : Twine(expectedDimSize))
1044 << ", but got " << resultDimSize << ".";
1045 }
1046
1047 return success();
1048}
1049
1051 Value operand, Value result, int64_t gatherAxis,
1052 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1053 auto resultRank = cast<ShapedType>(result.getType()).getRank();
1054 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1055 return emitError(result.getLoc())
1056 << "Gather axis " << gatherAxis << " is out of bounds [0, "
1057 << resultRank << ").";
1058 }
1059
1060 ShapedType operandType = cast<ShapedType>(operand.getType());
1061 ShapedType resultType = cast<ShapedType>(result.getType());
1062 auto deviceGroupSize =
1063 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1064 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1065 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1066 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1067 auto expectedResultDimSize =
1068 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1070 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1071 return failure();
1072 }
1073 }
1074 return success();
1075}
1076
1078 Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1079 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1080 ShapedType operandType = cast<ShapedType>(operand.getType());
1081 ShapedType resultType = cast<ShapedType>(result.getType());
1082 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1083 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1085 result.getLoc(), operandType.getDimSize(axis),
1086 resultType.getDimSize(axis), axis))) {
1087 return failure();
1088 }
1089 }
1090 }
1091
1092 if (splitAxis == concatAxis) {
1093 return success();
1094 }
1095
1096 auto deviceGroupSize =
1097 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1098 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1099 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1100 DimensionSize expectedResultConcatDimSize =
1101 operandConcatDimSize * deviceGroupSize;
1102 DimensionSize expectedResultSplitDimSize =
1103 operandSplitDimSize / deviceGroupSize;
1104 if (!expectedResultSplitDimSize.isDynamic() &&
1105 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1106 expectedResultSplitDimSize = DimensionSize::dynamic();
1107 }
1109 result.getLoc(), expectedResultConcatDimSize.value(),
1110 resultType.getDimSize(concatAxis), concatAxis))) {
1111 return failure();
1112 }
1114 result.getLoc(), expectedResultSplitDimSize.value(),
1115 resultType.getDimSize(splitAxis), splitAxis))) {
1116 return failure();
1117 }
1118
1119 return success();
1120}
1121
1123 Value operand, Value result, int64_t tensorAxis,
1124 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1125 ShapedType operandType = cast<ShapedType>(operand.getType());
1126 ShapedType resultType = cast<ShapedType>(result.getType());
1127 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1128 if (axis != tensorAxis) {
1130 result.getLoc(), operandType.getDimSize(axis),
1131 resultType.getDimSize(axis), axis))) {
1132 return failure();
1133 }
1134 }
1135 }
1136
1137 auto deviceGroupSize =
1138 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1139 auto operandScatterDimSize =
1140 DimensionSize(operandType.getDimSize(tensorAxis));
1141 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1142 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1143 return emitError(result.getLoc())
1144 << "Operand dimension size " << int64_t(operandScatterDimSize)
1145 << " is not divisible by collective device group size "
1146 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1147 << ".";
1148 }
1149 DimensionSize expectedResultTensorDimSize =
1150 operandScatterDimSize / deviceGroupSize;
1152 result.getLoc(), expectedResultTensorDimSize.value(),
1153 resultType.getDimSize(tensorAxis), tensorAxis))) {
1154 return failure();
1155 }
1156
1157 return success();
1158}
1159
1160static RankedTensorType sliceResultType(Type operandType, GridOp grid,
1161 ArrayRef<GridAxis> gridAxes,
1162 int64_t sliceAxis) {
1163 RankedTensorType operandRankedTensorType =
1164 cast<RankedTensorType>(operandType);
1165 DimensionSize operandSliceAxisSize =
1166 operandRankedTensorType.getShape()[sliceAxis];
1167 SmallVector<int64_t> resultShape =
1168 llvm::to_vector(operandRankedTensorType.getShape());
1169
1170 resultShape[sliceAxis] =
1171 operandSliceAxisSize /
1172 DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
1173 return operandRankedTensorType.clone(resultShape);
1174}
1175
1176//===----------------------------------------------------------------------===//
1177// shard.all_gather op
1178//===----------------------------------------------------------------------===//
1179
1180LogicalResult
1181AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1182 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1183 if (failed(grid)) {
1184 return failure();
1185 }
1186 auto gatherAxis = getGatherAxis().getSExtValue();
1187 return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1188 gatherAxis, getGridAxes(),
1189 grid.value().getShape());
1190}
1191
1192void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1193 MLIRContext *context) {
1194 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1195}
1196
1197void AllGatherOp::getAsmResultNames(
1198 function_ref<void(Value, StringRef)> setNameFn) {
1199 setNameFn(getResult(), "all_gather");
1200}
1201
1202//===----------------------------------------------------------------------===//
1203// shard.all_reduce op
1204//===----------------------------------------------------------------------===//
1205
1206LogicalResult
1207AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1208 return getGridAndVerifyAxes(*this, symbolTable);
1209}
1210
1211void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1212 MLIRContext *context) {
1213 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1214}
1215
1216void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1217 Value input, StringRef grid,
1218 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1219 build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
1220 reduction);
1221}
1222
1223void AllReduceOp::getAsmResultNames(
1224 function_ref<void(Value, StringRef)> setNameFn) {
1225 setNameFn(getResult(), "all_reduce");
1226}
1227
1228//===----------------------------------------------------------------------===//
1229// shard.all_slice op
1230//===----------------------------------------------------------------------===//
1231
1232LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1233 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1234 if (failed(grid)) {
1235 return failure();
1236 }
1238 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1239 grid.value().getShape());
1240}
1241
1242void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1243 MLIRContext *context) {
1244 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1245}
1246
1247void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1248 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1249 int64_t sliceAxis) {
1250 Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
1251 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1252 sliceAxis);
1253}
1254
1255void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1256 Type resultType, Value input, StringRef grid,
1257 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1258 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1259 APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1260}
1261
1262void AllSliceOp::getAsmResultNames(
1263 function_ref<void(Value, StringRef)> setNameFn) {
1264 setNameFn(getResult(), "all_slice");
1265}
1266
1267//===----------------------------------------------------------------------===//
1268// shard.all_to_all op
1269//===----------------------------------------------------------------------===//
1270
1271LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1272 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1273 if (failed(grid)) {
1274 return failure();
1275 }
1276
1278 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1279 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1280}
1281
1282void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1283 MLIRContext *context) {
1284 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1285}
1286
1287void AllToAllOp::getAsmResultNames(
1288 function_ref<void(Value, StringRef)> setNameFn) {
1289 setNameFn(getResult(), "all_to_all");
1290}
1291
1292//===----------------------------------------------------------------------===//
1293// shard.broadcast op
1294//===----------------------------------------------------------------------===//
1295
1296LogicalResult
1297BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1298 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1299 if (failed(grid)) {
1300 return failure();
1301 }
1302 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1303 getRootDynamic(), getGridAxes(),
1304 grid.value().getShape()))) {
1305 return failure();
1306 }
1307
1308 return success();
1309}
1310
1311void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1312 MLIRContext *context) {
1313 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1314}
1315
1316void BroadcastOp::getAsmResultNames(
1317 function_ref<void(Value, StringRef)> setNameFn) {
1318 setNameFn(getResult(), "broadcast");
1319}
1320
1321//===----------------------------------------------------------------------===//
1322// shard.gather op
1323//===----------------------------------------------------------------------===//
1324
1325LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1326 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1327 if (failed(grid)) {
1328 return failure();
1329 }
1330 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1331 getRootDynamic(), getGridAxes(),
1332 grid.value().getShape()))) {
1333 return failure();
1334 }
1335
1336 auto gatherAxis = getGatherAxis().getSExtValue();
1337 return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1338 getGridAxes(),
1339 grid.value().getShape());
1340}
1341
1342void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1343 MLIRContext *context) {
1344 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1345}
1346
1347void GatherOp::getAsmResultNames(
1348 function_ref<void(Value, StringRef)> setNameFn) {
1349 setNameFn(getResult(), "gather");
1350}
1351
1352//===----------------------------------------------------------------------===//
1353// shard.recv op
1354//===----------------------------------------------------------------------===//
1355
1356LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1357 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1358 if (failed(grid)) {
1359 return failure();
1360 }
1361 if (getSource() &&
1362 failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1363 getSource().value(), getSourceDynamic(),
1364 getGridAxes(), grid.value().getShape()))) {
1365 return failure();
1366 }
1367 return success();
1368}
1369
1370void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1371 MLIRContext *context) {
1372 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1373}
1374
1375void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1376 setNameFn(getResult(), "recv");
1377}
1378
1379//===----------------------------------------------------------------------===//
1380// shard.reduce op
1381//===----------------------------------------------------------------------===//
1382
1383LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1384 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1385 if (failed(grid)) {
1386 return failure();
1387 }
1388 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1389 getRootDynamic(), getGridAxes(),
1390 grid.value().getShape()))) {
1391 return failure();
1392 }
1393
1394 return success();
1395}
1396
1397void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1398 MLIRContext *context) {
1399 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1400}
1401
1402void ReduceOp::getAsmResultNames(
1403 function_ref<void(Value, StringRef)> setNameFn) {
1404 setNameFn(getResult(), "reduce");
1405}
1406
1407//===----------------------------------------------------------------------===//
1408// shard.reduce_scatter op
1409//===----------------------------------------------------------------------===//
1410
1411LogicalResult
1412ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1413 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1414 if (failed(grid)) {
1415 return failure();
1416 }
1417
1419 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1420 grid.value().getShape());
1421}
1422
1423void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1424 MLIRContext *context) {
1425 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1426}
1427
1428void ReduceScatterOp::getAsmResultNames(
1429 function_ref<void(Value, StringRef)> setNameFn) {
1430 setNameFn(getResult(), "reduce_scatter");
1431}
1432
1433//===----------------------------------------------------------------------===//
1434// shard.scatter op
1435//===----------------------------------------------------------------------===//
1436
1437LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1438 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1439 if (failed(grid)) {
1440 return failure();
1441 }
1442 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1443 getRootDynamic(), getGridAxes(),
1444 grid.value().getShape()))) {
1445 return failure();
1446 }
1447
1448 auto scatterAxis = getScatterAxis().getSExtValue();
1449 return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1450 scatterAxis, getGridAxes(),
1451 grid.value().getShape());
1452}
1453
1454void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1455 MLIRContext *context) {
1456 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1457}
1458
1459void ScatterOp::getAsmResultNames(
1460 function_ref<void(Value, StringRef)> setNameFn) {
1461 setNameFn(getResult(), "scatter");
1462}
1463
1464//===----------------------------------------------------------------------===//
1465// shard.send op
1466//===----------------------------------------------------------------------===//
1467
1468LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1469 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1470 if (failed(grid)) {
1471 return failure();
1472 }
1473 if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1474 getDestination(), getDestinationDynamic(),
1475 getGridAxes(), grid.value().getShape()))) {
1476 return failure();
1477 }
1478 return success();
1479}
1480
1481void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1482 MLIRContext *context) {
1483 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1484}
1485
1486void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1487 setNameFn(getResult(), "send");
1488}
1489
1490//===----------------------------------------------------------------------===//
1491// shard.shift op
1492//===----------------------------------------------------------------------===//
1493
1494LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1495 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1496 if (failed(grid)) {
1497 return failure();
1498 }
1499
1500 auto gridAxes = getGridAxes();
1501 auto shiftAxis = getShiftAxis().getZExtValue();
1502 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1503 return emitError() << "Invalid shift axis " << shiftAxis
1504 << ". It must be one of the grouping grid axes.";
1505 }
1506
1507 return success();
1508}
1509
1510void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1511 MLIRContext *context) {
1512 // TODO: remove op when offset is 0 or if it is a rotate with and
1513 // offset % shift_axis_grid_dim_size == 0.
1514}
1515
1516void ShiftOp::getAsmResultNames(
1517 function_ref<void(Value, StringRef)> setNameFn) {
1518 setNameFn(getResult(), "shift");
1519}
1520
1521//===----------------------------------------------------------------------===//
1522// shard.update_halo op
1523//===----------------------------------------------------------------------===//
1524
1525LogicalResult
1526UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1527 auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
1528 if (failed(grid)) {
1529 return failure();
1530 }
1531
1532 return success();
1533}
1534
1535//===----------------------------------------------------------------------===//
1536// TableGen'd op method definitions
1537//===----------------------------------------------------------------------===//
1538
1539#define GET_OP_CLASSES
1540#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1541
1542#define GET_ATTRDEF_CLASSES
1543#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1544
1545#define GET_TYPEDEF_CLASSES
1546#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1547
1548#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition ShardOps.cpp:62
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:148
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
Definition ShardOps.cpp:177
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:200
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
Definition ShardOps.cpp:299
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition ShardOps.h:69
bool operator!=(Value rhs) const
Definition ShardOps.cpp:743
bool equalShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:711
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition ShardOps.cpp:774
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:792
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:688
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:64
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:61
::llvm::StringRef getGrid() const
Definition ShardOps.h:62
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:707
bool operator==(Value rhs) const
Definition ShardOps.cpp:739
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:68
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:65
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:63
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:727
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
shard::Sharding Sharding
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const Sharding &sharding)
Definition ShardOps.cpp:751
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition ShardOps.cpp:338
DenseI16ArrayAttr GridAxesAttr
Definition ShardOps.h:28
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:123
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:116
int16_t GridAxis
Definition ShardOps.h:27
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition ShardOps.cpp:352
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:178
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:156
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition AffineExpr.h:252
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.