MLIR 22.0.0git
ShardOps.cpp
Go to the documentation of this file.
1//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
24#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include <algorithm>
34#include <functional>
35#include <iterator>
36#include <numeric>
37#include <optional>
38#include <utility>
39
40#define DEBUG_TYPE "shard-ops"
41
42using namespace mlir;
43using namespace mlir::shard;
44
45#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
46
47namespace {
48
49struct DimensionSize {
50 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(int64_t val) : val(val) {}
52 int64_t value() const { return val; }
53 operator int64_t() const { return val; }
54 bool isDynamic() const { return ShapedType::isDynamic(val); }
55
56private:
57 int64_t val;
58};
59
60} // namespace
61
62static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
63 if (lhs.isDynamic() || rhs.isDynamic()) {
64 return DimensionSize::dynamic();
65 }
66 return lhs.value() / rhs.value();
67}
68
69static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
70 if (lhs.isDynamic() || rhs.isDynamic()) {
71 return DimensionSize::dynamic();
72 }
73 return lhs.value() * rhs.value();
74}
75
79 ValueRange dynamics, Type type) {
80 SmallVector<Value> values;
81 auto dyn = dynamics.begin();
82 Type i64 = b.getI64Type();
83 if (!type)
84 type = i64;
85 assert((i64 == type || b.getIndexType() == type) &&
86 "expected an i64 or an intex type");
87 for (auto s : statics) {
88 if (s == ShapedType::kDynamic) {
89 values.emplace_back(*(dyn++));
90 } else {
91 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
92 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
93 }
94 }
95 return values;
96}
97
98//===----------------------------------------------------------------------===//
99// Inliner
100//===----------------------------------------------------------------------===//
101
102namespace {
103struct ShardInlinerinterface : public DialectInlinerInterface {
104 using DialectInlinerInterface::DialectInlinerInterface;
105 // Currently no restrictions are encoded for inlining.
106 bool isLegalToInline(Operation *, Operation *, bool) const final {
107 return true;
108 }
109 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
110 return true;
111 }
112 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
113 return true;
114 }
115};
116} // namespace
117
118//===----------------------------------------------------------------------===//
119// Shard dialect
120//===----------------------------------------------------------------------===//
121
122void ShardDialect::initialize() {
123 addOperations<
124#define GET_OP_LIST
125#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
126 >();
127 addAttributes<
128#define GET_ATTRDEF_LIST
129#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
130 >();
131 addTypes<
132#define GET_TYPEDEF_LIST
133#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
134 >();
135 addInterface<ShardInlinerinterface>();
136}
137
138Operation *ShardDialect::materializeConstant(OpBuilder &builder,
139 Attribute value, Type type,
140 Location loc) {
141 return arith::ConstantOp::materialize(builder, value, type, loc);
142}
143
144//===----------------------------------------------------------------------===//
145// Shard utilities
146//===----------------------------------------------------------------------===//
147
148static FailureOr<GridOp> getGridAndVerify(Operation *op,
149 FlatSymbolRefAttr gridSymbol,
150 SymbolTableCollection &symbolTable) {
151 shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
152 if (!grid) {
153 return op->emitError() << "Undefined required grid symbol \""
154 << gridSymbol.getValue() << "\".";
155 }
156
157 return grid;
158}
159
160template <typename It>
161static bool isUnique(It begin, It end) {
162 if (begin == end) {
163 return true;
164 }
165 It next = std::next(begin);
166 if (next == end) {
167 return true;
168 }
169 for (; next != end; ++next, ++begin) {
170 if (*begin == *next) {
171 return false;
172 }
173 }
174 return true;
175}
176
177static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
178 GridOp grid) {
179 SmallVector<GridAxis> sorted = llvm::to_vector(axes);
180 llvm::sort(sorted);
181 if (!isUnique(sorted.begin(), sorted.end())) {
182 return emitError(loc) << "Grid axes contains duplicate elements.";
183 }
184
185 GridAxis rank = grid.getRank();
186 for (auto axis : axes) {
187 if (axis >= rank || axis < 0) {
188 return emitError(loc)
189 << "0-based grid axis index " << axis
190 << " is out of bounds. The referenced grid \"" << grid.getSymName()
191 << "\" is of rank " << rank << ".";
192 }
193 }
194
195 return success();
196}
197
198template <typename Op>
199static FailureOr<GridOp>
201 auto grid =
202 ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
203 if (failed(grid)) {
204 return failure();
205 }
206 if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
207 return failure();
208 }
209 return grid;
210}
211
212template <typename InShape, typename GridShape, typename SplitAxes,
213 typename OutShape>
214static void shardShape(const InShape &inShape, const GridShape &gridShape,
215 const SplitAxes &splitAxes, OutShape &outShape,
216 ArrayRef<int64_t> shardedDimsOffsets = {},
217 ArrayRef<int64_t> haloSizes = {}) {
218 // 0d tensors cannot be sharded and must get replicated
219 if (inShape.empty()) {
220 assert(outShape.empty());
221 return;
222 }
223
224 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225 llvm::adl_begin(outShape));
226
227 if (!shardedDimsOffsets.empty()) {
228 auto isDynShape = ShapedType::isDynamicShape(gridShape);
229 uint64_t pos = 1;
230 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231 if (!innerSplitAxes.empty()) {
232 auto sz = shardedDimsOffsets[pos];
233 bool same = !isDynShape;
234 if (same) {
235 // Find sharded dims in shardedDimsOffsets with same static size on
236 // all devices. Use kDynamic for dimensions with dynamic or
237 // non-uniform offs in shardedDimsOffsets.
238 uint64_t numShards = 0;
239 for (auto i : innerSplitAxes.asArrayRef()) {
240 numShards += gridShape[i];
241 }
242 for (size_t i = 1; i < numShards; ++i) {
243 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
244 sz) {
245 same = false;
246 break;
247 }
248 }
249 pos += numShards + 1;
250 }
251 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
252 }
253 }
254 } else {
255 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
256 outShape[tensorAxis] = shardDimension(
257 inShape[tensorAxis],
258 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
259 }
260
261 if (!haloSizes.empty()) {
262 // add halo sizes if requested
263 int haloAxis = 0;
264 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265 if (ShapedType::isStatic(outShape[tensorAxis]) &&
266 !innerSplitAxes.empty()) {
267 if (haloSizes[haloAxis * 2] >= 0 &&
268 haloSizes[haloAxis * 2 + 1] >= 0) {
269 outShape[tensorAxis] +=
270 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
271 ++haloAxis;
272 } else {
273 outShape[tensorAxis] = ShapedType::kDynamic;
274 }
275 }
276 }
277 }
278 }
279}
280
281ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
282 Sharding sharding) {
283 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
284 SmallVector<Dim> resShapeArr(shape.getShape().size());
285 shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
286 resShapeArr, sharding.getStaticShardedDimsOffsets(),
287 sharding.getStaticHaloSizes());
288 return shape.clone(resShapeArr);
289}
290
291Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
292 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
294 return shardShapedType(rankedTensorType, grid, sharding);
295 }
296 return type;
297}
298
300 Value &operandValue,
301 Operation *operandOp,
302 OpBuilder &builder,
303 ShardOp &newShardOp) {
304 OpBuilder::InsertionGuard insertionGuard(builder);
305 builder.setInsertionPointAfterValue(operandValue);
306 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307 if (shardOp && sharding == shardOp.getSharding() &&
308 !shardOp.getAnnotateForUsers()) {
309 // No need for anything if the correct sharding is already set.
310 if (!newShardOp) {
311 newShardOp = shardOp;
312 }
313 return;
314 }
315
316 if (!newShardOp) {
317 auto shardingOp =
318 ShardingOp::create(builder, operandValue.getLoc(), sharding);
319 newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue,
320 shardingOp,
321 /*annotate_for_users*/ false);
322 }
323 operandValue.replaceUsesWithIf(
324 newShardOp, [operandOp, operandValue](OpOperand &use) {
325 return use.getOwner() == operandOp && use.get() == operandValue;
326 });
327
328 if (!shardOp || shardOp.getAnnotateForUsers()) {
329 return;
330 }
331
332 auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp,
333 newShardOp.getSharding(),
334 /*annotate_for_users*/ true);
335 newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
336}
337
340 OpBuilder &builder) {
341 ShardOp newShardOp;
343 for (auto &use : result.getUses()) {
344 uses.emplace_back(use.get(), use.getOwner());
345 }
346 for (auto &[operandValue, operandOp] : uses) {
347 maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
348 builder, newShardOp);
349 }
350}
351
353 OpOperand &operand,
354 OpBuilder &builder) {
355 OpBuilder::InsertionGuard insertionGuard(builder);
356 Value operandValue = operand.get();
357 Operation *operandSrcOp = operandValue.getDefiningOp();
358 bool isBlockArg = !operandSrcOp;
359 {
360 [[maybe_unused]] auto opType =
361 dyn_cast<mlir::RankedTensorType>(operandValue.getType());
362 assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
363 }
364 if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
365 operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
366 return;
367 }
368
369 Operation *operandOp = operand.getOwner();
370 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
371
372 if (shardOp && sharding == shardOp.getSharding() &&
373 shardOp.getAnnotateForUsers()) {
374 // No need for anything the correct sharding is already set.
375 return;
376 }
377
378 builder.setInsertionPoint(operandOp);
379 auto shardingOp =
380 ShardingOp::create(builder, operand.get().getLoc(), sharding);
381 auto newShardOp =
382 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
383 /*annotate_for_users*/ true);
384 IRRewriter rewriter(builder);
385 rewriter.replaceUsesWithIf(
386 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
387 return use.getOwner() == operandOp && use.get() == operandValue;
388 });
389
390 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
391 // No need for resharding.
392 return;
393 }
394
395 builder.setInsertionPoint(newShardOp);
396 auto newPreceedingShardOp =
397 ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
398 /*annotate_for_users*/ false);
399 rewriter.replaceUsesWithIf(
400 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
401 return use.getOwner() == newShardOp.getOperation();
402 });
403}
404
405//===----------------------------------------------------------------------===//
406// shard.grid op
407//===----------------------------------------------------------------------===//
408
409LogicalResult GridOp::verify() {
410 int64_t rank = getRank();
411
412 if (rank <= 0)
413 return emitOpError("rank of grid is expected to be a positive integer");
414
415 for (int64_t dimSize : getShape()) {
416 if (dimSize < 0 && ShapedType::isStatic(dimSize))
417 return emitOpError("dimension size of a grid is expected to be "
418 "non-negative or dynamic");
419 }
420
421 return success();
422}
423
424//===----------------------------------------------------------------------===//
425// shard.grid_shape op
426//===----------------------------------------------------------------------===//
427
428LogicalResult
429GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
430 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
431 if (failed(grid)) {
432 return failure();
433 }
434 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
435 return failure();
436 }
437
438 size_t expectedResultsCount =
439 getAxes().empty() ? grid->getRank() : getAxes().size();
440 if (getResult().size() != expectedResultsCount) {
441 return emitError() << "Unexpected number of results " << getResult().size()
442 << ". Expected " << expectedResultsCount << ".";
443 }
444
445 return success();
446}
447
448void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
449 GridOp grid) {
450 build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
451}
452
453void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
454 GridOp grid, ArrayRef<GridAxis> axes) {
455 build(odsBuilder, odsState,
456 SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
457 odsBuilder.getIndexType()),
458 grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
459}
460
461void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
462 StringRef grid, ArrayRef<GridAxis> axes) {
463 assert(!axes.empty());
464 build(odsBuilder, odsState,
465 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
466 GridAxesAttr::get(odsBuilder.getContext(), axes));
467}
468
469void GridShapeOp::getAsmResultNames(
470 function_ref<void(Value, StringRef)> setNameFn) {
471 setNameFn(getResults()[0], "grid_shape");
472}
473
474//===----------------------------------------------------------------------===//
475// shard.sharding
476//===----------------------------------------------------------------------===//
477
478void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480 ArrayRef<int64_t> staticHalos,
481 ArrayRef<int64_t> staticOffsets) {
482 return build(
483 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
484 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
485 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
486}
487
488void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
489 llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
490 ArrayRef<int64_t> staticHalos,
491 ArrayRef<int64_t> staticOffsets) {
492 return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
493 GridAxesArrayAttr::get(b.getContext(), splitAxes),
494 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
495 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
496 {});
497}
498
499void ShardingOp::build(
503 ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
504 mlir::SmallVector<int64_t> staticHalos, staticDims;
505 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
506 dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
507 dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
508 return build(
509 b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
510 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
511 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
512}
513
514void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
516
517 build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
518 GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
519 from.getStaticShardedDimsOffsets().empty()
521 : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
523 from.getStaticHaloSizes().empty()
525 : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
526 from.getDynamicHaloSizes());
527}
528
529LogicalResult ShardingOp::verify() {
530 llvm::SmallSet<GridAxis, 4> visitedAxes;
531
532 auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
533 for (GridAxis axis : axesArray) {
534 if (axis < 0)
535 return emitError() << "grid axis is expected to be non-negative";
536 if (!visitedAxes.insert(axis).second)
537 return emitError() << "grid axis duplicated";
538 }
539 return success();
540 };
541
542 for (auto subAxes : getSplitAxes().getAxes()) {
543 ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
544 if (failed(checkGridAxis(subAxesArray)))
545 return failure();
546 }
547
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError("halo sizes and shard offsets are mutually exclusive");
550 }
551
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
556 --numSplitAxes;
557 }
558 }
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() << "halo sizes must be specified for all split axes.";
561 }
562 }
563
564 return success();
565}
566
567void ShardingOp::getAsmResultNames(
568 function_ref<void(Value, StringRef)> setNameFn) {
569 setNameFn(getResult(), "sharding");
570}
571
572LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
573 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
574 if (failed(grid)) {
575 return failure();
576 }
577 if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
578 !getStaticShardedDimsOffsets().empty()) {
579 return emitError() << "sharded dims offsets are not allowed for "
580 "device grids with dynamic shape.";
581 }
582
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto gridShape = grid.value().getShape();
586 assert(ShapedType::isStaticShape(gridShape));
587 uint64_t pos = 0;
588 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (auto i : innerSplitAxes.asArrayRef()) {
592 numShards += gridShape[i];
593 }
594 for (int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() << "sharded dims offsets has wrong size.";
597 }
598 if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
600 return emitError()
601 << "sharded dims offsets must be non-decreasing.";
602 }
603 off = shardedDimsOffsets[pos + i];
604 }
605 }
606 pos += numShards + 1;
607 }
608 }
609 }
610 return success();
611}
612
613namespace {
614// Sharding annotations "halo sizes" and "sharded dims offsets"
615// are a mix of attributes and dynamic values. This canonicalization moves
616// constant values to the respective attribute lists, minimizing the number
617// of values.
618// It also removes sharded_dims_sizes and halos if they are effectively "empty".
619class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
620public:
621 using OpRewritePattern<ShardingOp>::OpRewritePattern;
622
623 LogicalResult matchAndRewrite(ShardingOp op,
624 PatternRewriter &b) const override {
625 auto mixedHalos =
626 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
627 auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
628 op.getDynamicShardedDimsOffsets(), b);
629
630 // No constant operands were folded, just return;
631 bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
632 succeeded(foldDynamicIndexList(mixedOffs, true));
633
634 auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
635 auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
636
637 if (dynamicHalos.empty() && !staticHalos.empty()) {
638 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
639 staticHalos.clear();
640 modified = true;
641 }
642 }
643
644 // Remove sharded dims offsets if they are effectively the default values,
645 // e.g. if they define equi-distance between all neighboring shards.
646 // Requires static-only offsets. Compares the first distance as the
647 // difference between the first two offsets. Only if all consecutive
648 // distances are the same, the offsets are removed.
649 if (dynamicOffs.empty() && !staticOffs.empty()) {
650 assert(staticOffs.size() >= 2);
651 auto diff = staticOffs[1] - staticOffs[0];
652 bool allSame = staticOffs.size() > 2;
653 for (auto i = 2u; i < staticOffs.size(); ++i) {
654 if (staticOffs[i] - staticOffs[i - 1] != diff) {
655 allSame = false;
656 break;
657 }
658 }
659 if (allSame) {
660 staticOffs.clear();
661 modified = true;
662 }
663 }
664
665 if (!modified) {
666 return failure();
667 }
668
669 op.setStaticHaloSizes(staticHalos);
670 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
671 op.setStaticShardedDimsOffsets(staticOffs);
672 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
673
674 return success();
675 }
676};
677} // namespace
678
679void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
680 mlir::MLIRContext *context) {
681 results.add<NormalizeSharding>(context);
682}
683
684//===----------------------------------------------------------------------===//
685// Sharding
686//===----------------------------------------------------------------------===//
687
689 if (getGrid() != rhs.getGrid()) {
690 return false;
691 }
692
693 auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
694 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
695 getSplitAxes().begin() + minSize),
696 llvm::make_range(rhs.getSplitAxes().begin(),
697 rhs.getSplitAxes().begin() + minSize))) {
698 return false;
699 }
700
701 return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
702 std::mem_fn(&GridAxesAttr::empty)) &&
703 llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
704 std::mem_fn(&GridAxesAttr::empty));
705}
706
710
712 if (rhs.getStaticShardedDimsOffsets().size() !=
714 !llvm::equal(getStaticShardedDimsOffsets(),
715 rhs.getStaticShardedDimsOffsets())) {
716 return false;
717 }
718 if (rhs.getDynamicShardedDimsOffsets().size() !=
720 !llvm::equal(getDynamicShardedDimsOffsets(),
721 rhs.getDynamicShardedDimsOffsets())) {
722 return false;
723 }
724 return true;
725}
726
728 if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
729 !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
730 return false;
731 }
732 if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
733 !llvm::equal(getDynamicHaloSizes(), rhs.getDynamicHaloSizes())) {
734 return false;
735 }
736 return true;
737}
738
742
743bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
744
748
749bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
750
752
754 auto shardingOp = rhs.getDefiningOp<ShardingOp>();
755 assert(shardingOp && "expected sharding op");
756 auto splitAxes = shardingOp.getSplitAxes().getAxes();
757 // If splitAxes are empty, use "empty" constructor.
758 if (splitAxes.empty()) {
759 *this = Sharding(shardingOp.getGridAttr());
760 return;
761 }
762 *this =
763 get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
764 shardingOp.getStaticShardedDimsOffsets(),
765 SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
766 SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
767}
768
770 ArrayRef<GridAxesAttr> splitAxes,
771 ArrayRef<int64_t> staticHaloSizes,
772 ArrayRef<int64_t> staticShardedDimsOffsets,
773 ArrayRef<Value> dynamicHaloSizes,
774 ArrayRef<Value> dynamicShardedDimsOffsets) {
775 Sharding res(grid);
776 if (splitAxes.empty()) {
777 return res;
778 }
779
780 res.split_axes.resize(splitAxes.size());
781 for (auto [i, axis] : llvm::enumerate(splitAxes)) {
782 res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
783 }
784
785 auto clone = [](const auto src, auto &dst) {
786 dst.resize(src.size());
787 llvm::copy(src, dst.begin());
788 };
789
790 clone(staticHaloSizes, res.static_halo_sizes);
791 clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
792 clone(dynamicHaloSizes, res.dynamic_halo_sizes);
793 clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
794
795 return res;
796}
797
798//===----------------------------------------------------------------------===//
799// shard.shard_shape
800//===----------------------------------------------------------------------===//
801
802void ShardShapeOp::getAsmResultNames(
803 function_ref<void(Value, StringRef)> setNameFn) {
804 setNameFn(getResult()[0], "shard_shape");
805}
806
807void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
808 ::mlir::OperationState &odsState,
810 ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
811 ::mlir::ValueRange device) {
812 SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
813 build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
814 SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
815}
816
817//===----------------------------------------------------------------------===//
818// shard.shard op
819//===----------------------------------------------------------------------===//
820
821void ShardOp::getAsmResultNames(
822 function_ref<void(Value, StringRef)> setNameFn) {
823 setNameFn(getResult(), "sharding_annotated");
824}
825
826namespace {
827// Determine if the given ShardOp is a duplicate of another ShardOp
828// on the same value. This can happen if constant values are sharded.
829class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
830public:
831 using OpRewritePattern<ShardOp>::OpRewritePattern;
832
833 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
834 // Get the use-list of the value being sharded and check if it has more than
835 // one use.
836 Value value = op.getSrc();
837 if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
838 return failure();
839 }
840
841 // Iterate through the uses of the value to find a duplicate ShardOp.
842 for (auto &use : value.getUses()) {
843 if (use.getOwner() != op.getOperation()) {
844 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
845 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
846 return failure();
847 }
848 // Create a Sharding object for the current and the other ShardOp
849 // If the two are equal replace current op with the other op.
850 Sharding currentSharding(op.getSharding());
851 Sharding otherSharding(otherOp.getSharding());
852 if (currentSharding == otherSharding) {
853 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
854 b.eraseOp(op.getOperation());
855 } else {
856 // use the other sharding as input for op
857 op.getSrcMutable().assign(otherOp.getResult());
858 }
859 return success();
860 }
861 }
862
863 return failure();
864 }
865};
866} // namespace
867
868void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
869 mlir::MLIRContext *context) {
870 results.add<FoldDuplicateShardOp>(context);
871}
872
873//===----------------------------------------------------------------------===//
874// shard.process_multi_index op
875//===----------------------------------------------------------------------===//
876
877LogicalResult
878ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
879 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
880 if (failed(grid)) {
881 return failure();
882 }
883 if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
884 return failure();
885 }
886
887 size_t expectedResultsCount =
888 getAxes().empty() ? grid->getRank() : getAxes().size();
889 if (getResult().size() != expectedResultsCount) {
890 return emitError() << "Unexpected number of results " << getResult().size()
891 << ". Expected " << expectedResultsCount << ".";
892 }
893
894 return success();
895}
896
897void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
898 GridOp grid) {
899 build(odsBuilder, odsState,
900 SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
901 grid.getSymName(), ArrayRef<GridAxis>());
902}
903
904void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
905 StringRef grid, ArrayRef<GridAxis> axes) {
906 build(odsBuilder, odsState,
907 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
908 GridAxesAttr::get(odsBuilder.getContext(), axes));
909}
910
911void ProcessMultiIndexOp::getAsmResultNames(
912 function_ref<void(Value, StringRef)> setNameFn) {
913 setNameFn(getResults()[0], "proc_linear_idx");
914}
915
916//===----------------------------------------------------------------------===//
917// shard.process_linear_index op
918//===----------------------------------------------------------------------===//
919
920LogicalResult
921ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
922 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
923 if (failed(grid)) {
924 return failure();
925 }
926 return success();
927}
928
929void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
930 OperationState &odsState, GridOp grid) {
931 build(odsBuilder, odsState, grid.getSymName());
932}
933
934void ProcessLinearIndexOp::getAsmResultNames(
935 function_ref<void(Value, StringRef)> setNameFn) {
936 setNameFn(getResult(), "proc_linear_idx");
937}
938
939//===----------------------------------------------------------------------===//
940// shard.neighbors_linear_indices op
941//===----------------------------------------------------------------------===//
942
943LogicalResult
944NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
945 auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
946 if (failed(grid)) {
947 return failure();
948 }
949 return success();
950}
951
952void NeighborsLinearIndicesOp::getAsmResultNames(
953 function_ref<void(Value, StringRef)> setNameFn) {
954 setNameFn(getNeighborDown(), "down_linear_idx");
955 setNameFn(getNeighborUp(), "up_linear_idx");
956}
957
958//===----------------------------------------------------------------------===//
959// collective communication ops
960//===----------------------------------------------------------------------===//
961
962namespace {
963
964template <typename Op>
965struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
966 using OpRewritePattern<Op>::OpRewritePattern;
967 LogicalResult matchAndRewrite(Op op,
968 PatternRewriter &rewriter) const override {
969 auto gridAxes = op.getGridAxes();
970 if (!gridAxes.empty()) {
971 return failure();
972 }
973 if (op.getInput().getType() != op.getResult().getType()) {
974 return failure();
975 }
976
977 rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
978 rewriter.eraseOp(op.getOperation());
979 return success();
980 }
981};
982
983} // namespace
984
985static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
986 ArrayRef<int64_t> device,
987 Operation::operand_range deviceDynamic,
988 ArrayRef<GridAxis> gridAxes,
989 ArrayRef<int64_t> gridShape) {
990 if (device.size() != gridAxes.size()) {
991 return emitError(loc) << "In-group device \"" << deviceName
992 << "\" has unexpected multi-index size "
993 << device.size() << ". Expected " << gridAxes.size()
994 << ".";
995 }
996
997 for (size_t i = 0; i < device.size(); ++i) {
998 if (ShapedType::isStatic(device[i]) &&
999 ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1000 gridShape[gridAxes[i]] <= device[i]) {
1001 return emitError(loc)
1002 << "Out of bounds coordinate " << i << " for in-group device \""
1003 << deviceName << "\"."
1004 << " Got " << device[i] << ", but expected value in the range [0, "
1005 << (gridShape[gridAxes[i]] - 1) << "].";
1006 }
1007 }
1008 return success();
1009}
1010
1012 int64_t expectedDimSize,
1013 int64_t resultDimSize,
1014 int64_t resultAxis) {
1015 if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1016 return emitError(loc) << "Dimension size mismatch for result axis "
1017 << resultAxis << ". Expected "
1018 << (ShapedType::isDynamic(expectedDimSize)
1019 ? Twine("dynamic")
1020 : Twine(expectedDimSize))
1021 << ", but got " << resultDimSize << ".";
1022 }
1023
1024 return success();
1025}
1026
1028 Value operand, Value result, int64_t gatherAxis,
1029 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1030 auto resultRank = cast<ShapedType>(result.getType()).getRank();
1031 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1032 return emitError(result.getLoc())
1033 << "Gather axis " << gatherAxis << " is out of bounds [0, "
1034 << resultRank << ").";
1035 }
1036
1037 ShapedType operandType = cast<ShapedType>(operand.getType());
1038 ShapedType resultType = cast<ShapedType>(result.getType());
1039 auto deviceGroupSize =
1040 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1041 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1042 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1043 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1044 auto expectedResultDimSize =
1045 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1047 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1048 return failure();
1049 }
1050 }
1051 return success();
1052}
1053
1055 Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1056 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1057 ShapedType operandType = cast<ShapedType>(operand.getType());
1058 ShapedType resultType = cast<ShapedType>(result.getType());
1059 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1060 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1062 result.getLoc(), operandType.getDimSize(axis),
1063 resultType.getDimSize(axis), axis))) {
1064 return failure();
1065 }
1066 }
1067 }
1068
1069 if (splitAxis == concatAxis) {
1070 return success();
1071 }
1072
1073 auto deviceGroupSize =
1074 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1075 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1076 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1077 DimensionSize expectedResultConcatDimSize =
1078 operandConcatDimSize * deviceGroupSize;
1079 DimensionSize expectedResultSplitDimSize =
1080 operandSplitDimSize / deviceGroupSize;
1081 if (!expectedResultSplitDimSize.isDynamic() &&
1082 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1083 expectedResultSplitDimSize = DimensionSize::dynamic();
1084 }
1086 result.getLoc(), expectedResultConcatDimSize.value(),
1087 resultType.getDimSize(concatAxis), concatAxis))) {
1088 return failure();
1089 }
1091 result.getLoc(), expectedResultSplitDimSize.value(),
1092 resultType.getDimSize(splitAxis), splitAxis))) {
1093 return failure();
1094 }
1095
1096 return success();
1097}
1098
1100 Value operand, Value result, int64_t tensorAxis,
1101 ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1102 ShapedType operandType = cast<ShapedType>(operand.getType());
1103 ShapedType resultType = cast<ShapedType>(result.getType());
1104 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1105 if (axis != tensorAxis) {
1107 result.getLoc(), operandType.getDimSize(axis),
1108 resultType.getDimSize(axis), axis))) {
1109 return failure();
1110 }
1111 }
1112 }
1113
1114 auto deviceGroupSize =
1115 DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1116 auto operandScatterDimSize =
1117 DimensionSize(operandType.getDimSize(tensorAxis));
1118 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1119 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1120 return emitError(result.getLoc())
1121 << "Operand dimension size " << int64_t(operandScatterDimSize)
1122 << " is not divisible by collective device group size "
1123 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1124 << ".";
1125 }
1126 DimensionSize expectedResultTensorDimSize =
1127 operandScatterDimSize / deviceGroupSize;
1129 result.getLoc(), expectedResultTensorDimSize.value(),
1130 resultType.getDimSize(tensorAxis), tensorAxis))) {
1131 return failure();
1132 }
1133
1134 return success();
1135}
1136
1137static RankedTensorType sliceResultType(Type operandType, GridOp grid,
1138 ArrayRef<GridAxis> gridAxes,
1139 int64_t sliceAxis) {
1140 RankedTensorType operandRankedTensorType =
1141 cast<RankedTensorType>(operandType);
1142 DimensionSize operandSliceAxisSize =
1143 operandRankedTensorType.getShape()[sliceAxis];
1144 SmallVector<int64_t> resultShape =
1145 llvm::to_vector(operandRankedTensorType.getShape());
1146
1147 resultShape[sliceAxis] =
1148 operandSliceAxisSize /
1149 DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
1150 return operandRankedTensorType.clone(resultShape);
1151}
1152
1153//===----------------------------------------------------------------------===//
1154// shard.all_gather op
1155//===----------------------------------------------------------------------===//
1156
1157LogicalResult
1158AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1159 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1160 if (failed(grid)) {
1161 return failure();
1162 }
1163 auto gatherAxis = getGatherAxis().getSExtValue();
1164 return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1165 gatherAxis, getGridAxes(),
1166 grid.value().getShape());
1167}
1168
1169void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1170 MLIRContext *context) {
1171 patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1172}
1173
1174void AllGatherOp::getAsmResultNames(
1175 function_ref<void(Value, StringRef)> setNameFn) {
1176 setNameFn(getResult(), "all_gather");
1177}
1178
1179//===----------------------------------------------------------------------===//
1180// shard.all_reduce op
1181//===----------------------------------------------------------------------===//
1182
1183LogicalResult
1184AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1185 return getGridAndVerifyAxes(*this, symbolTable);
1186}
1187
1188void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1189 MLIRContext *context) {
1190 patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1191}
1192
1193void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1194 Value input, StringRef grid,
1195 ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1196 build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
1197 reduction);
1198}
1199
1200void AllReduceOp::getAsmResultNames(
1201 function_ref<void(Value, StringRef)> setNameFn) {
1202 setNameFn(getResult(), "all_reduce");
1203}
1204
1205//===----------------------------------------------------------------------===//
1206// shard.all_slice op
1207//===----------------------------------------------------------------------===//
1208
1209LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1210 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1211 if (failed(grid)) {
1212 return failure();
1213 }
1215 getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1216 grid.value().getShape());
1217}
1218
1219void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1220 MLIRContext *context) {
1221 patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1222}
1223
1224void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1225 Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1226 int64_t sliceAxis) {
1227 Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
1228 build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1229 sliceAxis);
1230}
1231
1232void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1233 Type resultType, Value input, StringRef grid,
1234 ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1235 build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1236 APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1237}
1238
1239void AllSliceOp::getAsmResultNames(
1240 function_ref<void(Value, StringRef)> setNameFn) {
1241 setNameFn(getResult(), "all_slice");
1242}
1243
1244//===----------------------------------------------------------------------===//
1245// shard.all_to_all op
1246//===----------------------------------------------------------------------===//
1247
1248LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1249 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1250 if (failed(grid)) {
1251 return failure();
1252 }
1253
1255 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1256 getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1257}
1258
1259void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1260 MLIRContext *context) {
1261 patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1262}
1263
1264void AllToAllOp::getAsmResultNames(
1265 function_ref<void(Value, StringRef)> setNameFn) {
1266 setNameFn(getResult(), "all_to_all");
1267}
1268
1269//===----------------------------------------------------------------------===//
1270// shard.broadcast op
1271//===----------------------------------------------------------------------===//
1272
1273LogicalResult
1274BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1275 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1276 if (failed(grid)) {
1277 return failure();
1278 }
1279 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1280 getRootDynamic(), getGridAxes(),
1281 grid.value().getShape()))) {
1282 return failure();
1283 }
1284
1285 return success();
1286}
1287
1288void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1289 MLIRContext *context) {
1290 patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1291}
1292
1293void BroadcastOp::getAsmResultNames(
1294 function_ref<void(Value, StringRef)> setNameFn) {
1295 setNameFn(getResult(), "broadcast");
1296}
1297
1298//===----------------------------------------------------------------------===//
1299// shard.gather op
1300//===----------------------------------------------------------------------===//
1301
1302LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1303 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1304 if (failed(grid)) {
1305 return failure();
1306 }
1307 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1308 getRootDynamic(), getGridAxes(),
1309 grid.value().getShape()))) {
1310 return failure();
1311 }
1312
1313 auto gatherAxis = getGatherAxis().getSExtValue();
1314 return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1315 getGridAxes(),
1316 grid.value().getShape());
1317}
1318
1319void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1320 MLIRContext *context) {
1321 patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1322}
1323
1324void GatherOp::getAsmResultNames(
1325 function_ref<void(Value, StringRef)> setNameFn) {
1326 setNameFn(getResult(), "gather");
1327}
1328
1329//===----------------------------------------------------------------------===//
1330// shard.recv op
1331//===----------------------------------------------------------------------===//
1332
1333LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1334 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1335 if (failed(grid)) {
1336 return failure();
1337 }
1338 if (getSource() &&
1339 failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1340 getSource().value(), getSourceDynamic(),
1341 getGridAxes(), grid.value().getShape()))) {
1342 return failure();
1343 }
1344 return success();
1345}
1346
1347void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1348 MLIRContext *context) {
1349 patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1350}
1351
1352void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1353 setNameFn(getResult(), "recv");
1354}
1355
1356//===----------------------------------------------------------------------===//
1357// shard.reduce op
1358//===----------------------------------------------------------------------===//
1359
1360LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1361 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1362 if (failed(grid)) {
1363 return failure();
1364 }
1365 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1366 getRootDynamic(), getGridAxes(),
1367 grid.value().getShape()))) {
1368 return failure();
1369 }
1370
1371 return success();
1372}
1373
1374void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1375 MLIRContext *context) {
1376 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1377}
1378
1379void ReduceOp::getAsmResultNames(
1380 function_ref<void(Value, StringRef)> setNameFn) {
1381 setNameFn(getResult(), "reduce");
1382}
1383
1384//===----------------------------------------------------------------------===//
1385// shard.reduce_scatter op
1386//===----------------------------------------------------------------------===//
1387
1388LogicalResult
1389ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1390 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1391 if (failed(grid)) {
1392 return failure();
1393 }
1394
1396 getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1397 grid.value().getShape());
1398}
1399
1400void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1401 MLIRContext *context) {
1402 patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1403}
1404
1405void ReduceScatterOp::getAsmResultNames(
1406 function_ref<void(Value, StringRef)> setNameFn) {
1407 setNameFn(getResult(), "reduce_scatter");
1408}
1409
1410//===----------------------------------------------------------------------===//
1411// shard.scatter op
1412//===----------------------------------------------------------------------===//
1413
1414LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1415 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1416 if (failed(grid)) {
1417 return failure();
1418 }
1419 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1420 getRootDynamic(), getGridAxes(),
1421 grid.value().getShape()))) {
1422 return failure();
1423 }
1424
1425 auto scatterAxis = getScatterAxis().getSExtValue();
1426 return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1427 scatterAxis, getGridAxes(),
1428 grid.value().getShape());
1429}
1430
1431void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1432 MLIRContext *context) {
1433 patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1434}
1435
1436void ScatterOp::getAsmResultNames(
1437 function_ref<void(Value, StringRef)> setNameFn) {
1438 setNameFn(getResult(), "scatter");
1439}
1440
1441//===----------------------------------------------------------------------===//
1442// shard.send op
1443//===----------------------------------------------------------------------===//
1444
1445LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1446 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1447 if (failed(grid)) {
1448 return failure();
1449 }
1450 if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1451 getDestination(), getDestinationDynamic(),
1452 getGridAxes(), grid.value().getShape()))) {
1453 return failure();
1454 }
1455 return success();
1456}
1457
1458void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1459 MLIRContext *context) {
1460 patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1461}
1462
1463void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1464 setNameFn(getResult(), "send");
1465}
1466
1467//===----------------------------------------------------------------------===//
1468// shard.shift op
1469//===----------------------------------------------------------------------===//
1470
1471LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1472 auto grid = getGridAndVerifyAxes(*this, symbolTable);
1473 if (failed(grid)) {
1474 return failure();
1475 }
1476
1477 auto gridAxes = getGridAxes();
1478 auto shiftAxis = getShiftAxis().getZExtValue();
1479 if (!llvm::is_contained(gridAxes, shiftAxis)) {
1480 return emitError() << "Invalid shift axis " << shiftAxis
1481 << ". It must be one of the grouping grid axes.";
1482 }
1483
1484 return success();
1485}
1486
1487void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1488 MLIRContext *context) {
1489 // TODO: remove op when offset is 0 or if it is a rotate with and
1490 // offset % shift_axis_grid_dim_size == 0.
1491}
1492
1493void ShiftOp::getAsmResultNames(
1494 function_ref<void(Value, StringRef)> setNameFn) {
1495 setNameFn(getResult(), "shift");
1496}
1497
1498//===----------------------------------------------------------------------===//
1499// shard.update_halo op
1500//===----------------------------------------------------------------------===//
1501
1502LogicalResult
1503UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1504 auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
1505 if (failed(grid)) {
1506 return failure();
1507 }
1508
1509 return success();
1510}
1511
1512//===----------------------------------------------------------------------===//
1513// TableGen'd op method definitions
1514//===----------------------------------------------------------------------===//
1515
1516#define GET_OP_CLASSES
1517#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1518
1519#define GET_ATTRDEF_CLASSES
1520#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1521
1522#define GET_TYPEDEF_CLASSES
1523#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1524
1525#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition ShardOps.cpp:62
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:148
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
Definition ShardOps.cpp:177
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition ShardOps.cpp:200
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
Definition ShardOps.cpp:985
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
Definition ShardOps.cpp:299
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Location getLoc()
The source location the operation was defined or derived from.
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition ShardOps.h:68
bool operator!=(Value rhs) const
Definition ShardOps.cpp:743
bool equalShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:711
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition ShardOps.cpp:751
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:769
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:688
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:63
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:60
::llvm::StringRef getGrid() const
Definition ShardOps.h:61
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:707
bool operator==(Value rhs) const
Definition ShardOps.cpp:739
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:67
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:64
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:727
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
shard::Sharding Sharding
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition ShardOps.cpp:338
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:113
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:106
int16_t GridAxis
Definition ShardOps.h:26
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition ShardOps.cpp:352
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:168
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:146
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition AffineExpr.h:252
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.