MLIR  22.0.0git
ShardOps.cpp
Go to the documentation of this file.
1 //===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include <algorithm>
34 #include <functional>
35 #include <iterator>
36 #include <numeric>
37 #include <optional>
38 #include <utility>
39 
40 #define DEBUG_TYPE "shard-ops"
41 
42 using namespace mlir;
43 using namespace mlir::shard;
44 
45 #include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
46 
47 namespace {
48 
49 struct DimensionSize {
50  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
51  DimensionSize(int64_t val) : val(val) {}
52  int64_t value() const { return val; }
53  operator int64_t() const { return val; }
54  bool isDynamic() const { return ShapedType::isDynamic(val); }
55 
56 private:
57  int64_t val;
58 };
59 
60 } // namespace
61 
62 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
63  if (lhs.isDynamic() || rhs.isDynamic()) {
64  return DimensionSize::dynamic();
65  }
66  return lhs.value() / rhs.value();
67 }
68 
69 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
70  if (lhs.isDynamic() || rhs.isDynamic()) {
71  return DimensionSize::dynamic();
72  }
73  return lhs.value() * rhs.value();
74 }
75 
79  ValueRange dynamics, Type type) {
80  SmallVector<Value> values;
81  auto dyn = dynamics.begin();
82  Type i64 = b.getI64Type();
83  if (!type)
84  type = i64;
85  assert((i64 == type || b.getIndexType() == type) &&
86  "expected an i64 or an intex type");
87  for (auto s : statics) {
88  if (s == ShapedType::kDynamic) {
89  values.emplace_back(*(dyn++));
90  } else {
91  TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
92  values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
93  }
94  }
95  return values;
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // Inliner
100 //===----------------------------------------------------------------------===//
101 
102 namespace {
103 struct ShardInlinerinterface : public DialectInlinerInterface {
105  // Currently no restrictions are encoded for inlining.
106  bool isLegalToInline(Operation *, Operation *, bool) const final {
107  return true;
108  }
109  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
110  return true;
111  }
112  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
113  return true;
114  }
115 };
116 } // namespace
117 
118 //===----------------------------------------------------------------------===//
119 // Shard dialect
120 //===----------------------------------------------------------------------===//
121 
122 void ShardDialect::initialize() {
123  addOperations<
124 #define GET_OP_LIST
125 #include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
126  >();
127  addAttributes<
128 #define GET_ATTRDEF_LIST
129 #include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
130  >();
131  addTypes<
132 #define GET_TYPEDEF_LIST
133 #include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
134  >();
135  addInterface<ShardInlinerinterface>();
136 }
137 
139  Attribute value, Type type,
140  Location loc) {
141  return arith::ConstantOp::materialize(builder, value, type, loc);
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // Shard utilities
146 //===----------------------------------------------------------------------===//
147 
148 static FailureOr<GridOp> getGridAndVerify(Operation *op,
149  FlatSymbolRefAttr gridSymbol,
150  SymbolTableCollection &symbolTable) {
151  shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
152  if (!grid) {
153  return op->emitError() << "Undefined required grid symbol \""
154  << gridSymbol.getValue() << "\".";
155  }
156 
157  return grid;
158 }
159 
160 template <typename It>
161 bool isUnique(It begin, It end) {
162  if (begin == end) {
163  return true;
164  }
165  It next = std::next(begin);
166  if (next == end) {
167  return true;
168  }
169  for (; next != end; ++next, ++begin) {
170  if (*begin == *next) {
171  return false;
172  }
173  }
174  return true;
175 }
176 
177 static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
178  GridOp grid) {
179  SmallVector<GridAxis> sorted = llvm::to_vector(axes);
180  llvm::sort(sorted);
181  if (!isUnique(sorted.begin(), sorted.end())) {
182  return emitError(loc) << "Grid axes contains duplicate elements.";
183  }
184 
185  GridAxis rank = grid.getRank();
186  for (auto axis : axes) {
187  if (axis >= rank || axis < 0) {
188  return emitError(loc)
189  << "0-based grid axis index " << axis
190  << " is out of bounds. The referenced grid \"" << grid.getSymName()
191  << "\" is of rank " << rank << ".";
192  }
193  }
194 
195  return success();
196 }
197 
198 template <typename Op>
199 static FailureOr<GridOp>
201  auto grid =
202  ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
203  if (failed(grid)) {
204  return failure();
205  }
206  if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
207  return failure();
208  }
209  return grid;
210 }
211 
212 template <typename InShape, typename GridShape, typename SplitAxes,
213  typename OutShape>
214 static void shardShape(const InShape &inShape, const GridShape &gridShape,
215  const SplitAxes &splitAxes, OutShape &outShape,
216  ArrayRef<int64_t> shardedDimsOffsets = {},
217  ArrayRef<int64_t> haloSizes = {}) {
218  // 0d tensors cannot be sharded and must get replicated
219  if (inShape.empty()) {
220  assert(outShape.empty());
221  return;
222  }
223 
224  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
225  llvm::adl_begin(outShape));
226 
227  if (!shardedDimsOffsets.empty()) {
228  auto isDynShape = ShapedType::isDynamicShape(gridShape);
229  uint64_t pos = 1;
230  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
231  if (!innerSplitAxes.empty()) {
232  auto sz = shardedDimsOffsets[pos];
233  bool same = !isDynShape;
234  if (same) {
235  // Find sharded dims in shardedDimsOffsets with same static size on
236  // all devices. Use kDynamic for dimensions with dynamic or
237  // non-uniform offs in shardedDimsOffsets.
238  uint64_t numShards = 0;
239  for (auto i : innerSplitAxes.asArrayRef()) {
240  numShards += gridShape[i];
241  }
242  for (size_t i = 1; i < numShards; ++i) {
243  if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
244  sz) {
245  same = false;
246  break;
247  }
248  }
249  pos += numShards + 1;
250  }
251  outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
252  }
253  }
254  } else {
255  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
256  outShape[tensorAxis] = shardDimension(
257  inShape[tensorAxis],
258  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
259  }
260 
261  if (!haloSizes.empty()) {
262  // add halo sizes if requested
263  int haloAxis = 0;
264  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
265  if (ShapedType::isStatic(outShape[tensorAxis]) &&
266  !innerSplitAxes.empty()) {
267  if (haloSizes[haloAxis * 2] >= 0 &&
268  haloSizes[haloAxis * 2 + 1] >= 0) {
269  outShape[tensorAxis] +=
270  haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
271  ++haloAxis;
272  } else {
273  outShape[tensorAxis] = ShapedType::kDynamic;
274  }
275  }
276  }
277  }
278  }
279 }
280 
281 ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
282  Sharding sharding) {
283  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
284  SmallVector<Dim> resShapeArr(shape.getShape().size());
285  shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
286  resShapeArr, sharding.getStaticShardedDimsOffsets(),
287  sharding.getStaticHaloSizes());
288  return shape.clone(resShapeArr);
289 }
290 
291 Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
292  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
293  if (rankedTensorType && !rankedTensorType.getShape().empty()) {
294  return shardShapedType(rankedTensorType, grid, sharding);
295  }
296  return type;
297 }
298 
300  Value &operandValue,
301  Operation *operandOp,
302  OpBuilder &builder,
303  ShardOp &newShardOp) {
304  OpBuilder::InsertionGuard insertionGuard(builder);
305  builder.setInsertionPointAfterValue(operandValue);
306  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
307  if (shardOp && sharding == shardOp.getSharding() &&
308  !shardOp.getAnnotateForUsers()) {
309  // No need for anything if the correct sharding is already set.
310  if (!newShardOp) {
311  newShardOp = shardOp;
312  }
313  return;
314  }
315 
316  if (!newShardOp) {
317  auto shardingOp =
318  ShardingOp::create(builder, operandValue.getLoc(), sharding);
319  newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue,
320  shardingOp,
321  /*annotate_for_users*/ false);
322  }
323  operandValue.replaceUsesWithIf(
324  newShardOp, [operandOp, operandValue](OpOperand &use) {
325  return use.getOwner() == operandOp && use.get() == operandValue;
326  });
327 
328  if (!shardOp || shardOp.getAnnotateForUsers()) {
329  return;
330  }
331 
332  auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp,
333  newShardOp.getSharding(),
334  /*annotate_for_users*/ true);
335  newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
336 }
337 
339  OpResult result,
340  OpBuilder &builder) {
341  ShardOp newShardOp;
343  for (auto &use : result.getUses()) {
344  uses.emplace_back(use.get(), use.getOwner());
345  }
346  for (auto &[operandValue, operandOp] : uses) {
347  maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
348  builder, newShardOp);
349  }
350 }
351 
353  OpOperand &operand,
354  OpBuilder &builder) {
355  OpBuilder::InsertionGuard insertionGuard(builder);
356  Value operandValue = operand.get();
357  Operation *operandSrcOp = operandValue.getDefiningOp();
358  bool isBlockArg = !operandSrcOp;
359  {
360  [[maybe_unused]] auto opType =
361  dyn_cast<mlir::RankedTensorType>(operandValue.getType());
362  assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
363  }
364  if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
365  operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
366  return;
367  }
368 
369  Operation *operandOp = operand.getOwner();
370  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
371 
372  if (shardOp && sharding == shardOp.getSharding() &&
373  shardOp.getAnnotateForUsers()) {
374  // No need for anything the correct sharding is already set.
375  return;
376  }
377 
378  builder.setInsertionPoint(operandOp);
379  auto shardingOp =
380  ShardingOp::create(builder, operand.get().getLoc(), sharding);
381  auto newShardOp =
382  ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
383  /*annotate_for_users*/ true);
384  IRRewriter rewriter(builder);
385  rewriter.replaceUsesWithIf(
386  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
387  return use.getOwner() == operandOp && use.get() == operandValue;
388  });
389 
390  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
391  // No need for resharding.
392  return;
393  }
394 
395  builder.setInsertionPoint(newShardOp);
396  auto newPreceedingShardOp =
397  ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
398  /*annotate_for_users*/ false);
399  rewriter.replaceUsesWithIf(
400  newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
401  return use.getOwner() == newShardOp.getOperation();
402  });
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // shard.grid op
407 //===----------------------------------------------------------------------===//
408 
409 LogicalResult GridOp::verify() {
410  int64_t rank = getRank();
411 
412  if (rank <= 0)
413  return emitOpError("rank of grid is expected to be a positive integer");
414 
415  for (int64_t dimSize : getShape()) {
416  if (dimSize < 0 && ShapedType::isStatic(dimSize))
417  return emitOpError("dimension size of a grid is expected to be "
418  "non-negative or dynamic");
419  }
420 
421  return success();
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // shard.grid_shape op
426 //===----------------------------------------------------------------------===//
427 
428 LogicalResult
429 GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
430  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
431  if (failed(grid)) {
432  return failure();
433  }
434  if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
435  return failure();
436  }
437 
438  size_t expectedResultsCount =
439  getAxes().empty() ? grid->getRank() : getAxes().size();
440  if (getResult().size() != expectedResultsCount) {
441  return emitError() << "Unexpected number of results " << getResult().size()
442  << ". Expected " << expectedResultsCount << ".";
443  }
444 
445  return success();
446 }
447 
448 void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
449  GridOp grid) {
450  build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
451 }
452 
453 void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
454  GridOp grid, ArrayRef<GridAxis> axes) {
455  build(odsBuilder, odsState,
456  SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
457  odsBuilder.getIndexType()),
458  grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
459 }
460 
461 void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
462  StringRef grid, ArrayRef<GridAxis> axes) {
463  assert(!axes.empty());
464  build(odsBuilder, odsState,
465  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
466  GridAxesAttr::get(odsBuilder.getContext(), axes));
467 }
468 
469 void GridShapeOp::getAsmResultNames(
470  function_ref<void(Value, StringRef)> setNameFn) {
471  setNameFn(getResults()[0], "grid_shape");
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // shard.sharding
476 //===----------------------------------------------------------------------===//
477 
478 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
479  FlatSymbolRefAttr grid,
480  ArrayRef<GridAxesAttr> split_axes,
481  ArrayRef<int64_t> static_halos,
482  ArrayRef<int64_t> static_offsets) {
483  return build(
484  b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
485  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
486  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
487 }
488 
489 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
490  llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
491  ArrayRef<int64_t> static_halos,
492  ArrayRef<int64_t> static_offsets) {
493  return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
494  GridAxesArrayAttr::get(b.getContext(), split_axes),
495  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
496  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
497  {});
498 }
499 
500 void ShardingOp::build(
501  ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
502  FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
504  ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
505  mlir::SmallVector<int64_t> staticHalos, staticDims;
506  mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
507  dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
508  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
509  return build(
510  b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
511  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
512  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
513 }
514 
515 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
516  mlir::shard::Sharding from) {
517 
518  build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
520  from.getStaticShardedDimsOffsets().empty()
524  from.getStaticHaloSizes().empty()
527  from.getDynamicHaloSizes());
528 }
529 
530 LogicalResult ShardingOp::verify() {
531  llvm::SmallSet<GridAxis, 4> visitedAxes;
532 
533  auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
534  for (GridAxis axis : axesArray) {
535  if (axis < 0)
536  return emitError() << "grid axis is expected to be non-negative";
537  if (!visitedAxes.insert(axis).second)
538  return emitError() << "grid axis duplicated";
539  }
540  return success();
541  };
542 
543  for (auto subAxes : getSplitAxes().getAxes()) {
544  ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
545  if (failed(checkGridAxis(subAxesArray)))
546  return failure();
547  }
548 
549  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550  return emitOpError("halo sizes and shard offsets are mutually exclusive");
551  }
552 
553  if (!getStaticHaloSizes().empty()) {
554  auto numSplitAxes = getSplitAxes().getAxes().size();
555  for (auto splitAxis : getSplitAxes().getAxes()) {
556  if (splitAxis.empty()) {
557  --numSplitAxes;
558  }
559  }
560  if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561  return emitError() << "halo sizes must be specified for all split axes.";
562  }
563  }
564 
565  return success();
566 }
567 
568 void ShardingOp::getAsmResultNames(
569  function_ref<void(Value, StringRef)> setNameFn) {
570  setNameFn(getResult(), "sharding");
571 }
572 
573 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
574  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
575  if (failed(grid)) {
576  return failure();
577  }
578  if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
579  getStaticShardedDimsOffsets().size() > 0) {
580  return emitError() << "sharded dims offsets are not allowed for "
581  "device grids with dynamic shape.";
582  }
583 
584  auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585  if (!shardedDimsOffsets.empty()) {
586  auto gridShape = grid.value().getShape();
587  assert(ShapedType::isStaticShape(gridShape));
588  uint64_t pos = 0;
589  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
590  if (!innerSplitAxes.empty()) {
591  int64_t numShards = 0, off = 0;
592  for (auto i : innerSplitAxes.asArrayRef()) {
593  numShards += gridShape[i];
594  }
595  for (int64_t i = 0; i <= numShards; ++i) {
596  if (shardedDimsOffsets.size() <= pos + i) {
597  return emitError() << "sharded dims offsets has wrong size.";
598  }
599  if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
600  if (shardedDimsOffsets[pos + i] < off) {
601  return emitError()
602  << "sharded dims offsets must be non-decreasing.";
603  }
604  off = shardedDimsOffsets[pos + i];
605  }
606  }
607  pos += numShards + 1;
608  }
609  }
610  }
611  return success();
612 }
613 
614 namespace {
615 // Sharding annotations "halo sizes" and "sharded dims offsets"
616 // are a mix of attributes and dynamic values. This canonicalization moves
617 // constant values to the respective attribute lists, minimizing the number
618 // of values.
619 // It also removes sharded_dims_sizes and halos if they are effectively "empty".
620 class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
621 public:
623 
624  LogicalResult matchAndRewrite(ShardingOp op,
625  PatternRewriter &b) const override {
626  auto mixedHalos =
627  getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
628  auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
629  op.getDynamicShardedDimsOffsets(), b);
630 
631  // No constant operands were folded, just return;
632  bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
633  succeeded(foldDynamicIndexList(mixedOffs, true));
634 
635  auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
636  auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
637 
638  if (dynamicHalos.empty() && !staticHalos.empty()) {
639  if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
640  staticHalos.clear();
641  modified = true;
642  }
643  }
644 
645  // Remove sharded dims offsets if they are effectively the default values,
646  // e.g. if they define equi-distance between all neighboring shards.
647  // Requires static-only offsets. Compares the first distance as the
648  // difference between the first two offsets. Only if all consecutive
649  // distances are the same, the offsets are removed.
650  if (dynamicOffs.empty() && !staticOffs.empty()) {
651  assert(staticOffs.size() >= 2);
652  auto diff = staticOffs[1] - staticOffs[0];
653  bool all_same = staticOffs.size() > 2;
654  for (auto i = 2u; i < staticOffs.size(); ++i) {
655  if (staticOffs[i] - staticOffs[i - 1] != diff) {
656  all_same = false;
657  break;
658  }
659  }
660  if (all_same) {
661  staticOffs.clear();
662  modified = true;
663  }
664  }
665 
666  if (!modified) {
667  return failure();
668  }
669 
670  op.setStaticHaloSizes(staticHalos);
671  op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672  op.setStaticShardedDimsOffsets(staticOffs);
673  op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
674 
675  return success();
676  }
677 };
678 } // namespace
679 
680 void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
681  mlir::MLIRContext *context) {
682  results.add<NormalizeSharding>(context);
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // Sharding
687 //===----------------------------------------------------------------------===//
688 
689 bool Sharding::equalSplitAxes(const Sharding &rhs) const {
690  if (getGrid() != rhs.getGrid()) {
691  return false;
692  }
693 
694  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
695  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
696  getSplitAxes().begin() + minSize),
697  llvm::make_range(rhs.getSplitAxes().begin(),
698  rhs.getSplitAxes().begin() + minSize))) {
699  return false;
700  }
701 
702  return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
703  std::mem_fn(&GridAxesAttr::empty)) &&
704  llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
705  std::mem_fn(&GridAxesAttr::empty));
706 }
707 
709  return equalShardSizes(rhs) && equalHaloSizes(rhs);
710 }
711 
712 bool Sharding::equalShardSizes(const Sharding &rhs) const {
713  if (rhs.getStaticShardedDimsOffsets().size() !=
714  getStaticShardedDimsOffsets().size() ||
715  !llvm::equal(getStaticShardedDimsOffsets(),
717  return false;
718  }
719  if (rhs.getDynamicShardedDimsOffsets().size() !=
720  getDynamicShardedDimsOffsets().size() ||
721  !llvm::equal(getDynamicShardedDimsOffsets(),
723  return false;
724  }
725  return true;
726 }
727 
728 bool Sharding::equalHaloSizes(const Sharding &rhs) const {
729  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
730  !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
731  return false;
732  }
733  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
734  !llvm::equal(getDynamicHaloSizes(), rhs.getDynamicHaloSizes())) {
735  return false;
736  }
737  return true;
738 }
739 
740 bool Sharding::operator==(Value rhs) const {
741  return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
742 }
743 
744 bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
745 
746 bool Sharding::operator==(const Sharding &rhs) const {
747  return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
748 }
749 
750 bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
751 
753 
755  auto shardingOp = rhs.getDefiningOp<ShardingOp>();
756  assert(shardingOp && "expected sharding op");
757  auto splitAxes = shardingOp.getSplitAxes().getAxes();
758  // If splitAxes are empty, use "empty" constructor.
759  if (splitAxes.empty()) {
760  *this = Sharding(shardingOp.getGridAttr());
761  return;
762  }
763  *this =
764  get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
765  shardingOp.getStaticShardedDimsOffsets(),
766  SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
767  SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
768 }
769 
771  ArrayRef<GridAxesAttr> split_axes_,
772  ArrayRef<int64_t> static_halo_sizes_,
773  ArrayRef<int64_t> static_sharded_dims_offsets_,
774  ArrayRef<Value> dynamic_halo_sizes_,
775  ArrayRef<Value> dynamic_sharded_dims_offsets_) {
776  Sharding res(grid_);
777  if (split_axes_.empty()) {
778  return res;
779  }
780 
781  res.split_axes.resize(split_axes_.size());
782  for (auto [i, axis] : llvm::enumerate(split_axes_)) {
783  res.split_axes[i] =
784  GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
785  }
786 
787  auto clone = [](const auto src, auto &dst) {
788  dst.resize(src.size());
789  llvm::copy(src, dst.begin());
790  };
791 
792  clone(static_halo_sizes_, res.static_halo_sizes);
793  clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
794  clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
795  clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
796 
797  return res;
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // shard.shard_shape
802 //===----------------------------------------------------------------------===//
803 
804 void ShardShapeOp::getAsmResultNames(
805  function_ref<void(Value, StringRef)> setNameFn) {
806  setNameFn(getResult()[0], "shard_shape");
807 }
808 
809 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
810  ::mlir::OperationState &odsState,
812  ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
813  ::mlir::ValueRange device) {
814  SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
815  build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
816  SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
817 }
818 
819 //===----------------------------------------------------------------------===//
820 // shard.shard op
821 //===----------------------------------------------------------------------===//
822 
823 void ShardOp::getAsmResultNames(
824  function_ref<void(Value, StringRef)> setNameFn) {
825  setNameFn(getResult(), "sharding_annotated");
826 }
827 
828 namespace {
829 // Determine if the given ShardOp is a duplicate of another ShardOp
830 // on the same value. This can happen if constant values are sharded.
831 class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
832 public:
834 
835  LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
836  // Get the use-list of the value being sharded and check if it has more than
837  // one use.
838  Value value = op.getSrc();
839  if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
840  return failure();
841  }
842 
843  // Iterate through the uses of the value to find a duplicate ShardOp.
844  for (auto &use : value.getUses()) {
845  if (use.getOwner() != op.getOperation()) {
846  auto otherOp = dyn_cast<ShardOp>(use.getOwner());
847  if (!otherOp || !otherOp->isBeforeInBlock(op)) {
848  return failure();
849  }
850  // Create a Sharding object for the current and the other ShardOp
851  // If the two are equal replace current op with the other op.
852  Sharding currentSharding(op.getSharding());
853  Sharding otherSharding(otherOp.getSharding());
854  if (currentSharding == otherSharding) {
855  b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
856  b.eraseOp(op.getOperation());
857  } else {
858  // use the other sharding as input for op
859  op.getSrcMutable().assign(otherOp.getResult());
860  }
861  return success();
862  }
863  }
864 
865  return failure();
866  }
867 };
868 } // namespace
869 
870 void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
871  mlir::MLIRContext *context) {
872  results.add<FoldDuplicateShardOp>(context);
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // shard.process_multi_index op
877 //===----------------------------------------------------------------------===//
878 
879 LogicalResult
880 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
881  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
882  if (failed(grid)) {
883  return failure();
884  }
885  if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
886  return failure();
887  }
888 
889  size_t expectedResultsCount =
890  getAxes().empty() ? grid->getRank() : getAxes().size();
891  if (getResult().size() != expectedResultsCount) {
892  return emitError() << "Unexpected number of results " << getResult().size()
893  << ". Expected " << expectedResultsCount << ".";
894  }
895 
896  return success();
897 }
898 
899 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
900  GridOp grid) {
901  build(odsBuilder, odsState,
902  SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
903  grid.getSymName(), ArrayRef<GridAxis>());
904 }
905 
906 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
907  StringRef grid, ArrayRef<GridAxis> axes) {
908  build(odsBuilder, odsState,
909  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
910  GridAxesAttr::get(odsBuilder.getContext(), axes));
911 }
912 
913 void ProcessMultiIndexOp::getAsmResultNames(
914  function_ref<void(Value, StringRef)> setNameFn) {
915  setNameFn(getResults()[0], "proc_linear_idx");
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // shard.process_linear_index op
920 //===----------------------------------------------------------------------===//
921 
922 LogicalResult
923 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
924  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
925  if (failed(grid)) {
926  return failure();
927  }
928  return success();
929 }
930 
931 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
932  OperationState &odsState, GridOp grid) {
933  build(odsBuilder, odsState, grid.getSymName());
934 }
935 
936 void ProcessLinearIndexOp::getAsmResultNames(
937  function_ref<void(Value, StringRef)> setNameFn) {
938  setNameFn(getResult(), "proc_linear_idx");
939 }
940 
941 //===----------------------------------------------------------------------===//
942 // shard.neighbors_linear_indices op
943 //===----------------------------------------------------------------------===//
944 
945 LogicalResult
946 NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
947  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
948  if (failed(grid)) {
949  return failure();
950  }
951  return success();
952 }
953 
954 void NeighborsLinearIndicesOp::getAsmResultNames(
955  function_ref<void(Value, StringRef)> setNameFn) {
956  setNameFn(getNeighborDown(), "down_linear_idx");
957  setNameFn(getNeighborUp(), "up_linear_idx");
958 }
959 
960 //===----------------------------------------------------------------------===//
961 // collective communication ops
962 //===----------------------------------------------------------------------===//
963 
964 namespace {
965 
966 template <typename Op>
967 struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
969  LogicalResult matchAndRewrite(Op op,
970  PatternRewriter &rewriter) const override {
971  auto gridAxes = op.getGridAxes();
972  if (!gridAxes.empty()) {
973  return failure();
974  }
975  if (op.getInput().getType() != op.getResult().getType()) {
976  return failure();
977  }
978 
979  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
980  rewriter.eraseOp(op.getOperation());
981  return success();
982  }
983 };
984 
985 } // namespace
986 
987 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
988  ArrayRef<int64_t> device,
989  Operation::operand_range deviceDynamic,
990  ArrayRef<GridAxis> gridAxes,
991  ArrayRef<int64_t> gridShape) {
992  if (device.size() != gridAxes.size()) {
993  return emitError(loc) << "In-group device \"" << deviceName
994  << "\" has unexpected multi-index size "
995  << device.size() << ". Expected " << gridAxes.size()
996  << ".";
997  }
998 
999  for (size_t i = 0; i < device.size(); ++i) {
1000  if (ShapedType::isStatic(device[i]) &&
1001  ShapedType::isStatic(gridShape[gridAxes[i]]) &&
1002  gridShape[gridAxes[i]] <= device[i]) {
1003  return emitError(loc)
1004  << "Out of bounds coordinate " << i << " for in-group device \""
1005  << deviceName << "\"."
1006  << " Got " << device[i] << ", but expected value in the range [0, "
1007  << (gridShape[gridAxes[i]] - 1) << "].";
1008  }
1009  }
1010  return success();
1011 }
1012 
1013 template <typename It>
1014 static auto product(It begin, It end) {
1015  using ElementType = std::decay_t<decltype(*begin)>;
1016  return std::accumulate(begin, end, static_cast<ElementType>(1),
1017  std::multiplies<ElementType>());
1018 }
1019 
1020 template <typename R>
1021 static auto product(R &&range) {
1022  return product(adl_begin(range), adl_end(range));
1023 }
1024 
1025 static LogicalResult verifyDimensionCompatibility(Location loc,
1026  int64_t expectedDimSize,
1027  int64_t resultDimSize,
1028  int64_t resultAxis) {
1029  if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
1030  return emitError(loc) << "Dimension size mismatch for result axis "
1031  << resultAxis << ". Expected "
1032  << (ShapedType::isDynamic(expectedDimSize)
1033  ? Twine("dynamic")
1034  : Twine(expectedDimSize))
1035  << ", but got " << resultDimSize << ".";
1036  }
1037 
1038  return success();
1039 }
1040 
1042  Value operand, Value result, int64_t gatherAxis,
1043  ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1044  auto resultRank = cast<ShapedType>(result.getType()).getRank();
1045  if (gatherAxis < 0 || gatherAxis >= resultRank) {
1046  return emitError(result.getLoc())
1047  << "Gather axis " << gatherAxis << " is out of bounds [0, "
1048  << resultRank << ").";
1049  }
1050 
1051  ShapedType operandType = cast<ShapedType>(operand.getType());
1052  ShapedType resultType = cast<ShapedType>(result.getType());
1053  auto deviceGroupSize =
1054  DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1055  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1056  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1057  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1058  auto expectedResultDimSize =
1059  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1061  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1062  return failure();
1063  }
1064  }
1065  return success();
1066 }
1067 
1069  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1070  ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1071  ShapedType operandType = cast<ShapedType>(operand.getType());
1072  ShapedType resultType = cast<ShapedType>(result.getType());
1073  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1074  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1076  result.getLoc(), operandType.getDimSize(axis),
1077  resultType.getDimSize(axis), axis))) {
1078  return failure();
1079  }
1080  }
1081  }
1082 
1083  if (splitAxis == concatAxis) {
1084  return success();
1085  }
1086 
1087  auto deviceGroupSize =
1088  DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1089  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1090  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1091  DimensionSize expectedResultConcatDimSize =
1092  operandConcatDimSize * deviceGroupSize;
1093  DimensionSize expectedResultSplitDimSize =
1094  operandSplitDimSize / deviceGroupSize;
1095  if (!expectedResultSplitDimSize.isDynamic() &&
1096  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1097  expectedResultSplitDimSize = DimensionSize::dynamic();
1098  }
1100  result.getLoc(), expectedResultConcatDimSize.value(),
1101  resultType.getDimSize(concatAxis), concatAxis))) {
1102  return failure();
1103  }
1105  result.getLoc(), expectedResultSplitDimSize.value(),
1106  resultType.getDimSize(splitAxis), splitAxis))) {
1107  return failure();
1108  }
1109 
1110  return success();
1111 }
1112 
1114  Value operand, Value result, int64_t tensorAxis,
1115  ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
1116  ShapedType operandType = cast<ShapedType>(operand.getType());
1117  ShapedType resultType = cast<ShapedType>(result.getType());
1118  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1119  if (axis != tensorAxis) {
1121  result.getLoc(), operandType.getDimSize(axis),
1122  resultType.getDimSize(axis), axis))) {
1123  return failure();
1124  }
1125  }
1126  }
1127 
1128  auto deviceGroupSize =
1129  DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
1130  auto operandScatterDimSize =
1131  DimensionSize(operandType.getDimSize(tensorAxis));
1132  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1133  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1134  return emitError(result.getLoc())
1135  << "Operand dimension size " << int64_t(operandScatterDimSize)
1136  << " is not divisible by collective device group size "
1137  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1138  << ".";
1139  }
1140  DimensionSize expectedResultTensorDimSize =
1141  operandScatterDimSize / deviceGroupSize;
1143  result.getLoc(), expectedResultTensorDimSize.value(),
1144  resultType.getDimSize(tensorAxis), tensorAxis))) {
1145  return failure();
1146  }
1147 
1148  return success();
1149 }
1150 
1151 static RankedTensorType sliceResultType(Type operandType, GridOp grid,
1152  ArrayRef<GridAxis> gridAxes,
1153  int64_t sliceAxis) {
1154  RankedTensorType operandRankedTensorType =
1155  cast<RankedTensorType>(operandType);
1156  DimensionSize operandSliceAxisSize =
1157  operandRankedTensorType.getShape()[sliceAxis];
1158  SmallVector<int64_t> resultShape =
1159  llvm::to_vector(operandRankedTensorType.getShape());
1160 
1161  resultShape[sliceAxis] =
1162  operandSliceAxisSize /
1163  DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
1164  return operandRankedTensorType.clone(resultShape);
1165 }
1166 
1167 //===----------------------------------------------------------------------===//
1168 // shard.all_gather op
1169 //===----------------------------------------------------------------------===//
1170 
1171 LogicalResult
1172 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1173  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1174  if (failed(grid)) {
1175  return failure();
1176  }
1177  auto gatherAxis = getGatherAxis().getSExtValue();
1178  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1179  gatherAxis, getGridAxes(),
1180  grid.value().getShape());
1181 }
1182 
1183 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1184  MLIRContext *context) {
1185  patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
1186 }
1187 
1188 void AllGatherOp::getAsmResultNames(
1189  function_ref<void(Value, StringRef)> setNameFn) {
1190  setNameFn(getResult(), "all_gather");
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // shard.all_reduce op
1195 //===----------------------------------------------------------------------===//
1196 
1197 LogicalResult
1198 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1199  return getGridAndVerifyAxes(*this, symbolTable);
1200 }
1201 
1202 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1203  MLIRContext *context) {
1204  patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
1205 }
1206 
1207 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1208  Value input, StringRef grid,
1209  ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
1210  build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
1211  reduction);
1212 }
1213 
1214 void AllReduceOp::getAsmResultNames(
1215  function_ref<void(Value, StringRef)> setNameFn) {
1216  setNameFn(getResult(), "all_reduce");
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // shard.all_slice op
1221 //===----------------------------------------------------------------------===//
1222 
1223 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1224  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1225  if (failed(grid)) {
1226  return failure();
1227  }
1229  getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
1230  grid.value().getShape());
1231 }
1232 
1233 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1234  MLIRContext *context) {
1235  patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
1236 }
1237 
1238 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1239  Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
1240  int64_t sliceAxis) {
1241  Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
1242  build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
1243  sliceAxis);
1244 }
1245 
1246 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1247  Type resultType, Value input, StringRef grid,
1248  ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
1249  build(odsBuilder, odsState, resultType, grid, gridAxes, input,
1250  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1251 }
1252 
1253 void AllSliceOp::getAsmResultNames(
1254  function_ref<void(Value, StringRef)> setNameFn) {
1255  setNameFn(getResult(), "all_slice");
1256 }
1257 
1258 //===----------------------------------------------------------------------===//
1259 // shard.all_to_all op
1260 //===----------------------------------------------------------------------===//
1261 
1262 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1263  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1264  if (failed(grid)) {
1265  return failure();
1266  }
1267 
1269  getOperand(), getResult(), getSplitAxis().getSExtValue(),
1270  getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
1271 }
1272 
1273 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1274  MLIRContext *context) {
1275  patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
1276 }
1277 
1278 void AllToAllOp::getAsmResultNames(
1279  function_ref<void(Value, StringRef)> setNameFn) {
1280  setNameFn(getResult(), "all_to_all");
1281 }
1282 
1283 //===----------------------------------------------------------------------===//
1284 // shard.broadcast op
1285 //===----------------------------------------------------------------------===//
1286 
1287 LogicalResult
1288 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1289  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1290  if (failed(grid)) {
1291  return failure();
1292  }
1293  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1294  getRootDynamic(), getGridAxes(),
1295  grid.value().getShape()))) {
1296  return failure();
1297  }
1298 
1299  return success();
1300 }
1301 
1302 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1303  MLIRContext *context) {
1304  patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
1305 }
1306 
1307 void BroadcastOp::getAsmResultNames(
1308  function_ref<void(Value, StringRef)> setNameFn) {
1309  setNameFn(getResult(), "broadcast");
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // shard.gather op
1314 //===----------------------------------------------------------------------===//
1315 
1316 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1317  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1318  if (failed(grid)) {
1319  return failure();
1320  }
1321  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1322  getRootDynamic(), getGridAxes(),
1323  grid.value().getShape()))) {
1324  return failure();
1325  }
1326 
1327  auto gatherAxis = getGatherAxis().getSExtValue();
1328  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1329  getGridAxes(),
1330  grid.value().getShape());
1331 }
1332 
1333 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1334  MLIRContext *context) {
1335  patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
1336 }
1337 
1338 void GatherOp::getAsmResultNames(
1339  function_ref<void(Value, StringRef)> setNameFn) {
1340  setNameFn(getResult(), "gather");
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // shard.recv op
1345 //===----------------------------------------------------------------------===//
1346 
1347 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1348  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1349  if (failed(grid)) {
1350  return failure();
1351  }
1352  if (getSource() &&
1353  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1354  getSource().value(), getSourceDynamic(),
1355  getGridAxes(), grid.value().getShape()))) {
1356  return failure();
1357  }
1358  return success();
1359 }
1360 
1361 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1362  MLIRContext *context) {
1363  patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
1364 }
1365 
1366 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1367  setNameFn(getResult(), "recv");
1368 }
1369 
1370 //===----------------------------------------------------------------------===//
1371 // shard.reduce op
1372 //===----------------------------------------------------------------------===//
1373 
1374 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1375  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1376  if (failed(grid)) {
1377  return failure();
1378  }
1379  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1380  getRootDynamic(), getGridAxes(),
1381  grid.value().getShape()))) {
1382  return failure();
1383  }
1384 
1385  return success();
1386 }
1387 
1388 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1389  MLIRContext *context) {
1390  patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
1391 }
1392 
1393 void ReduceOp::getAsmResultNames(
1394  function_ref<void(Value, StringRef)> setNameFn) {
1395  setNameFn(getResult(), "reduce");
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // shard.reduce_scatter op
1400 //===----------------------------------------------------------------------===//
1401 
1402 LogicalResult
1403 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1404  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1405  if (failed(grid)) {
1406  return failure();
1407  }
1408 
1410  getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
1411  grid.value().getShape());
1412 }
1413 
1414 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1415  MLIRContext *context) {
1416  patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1417 }
1418 
1419 void ReduceScatterOp::getAsmResultNames(
1420  function_ref<void(Value, StringRef)> setNameFn) {
1421  setNameFn(getResult(), "reduce_scatter");
1422 }
1423 
1424 //===----------------------------------------------------------------------===//
1425 // shard.scatter op
1426 //===----------------------------------------------------------------------===//
1427 
1428 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1429  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1430  if (failed(grid)) {
1431  return failure();
1432  }
1433  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1434  getRootDynamic(), getGridAxes(),
1435  grid.value().getShape()))) {
1436  return failure();
1437  }
1438 
1439  auto scatterAxis = getScatterAxis().getSExtValue();
1440  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1441  scatterAxis, getGridAxes(),
1442  grid.value().getShape());
1443 }
1444 
1445 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1446  MLIRContext *context) {
1447  patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
1448 }
1449 
1450 void ScatterOp::getAsmResultNames(
1451  function_ref<void(Value, StringRef)> setNameFn) {
1452  setNameFn(getResult(), "scatter");
1453 }
1454 
1455 //===----------------------------------------------------------------------===//
1456 // shard.send op
1457 //===----------------------------------------------------------------------===//
1458 
1459 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1460  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1461  if (failed(grid)) {
1462  return failure();
1463  }
1464  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1465  getDestination(), getDestinationDynamic(),
1466  getGridAxes(), grid.value().getShape()))) {
1467  return failure();
1468  }
1469  return success();
1470 }
1471 
1472 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1473  MLIRContext *context) {
1474  patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
1475 }
1476 
1477 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1478  setNameFn(getResult(), "send");
1479 }
1480 
1481 //===----------------------------------------------------------------------===//
1482 // shard.shift op
1483 //===----------------------------------------------------------------------===//
1484 
1485 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1486  auto grid = getGridAndVerifyAxes(*this, symbolTable);
1487  if (failed(grid)) {
1488  return failure();
1489  }
1490 
1491  auto gridAxes = getGridAxes();
1492  auto shiftAxis = getShiftAxis().getZExtValue();
1493  if (!llvm::is_contained(gridAxes, shiftAxis)) {
1494  return emitError() << "Invalid shift axis " << shiftAxis
1495  << ". It must be one of the grouping grid axes.";
1496  }
1497 
1498  return success();
1499 }
1500 
1501 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1502  MLIRContext *context) {
1503  // TODO: remove op when offset is 0 or if it is a rotate with and
1504  // offset % shift_axis_grid_dim_size == 0.
1505 }
1506 
1507 void ShiftOp::getAsmResultNames(
1508  function_ref<void(Value, StringRef)> setNameFn) {
1509  setNameFn(getResult(), "shift");
1510 }
1511 
1512 //===----------------------------------------------------------------------===//
1513 // shard.update_halo op
1514 //===----------------------------------------------------------------------===//
1515 
1516 LogicalResult
1517 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1518  auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
1519  if (failed(grid)) {
1520  return failure();
1521  }
1522 
1523  return success();
1524 }
1525 
1526 //===----------------------------------------------------------------------===//
1527 // TableGen'd op method definitions
1528 //===----------------------------------------------------------------------===//
1529 
1530 #define GET_OP_CLASSES
1531 #include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
1532 
1533 #define GET_ATTRDEF_CLASSES
1534 #include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
1535 
1536 #define GET_TYPEDEF_CLASSES
1537 #include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
1538 
1539 #include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
Definition: ShardOps.cpp:1041
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: ShardOps.cpp:62
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: ShardOps.cpp:1025
static LogicalResult verifyGridAxes(Location loc, ArrayRef< GridAxis > axes, GridOp grid)
Definition: ShardOps.cpp:177
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
Definition: ShardOps.cpp:1068
static auto product(It begin, It end)
Definition: ShardOps.cpp:1014
static FailureOr< GridOp > getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: ShardOps.cpp:200
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition: ShardOps.cpp:214
static FailureOr< GridOp > getGridAndVerify(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable)
Definition: ShardOps.cpp:148
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
Definition: ShardOps.cpp:987
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
Definition: ShardOps.cpp:1151
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< GridAxis > gridAxes, ArrayRef< int64_t > gridShape)
Definition: ShardOps.cpp:1113
static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
Definition: ShardOps.cpp:299
bool isUnique(It begin, It end)
Definition: ShardOps.cpp:161
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:421
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:112
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition: Value.cpp:91
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool operator!=(Value rhs) const
Definition: ShardOps.cpp:744
bool equalShardSizes(const Sharding &rhs) const
Definition: ShardOps.cpp:712
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition: ShardOps.cpp:752
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: ShardOps.cpp:770
bool equalSplitAxes(const Sharding &rhs) const
Definition: ShardOps.cpp:689
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition: ShardOps.h:60
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: ShardOps.h:64
::llvm::StringRef getGrid() const
Definition: ShardOps.h:61
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition: ShardOps.cpp:708
bool operator==(Value rhs) const
Definition: ShardOps.cpp:740
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: ShardOps.h:63
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition: ShardOps.h:68
ArrayRef< Value > getDynamicHaloSizes() const
Definition: ShardOps.h:67
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition: ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition: ShardOps.cpp:728
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
shard::ReductionKind ReductionKind
shard::GridOp GridOp
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
int16_t GridAxis
Definition: ShardOps.h:26
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition: ShardOps.cpp:281
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition: ShardOps.cpp:338
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:113
bool isFullReplication(Sharding sharding)
Definition: ShardOps.h:106
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: ShardOps.cpp:352
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: ShardOps.h:168
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition: ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition: ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition: ShardOps.h:146
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:252
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.