MLIR  20.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <functional>
36 #include <iterator>
37 #include <numeric>
38 #include <optional>
39 #include <utility>
40 
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 using namespace mlir;
45 using namespace mlir::mesh;
46 
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
48 
49 namespace {
50 
51 struct DimensionSize {
52  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
53  DimensionSize(int64_t val) : val(val) {}
54  int64_t value() const { return val; }
55  operator int64_t() const { return val; }
56  bool isDynamic() const { return ShapedType::isDynamic(val); }
57 
58 private:
59  int64_t val;
60 };
61 
62 } // namespace
63 
64 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
65  if (lhs.isDynamic() || rhs.isDynamic()) {
66  return DimensionSize::dynamic();
67  }
68  return lhs.value() / rhs.value();
69 }
70 
71 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
72  if (lhs.isDynamic() || rhs.isDynamic()) {
73  return DimensionSize::dynamic();
74  }
75  return lhs.value() * rhs.value();
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // Inliner
80 //===----------------------------------------------------------------------===//
81 
82 namespace {
83 struct MeshInlinerInterface : public DialectInlinerInterface {
85  // Currently no restrictions are encoded for inlining.
86  bool isLegalToInline(Operation *, Operation *, bool) const final {
87  return true;
88  }
89  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
90  return true;
91  }
92  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
93  return true;
94  }
95 };
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // Mesh dialect
100 //===----------------------------------------------------------------------===//
101 
102 void MeshDialect::initialize() {
103  addOperations<
104 #define GET_OP_LIST
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
106  >();
107  addAttributes<
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
110  >();
111  addTypes<
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
114  >();
115  addInterface<MeshInlinerInterface>();
116 }
117 
119  Type type, Location loc) {
120  return arith::ConstantOp::materialize(builder, value, type, loc);
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Mesh utilities
125 //===----------------------------------------------------------------------===//
126 
127 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
128  FlatSymbolRefAttr meshSymbol,
129  SymbolTableCollection &symbolTable) {
130  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
131  if (!mesh) {
132  return op->emitError() << "Undefined required mesh symbol \""
133  << meshSymbol.getValue() << "\".";
134  }
135 
136  return mesh;
137 }
138 
139 template <typename It>
140 bool isUnique(It begin, It end) {
141  if (begin == end) {
142  return true;
143  }
144  It next = std::next(begin);
145  if (next == end) {
146  return true;
147  }
148  for (; next != end; ++next, ++begin) {
149  if (*begin == *next) {
150  return false;
151  }
152  }
153  return true;
154 }
155 
156 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
157  MeshOp mesh) {
158  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
159  llvm::sort(sorted);
160  if (!isUnique(sorted.begin(), sorted.end())) {
161  return emitError(loc) << "Mesh axes contains duplicate elements.";
162  }
163 
164  MeshAxis rank = mesh.getRank();
165  for (auto axis : axes) {
166  if (axis >= rank || axis < 0) {
167  return emitError(loc)
168  << "0-based mesh axis index " << axis
169  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
170  << "\" is of rank " << rank << ".";
171  }
172  }
173 
174  return success();
175 }
176 
177 template <typename Op>
178 static FailureOr<MeshOp>
180  auto mesh =
181  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
182  if (failed(mesh)) {
183  return failure();
184  }
185  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
186  return failure();
187  }
188  return mesh;
189 }
190 
191 template <typename InShape, typename MeshShape, typename SplitAxes,
192  typename OutShape>
193 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
194  const SplitAxes &splitAxes, OutShape &outShape,
195  ArrayRef<int64_t> shardedDimsOffsets = {},
196  ArrayRef<int64_t> haloSizes = {}) {
197  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
198  llvm::adl_begin(outShape));
199 
200  if (!shardedDimsOffsets.empty()) {
201  auto isDynShape = ShapedType::isDynamicShape(meshShape);
202  uint64_t pos = 1;
203  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
204  if (!innerSplitAxes.empty()) {
205  auto sz = shardedDimsOffsets[pos];
206  bool same = !isDynShape;
207  if (same) {
208  // Find sharded dims in shardedDimsOffsets with same static size on
209  // all devices. Use kDynamic for dimensions with dynamic or
210  // non-uniform offs in shardedDimsOffsets.
211  uint64_t numShards = 0;
212  for (auto i : innerSplitAxes.asArrayRef()) {
213  numShards += meshShape[i];
214  }
215  for (size_t i = 1; i < numShards; ++i) {
216  if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
217  sz) {
218  same = false;
219  break;
220  }
221  }
222  pos += numShards + 1;
223  }
224  outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
225  }
226  }
227  } else {
228  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
229  outShape[tensorAxis] = shardDimension(
230  inShape[tensorAxis],
231  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
232  }
233 
234  if (!haloSizes.empty()) {
235  // add halo sizes if requested
236  int haloAxis = 0;
237  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
238  if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
239  !innerSplitAxes.empty()) {
240  if (haloSizes[haloAxis * 2] >= 0 &&
241  haloSizes[haloAxis * 2 + 1] >= 0) {
242  outShape[tensorAxis] +=
243  haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
244  ++haloAxis;
245  } else {
246  outShape[tensorAxis] = ShapedType::kDynamic;
247  }
248  }
249  }
250  }
251  }
252 }
253 
254 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
255  MeshSharding sharding) {
256  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
257  SmallVector<Dim> resShapeArr(shape.getShape().size());
258  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
259  resShapeArr, sharding.getStaticShardedDimsOffsets(),
260  sharding.getStaticHaloSizes());
261  return shape.clone(resShapeArr);
262 }
263 
264 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
265  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
266  if (rankedTensorType) {
267  return shardShapedType(rankedTensorType, mesh, sharding);
268  }
269  return type;
270 }
271 
273  OpOperand &operand,
274  OpBuilder &builder) {
275  OpBuilder::InsertionGuard insertionGuard(builder);
276  Value operandValue = operand.get();
277  Operation *operandOp = operand.getOwner();
278  builder.setInsertionPointAfterValue(operandValue);
279  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
280  if (shardOp && sharding == shardOp.getSharding() &&
281  !shardOp.getAnnotateForUsers()) {
282  // No need for anything the correct sharding is already set.
283  return;
284  }
285 
286  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
287  auto newShardOp =
288  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
289  /*annotate_for_users*/ false);
290  IRRewriter rewriter(builder);
291  rewriter.replaceUsesWithIf(
292  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
293  return use.getOwner() == operandOp && use.get() == operandValue;
294  });
295 
296  if (!shardOp || shardOp.getAnnotateForUsers()) {
297  return;
298  }
299 
300  auto newShardOp2 =
301  builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
302  /*annotate_for_users*/ true);
303  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
304 }
305 
307  OpResult result,
308  OpBuilder &builder) {
309  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
310  maybeInsertTargetShardingAnnotation(sharding, use, builder);
311  }
312 }
313 
315  OpOperand &operand,
316  OpBuilder &builder) {
317  OpBuilder::InsertionGuard insertionGuard(builder);
318  Value operandValue = operand.get();
319  Operation *operandOp = operand.getOwner();
320  Operation *operandSrcOp = operandValue.getDefiningOp();
321  bool isBlockArg = !operandSrcOp;
322  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
323 
324  if (shardOp && sharding == shardOp.getSharding() &&
325  shardOp.getAnnotateForUsers()) {
326  // No need for anything the correct sharding is already set.
327  return;
328  }
329 
330  builder.setInsertionPoint(operandOp);
331  auto shardingOp =
332  builder.create<ShardingOp>(operand.get().getLoc(), sharding);
333  auto newShardOp =
334  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
335  /*annotate_for_users*/ true);
336  IRRewriter rewriter(builder);
337  rewriter.replaceUsesWithIf(
338  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
339  return use.getOwner() == operandOp && use.get() == operandValue;
340  });
341 
342  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
343  // No need for resharding.
344  return;
345  }
346 
347  builder.setInsertionPoint(newShardOp);
348  auto newPreceedingShardOp =
349  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
350  /*annotate_for_users*/ false);
351  rewriter.replaceUsesWithIf(
352  newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
353  return use.getOwner() == newShardOp.getOperation();
354  });
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // mesh.mesh op
359 //===----------------------------------------------------------------------===//
360 
361 LogicalResult MeshOp::verify() {
362  int64_t rank = getRank();
363 
364  if (rank <= 0)
365  return emitOpError("rank of mesh is expected to be a positive integer");
366 
367  for (int64_t dimSize : getShape()) {
368  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
369  return emitOpError("dimension size of a mesh is expected to be "
370  "non-negative or dynamic");
371  }
372 
373  return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // mesh.mesh_shape op
378 //===----------------------------------------------------------------------===//
379 
380 LogicalResult
381 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
382  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
383  if (failed(mesh)) {
384  return failure();
385  }
386  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
387  return failure();
388  }
389 
390  size_t expectedResultsCount =
391  getAxes().empty() ? mesh->getRank() : getAxes().size();
392  if (getResult().size() != expectedResultsCount) {
393  return emitError() << "Unexpected number of results " << getResult().size()
394  << ". Expected " << expectedResultsCount << ".";
395  }
396 
397  return success();
398 }
399 
400 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
401  MeshOp mesh) {
402  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
403 }
404 
405 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
406  MeshOp mesh, ArrayRef<MeshAxis> axes) {
407  build(odsBuilder, odsState,
408  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
409  odsBuilder.getIndexType()),
410  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
411 }
412 
413 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
414  StringRef mesh, ArrayRef<MeshAxis> axes) {
415  assert(!axes.empty());
416  build(odsBuilder, odsState,
417  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
418  MeshAxesAttr::get(odsBuilder.getContext(), axes));
419 }
420 
421 void MeshShapeOp::getAsmResultNames(
422  function_ref<void(Value, StringRef)> setNameFn) {
423  setNameFn(getResults()[0], "mesh_shape");
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // mesh.sharding
428 //===----------------------------------------------------------------------===//
429 
430 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
431  FlatSymbolRefAttr mesh,
432  ArrayRef<MeshAxesAttr> split_axes,
433  ArrayRef<MeshAxis> partial_axes,
434  mesh::ReductionKind partial_type,
435  ArrayRef<int64_t> static_halo_sizes,
436  ArrayRef<int64_t> static_sharded_dims_offsets) {
437  return build(
438  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
439  ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
440  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
441  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
443  static_sharded_dims_offsets),
444  {});
445 }
446 
447 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
448  FlatSymbolRefAttr mesh,
449  ArrayRef<MeshAxesAttr> split_axes) {
450  return build(
451  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
452  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
453  {}, {}, {}, {});
454 }
455 
456 void ShardingOp::build(
457  ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
458  FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
460  ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
461  mlir::SmallVector<int64_t> staticHalos, staticDims;
462  mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
463  dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
464  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
465  return build(
466  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
467  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
468  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
469  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
470 }
471 
472 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
474 
475  build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
477  from.getPartialAxes().empty()
481  from.getPartialType()),
482  from.getStaticShardedDimsOffsets().empty()
486  from.getStaticHaloSizes().empty()
489  from.getDynamicHaloSizes());
490 }
491 
492 LogicalResult ShardingOp::verify() {
493  llvm::SmallSet<MeshAxis, 4> visitedAxes;
494 
495  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
496  for (MeshAxis axis : axesArray) {
497  if (axis < 0)
498  return emitError() << "mesh axis is expected to be non-negative";
499  if (!visitedAxes.insert(axis).second)
500  return emitError() << "mesh axis duplicated";
501  }
502  return success();
503  };
504 
505  for (auto subAxes : getSplitAxes().getAxes()) {
506  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
507  if (failed(checkMeshAxis(subAxesArray)))
508  return failure();
509  }
510  if (getPartialAxes().has_value() &&
511  failed(checkMeshAxis(getPartialAxes().value())))
512  return failure();
513 
514  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
515  return emitOpError("halo sizes and shard offsets are mutually exclusive");
516  }
517 
518  if (!getStaticHaloSizes().empty()) {
519  auto numSplitAxes = getSplitAxes().getAxes().size();
520  for (auto splitAxis : getSplitAxes().getAxes()) {
521  if (splitAxis.empty()) {
522  --numSplitAxes;
523  }
524  }
525  if (getStaticHaloSizes().size() != numSplitAxes * 2) {
526  return emitError() << "halo sizes must be specified for all split axes.";
527  }
528  }
529 
530  return success();
531 }
532 
533 void ShardingOp::getAsmResultNames(
534  function_ref<void(Value, StringRef)> setNameFn) {
535  setNameFn(getResult(), "sharding");
536 }
537 
538 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
539  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
540  if (failed(mesh)) {
541  return failure();
542  }
543  if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
544  getStaticShardedDimsOffsets().size() > 0) {
545  return emitError() << "sharded dims offsets are not allowed for "
546  "devices meshes with dynamic shape.";
547  }
548 
549  auto shardedDimsOffsets = getStaticShardedDimsOffsets();
550  if (!shardedDimsOffsets.empty()) {
551  auto meshShape = mesh.value().getShape();
552  assert(!ShapedType::isDynamicShape(meshShape));
553  uint64_t pos = 0;
554  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
555  if (!innerSplitAxes.empty()) {
556  int64_t numShards = 0, off = 0;
557  for (auto i : innerSplitAxes.asArrayRef()) {
558  numShards += meshShape[i];
559  }
560  for (int64_t i = 0; i <= numShards; ++i) {
561  if (shardedDimsOffsets.size() <= pos + i) {
562  return emitError() << "sharded dims offsets has wrong size.";
563  }
564  if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
565  if (shardedDimsOffsets[pos + i] < off) {
566  return emitError()
567  << "sharded dims offsets must be non-decreasing.";
568  }
569  off = shardedDimsOffsets[pos + i];
570  }
571  }
572  pos += numShards + 1;
573  }
574  }
575  }
576  return success();
577 }
578 
579 namespace {
580 // Sharding annotations "halo sizes" and "sharded dims offsets"
581 // are a mix of attributes and dynamic values. This canonicalization moves
582 // constant values to the respective attribute lists and so minimizes the number
583 // of values.
584 class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
585 public:
587 
588  LogicalResult matchAndRewrite(ShardingOp op,
589  PatternRewriter &b) const override {
590  auto mixedHalos =
591  getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
592  auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
593  op.getDynamicShardedDimsOffsets(), b);
594 
595  // No constant operands were folded, just return;
596  if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
597  failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
598  return failure();
599  }
600 
601  auto halos = decomposeMixedValues(mixedHalos);
602  auto offs = decomposeMixedValues(mixedOffs);
603 
604  op.setStaticHaloSizes(halos.first);
605  op.getDynamicHaloSizesMutable().assign(halos.second);
606  op.setStaticShardedDimsOffsets(offs.first);
607  op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
608 
609  return success();
610  }
611 };
612 } // namespace
613 
614 void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
615  mlir::MLIRContext *context) {
616  results.add<FoldDynamicLists>(context);
617 }
618 
619 //===----------------------------------------------------------------------===//
620 // MeshSharding
621 //===----------------------------------------------------------------------===//
622 
624  if (getMesh() != rhs.getMesh()) {
625  return false;
626  }
627 
628  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
629  (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
630  !llvm::equal(
631  llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
632  llvm::make_range(rhs.getPartialAxes().begin(),
633  rhs.getPartialAxes().end()))) {
634  return false;
635  }
636 
637  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
638  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
639  getSplitAxes().begin() + minSize),
640  llvm::make_range(rhs.getSplitAxes().begin(),
641  rhs.getSplitAxes().begin() + minSize))) {
642  return false;
643  }
644 
645  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
646  getSplitAxes().end()),
647  std::mem_fn(&MeshAxesAttr::empty)) &&
648  llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
649  rhs.getSplitAxes().end()),
650  std::mem_fn(&MeshAxesAttr::empty));
651 }
652 
654  return equalShardSizes(rhs) && equalHaloSizes(rhs);
655 }
656 
658  if (rhs.getStaticShardedDimsOffsets().size() !=
659  getStaticShardedDimsOffsets().size() ||
660  !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
662  llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
663  rhs.getStaticShardedDimsOffsets().end()))) {
664  return false;
665  }
666  if (rhs.getDynamicShardedDimsOffsets().size() !=
667  getDynamicShardedDimsOffsets().size() ||
668  !llvm::equal(
669  llvm::make_range(getDynamicShardedDimsOffsets().begin(),
671  llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(),
672  rhs.getDynamicShardedDimsOffsets().end()))) {
673  return false;
674  }
675  return true;
676 }
677 
679  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
680  !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
681  getStaticHaloSizes().end()),
682  llvm::make_range(rhs.getStaticHaloSizes().begin(),
683  rhs.getStaticHaloSizes().end()))) {
684  return false;
685  }
686  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
687  !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
688  getDynamicHaloSizes().end()),
689  llvm::make_range(rhs.getDynamicHaloSizes().begin(),
690  rhs.getDynamicHaloSizes().end()))) {
691  return false;
692  }
693  return true;
694 }
695 
698 }
699 
700 bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
701 
702 bool MeshSharding::operator==(const MeshSharding &rhs) const {
704 }
705 
706 bool MeshSharding::operator!=(const MeshSharding &rhs) const {
707  return !(*this == rhs);
708 }
709 
711  auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
712  assert(shardingOp && "expected sharding op");
713  *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
714  shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
715  shardingOp.getPartialType().value_or(ReductionKind::Sum),
716  shardingOp.getStaticHaloSizes(),
717  shardingOp.getStaticShardedDimsOffsets(),
718  SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
719  SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
720 }
721 
723  ArrayRef<MeshAxesAttr> split_axes_,
724  ArrayRef<MeshAxis> partial_axes_,
725  ReductionKind partial_type_,
726  ArrayRef<int64_t> static_halo_sizes_,
727  ArrayRef<int64_t> static_sharded_dims_offsets_,
728  ArrayRef<Value> dynamic_halo_sizes_,
729  ArrayRef<Value> dynamic_sharded_dims_offsets_) {
730  MeshSharding res;
731  res.mesh = mesh_;
732  res.split_axes.resize(split_axes_.size());
733  for (auto [i, axis] : llvm::enumerate(split_axes_)) {
734  res.split_axes[i] =
735  MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
736  }
737 
738  auto clone = [](const auto src, auto &dst) {
739  dst.resize(src.size());
740  llvm::copy(src, dst.begin());
741  };
742 
743  clone(partial_axes_, res.partial_axes);
744  res.partial_type = partial_type_;
745  clone(static_halo_sizes_, res.static_halo_sizes);
746  clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
747  clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
748  clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
749 
750  return res;
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // mesh.shard_shape
755 //===----------------------------------------------------------------------===//
756 
757 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
758  ::mlir::OperationState &odsState,
759  ::llvm::ArrayRef<int64_t> shape,
760  ::mlir::Value sharding, ::mlir::Value device) {
761  SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
762  build(odsBuilder, odsState, resType, shape, sharding, device);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // mesh.shard op
767 //===----------------------------------------------------------------------===//
768 
769 void ShardOp::getAsmResultNames(
770  function_ref<void(Value, StringRef)> setNameFn) {
771  setNameFn(getResult(), "sharding_annotated");
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // mesh.process_multi_index op
776 //===----------------------------------------------------------------------===//
777 
778 LogicalResult
779 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
780  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
781  if (failed(mesh)) {
782  return failure();
783  }
784  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
785  return failure();
786  }
787 
788  size_t expectedResultsCount =
789  getAxes().empty() ? mesh->getRank() : getAxes().size();
790  if (getResult().size() != expectedResultsCount) {
791  return emitError() << "Unexpected number of results " << getResult().size()
792  << ". Expected " << expectedResultsCount << ".";
793  }
794 
795  return success();
796 }
797 
798 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
799  MeshOp mesh) {
800  build(odsBuilder, odsState,
801  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
802  mesh.getSymName(), ArrayRef<MeshAxis>());
803 }
804 
805 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
806  StringRef mesh, ArrayRef<MeshAxis> axes) {
807  build(odsBuilder, odsState,
808  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
809  MeshAxesAttr::get(odsBuilder.getContext(), axes));
810 }
811 
812 void ProcessMultiIndexOp::getAsmResultNames(
813  function_ref<void(Value, StringRef)> setNameFn) {
814  setNameFn(getResults()[0], "proc_linear_idx");
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // mesh.process_linear_index op
819 //===----------------------------------------------------------------------===//
820 
821 LogicalResult
822 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
823  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
824  if (failed(mesh)) {
825  return failure();
826  }
827  return success();
828 }
829 
830 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
831  OperationState &odsState, MeshOp mesh) {
832  build(odsBuilder, odsState, mesh.getSymName());
833 }
834 
835 void ProcessLinearIndexOp::getAsmResultNames(
836  function_ref<void(Value, StringRef)> setNameFn) {
837  setNameFn(getResult(), "proc_linear_idx");
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // collective communication ops
842 //===----------------------------------------------------------------------===//
843 
844 namespace {
845 
846 template <typename Op>
847 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
849  LogicalResult matchAndRewrite(Op op,
850  PatternRewriter &rewriter) const override {
851  auto meshAxes = op.getMeshAxes();
852  if (!meshAxes.empty()) {
853  return failure();
854  }
855  if (op.getInput().getType() != op.getResult().getType()) {
856  return failure();
857  }
858 
859  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
860  rewriter.eraseOp(op.getOperation());
861  return success();
862  }
863 };
864 
865 } // namespace
866 
867 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
868  ArrayRef<int64_t> device,
869  Operation::operand_range deviceDynamic,
870  ArrayRef<MeshAxis> meshAxes,
871  ArrayRef<int64_t> meshShape) {
872  if (device.size() != meshAxes.size()) {
873  return emitError(loc) << "In-group device \"" << deviceName
874  << "\" has unexpected multi-index size "
875  << device.size() << ". Expected " << meshAxes.size()
876  << ".";
877  }
878 
879  for (size_t i = 0; i < device.size(); ++i) {
880  if (!ShapedType::isDynamic(device[i]) &&
881  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
882  meshShape[meshAxes[i]] <= device[i]) {
883  return emitError(loc)
884  << "Out of bounds coordinate " << i << " for in-group device \""
885  << deviceName << "\"."
886  << " Got " << device[i] << ", but expected value in the range [0, "
887  << (meshShape[meshAxes[i]] - 1) << "].";
888  }
889  }
890  return success();
891 }
892 
893 template <typename It>
894 static auto product(It begin, It end) {
895  using ElementType = std::decay_t<decltype(*begin)>;
896  return std::accumulate(begin, end, static_cast<ElementType>(1),
897  std::multiplies<ElementType>());
898 }
899 
900 template <typename R>
901 static auto product(R &&range) {
902  return product(adl_begin(range), adl_end(range));
903 }
904 
905 static LogicalResult verifyDimensionCompatibility(Location loc,
906  int64_t expectedDimSize,
907  int64_t resultDimSize,
908  int64_t resultAxis) {
909  if (!ShapedType::isDynamic(resultDimSize) &&
910  expectedDimSize != resultDimSize) {
911  return emitError(loc) << "Dimension size mismatch for result axis "
912  << resultAxis << ". Expected "
913  << (ShapedType::isDynamic(expectedDimSize)
914  ? Twine("dynamic")
915  : Twine(expectedDimSize))
916  << ", but got " << resultDimSize << ".";
917  }
918 
919  return success();
920 }
921 
923  Value operand, Value result, int64_t gatherAxis,
924  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
925  auto resultRank = cast<ShapedType>(result.getType()).getRank();
926  if (gatherAxis < 0 || gatherAxis >= resultRank) {
927  return emitError(result.getLoc())
928  << "Gather axis " << gatherAxis << " is out of bounds [0, "
929  << resultRank << ").";
930  }
931 
932  ShapedType operandType = cast<ShapedType>(operand.getType());
933  ShapedType resultType = cast<ShapedType>(result.getType());
934  auto deviceGroupSize =
935  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
936  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
937  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
938  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
939  auto expectedResultDimSize =
940  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
941  if (failed(verifyDimensionCompatibility(
942  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
943  return failure();
944  }
945  }
946  return success();
947 }
948 
950  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
951  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
952  ShapedType operandType = cast<ShapedType>(operand.getType());
953  ShapedType resultType = cast<ShapedType>(result.getType());
954  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
955  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
956  if (failed(verifyDimensionCompatibility(
957  result.getLoc(), operandType.getDimSize(axis),
958  resultType.getDimSize(axis), axis))) {
959  return failure();
960  }
961  }
962  }
963 
964  if (splitAxis == concatAxis) {
965  return success();
966  }
967 
968  auto deviceGroupSize =
969  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
970  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
971  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
972  DimensionSize expectedResultConcatDimSize =
973  operandConcatDimSize * deviceGroupSize;
974  DimensionSize expectedResultSplitDimSize =
975  operandSplitDimSize / deviceGroupSize;
976  if (!expectedResultSplitDimSize.isDynamic() &&
977  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
978  expectedResultSplitDimSize = DimensionSize::dynamic();
979  }
980  if (failed(verifyDimensionCompatibility(
981  result.getLoc(), expectedResultConcatDimSize.value(),
982  resultType.getDimSize(concatAxis), concatAxis))) {
983  return failure();
984  }
985  if (failed(verifyDimensionCompatibility(
986  result.getLoc(), expectedResultSplitDimSize.value(),
987  resultType.getDimSize(splitAxis), splitAxis))) {
988  return failure();
989  }
990 
991  return success();
992 }
993 
995  Value operand, Value result, int64_t tensorAxis,
996  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
997  ShapedType operandType = cast<ShapedType>(operand.getType());
998  ShapedType resultType = cast<ShapedType>(result.getType());
999  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1000  if (axis != tensorAxis) {
1001  if (failed(verifyDimensionCompatibility(
1002  result.getLoc(), operandType.getDimSize(axis),
1003  resultType.getDimSize(axis), axis))) {
1004  return failure();
1005  }
1006  }
1007  }
1008 
1009  auto deviceGroupSize =
1010  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1011  auto operandScatterDimSize =
1012  DimensionSize(operandType.getDimSize(tensorAxis));
1013  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1014  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1015  return emitError(result.getLoc())
1016  << "Operand dimension size " << int64_t(operandScatterDimSize)
1017  << " is not divisible by collective device group size "
1018  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1019  << ".";
1020  }
1021  DimensionSize expectedResultTensorDimSize =
1022  operandScatterDimSize / deviceGroupSize;
1023  if (failed(verifyDimensionCompatibility(
1024  result.getLoc(), expectedResultTensorDimSize.value(),
1025  resultType.getDimSize(tensorAxis), tensorAxis))) {
1026  return failure();
1027  }
1028 
1029  return success();
1030 }
1031 
1032 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
1033  ArrayRef<MeshAxis> meshAxes,
1034  int64_t sliceAxis) {
1035  RankedTensorType operandRankedTensorType =
1036  cast<RankedTensorType>(operandType);
1037  DimensionSize operandSliceAxisSize =
1038  operandRankedTensorType.getShape()[sliceAxis];
1039  SmallVector<int64_t> resultShape =
1040  llvm::to_vector(operandRankedTensorType.getShape());
1041 
1042  resultShape[sliceAxis] =
1043  operandSliceAxisSize /
1044  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1045  return operandRankedTensorType.clone(resultShape);
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // mesh.all_gather op
1050 //===----------------------------------------------------------------------===//
1051 
1052 LogicalResult
1053 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1054  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1055  if (failed(mesh)) {
1056  return failure();
1057  }
1058  auto gatherAxis = getGatherAxis().getSExtValue();
1059  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1060  gatherAxis, getMeshAxes(),
1061  mesh.value().getShape());
1062 }
1063 
1064 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1065  MLIRContext *context) {
1066  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1067 }
1068 
1069 void AllGatherOp::getAsmResultNames(
1070  function_ref<void(Value, StringRef)> setNameFn) {
1071  setNameFn(getResult(), "all_gather");
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // mesh.all_reduce op
1076 //===----------------------------------------------------------------------===//
1077 
1078 LogicalResult
1079 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1080  return getMeshAndVerifyAxes(*this, symbolTable);
1081 }
1082 
1083 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1084  MLIRContext *context) {
1085  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1086 }
1087 
1088 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1089  Value input, StringRef mesh,
1090  ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
1091  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1092  reduction);
1093 }
1094 
1095 void AllReduceOp::getAsmResultNames(
1096  function_ref<void(Value, StringRef)> setNameFn) {
1097  setNameFn(getResult(), "all_reduce");
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // mesh.all_slice op
1102 //===----------------------------------------------------------------------===//
1103 
1104 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1105  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1106  if (failed(mesh)) {
1107  return failure();
1108  }
1110  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1111  mesh.value().getShape());
1112 }
1113 
1114 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1115  MLIRContext *context) {
1116  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1117 }
1118 
1119 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1120  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1121  int64_t sliceAxis) {
1122  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1123  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1124  sliceAxis);
1125 }
1126 
1127 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1128  Type resultType, Value input, StringRef mesh,
1129  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1130  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1131  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1132 }
1133 
1134 void AllSliceOp::getAsmResultNames(
1135  function_ref<void(Value, StringRef)> setNameFn) {
1136  setNameFn(getResult(), "all_slice");
1137 }
1138 
1139 //===----------------------------------------------------------------------===//
1140 // mesh.all_to_all op
1141 //===----------------------------------------------------------------------===//
1142 
1143 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1144  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1145  if (failed(mesh)) {
1146  return failure();
1147  }
1148 
1150  getOperand(), getResult(), getSplitAxis().getSExtValue(),
1151  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1152 }
1153 
1154 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1155  MLIRContext *context) {
1156  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1157 }
1158 
1159 void AllToAllOp::getAsmResultNames(
1160  function_ref<void(Value, StringRef)> setNameFn) {
1161  setNameFn(getResult(), "all_to_all");
1162 }
1163 
1164 //===----------------------------------------------------------------------===//
1165 // mesh.broadcast op
1166 //===----------------------------------------------------------------------===//
1167 
1168 LogicalResult
1169 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1170  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1171  if (failed(mesh)) {
1172  return failure();
1173  }
1174  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1175  getRootDynamic(), getMeshAxes(),
1176  mesh.value().getShape()))) {
1177  return failure();
1178  }
1179 
1180  return success();
1181 }
1182 
1183 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1184  MLIRContext *context) {
1185  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1186 }
1187 
1188 void BroadcastOp::getAsmResultNames(
1189  function_ref<void(Value, StringRef)> setNameFn) {
1190  setNameFn(getResult(), "broadcast");
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // mesh.gather op
1195 //===----------------------------------------------------------------------===//
1196 
1197 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1198  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1199  if (failed(mesh)) {
1200  return failure();
1201  }
1202  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1203  getRootDynamic(), getMeshAxes(),
1204  mesh.value().getShape()))) {
1205  return failure();
1206  }
1207 
1208  auto gatherAxis = getGatherAxis().getSExtValue();
1209  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1210  getMeshAxes(),
1211  mesh.value().getShape());
1212 }
1213 
1214 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1215  MLIRContext *context) {
1216  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1217 }
1218 
1219 void GatherOp::getAsmResultNames(
1220  function_ref<void(Value, StringRef)> setNameFn) {
1221  setNameFn(getResult(), "gather");
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // mesh.recv op
1226 //===----------------------------------------------------------------------===//
1227 
1228 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1229  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1230  if (failed(mesh)) {
1231  return failure();
1232  }
1233  if (getSource() &&
1234  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1235  getSource().value(), getSourceDynamic(),
1236  getMeshAxes(), mesh.value().getShape()))) {
1237  return failure();
1238  }
1239  return success();
1240 }
1241 
1242 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1243  MLIRContext *context) {
1244  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1245 }
1246 
1247 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1248  setNameFn(getResult(), "recv");
1249 }
1250 
1251 //===----------------------------------------------------------------------===//
1252 // mesh.reduce op
1253 //===----------------------------------------------------------------------===//
1254 
1255 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1256  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1257  if (failed(mesh)) {
1258  return failure();
1259  }
1260  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1261  getRootDynamic(), getMeshAxes(),
1262  mesh.value().getShape()))) {
1263  return failure();
1264  }
1265 
1266  return success();
1267 }
1268 
1269 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1270  MLIRContext *context) {
1271  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1272 }
1273 
1274 void ReduceOp::getAsmResultNames(
1275  function_ref<void(Value, StringRef)> setNameFn) {
1276  setNameFn(getResult(), "reduce");
1277 }
1278 
1279 //===----------------------------------------------------------------------===//
1280 // mesh.reduce_scatter op
1281 //===----------------------------------------------------------------------===//
1282 
1283 LogicalResult
1284 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1285  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1286  if (failed(mesh)) {
1287  return failure();
1288  }
1289 
1291  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1292  mesh.value().getShape());
1293 }
1294 
1295 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1296  MLIRContext *context) {
1297  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1298 }
1299 
1300 void ReduceScatterOp::getAsmResultNames(
1301  function_ref<void(Value, StringRef)> setNameFn) {
1302  setNameFn(getResult(), "reduce_scatter");
1303 }
1304 
1305 //===----------------------------------------------------------------------===//
1306 // mesh.scatter op
1307 //===----------------------------------------------------------------------===//
1308 
1309 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1310  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1311  if (failed(mesh)) {
1312  return failure();
1313  }
1314  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1315  getRootDynamic(), getMeshAxes(),
1316  mesh.value().getShape()))) {
1317  return failure();
1318  }
1319 
1320  auto scatterAxis = getScatterAxis().getSExtValue();
1321  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1322  scatterAxis, getMeshAxes(),
1323  mesh.value().getShape());
1324 }
1325 
1326 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1327  MLIRContext *context) {
1328  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1329 }
1330 
1331 void ScatterOp::getAsmResultNames(
1332  function_ref<void(Value, StringRef)> setNameFn) {
1333  setNameFn(getResult(), "scatter");
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // mesh.send op
1338 //===----------------------------------------------------------------------===//
1339 
1340 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1341  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1342  if (failed(mesh)) {
1343  return failure();
1344  }
1345  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1346  getDestination(), getDestinationDynamic(),
1347  getMeshAxes(), mesh.value().getShape()))) {
1348  return failure();
1349  }
1350  return success();
1351 }
1352 
1353 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1354  MLIRContext *context) {
1355  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1356 }
1357 
1358 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1359  setNameFn(getResult(), "send");
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // mesh.shift op
1364 //===----------------------------------------------------------------------===//
1365 
1366 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1367  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1368  if (failed(mesh)) {
1369  return failure();
1370  }
1371 
1372  auto meshAxes = getMeshAxes();
1373  auto shiftAxis = getShiftAxis().getZExtValue();
1374  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1375  return emitError() << "Invalid shift axis " << shiftAxis
1376  << ". It must be one of the grouping mesh axes.";
1377  }
1378 
1379  return success();
1380 }
1381 
1382 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1383  MLIRContext *context) {
1384  // TODO: remove op when offset is 0 or if it is a rotate with and
1385  // offset % shift_axis_mesh_dim_size == 0.
1386 }
1387 
1388 void ShiftOp::getAsmResultNames(
1389  function_ref<void(Value, StringRef)> setNameFn) {
1390  setNameFn(getResult(), "shift");
1391 }
1392 
1393 //===----------------------------------------------------------------------===//
1394 // mesh.update_halo op
1395 //===----------------------------------------------------------------------===//
1396 
1397 LogicalResult
1398 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1399  auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
1400  if (failed(mesh)) {
1401  return failure();
1402  }
1403 
1404  return success();
1405 }
1406 
1407 //===----------------------------------------------------------------------===//
1408 // TableGen'd op method definitions
1409 //===----------------------------------------------------------------------===//
1410 
1411 #define GET_OP_CLASSES
1412 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1413 
1414 #define GET_ATTRDEF_CLASSES
1415 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1416 
1417 #define GET_TYPEDEF_CLASSES
1418 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1419 
1420 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:1032
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:64
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:905
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:179
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:127
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:994
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:922
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:949
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition: MeshOps.cpp:193
static auto product(It begin, It end)
Definition: MeshOps.cpp:894
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:140
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:156
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:867
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
Definition: Builders.cpp:199
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:108
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
Definition: Operation.h:279
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:623
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:653
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:678
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
bool operator!=(Value rhs) const
Definition: MeshOps.cpp:700
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition: MeshOps.h:74
bool operator==(Value rhs) const
Definition: MeshOps.cpp:696
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
::llvm::StringRef getMesh() const
Definition: MeshOps.h:65
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:722
bool equalShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:657
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:151
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:120
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:314
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:264
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:254
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:173
int16_t MeshAxis
Definition: MeshOps.h:26
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:272
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:265
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.