MLIR  20.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Casting.h"
33 #include <algorithm>
34 #include <functional>
35 #include <iterator>
36 #include <numeric>
37 #include <optional>
38 #include <utility>
39 
40 #define DEBUG_TYPE "mesh-ops"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
42 
43 using namespace mlir;
44 using namespace mlir::mesh;
45 
46 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
47 
48 namespace {
49 
50 struct DimensionSize {
51  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
52  DimensionSize(int64_t val) : val(val) {}
53  int64_t value() const { return val; }
54  operator int64_t() const { return val; }
55  bool isDynamic() const { return ShapedType::isDynamic(val); }
56 
57 private:
58  int64_t val;
59 };
60 
61 } // namespace
62 
63 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
64  if (lhs.isDynamic() || rhs.isDynamic()) {
65  return DimensionSize::dynamic();
66  }
67  return lhs.value() / rhs.value();
68 }
69 
70 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
71  if (lhs.isDynamic() || rhs.isDynamic()) {
72  return DimensionSize::dynamic();
73  }
74  return lhs.value() * rhs.value();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Mesh dialect
79 //===----------------------------------------------------------------------===//
80 
81 void MeshDialect::initialize() {
82  addOperations<
83 #define GET_OP_LIST
84 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
85  >();
86  addAttributes<
87 #define GET_ATTRDEF_LIST
88 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
89  >();
90  addTypes<
91 #define GET_TYPEDEF_LIST
92 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
93  >();
94 }
95 
97  Type type, Location loc) {
98  return arith::ConstantOp::materialize(builder, value, type, loc);
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Mesh utilities
103 //===----------------------------------------------------------------------===//
104 
105 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
106  FlatSymbolRefAttr meshSymbol,
107  SymbolTableCollection &symbolTable) {
108  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
109  if (!mesh) {
110  return op->emitError() << "Undefined required mesh symbol \""
111  << meshSymbol.getValue() << "\".";
112  }
113 
114  return mesh;
115 }
116 
117 template <typename It>
118 bool isUnique(It begin, It end) {
119  if (begin == end) {
120  return true;
121  }
122  It next = std::next(begin);
123  if (next == end) {
124  return true;
125  }
126  for (; next != end; ++next, ++begin) {
127  if (*begin == *next) {
128  return false;
129  }
130  }
131  return true;
132 }
133 
134 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
135  MeshOp mesh) {
136  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
137  llvm::sort(sorted);
138  if (!isUnique(sorted.begin(), sorted.end())) {
139  return emitError(loc) << "Mesh axes contains duplicate elements.";
140  }
141 
142  MeshAxis rank = mesh.getRank();
143  for (auto axis : axes) {
144  if (axis >= rank || axis < 0) {
145  return emitError(loc)
146  << "0-based mesh axis index " << axis
147  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
148  << "\" is of rank " << rank << ".";
149  }
150  }
151 
152  return success();
153 }
154 
155 template <typename Op>
156 static FailureOr<MeshOp>
158  auto mesh =
159  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
160  if (failed(mesh)) {
161  return failure();
162  }
163  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
164  return failure();
165  }
166  return mesh;
167 }
168 
169 template <typename InShape, typename MeshShape, typename SplitAxes,
170  typename OutShape>
171 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
172  const SplitAxes &splitAxes, OutShape &outShape,
173  ArrayRef<int64_t> shardedDimsSizes = {},
174  ArrayRef<int64_t> haloSizes = {}) {
175  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
176  llvm::adl_begin(outShape));
177 
178  if (!shardedDimsSizes.empty()) {
179  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
180  if (innerSplitAxes.empty()) {
181 #ifndef NDEBUG
182  for (auto dimSz : shardedDimsSizes) {
183  auto inAxis = dimSz % inShape.size();
184  assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
185  inShape[inAxis] == ShapedType::kDynamic);
186  }
187 #endif // NDEBUG
188  } else {
189  // find sharded dims in sharded_dims_sizes with same static size on
190  // all devices. Use kDynamic for dimensions with dynamic or non-uniform
191  // sizes in sharded_dims_sizes.
192  auto sz = shardedDimsSizes[tensorAxis];
193  bool same = true;
194  for (size_t i = tensorAxis + inShape.size();
195  i < shardedDimsSizes.size(); i += inShape.size()) {
196  if (shardedDimsSizes[i] != sz) {
197  same = false;
198  break;
199  }
200  }
201  outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
202  }
203  }
204  } else {
205  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
206  outShape[tensorAxis] = shardDimension(
207  inShape[tensorAxis],
208  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
209  }
210 
211  if (!haloSizes.empty()) {
212  // add halo sizes if requested
213  int haloAxis = 0;
214  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
215  if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
216  !innerSplitAxes.empty()) {
217  if (haloSizes[haloAxis * 2] >= 0 &&
218  haloSizes[haloAxis * 2 + 1] >= 0) {
219  outShape[tensorAxis] +=
220  haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
221  ++haloAxis;
222  } else {
223  outShape[tensorAxis] = ShapedType::kDynamic;
224  }
225  }
226  }
227  }
228  }
229 }
230 
231 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
232  MeshSharding sharding) {
233  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
234  SmallVector<Dim> resShapeArr(shape.getShape().size());
235  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
236  resShapeArr, sharding.getStaticShardedDimsSizes(),
237  sharding.getStaticHaloSizes());
238  return shape.clone(resShapeArr);
239 }
240 
241 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
242  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
243  if (rankedTensorType) {
244  return shardShapedType(rankedTensorType, mesh, sharding);
245  }
246  return type;
247 }
248 
250  OpOperand &operand,
251  OpBuilder &builder) {
252  OpBuilder::InsertionGuard insertionGuard(builder);
253  Value operandValue = operand.get();
254  Operation *operandOp = operand.getOwner();
255  builder.setInsertionPointAfterValue(operandValue);
256  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
257  if (shardOp && sharding == shardOp.getSharding() &&
258  !shardOp.getAnnotateForUsers()) {
259  // No need for anything the correct sharding is already set.
260  return;
261  }
262 
263  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
264  auto newShardOp =
265  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
266  /*annotate_for_users*/ false);
267  IRRewriter rewriter(builder);
268  rewriter.replaceUsesWithIf(
269  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
270  return use.getOwner() == operandOp && use.get() == operandValue;
271  });
272 
273  if (!shardOp || shardOp.getAnnotateForUsers()) {
274  return;
275  }
276 
277  auto newShardOp2 =
278  builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
279  /*annotate_for_users*/ true);
280  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
281 }
282 
284  OpResult result,
285  OpBuilder &builder) {
286  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
287  maybeInsertTargetShardingAnnotation(sharding, use, builder);
288  }
289 }
290 
292  OpOperand &operand,
293  OpBuilder &builder) {
294  OpBuilder::InsertionGuard insertionGuard(builder);
295  Value operandValue = operand.get();
296  Operation *operandOp = operand.getOwner();
297  Operation *operandSrcOp = operandValue.getDefiningOp();
298  bool isBlockArg = !operandSrcOp;
299  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
300 
301  if (shardOp && sharding == shardOp.getSharding() &&
302  shardOp.getAnnotateForUsers()) {
303  // No need for anything the correct sharding is already set.
304  return;
305  }
306 
307  builder.setInsertionPoint(operandOp);
308  auto shardingOp =
309  builder.create<ShardingOp>(operand.get().getLoc(), sharding);
310  auto newShardOp =
311  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
312  /*annotate_for_users*/ true);
313  IRRewriter rewriter(builder);
314  rewriter.replaceUsesWithIf(
315  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
316  return use.getOwner() == operandOp && use.get() == operandValue;
317  });
318 
319  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
320  // No need for resharding.
321  return;
322  }
323 
324  builder.setInsertionPoint(newShardOp);
325  auto newPreceedingShardOp =
326  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
327  /*annotate_for_users*/ false);
328  rewriter.replaceUsesWithIf(
329  newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
330  return use.getOwner() == newShardOp.getOperation();
331  });
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // mesh.mesh op
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult MeshOp::verify() {
339  int64_t rank = getRank();
340 
341  if (rank <= 0)
342  return emitOpError("rank of mesh is expected to be a positive integer");
343 
344  for (int64_t dimSize : getShape()) {
345  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
346  return emitOpError("dimension size of a mesh is expected to be "
347  "non-negative or dynamic");
348  }
349 
350  return success();
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // mesh.mesh_shape op
355 //===----------------------------------------------------------------------===//
356 
357 LogicalResult
358 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
359  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
360  if (failed(mesh)) {
361  return failure();
362  }
363  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
364  return failure();
365  }
366 
367  size_t expectedResultsCount =
368  getAxes().empty() ? mesh->getRank() : getAxes().size();
369  if (getResult().size() != expectedResultsCount) {
370  return emitError() << "Unexpected number of results " << getResult().size()
371  << ". Expected " << expectedResultsCount << ".";
372  }
373 
374  return success();
375 }
376 
377 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
378  MeshOp mesh) {
379  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
380 }
381 
382 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
383  MeshOp mesh, ArrayRef<MeshAxis> axes) {
384  build(odsBuilder, odsState,
385  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
386  odsBuilder.getIndexType()),
387  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
388 }
389 
390 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
391  StringRef mesh, ArrayRef<MeshAxis> axes) {
392  assert(!axes.empty());
393  build(odsBuilder, odsState,
394  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
395  MeshAxesAttr::get(odsBuilder.getContext(), axes));
396 }
397 
398 void MeshShapeOp::getAsmResultNames(
399  function_ref<void(Value, StringRef)> setNameFn) {
400  setNameFn(getResults()[0], "mesh_shape");
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // mesh.sharding
405 //===----------------------------------------------------------------------===//
406 
407 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
408  FlatSymbolRefAttr mesh,
409  ArrayRef<MeshAxesAttr> split_axes,
410  ArrayRef<MeshAxis> partial_axes,
411  mesh::ReductionKind partial_type,
412  ArrayRef<int64_t> static_halo_sizes,
413  ArrayRef<int64_t> static_sharded_dims_sizes) {
414  return build(
415  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
416  ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
417  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
418  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
419  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
420  {});
421 }
422 
423 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
424  FlatSymbolRefAttr mesh,
425  ArrayRef<MeshAxesAttr> split_axes) {
426  return build(
427  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
428  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
429  {}, {}, {}, {});
430 }
431 
432 void ShardingOp::build(
433  ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
434  FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
436  ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
437  mlir::SmallVector<int64_t> staticHalos, staticDims;
438  mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
439  dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
440  dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
441  return build(
442  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
443  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
444  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
445  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
446 }
447 
448 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
450 
451  build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
453  from.getPartialAxes().empty()
457  from.getPartialType()),
458  from.getStaticShardedDimsSizes().empty()
462  from.getStaticHaloSizes().empty()
465  from.getDynamicHaloSizes());
466 }
467 
468 LogicalResult ShardingOp::verify() {
469  llvm::SmallSet<MeshAxis, 4> visitedAxes;
470 
471  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
472  for (MeshAxis axis : axesArray) {
473  if (axis < 0)
474  return emitError() << "mesh axis is expected to be non-negative";
475  if (!visitedAxes.insert(axis).second)
476  return emitError() << "mesh axis duplicated";
477  }
478  return success();
479  };
480 
481  for (auto subAxes : getSplitAxes().getAxes()) {
482  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
483  if (failed(checkMeshAxis(subAxesArray)))
484  return failure();
485  }
486  if (getPartialAxes().has_value() &&
487  failed(checkMeshAxis(getPartialAxes().value())))
488  return failure();
489 
490  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
491  return emitOpError("halo sizes and shard shapes are mutually exclusive");
492  }
493 
494  if (!getStaticHaloSizes().empty()) {
495  auto numSplitAxes = getSplitAxes().getAxes().size();
496  for (auto splitAxis : getSplitAxes().getAxes()) {
497  if (splitAxis.empty()) {
498  --numSplitAxes;
499  }
500  }
501  if (getStaticHaloSizes().size() != numSplitAxes * 2) {
502  return emitError() << "halo sizes must be specified for all split axes.";
503  }
504  }
505 
506  return success();
507 }
508 
509 void ShardingOp::getAsmResultNames(
510  function_ref<void(Value, StringRef)> setNameFn) {
511  setNameFn(getResult(), "sharding");
512 }
513 
514 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
515  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
516  if (failed(mesh)) {
517  return failure();
518  }
519  if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
520  getStaticShardedDimsSizes().size() > 0) {
521  return emitError() << "sharded dims sizes are not allowed for "
522  "devices meshes with dynamic shape.";
523  }
524  return success();
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // MeshSharding
529 //===----------------------------------------------------------------------===//
530 
532  if (getMesh() != rhs.getMesh()) {
533  return false;
534  }
535 
536  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
537  return false;
538  }
539 
540  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
541  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
542  getSplitAxes().begin() + minSize),
543  llvm::make_range(rhs.getSplitAxes().begin(),
544  rhs.getSplitAxes().begin() + minSize))) {
545  return false;
546  }
547 
548  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
549  getSplitAxes().end()),
550  std::mem_fn(&MeshAxesAttr::empty)) &&
551  llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
552  rhs.getSplitAxes().end()),
553  std::mem_fn(&MeshAxesAttr::empty));
554 }
555 
557  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
558  !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
559  getStaticHaloSizes().end()),
560  llvm::make_range(rhs.getStaticHaloSizes().begin(),
561  rhs.getStaticHaloSizes().end()))) {
562  return false;
563  }
564  if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
565  !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
566  getStaticShardedDimsSizes().end()),
567  llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
568  rhs.getStaticShardedDimsSizes().end()))) {
569  return false;
570  }
571  if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
572  !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
573  getDynamicHaloSizes().end()),
574  llvm::make_range(rhs.getDynamicHaloSizes().begin(),
575  rhs.getDynamicHaloSizes().end()))) {
576  return false;
577  }
578  if (rhs.getDynamicShardedDimsSizes().size() !=
579  getDynamicShardedDimsSizes().size() ||
580  !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
582  llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
583  rhs.getDynamicShardedDimsSizes().end()))) {
584  return false;
585  }
586  return true;
587 }
588 
591 }
592 
593 bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
594 
595 bool MeshSharding::operator==(const MeshSharding &rhs) const {
597 }
598 
599 bool MeshSharding::operator!=(const MeshSharding &rhs) const {
600  return !(*this == rhs);
601 }
602 
604  auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
605  assert(shardingOp && "expected sharding op");
606  *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
607  shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
608  shardingOp.getPartialType().value_or(ReductionKind::Sum),
609  shardingOp.getStaticHaloSizes(),
610  shardingOp.getStaticShardedDimsSizes(),
611  SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
612  SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
613 }
614 
616  ArrayRef<MeshAxesAttr> split_axes_,
617  ArrayRef<MeshAxis> partial_axes_,
618  ReductionKind partial_type_,
619  ArrayRef<int64_t> static_halo_sizes_,
620  ArrayRef<int64_t> static_sharded_dims_sizes_,
621  ArrayRef<Value> dynamic_halo_sizes_,
622  ArrayRef<Value> dynamic_sharded_dims_sizes_) {
623  MeshSharding res;
624  res.mesh = mesh_;
625  res.split_axes.resize(split_axes_.size());
626  for (auto [i, axis] : llvm::enumerate(split_axes_)) {
627  res.split_axes[i] =
628  MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
629  }
630 
631  auto clone = [](const auto src, auto &dst) {
632  dst.resize(src.size());
633  llvm::copy(src, dst.begin());
634  };
635 
636  clone(partial_axes_, res.partial_axes);
637  res.partial_type = partial_type_;
638  clone(static_halo_sizes_, res.static_halo_sizes);
639  clone(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
640  clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
641  clone(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
642 
643  return res;
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // mesh.shard_shape
648 //===----------------------------------------------------------------------===//
649 
650 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
651  ::mlir::OperationState &odsState,
652  ::llvm::ArrayRef<int64_t> shape,
653  ::mlir::Value sharding, ::mlir::Value device) {
654  SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
655  build(odsBuilder, odsState, resType, shape, sharding, device);
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // mesh.shard op
660 //===----------------------------------------------------------------------===//
661 
662 void ShardOp::getAsmResultNames(
663  function_ref<void(Value, StringRef)> setNameFn) {
664  setNameFn(getResult(), "sharding_annotated");
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // mesh.process_multi_index op
669 //===----------------------------------------------------------------------===//
670 
671 LogicalResult
672 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
673  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
674  if (failed(mesh)) {
675  return failure();
676  }
677  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
678  return failure();
679  }
680 
681  size_t expectedResultsCount =
682  getAxes().empty() ? mesh->getRank() : getAxes().size();
683  if (getResult().size() != expectedResultsCount) {
684  return emitError() << "Unexpected number of results " << getResult().size()
685  << ". Expected " << expectedResultsCount << ".";
686  }
687 
688  return success();
689 }
690 
691 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
692  MeshOp mesh) {
693  build(odsBuilder, odsState,
694  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
695  mesh.getSymName(), ArrayRef<MeshAxis>());
696 }
697 
698 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
699  StringRef mesh, ArrayRef<MeshAxis> axes) {
700  build(odsBuilder, odsState,
701  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
702  MeshAxesAttr::get(odsBuilder.getContext(), axes));
703 }
704 
705 void ProcessMultiIndexOp::getAsmResultNames(
706  function_ref<void(Value, StringRef)> setNameFn) {
707  setNameFn(getResults()[0], "proc_linear_idx");
708 }
709 
710 //===----------------------------------------------------------------------===//
711 // mesh.process_linear_index op
712 //===----------------------------------------------------------------------===//
713 
714 LogicalResult
715 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
716  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
717  if (failed(mesh)) {
718  return failure();
719  }
720  return success();
721 }
722 
723 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
724  OperationState &odsState, MeshOp mesh) {
725  build(odsBuilder, odsState, mesh.getSymName());
726 }
727 
728 void ProcessLinearIndexOp::getAsmResultNames(
729  function_ref<void(Value, StringRef)> setNameFn) {
730  setNameFn(getResult(), "proc_linear_idx");
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // collective communication ops
735 //===----------------------------------------------------------------------===//
736 
737 namespace {
738 
739 template <typename Op>
740 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
742  LogicalResult matchAndRewrite(Op op,
743  PatternRewriter &rewriter) const override {
744  auto meshAxes = op.getMeshAxes();
745  if (!meshAxes.empty()) {
746  return failure();
747  }
748  if (op.getInput().getType() != op.getResult().getType()) {
749  return failure();
750  }
751 
752  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
753  rewriter.eraseOp(op.getOperation());
754  return success();
755  }
756 };
757 
758 } // namespace
759 
760 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
761  ArrayRef<int64_t> device,
762  Operation::operand_range deviceDynamic,
763  ArrayRef<MeshAxis> meshAxes,
764  ArrayRef<int64_t> meshShape) {
765  if (device.size() != meshAxes.size()) {
766  return emitError(loc) << "In-group device \"" << deviceName
767  << "\" has unexpected multi-index size "
768  << device.size() << ". Expected " << meshAxes.size()
769  << ".";
770  }
771 
772  for (size_t i = 0; i < device.size(); ++i) {
773  if (!ShapedType::isDynamic(device[i]) &&
774  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
775  meshShape[meshAxes[i]] <= device[i]) {
776  return emitError(loc)
777  << "Out of bounds coordinate " << i << " for in-group device \""
778  << deviceName << "\"."
779  << " Got " << device[i] << ", but expected value in the range [0, "
780  << (meshShape[meshAxes[i]] - 1) << "].";
781  }
782  }
783  return success();
784 }
785 
786 template <typename It>
787 static auto product(It begin, It end) {
788  using ElementType = std::decay_t<decltype(*begin)>;
789  return std::accumulate(begin, end, static_cast<ElementType>(1),
790  std::multiplies<ElementType>());
791 }
792 
793 template <typename R>
794 static auto product(R &&range) {
795  return product(adl_begin(range), adl_end(range));
796 }
797 
798 static LogicalResult verifyDimensionCompatibility(Location loc,
799  int64_t expectedDimSize,
800  int64_t resultDimSize,
801  int64_t resultAxis) {
802  if (!ShapedType::isDynamic(resultDimSize) &&
803  expectedDimSize != resultDimSize) {
804  return emitError(loc) << "Dimension size mismatch for result axis "
805  << resultAxis << ". Expected "
806  << (ShapedType::isDynamic(expectedDimSize)
807  ? Twine("dynamic")
808  : Twine(expectedDimSize))
809  << ", but got " << resultDimSize << ".";
810  }
811 
812  return success();
813 }
814 
816  Value operand, Value result, int64_t gatherAxis,
817  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
818  auto resultRank = cast<ShapedType>(result.getType()).getRank();
819  if (gatherAxis < 0 || gatherAxis >= resultRank) {
820  return emitError(result.getLoc())
821  << "Gather axis " << gatherAxis << " is out of bounds [0, "
822  << resultRank << ").";
823  }
824 
825  ShapedType operandType = cast<ShapedType>(operand.getType());
826  ShapedType resultType = cast<ShapedType>(result.getType());
827  auto deviceGroupSize =
828  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
829  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
830  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
831  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
832  auto expectedResultDimSize =
833  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
834  if (failed(verifyDimensionCompatibility(
835  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
836  return failure();
837  }
838  }
839  return success();
840 }
841 
843  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
844  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
845  ShapedType operandType = cast<ShapedType>(operand.getType());
846  ShapedType resultType = cast<ShapedType>(result.getType());
847  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
848  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
849  if (failed(verifyDimensionCompatibility(
850  result.getLoc(), operandType.getDimSize(axis),
851  resultType.getDimSize(axis), axis))) {
852  return failure();
853  }
854  }
855  }
856 
857  if (splitAxis == concatAxis) {
858  return success();
859  }
860 
861  auto deviceGroupSize =
862  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
863  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
864  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
865  DimensionSize expectedResultConcatDimSize =
866  operandConcatDimSize * deviceGroupSize;
867  DimensionSize expectedResultSplitDimSize =
868  operandSplitDimSize / deviceGroupSize;
869  if (!expectedResultSplitDimSize.isDynamic() &&
870  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
871  expectedResultSplitDimSize = DimensionSize::dynamic();
872  }
873  if (failed(verifyDimensionCompatibility(
874  result.getLoc(), expectedResultConcatDimSize.value(),
875  resultType.getDimSize(concatAxis), concatAxis))) {
876  return failure();
877  }
878  if (failed(verifyDimensionCompatibility(
879  result.getLoc(), expectedResultSplitDimSize.value(),
880  resultType.getDimSize(splitAxis), splitAxis))) {
881  return failure();
882  }
883 
884  return success();
885 }
886 
888  Value operand, Value result, int64_t tensorAxis,
889  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
890  ShapedType operandType = cast<ShapedType>(operand.getType());
891  ShapedType resultType = cast<ShapedType>(result.getType());
892  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
893  if (axis != tensorAxis) {
894  if (failed(verifyDimensionCompatibility(
895  result.getLoc(), operandType.getDimSize(axis),
896  resultType.getDimSize(axis), axis))) {
897  return failure();
898  }
899  }
900  }
901 
902  auto deviceGroupSize =
903  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
904  auto operandScatterDimSize =
905  DimensionSize(operandType.getDimSize(tensorAxis));
906  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
907  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
908  return emitError(result.getLoc())
909  << "Operand dimension size " << int64_t(operandScatterDimSize)
910  << " is not divisible by collective device group size "
911  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
912  << ".";
913  }
914  DimensionSize expectedResultTensorDimSize =
915  operandScatterDimSize / deviceGroupSize;
916  if (failed(verifyDimensionCompatibility(
917  result.getLoc(), expectedResultTensorDimSize.value(),
918  resultType.getDimSize(tensorAxis), tensorAxis))) {
919  return failure();
920  }
921 
922  return success();
923 }
924 
925 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
926  ArrayRef<MeshAxis> meshAxes,
927  int64_t sliceAxis) {
928  RankedTensorType operandRankedTensorType =
929  cast<RankedTensorType>(operandType);
930  DimensionSize operandSliceAxisSize =
931  operandRankedTensorType.getShape()[sliceAxis];
932  SmallVector<int64_t> resultShape =
933  llvm::to_vector(operandRankedTensorType.getShape());
934 
935  resultShape[sliceAxis] =
936  operandSliceAxisSize /
937  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
938  return operandRankedTensorType.clone(resultShape);
939 }
940 
941 //===----------------------------------------------------------------------===//
942 // mesh.all_gather op
943 //===----------------------------------------------------------------------===//
944 
945 LogicalResult
946 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
947  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
948  if (failed(mesh)) {
949  return failure();
950  }
951  auto gatherAxis = getGatherAxis().getSExtValue();
952  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
953  gatherAxis, getMeshAxes(),
954  mesh.value().getShape());
955 }
956 
957 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
958  MLIRContext *context) {
959  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
960 }
961 
962 void AllGatherOp::getAsmResultNames(
963  function_ref<void(Value, StringRef)> setNameFn) {
964  setNameFn(getResult(), "all_gather");
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // mesh.all_reduce op
969 //===----------------------------------------------------------------------===//
970 
971 LogicalResult
972 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
973  return getMeshAndVerifyAxes(*this, symbolTable);
974 }
975 
976 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
977  MLIRContext *context) {
978  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
979 }
980 
981 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
982  Value input, StringRef mesh,
983  ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
984  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
985  reduction);
986 }
987 
988 void AllReduceOp::getAsmResultNames(
989  function_ref<void(Value, StringRef)> setNameFn) {
990  setNameFn(getResult(), "all_reduce");
991 }
992 
993 //===----------------------------------------------------------------------===//
994 // mesh.all_slice op
995 //===----------------------------------------------------------------------===//
996 
997 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
998  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
999  if (failed(mesh)) {
1000  return failure();
1001  }
1003  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1004  mesh.value().getShape());
1005 }
1006 
1007 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1008  MLIRContext *context) {
1009  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1010 }
1011 
1012 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1013  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1014  int64_t sliceAxis) {
1015  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1016  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1017  sliceAxis);
1018 }
1019 
1020 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1021  Type resultType, Value input, StringRef mesh,
1022  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1023  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1024  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1025 }
1026 
1027 void AllSliceOp::getAsmResultNames(
1028  function_ref<void(Value, StringRef)> setNameFn) {
1029  setNameFn(getResult(), "all_slice");
1030 }
1031 
1032 //===----------------------------------------------------------------------===//
1033 // mesh.all_to_all op
1034 //===----------------------------------------------------------------------===//
1035 
1036 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1037  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1038  if (failed(mesh)) {
1039  return failure();
1040  }
1041 
1043  getOperand(), getResult(), getSplitAxis().getSExtValue(),
1044  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1045 }
1046 
1047 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1048  MLIRContext *context) {
1049  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1050 }
1051 
1052 void AllToAllOp::getAsmResultNames(
1053  function_ref<void(Value, StringRef)> setNameFn) {
1054  setNameFn(getResult(), "all_to_all");
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // mesh.broadcast op
1059 //===----------------------------------------------------------------------===//
1060 
1061 LogicalResult
1062 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1063  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1064  if (failed(mesh)) {
1065  return failure();
1066  }
1067  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1068  getRootDynamic(), getMeshAxes(),
1069  mesh.value().getShape()))) {
1070  return failure();
1071  }
1072 
1073  return success();
1074 }
1075 
1076 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1077  MLIRContext *context) {
1078  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1079 }
1080 
1081 void BroadcastOp::getAsmResultNames(
1082  function_ref<void(Value, StringRef)> setNameFn) {
1083  setNameFn(getResult(), "broadcast");
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // mesh.gather op
1088 //===----------------------------------------------------------------------===//
1089 
1090 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1091  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1092  if (failed(mesh)) {
1093  return failure();
1094  }
1095  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1096  getRootDynamic(), getMeshAxes(),
1097  mesh.value().getShape()))) {
1098  return failure();
1099  }
1100 
1101  auto gatherAxis = getGatherAxis().getSExtValue();
1102  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1103  getMeshAxes(),
1104  mesh.value().getShape());
1105 }
1106 
1107 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1108  MLIRContext *context) {
1109  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1110 }
1111 
1112 void GatherOp::getAsmResultNames(
1113  function_ref<void(Value, StringRef)> setNameFn) {
1114  setNameFn(getResult(), "gather");
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // mesh.recv op
1119 //===----------------------------------------------------------------------===//
1120 
1121 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1122  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1123  if (failed(mesh)) {
1124  return failure();
1125  }
1126  if (getSource() &&
1127  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1128  getSource().value(), getSourceDynamic(),
1129  getMeshAxes(), mesh.value().getShape()))) {
1130  return failure();
1131  }
1132  return success();
1133 }
1134 
1135 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1136  MLIRContext *context) {
1137  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1138 }
1139 
1140 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1141  setNameFn(getResult(), "recv");
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // mesh.reduce op
1146 //===----------------------------------------------------------------------===//
1147 
1148 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1149  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1150  if (failed(mesh)) {
1151  return failure();
1152  }
1153  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1154  getRootDynamic(), getMeshAxes(),
1155  mesh.value().getShape()))) {
1156  return failure();
1157  }
1158 
1159  return success();
1160 }
1161 
1162 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1163  MLIRContext *context) {
1164  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1165 }
1166 
1167 void ReduceOp::getAsmResultNames(
1168  function_ref<void(Value, StringRef)> setNameFn) {
1169  setNameFn(getResult(), "reduce");
1170 }
1171 
1172 //===----------------------------------------------------------------------===//
1173 // mesh.reduce_scatter op
1174 //===----------------------------------------------------------------------===//
1175 
1176 LogicalResult
1177 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1178  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1179  if (failed(mesh)) {
1180  return failure();
1181  }
1182 
1184  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1185  mesh.value().getShape());
1186 }
1187 
1188 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1189  MLIRContext *context) {
1190  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1191 }
1192 
1193 void ReduceScatterOp::getAsmResultNames(
1194  function_ref<void(Value, StringRef)> setNameFn) {
1195  setNameFn(getResult(), "reduce_scatter");
1196 }
1197 
1198 //===----------------------------------------------------------------------===//
1199 // mesh.scatter op
1200 //===----------------------------------------------------------------------===//
1201 
1202 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1203  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1204  if (failed(mesh)) {
1205  return failure();
1206  }
1207  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1208  getRootDynamic(), getMeshAxes(),
1209  mesh.value().getShape()))) {
1210  return failure();
1211  }
1212 
1213  auto scatterAxis = getScatterAxis().getSExtValue();
1214  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1215  scatterAxis, getMeshAxes(),
1216  mesh.value().getShape());
1217 }
1218 
1219 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1220  MLIRContext *context) {
1221  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1222 }
1223 
1224 void ScatterOp::getAsmResultNames(
1225  function_ref<void(Value, StringRef)> setNameFn) {
1226  setNameFn(getResult(), "scatter");
1227 }
1228 
1229 //===----------------------------------------------------------------------===//
1230 // mesh.send op
1231 //===----------------------------------------------------------------------===//
1232 
1233 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1234  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1235  if (failed(mesh)) {
1236  return failure();
1237  }
1238  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1239  getDestination(), getDestinationDynamic(),
1240  getMeshAxes(), mesh.value().getShape()))) {
1241  return failure();
1242  }
1243  return success();
1244 }
1245 
1246 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1247  MLIRContext *context) {
1248  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1249 }
1250 
1251 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1252  setNameFn(getResult(), "send");
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // mesh.shift op
1257 //===----------------------------------------------------------------------===//
1258 
1259 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1260  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1261  if (failed(mesh)) {
1262  return failure();
1263  }
1264 
1265  auto meshAxes = getMeshAxes();
1266  auto shiftAxis = getShiftAxis().getZExtValue();
1267  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1268  return emitError() << "Invalid shift axis " << shiftAxis
1269  << ". It must be one of the grouping mesh axes.";
1270  }
1271 
1272  return success();
1273 }
1274 
1275 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1276  MLIRContext *context) {
1277  // TODO: remove op when offset is 0 or if it is a rotate with and
1278  // offset % shift_axis_mesh_dim_size == 0.
1279 }
1280 
1281 void ShiftOp::getAsmResultNames(
1282  function_ref<void(Value, StringRef)> setNameFn) {
1283  setNameFn(getResult(), "shift");
1284 }
1285 
1286 //===----------------------------------------------------------------------===//
1287 // mesh.update_halo op
1288 //===----------------------------------------------------------------------===//
1289 
1290 LogicalResult
1291 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1292  auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
1293  if (failed(mesh)) {
1294  return failure();
1295  }
1296 
1297  return success();
1298 }
1299 
1300 //===----------------------------------------------------------------------===//
1301 // TableGen'd op method definitions
1302 //===----------------------------------------------------------------------===//
1303 
1304 #define GET_OP_CLASSES
1305 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1306 
1307 #define GET_ATTRDEF_CLASSES
1308 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1309 
1310 #define GET_TYPEDEF_CLASSES
1311 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1312 
1313 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:925
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:63
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:798
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:157
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:105
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:887
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:815
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:842
static auto product(It begin, It end)
Definition: MeshOps.cpp:787
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsSizes={}, ArrayRef< int64_t > haloSizes={})
Definition: MeshOps.cpp:171
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:118
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:134
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:760
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:195
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
Definition: Builders.cpp:187
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:83
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:403
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
Definition: Operation.h:279
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:531
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:556
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:63
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:65
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
Definition: MeshOps.cpp:615
bool operator!=(Value rhs) const
Definition: MeshOps.cpp:593
ReductionKind getPartialType() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicShardedDimsSizes() const
Definition: MeshOps.h:73
ArrayRef< int64_t > getStaticShardedDimsSizes() const
Definition: MeshOps.h:69
bool operator==(Value rhs) const
Definition: MeshOps.cpp:589
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:66
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:72
::llvm::StringRef getMesh() const
Definition: MeshOps.h:64
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:68
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:147
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:116
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:291
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:241
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:231
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:169
int16_t MeshAxis
Definition: MeshOps.h:25
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:249
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:265
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.