MLIR  21.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <functional>
36 #include <iterator>
37 #include <numeric>
38 #include <optional>
39 #include <utility>
40 
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 using namespace mlir;
45 using namespace mlir::mesh;
46 
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
48 
49 namespace {
50 
51 struct DimensionSize {
52  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
53  DimensionSize(int64_t val) : val(val) {}
54  int64_t value() const { return val; }
55  operator int64_t() const { return val; }
56  bool isDynamic() const { return ShapedType::isDynamic(val); }
57 
58 private:
59  int64_t val;
60 };
61 
62 } // namespace
63 
64 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
65  if (lhs.isDynamic() || rhs.isDynamic()) {
66  return DimensionSize::dynamic();
67  }
68  return lhs.value() / rhs.value();
69 }
70 
71 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
72  if (lhs.isDynamic() || rhs.isDynamic()) {
73  return DimensionSize::dynamic();
74  }
75  return lhs.value() * rhs.value();
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // Inliner
80 //===----------------------------------------------------------------------===//
81 
82 namespace {
83 struct MeshInlinerInterface : public DialectInlinerInterface {
85  // Currently no restrictions are encoded for inlining.
86  bool isLegalToInline(Operation *, Operation *, bool) const final {
87  return true;
88  }
89  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
90  return true;
91  }
92  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
93  return true;
94  }
95 };
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // Mesh dialect
100 //===----------------------------------------------------------------------===//
101 
102 void MeshDialect::initialize() {
103  addOperations<
104 #define GET_OP_LIST
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
106  >();
107  addAttributes<
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
110  >();
111  addTypes<
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
114  >();
115  addInterface<MeshInlinerInterface>();
116 }
117 
119  Type type, Location loc) {
120  return arith::ConstantOp::materialize(builder, value, type, loc);
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Mesh utilities
125 //===----------------------------------------------------------------------===//
126 
127 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
128  FlatSymbolRefAttr meshSymbol,
129  SymbolTableCollection &symbolTable) {
130  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
131  if (!mesh) {
132  return op->emitError() << "Undefined required mesh symbol \""
133  << meshSymbol.getValue() << "\".";
134  }
135 
136  return mesh;
137 }
138 
139 template <typename It>
140 bool isUnique(It begin, It end) {
141  if (begin == end) {
142  return true;
143  }
144  It next = std::next(begin);
145  if (next == end) {
146  return true;
147  }
148  for (; next != end; ++next, ++begin) {
149  if (*begin == *next) {
150  return false;
151  }
152  }
153  return true;
154 }
155 
156 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
157  MeshOp mesh) {
158  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
159  llvm::sort(sorted);
160  if (!isUnique(sorted.begin(), sorted.end())) {
161  return emitError(loc) << "Mesh axes contains duplicate elements.";
162  }
163 
164  MeshAxis rank = mesh.getRank();
165  for (auto axis : axes) {
166  if (axis >= rank || axis < 0) {
167  return emitError(loc)
168  << "0-based mesh axis index " << axis
169  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
170  << "\" is of rank " << rank << ".";
171  }
172  }
173 
174  return success();
175 }
176 
177 template <typename Op>
178 static FailureOr<MeshOp>
180  auto mesh =
181  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
182  if (failed(mesh)) {
183  return failure();
184  }
185  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
186  return failure();
187  }
188  return mesh;
189 }
190 
191 template <typename InShape, typename MeshShape, typename SplitAxes,
192  typename OutShape>
193 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
194  const SplitAxes &splitAxes, OutShape &outShape,
195  ArrayRef<int64_t> shardedDimsOffsets = {},
196  ArrayRef<int64_t> haloSizes = {}) {
197  // 0d tensors cannot be sharded and must get replicated
198  if (inShape.empty()) {
199  assert(outShape.empty());
200  return;
201  }
202 
203  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
204  llvm::adl_begin(outShape));
205 
206  if (!shardedDimsOffsets.empty()) {
207  auto isDynShape = ShapedType::isDynamicShape(meshShape);
208  uint64_t pos = 1;
209  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
210  if (!innerSplitAxes.empty()) {
211  auto sz = shardedDimsOffsets[pos];
212  bool same = !isDynShape;
213  if (same) {
214  // Find sharded dims in shardedDimsOffsets with same static size on
215  // all devices. Use kDynamic for dimensions with dynamic or
216  // non-uniform offs in shardedDimsOffsets.
217  uint64_t numShards = 0;
218  for (auto i : innerSplitAxes.asArrayRef()) {
219  numShards += meshShape[i];
220  }
221  for (size_t i = 1; i < numShards; ++i) {
222  if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
223  sz) {
224  same = false;
225  break;
226  }
227  }
228  pos += numShards + 1;
229  }
230  outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
231  }
232  }
233  } else {
234  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
235  outShape[tensorAxis] = shardDimension(
236  inShape[tensorAxis],
237  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
238  }
239 
240  if (!haloSizes.empty()) {
241  // add halo sizes if requested
242  int haloAxis = 0;
243  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
244  if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
245  !innerSplitAxes.empty()) {
246  if (haloSizes[haloAxis * 2] >= 0 &&
247  haloSizes[haloAxis * 2 + 1] >= 0) {
248  outShape[tensorAxis] +=
249  haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
250  ++haloAxis;
251  } else {
252  outShape[tensorAxis] = ShapedType::kDynamic;
253  }
254  }
255  }
256  }
257  }
258 }
259 
260 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
261  MeshSharding sharding) {
262  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
263  SmallVector<Dim> resShapeArr(shape.getShape().size());
264  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
265  resShapeArr, sharding.getStaticShardedDimsOffsets(),
266  sharding.getStaticHaloSizes());
267  return shape.clone(resShapeArr);
268 }
269 
270 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
271  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
272  if (rankedTensorType && !rankedTensorType.getShape().empty()) {
273  return shardShapedType(rankedTensorType, mesh, sharding);
274  }
275  return type;
276 }
277 
279  OpOperand &operand,
280  OpBuilder &builder,
281  ShardOp &newShardOp) {
282  OpBuilder::InsertionGuard insertionGuard(builder);
283  Value operandValue = operand.get();
284  Operation *operandOp = operand.getOwner();
285  builder.setInsertionPointAfterValue(operandValue);
286  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
287  if (shardOp && sharding == shardOp.getSharding() &&
288  !shardOp.getAnnotateForUsers()) {
289  // No need for anything if the correct sharding is already set.
290  if (!newShardOp) {
291  newShardOp = shardOp;
292  }
293  return;
294  }
295 
296  if (!newShardOp) {
297  auto shardingOp =
298  builder.create<ShardingOp>(operandValue.getLoc(), sharding);
299  newShardOp =
300  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
301  /*annotate_for_users*/ false);
302  }
303  IRRewriter rewriter(builder);
304  rewriter.replaceUsesWithIf(
305  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
306  return use.getOwner() == operandOp && use.get() == operandValue;
307  });
308 
309  if (!shardOp || shardOp.getAnnotateForUsers()) {
310  return;
311  }
312 
313  auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
314  newShardOp.getSharding(),
315  /*annotate_for_users*/ true);
316  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
317  return;
318 }
319 
321  OpResult result,
322  OpBuilder &builder) {
323  ShardOp newShardOp;
324  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
325  maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
326  }
327 }
328 
330  OpOperand &operand,
331  OpBuilder &builder) {
332  OpBuilder::InsertionGuard insertionGuard(builder);
333  Value operandValue = operand.get();
334  Operation *operandSrcOp = operandValue.getDefiningOp();
335  bool isBlockArg = !operandSrcOp;
336  {
337  [[maybe_unused]] auto opType =
338  dyn_cast<mlir::RankedTensorType>(operandValue.getType());
339  assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
340  }
341  if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
342  operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
343  return;
344  }
345 
346  Operation *operandOp = operand.getOwner();
347  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
348 
349  if (shardOp && sharding == shardOp.getSharding() &&
350  shardOp.getAnnotateForUsers()) {
351  // No need for anything the correct sharding is already set.
352  return;
353  }
354 
355  builder.setInsertionPoint(operandOp);
356  auto shardingOp =
357  builder.create<ShardingOp>(operand.get().getLoc(), sharding);
358  auto newShardOp =
359  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
360  /*annotate_for_users*/ true);
361  IRRewriter rewriter(builder);
362  rewriter.replaceUsesWithIf(
363  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
364  return use.getOwner() == operandOp && use.get() == operandValue;
365  });
366 
367  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
368  // No need for resharding.
369  return;
370  }
371 
372  builder.setInsertionPoint(newShardOp);
373  auto newPreceedingShardOp =
374  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
375  /*annotate_for_users*/ false);
376  rewriter.replaceUsesWithIf(
377  newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
378  return use.getOwner() == newShardOp.getOperation();
379  });
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // mesh.mesh op
384 //===----------------------------------------------------------------------===//
385 
386 LogicalResult MeshOp::verify() {
387  int64_t rank = getRank();
388 
389  if (rank <= 0)
390  return emitOpError("rank of mesh is expected to be a positive integer");
391 
392  for (int64_t dimSize : getShape()) {
393  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
394  return emitOpError("dimension size of a mesh is expected to be "
395  "non-negative or dynamic");
396  }
397 
398  return success();
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // mesh.mesh_shape op
403 //===----------------------------------------------------------------------===//
404 
405 LogicalResult
406 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
407  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
408  if (failed(mesh)) {
409  return failure();
410  }
411  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
412  return failure();
413  }
414 
415  size_t expectedResultsCount =
416  getAxes().empty() ? mesh->getRank() : getAxes().size();
417  if (getResult().size() != expectedResultsCount) {
418  return emitError() << "Unexpected number of results " << getResult().size()
419  << ". Expected " << expectedResultsCount << ".";
420  }
421 
422  return success();
423 }
424 
425 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
426  MeshOp mesh) {
427  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
428 }
429 
430 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
431  MeshOp mesh, ArrayRef<MeshAxis> axes) {
432  build(odsBuilder, odsState,
433  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
434  odsBuilder.getIndexType()),
435  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
436 }
437 
438 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
439  StringRef mesh, ArrayRef<MeshAxis> axes) {
440  assert(!axes.empty());
441  build(odsBuilder, odsState,
442  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
443  MeshAxesAttr::get(odsBuilder.getContext(), axes));
444 }
445 
446 void MeshShapeOp::getAsmResultNames(
447  function_ref<void(Value, StringRef)> setNameFn) {
448  setNameFn(getResults()[0], "mesh_shape");
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // mesh.sharding
453 //===----------------------------------------------------------------------===//
454 
455 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
456  FlatSymbolRefAttr mesh,
457  ArrayRef<MeshAxesAttr> split_axes,
458  ArrayRef<MeshAxis> partial_axes,
459  mesh::ReductionKind partial_type,
460  ArrayRef<int64_t> static_halos,
461  ArrayRef<int64_t> static_offsets) {
462  return build(
463  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
464  ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
465  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
466  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
467  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
468 }
469 
470 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
471  FlatSymbolRefAttr mesh,
472  ArrayRef<MeshAxesAttr> split_axes) {
473  return build(
474  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
475  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
476  {}, {}, {}, {});
477 }
478 
479 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480  llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
481  ArrayRef<int64_t> static_halos,
482  ArrayRef<int64_t> static_offsets) {
483  return build(
484  b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
485  MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
486  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
487  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
488  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
489 }
490 
491 void ShardingOp::build(
492  ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
493  FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
495  ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
496  mlir::SmallVector<int64_t> staticHalos, staticDims;
497  mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
498  dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
499  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
500  return build(
501  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
502  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
503  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
504  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
505 }
506 
507 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
509 
510  build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
512  from.getPartialAxes().empty()
516  from.getPartialType()),
517  from.getStaticShardedDimsOffsets().empty()
521  from.getStaticHaloSizes().empty()
524  from.getDynamicHaloSizes());
525 }
526 
527 LogicalResult ShardingOp::verify() {
528  llvm::SmallSet<MeshAxis, 4> visitedAxes;
529 
530  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
531  for (MeshAxis axis : axesArray) {
532  if (axis < 0)
533  return emitError() << "mesh axis is expected to be non-negative";
534  if (!visitedAxes.insert(axis).second)
535  return emitError() << "mesh axis duplicated";
536  }
537  return success();
538  };
539 
540  for (auto subAxes : getSplitAxes().getAxes()) {
541  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
542  if (failed(checkMeshAxis(subAxesArray)))
543  return failure();
544  }
545  if (getPartialAxes().has_value() &&
546  failed(checkMeshAxis(getPartialAxes().value())))
547  return failure();
548 
549  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
550  return emitOpError("halo sizes and shard offsets are mutually exclusive");
551  }
552 
553  if (!getStaticHaloSizes().empty()) {
554  auto numSplitAxes = getSplitAxes().getAxes().size();
555  for (auto splitAxis : getSplitAxes().getAxes()) {
556  if (splitAxis.empty()) {
557  --numSplitAxes;
558  }
559  }
560  if (getStaticHaloSizes().size() != numSplitAxes * 2) {
561  return emitError() << "halo sizes must be specified for all split axes.";
562  }
563  }
564 
565  return success();
566 }
567 
568 void ShardingOp::getAsmResultNames(
569  function_ref<void(Value, StringRef)> setNameFn) {
570  setNameFn(getResult(), "sharding");
571 }
572 
573 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
574  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
575  if (failed(mesh)) {
576  return failure();
577  }
578  if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
579  getStaticShardedDimsOffsets().size() > 0) {
580  return emitError() << "sharded dims offsets are not allowed for "
581  "devices meshes with dynamic shape.";
582  }
583 
584  auto shardedDimsOffsets = getStaticShardedDimsOffsets();
585  if (!shardedDimsOffsets.empty()) {
586  auto meshShape = mesh.value().getShape();
587  assert(!ShapedType::isDynamicShape(meshShape));
588  uint64_t pos = 0;
589  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
590  if (!innerSplitAxes.empty()) {
591  int64_t numShards = 0, off = 0;
592  for (auto i : innerSplitAxes.asArrayRef()) {
593  numShards += meshShape[i];
594  }
595  for (int64_t i = 0; i <= numShards; ++i) {
596  if (shardedDimsOffsets.size() <= pos + i) {
597  return emitError() << "sharded dims offsets has wrong size.";
598  }
599  if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
600  if (shardedDimsOffsets[pos + i] < off) {
601  return emitError()
602  << "sharded dims offsets must be non-decreasing.";
603  }
604  off = shardedDimsOffsets[pos + i];
605  }
606  }
607  pos += numShards + 1;
608  }
609  }
610  }
611  return success();
612 }
613 
614 namespace {
615 // Sharding annotations "halo sizes" and "sharded dims offsets"
616 // are a mix of attributes and dynamic values. This canonicalization moves
617 // constant values to the respective attribute lists, minimizing the number
618 // of values.
619 // It also removes sharded_dims_sizes and halos if they are effectively "empty".
620 class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
621 public:
623 
624  LogicalResult matchAndRewrite(ShardingOp op,
625  PatternRewriter &b) const override {
626  auto mixedHalos =
627  getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
628  auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
629  op.getDynamicShardedDimsOffsets(), b);
630 
631  // No constant operands were folded, just return;
632  bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
633  succeeded(foldDynamicIndexList(mixedOffs, true));
634 
635  auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
636  auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
637 
638  if (dynamicHalos.empty() && !staticHalos.empty()) {
639  if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
640  staticHalos.clear();
641  modified = true;
642  }
643  }
644 
645  // Remove sharded dims offsets if they are effectively the default values,
646  // e.g. if they define equi-distance between all neighboring shards.
647  // Requires static-only offsets. Compares the first distance as the
648  // difference between the first two offsets. Only if all consecutive
649  // distances are the same, the offsets are removed.
650  if (dynamicOffs.empty() && !staticOffs.empty()) {
651  assert(staticOffs.size() >= 2);
652  auto diff = staticOffs[1] - staticOffs[0];
653  bool all_same = staticOffs.size() > 2;
654  for (auto i = 2u; i < staticOffs.size(); ++i) {
655  if (staticOffs[i] - staticOffs[i - 1] != diff) {
656  all_same = false;
657  break;
658  }
659  }
660  if (all_same) {
661  staticOffs.clear();
662  modified = true;
663  }
664  }
665 
666  if (!modified) {
667  return failure();
668  }
669 
670  op.setStaticHaloSizes(staticHalos);
671  op.getDynamicHaloSizesMutable().assign(dynamicHalos);
672  op.setStaticShardedDimsOffsets(staticOffs);
673  op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
674 
675  return success();
676  }
677 };
678 } // namespace
679 
680 void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
681  mlir::MLIRContext *context) {
682  results.add<NormalizeSharding>(context);
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // MeshSharding
687 //===----------------------------------------------------------------------===//
688 
690  if (getMesh() != rhs.getMesh()) {
691  return false;
692  }
693 
694  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
695  (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
696  !llvm::equal(
697  llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
698  llvm::make_range(rhs.getPartialAxes().begin(),
699  rhs.getPartialAxes().end()))) {
700  return false;
701  }
702 
703  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
704  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
705  getSplitAxes().begin() + minSize),
706  llvm::make_range(rhs.getSplitAxes().begin(),
707  rhs.getSplitAxes().begin() + minSize))) {
708  return false;
709  }
710 
711  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
712  getSplitAxes().end()),
713  std::mem_fn(&MeshAxesAttr::empty)) &&
714  llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
715  rhs.getSplitAxes().end()),
716  std::mem_fn(&MeshAxesAttr::empty));
717 }
718 
720  return equalShardSizes(rhs) && equalHaloSizes(rhs);
721 }
722 
724  if (rhs.getStaticShardedDimsOffsets().size() !=
725  getStaticShardedDimsOffsets().size() ||
726  !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
728  llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
729  rhs.getStaticShardedDimsOffsets().end()))) {
730  return false;
731  }
732  if (rhs.getDynamicShardedDimsOffsets().size() !=
733  getDynamicShardedDimsOffsets().size() ||
734  !llvm::equal(
735  llvm::make_range(getDynamicShardedDimsOffsets().begin(),
737  llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(),
738  rhs.getDynamicShardedDimsOffsets().end()))) {
739  return false;
740  }
741  return true;
742 }
743 
745  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
746  !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
747  getStaticHaloSizes().end()),
748  llvm::make_range(rhs.getStaticHaloSizes().begin(),
749  rhs.getStaticHaloSizes().end()))) {
750  return false;
751  }
752  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
753  !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
754  getDynamicHaloSizes().end()),
755  llvm::make_range(rhs.getDynamicHaloSizes().begin(),
756  rhs.getDynamicHaloSizes().end()))) {
757  return false;
758  }
759  return true;
760 }
761 
764 }
765 
766 bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
767 
768 bool MeshSharding::operator==(const MeshSharding &rhs) const {
770 }
771 
772 bool MeshSharding::operator!=(const MeshSharding &rhs) const {
773  return !(*this == rhs);
774 }
775 
777 
779  auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
780  assert(shardingOp && "expected sharding op");
781  auto splitAxes = shardingOp.getSplitAxes().getAxes();
782  auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
783  // If splitAxes and partialAxes are empty, use "empty" constructor.
784  if (splitAxes.empty() && partialAxes.empty()) {
785  *this = MeshSharding(shardingOp.getMeshAttr());
786  return;
787  }
788  *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
789  shardingOp.getPartialType().value_or(ReductionKind::Sum),
790  shardingOp.getStaticHaloSizes(),
791  shardingOp.getStaticShardedDimsOffsets(),
792  SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
793  SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
794 }
795 
797  ArrayRef<MeshAxesAttr> split_axes_,
798  ArrayRef<MeshAxis> partial_axes_,
799  ReductionKind partial_type_,
800  ArrayRef<int64_t> static_halo_sizes_,
801  ArrayRef<int64_t> static_sharded_dims_offsets_,
802  ArrayRef<Value> dynamic_halo_sizes_,
803  ArrayRef<Value> dynamic_sharded_dims_offsets_) {
804  MeshSharding res(mesh_);
805  if (split_axes_.empty() && partial_axes_.empty()) {
806  return res;
807  }
808 
809  res.split_axes.resize(split_axes_.size());
810  for (auto [i, axis] : llvm::enumerate(split_axes_)) {
811  res.split_axes[i] =
812  MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
813  }
814 
815  auto clone = [](const auto src, auto &dst) {
816  dst.resize(src.size());
817  llvm::copy(src, dst.begin());
818  };
819 
820  clone(partial_axes_, res.partial_axes);
821  res.partial_type = partial_type_;
822  clone(static_halo_sizes_, res.static_halo_sizes);
823  clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
824  clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
825  clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
826 
827  return res;
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // mesh.shard_shape
832 //===----------------------------------------------------------------------===//
833 
834 void ShardShapeOp::getAsmResultNames(
835  function_ref<void(Value, StringRef)> setNameFn) {
836  setNameFn(getResult()[0], "shard_shape");
837 }
838 
839 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
840  ::mlir::OperationState &odsState,
842  ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
843  ::mlir::ValueRange device) {
844  SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
845  build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
846  SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // mesh.shard op
851 //===----------------------------------------------------------------------===//
852 
853 void ShardOp::getAsmResultNames(
854  function_ref<void(Value, StringRef)> setNameFn) {
855  setNameFn(getResult(), "sharding_annotated");
856 }
857 
858 namespace {
859 // Determine if the given ShardOp is a duplicate of another ShardOp
860 // on the same value. This can happen if constant values are sharded.
861 class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
862 public:
864 
865  LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
866  // Get the use-list of the value being sharded and check if it has more than
867  // one use.
868  Value value = op.getSrc();
869  if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
870  return failure();
871  }
872 
873  // Iterate through the uses of the value to find a duplicate ShardOp.
874  for (auto &use : value.getUses()) {
875  if (use.getOwner() != op.getOperation()) {
876  auto otherOp = dyn_cast<ShardOp>(use.getOwner());
877  if (!otherOp || !otherOp->isBeforeInBlock(op)) {
878  return failure();
879  }
880  // Create a MeshSharding object for the current and the other ShardOp
881  // If the two are equal replace current op with the other op.
882  MeshSharding currentSharding(op.getSharding());
883  MeshSharding otherSharding(otherOp.getSharding());
884  if (currentSharding == otherSharding) {
885  b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
886  b.eraseOp(op.getOperation());
887  } else {
888  // use the other sharding as input for op
889  op.getSrcMutable().assign(otherOp.getResult());
890  }
891  return success();
892  }
893  }
894 
895  return failure();
896  }
897 };
898 } // namespace
899 
900 void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
901  mlir::MLIRContext *context) {
902  results.add<FoldDuplicateShardOp>(context);
903 }
904 
905 //===----------------------------------------------------------------------===//
906 // mesh.process_multi_index op
907 //===----------------------------------------------------------------------===//
908 
909 LogicalResult
910 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
911  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
912  if (failed(mesh)) {
913  return failure();
914  }
915  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
916  return failure();
917  }
918 
919  size_t expectedResultsCount =
920  getAxes().empty() ? mesh->getRank() : getAxes().size();
921  if (getResult().size() != expectedResultsCount) {
922  return emitError() << "Unexpected number of results " << getResult().size()
923  << ". Expected " << expectedResultsCount << ".";
924  }
925 
926  return success();
927 }
928 
929 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
930  MeshOp mesh) {
931  build(odsBuilder, odsState,
932  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
933  mesh.getSymName(), ArrayRef<MeshAxis>());
934 }
935 
936 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
937  StringRef mesh, ArrayRef<MeshAxis> axes) {
938  build(odsBuilder, odsState,
939  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
940  MeshAxesAttr::get(odsBuilder.getContext(), axes));
941 }
942 
943 void ProcessMultiIndexOp::getAsmResultNames(
944  function_ref<void(Value, StringRef)> setNameFn) {
945  setNameFn(getResults()[0], "proc_linear_idx");
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // mesh.process_linear_index op
950 //===----------------------------------------------------------------------===//
951 
952 LogicalResult
953 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
954  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
955  if (failed(mesh)) {
956  return failure();
957  }
958  return success();
959 }
960 
961 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
962  OperationState &odsState, MeshOp mesh) {
963  build(odsBuilder, odsState, mesh.getSymName());
964 }
965 
966 void ProcessLinearIndexOp::getAsmResultNames(
967  function_ref<void(Value, StringRef)> setNameFn) {
968  setNameFn(getResult(), "proc_linear_idx");
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // mesh.neighbors_linear_indices op
973 //===----------------------------------------------------------------------===//
974 
975 LogicalResult
976 NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
977  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
978  if (failed(mesh)) {
979  return failure();
980  }
981  return success();
982 }
983 
984 void NeighborsLinearIndicesOp::getAsmResultNames(
985  function_ref<void(Value, StringRef)> setNameFn) {
986  setNameFn(getNeighborDown(), "down_linear_idx");
987  setNameFn(getNeighborUp(), "up_linear_idx");
988 }
989 
990 //===----------------------------------------------------------------------===//
991 // collective communication ops
992 //===----------------------------------------------------------------------===//
993 
994 namespace {
995 
996 template <typename Op>
997 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
999  LogicalResult matchAndRewrite(Op op,
1000  PatternRewriter &rewriter) const override {
1001  auto meshAxes = op.getMeshAxes();
1002  if (!meshAxes.empty()) {
1003  return failure();
1004  }
1005  if (op.getInput().getType() != op.getResult().getType()) {
1006  return failure();
1007  }
1008 
1009  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
1010  rewriter.eraseOp(op.getOperation());
1011  return success();
1012  }
1013 };
1014 
1015 } // namespace
1016 
1017 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
1018  ArrayRef<int64_t> device,
1019  Operation::operand_range deviceDynamic,
1020  ArrayRef<MeshAxis> meshAxes,
1021  ArrayRef<int64_t> meshShape) {
1022  if (device.size() != meshAxes.size()) {
1023  return emitError(loc) << "In-group device \"" << deviceName
1024  << "\" has unexpected multi-index size "
1025  << device.size() << ". Expected " << meshAxes.size()
1026  << ".";
1027  }
1028 
1029  for (size_t i = 0; i < device.size(); ++i) {
1030  if (!ShapedType::isDynamic(device[i]) &&
1031  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
1032  meshShape[meshAxes[i]] <= device[i]) {
1033  return emitError(loc)
1034  << "Out of bounds coordinate " << i << " for in-group device \""
1035  << deviceName << "\"."
1036  << " Got " << device[i] << ", but expected value in the range [0, "
1037  << (meshShape[meshAxes[i]] - 1) << "].";
1038  }
1039  }
1040  return success();
1041 }
1042 
1043 template <typename It>
1044 static auto product(It begin, It end) {
1045  using ElementType = std::decay_t<decltype(*begin)>;
1046  return std::accumulate(begin, end, static_cast<ElementType>(1),
1047  std::multiplies<ElementType>());
1048 }
1049 
1050 template <typename R>
1051 static auto product(R &&range) {
1052  return product(adl_begin(range), adl_end(range));
1053 }
1054 
1055 static LogicalResult verifyDimensionCompatibility(Location loc,
1056  int64_t expectedDimSize,
1057  int64_t resultDimSize,
1058  int64_t resultAxis) {
1059  if (!ShapedType::isDynamic(resultDimSize) &&
1060  expectedDimSize != resultDimSize) {
1061  return emitError(loc) << "Dimension size mismatch for result axis "
1062  << resultAxis << ". Expected "
1063  << (ShapedType::isDynamic(expectedDimSize)
1064  ? Twine("dynamic")
1065  : Twine(expectedDimSize))
1066  << ", but got " << resultDimSize << ".";
1067  }
1068 
1069  return success();
1070 }
1071 
1073  Value operand, Value result, int64_t gatherAxis,
1074  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1075  auto resultRank = cast<ShapedType>(result.getType()).getRank();
1076  if (gatherAxis < 0 || gatherAxis >= resultRank) {
1077  return emitError(result.getLoc())
1078  << "Gather axis " << gatherAxis << " is out of bounds [0, "
1079  << resultRank << ").";
1080  }
1081 
1082  ShapedType operandType = cast<ShapedType>(operand.getType());
1083  ShapedType resultType = cast<ShapedType>(result.getType());
1084  auto deviceGroupSize =
1085  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1086  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1087  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1088  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1089  auto expectedResultDimSize =
1090  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1091  if (failed(verifyDimensionCompatibility(
1092  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1093  return failure();
1094  }
1095  }
1096  return success();
1097 }
1098 
1100  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1101  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1102  ShapedType operandType = cast<ShapedType>(operand.getType());
1103  ShapedType resultType = cast<ShapedType>(result.getType());
1104  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1105  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1106  if (failed(verifyDimensionCompatibility(
1107  result.getLoc(), operandType.getDimSize(axis),
1108  resultType.getDimSize(axis), axis))) {
1109  return failure();
1110  }
1111  }
1112  }
1113 
1114  if (splitAxis == concatAxis) {
1115  return success();
1116  }
1117 
1118  auto deviceGroupSize =
1119  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1120  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1121  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1122  DimensionSize expectedResultConcatDimSize =
1123  operandConcatDimSize * deviceGroupSize;
1124  DimensionSize expectedResultSplitDimSize =
1125  operandSplitDimSize / deviceGroupSize;
1126  if (!expectedResultSplitDimSize.isDynamic() &&
1127  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1128  expectedResultSplitDimSize = DimensionSize::dynamic();
1129  }
1130  if (failed(verifyDimensionCompatibility(
1131  result.getLoc(), expectedResultConcatDimSize.value(),
1132  resultType.getDimSize(concatAxis), concatAxis))) {
1133  return failure();
1134  }
1135  if (failed(verifyDimensionCompatibility(
1136  result.getLoc(), expectedResultSplitDimSize.value(),
1137  resultType.getDimSize(splitAxis), splitAxis))) {
1138  return failure();
1139  }
1140 
1141  return success();
1142 }
1143 
1145  Value operand, Value result, int64_t tensorAxis,
1146  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1147  ShapedType operandType = cast<ShapedType>(operand.getType());
1148  ShapedType resultType = cast<ShapedType>(result.getType());
1149  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1150  if (axis != tensorAxis) {
1151  if (failed(verifyDimensionCompatibility(
1152  result.getLoc(), operandType.getDimSize(axis),
1153  resultType.getDimSize(axis), axis))) {
1154  return failure();
1155  }
1156  }
1157  }
1158 
1159  auto deviceGroupSize =
1160  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1161  auto operandScatterDimSize =
1162  DimensionSize(operandType.getDimSize(tensorAxis));
1163  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1164  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1165  return emitError(result.getLoc())
1166  << "Operand dimension size " << int64_t(operandScatterDimSize)
1167  << " is not divisible by collective device group size "
1168  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1169  << ".";
1170  }
1171  DimensionSize expectedResultTensorDimSize =
1172  operandScatterDimSize / deviceGroupSize;
1173  if (failed(verifyDimensionCompatibility(
1174  result.getLoc(), expectedResultTensorDimSize.value(),
1175  resultType.getDimSize(tensorAxis), tensorAxis))) {
1176  return failure();
1177  }
1178 
1179  return success();
1180 }
1181 
1182 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
1183  ArrayRef<MeshAxis> meshAxes,
1184  int64_t sliceAxis) {
1185  RankedTensorType operandRankedTensorType =
1186  cast<RankedTensorType>(operandType);
1187  DimensionSize operandSliceAxisSize =
1188  operandRankedTensorType.getShape()[sliceAxis];
1189  SmallVector<int64_t> resultShape =
1190  llvm::to_vector(operandRankedTensorType.getShape());
1191 
1192  resultShape[sliceAxis] =
1193  operandSliceAxisSize /
1194  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1195  return operandRankedTensorType.clone(resultShape);
1196 }
1197 
1198 //===----------------------------------------------------------------------===//
1199 // mesh.all_gather op
1200 //===----------------------------------------------------------------------===//
1201 
1202 LogicalResult
1203 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1204  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1205  if (failed(mesh)) {
1206  return failure();
1207  }
1208  auto gatherAxis = getGatherAxis().getSExtValue();
1209  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1210  gatherAxis, getMeshAxes(),
1211  mesh.value().getShape());
1212 }
1213 
1214 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1215  MLIRContext *context) {
1216  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1217 }
1218 
1219 void AllGatherOp::getAsmResultNames(
1220  function_ref<void(Value, StringRef)> setNameFn) {
1221  setNameFn(getResult(), "all_gather");
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // mesh.all_reduce op
1226 //===----------------------------------------------------------------------===//
1227 
1228 LogicalResult
1229 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1230  return getMeshAndVerifyAxes(*this, symbolTable);
1231 }
1232 
1233 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1234  MLIRContext *context) {
1235  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1236 }
1237 
1238 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1239  Value input, StringRef mesh,
1240  ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
1241  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1242  reduction);
1243 }
1244 
1245 void AllReduceOp::getAsmResultNames(
1246  function_ref<void(Value, StringRef)> setNameFn) {
1247  setNameFn(getResult(), "all_reduce");
1248 }
1249 
1250 //===----------------------------------------------------------------------===//
1251 // mesh.all_slice op
1252 //===----------------------------------------------------------------------===//
1253 
1254 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1255  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1256  if (failed(mesh)) {
1257  return failure();
1258  }
1260  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1261  mesh.value().getShape());
1262 }
1263 
1264 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1265  MLIRContext *context) {
1266  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1267 }
1268 
1269 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1270  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1271  int64_t sliceAxis) {
1272  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1273  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1274  sliceAxis);
1275 }
1276 
1277 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1278  Type resultType, Value input, StringRef mesh,
1279  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1280  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1281  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1282 }
1283 
1284 void AllSliceOp::getAsmResultNames(
1285  function_ref<void(Value, StringRef)> setNameFn) {
1286  setNameFn(getResult(), "all_slice");
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // mesh.all_to_all op
1291 //===----------------------------------------------------------------------===//
1292 
1293 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1294  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1295  if (failed(mesh)) {
1296  return failure();
1297  }
1298 
1300  getOperand(), getResult(), getSplitAxis().getSExtValue(),
1301  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1302 }
1303 
1304 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1305  MLIRContext *context) {
1306  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1307 }
1308 
1309 void AllToAllOp::getAsmResultNames(
1310  function_ref<void(Value, StringRef)> setNameFn) {
1311  setNameFn(getResult(), "all_to_all");
1312 }
1313 
1314 //===----------------------------------------------------------------------===//
1315 // mesh.broadcast op
1316 //===----------------------------------------------------------------------===//
1317 
1318 LogicalResult
1319 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1320  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1321  if (failed(mesh)) {
1322  return failure();
1323  }
1324  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1325  getRootDynamic(), getMeshAxes(),
1326  mesh.value().getShape()))) {
1327  return failure();
1328  }
1329 
1330  return success();
1331 }
1332 
1333 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1334  MLIRContext *context) {
1335  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1336 }
1337 
1338 void BroadcastOp::getAsmResultNames(
1339  function_ref<void(Value, StringRef)> setNameFn) {
1340  setNameFn(getResult(), "broadcast");
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // mesh.gather op
1345 //===----------------------------------------------------------------------===//
1346 
1347 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1348  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1349  if (failed(mesh)) {
1350  return failure();
1351  }
1352  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1353  getRootDynamic(), getMeshAxes(),
1354  mesh.value().getShape()))) {
1355  return failure();
1356  }
1357 
1358  auto gatherAxis = getGatherAxis().getSExtValue();
1359  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1360  getMeshAxes(),
1361  mesh.value().getShape());
1362 }
1363 
1364 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1365  MLIRContext *context) {
1366  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1367 }
1368 
1369 void GatherOp::getAsmResultNames(
1370  function_ref<void(Value, StringRef)> setNameFn) {
1371  setNameFn(getResult(), "gather");
1372 }
1373 
1374 //===----------------------------------------------------------------------===//
1375 // mesh.recv op
1376 //===----------------------------------------------------------------------===//
1377 
1378 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1379  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1380  if (failed(mesh)) {
1381  return failure();
1382  }
1383  if (getSource() &&
1384  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1385  getSource().value(), getSourceDynamic(),
1386  getMeshAxes(), mesh.value().getShape()))) {
1387  return failure();
1388  }
1389  return success();
1390 }
1391 
1392 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1393  MLIRContext *context) {
1394  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1395 }
1396 
1397 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1398  setNameFn(getResult(), "recv");
1399 }
1400 
1401 //===----------------------------------------------------------------------===//
1402 // mesh.reduce op
1403 //===----------------------------------------------------------------------===//
1404 
1405 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1406  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1407  if (failed(mesh)) {
1408  return failure();
1409  }
1410  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1411  getRootDynamic(), getMeshAxes(),
1412  mesh.value().getShape()))) {
1413  return failure();
1414  }
1415 
1416  return success();
1417 }
1418 
1419 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1420  MLIRContext *context) {
1421  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1422 }
1423 
1424 void ReduceOp::getAsmResultNames(
1425  function_ref<void(Value, StringRef)> setNameFn) {
1426  setNameFn(getResult(), "reduce");
1427 }
1428 
1429 //===----------------------------------------------------------------------===//
1430 // mesh.reduce_scatter op
1431 //===----------------------------------------------------------------------===//
1432 
1433 LogicalResult
1434 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1435  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1436  if (failed(mesh)) {
1437  return failure();
1438  }
1439 
1441  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1442  mesh.value().getShape());
1443 }
1444 
1445 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1446  MLIRContext *context) {
1447  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1448 }
1449 
1450 void ReduceScatterOp::getAsmResultNames(
1451  function_ref<void(Value, StringRef)> setNameFn) {
1452  setNameFn(getResult(), "reduce_scatter");
1453 }
1454 
1455 //===----------------------------------------------------------------------===//
1456 // mesh.scatter op
1457 //===----------------------------------------------------------------------===//
1458 
1459 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1460  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1461  if (failed(mesh)) {
1462  return failure();
1463  }
1464  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1465  getRootDynamic(), getMeshAxes(),
1466  mesh.value().getShape()))) {
1467  return failure();
1468  }
1469 
1470  auto scatterAxis = getScatterAxis().getSExtValue();
1471  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1472  scatterAxis, getMeshAxes(),
1473  mesh.value().getShape());
1474 }
1475 
1476 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1477  MLIRContext *context) {
1478  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1479 }
1480 
1481 void ScatterOp::getAsmResultNames(
1482  function_ref<void(Value, StringRef)> setNameFn) {
1483  setNameFn(getResult(), "scatter");
1484 }
1485 
1486 //===----------------------------------------------------------------------===//
1487 // mesh.send op
1488 //===----------------------------------------------------------------------===//
1489 
1490 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1491  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1492  if (failed(mesh)) {
1493  return failure();
1494  }
1495  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1496  getDestination(), getDestinationDynamic(),
1497  getMeshAxes(), mesh.value().getShape()))) {
1498  return failure();
1499  }
1500  return success();
1501 }
1502 
1503 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1504  MLIRContext *context) {
1505  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1506 }
1507 
1508 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1509  setNameFn(getResult(), "send");
1510 }
1511 
1512 //===----------------------------------------------------------------------===//
1513 // mesh.shift op
1514 //===----------------------------------------------------------------------===//
1515 
1516 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1517  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1518  if (failed(mesh)) {
1519  return failure();
1520  }
1521 
1522  auto meshAxes = getMeshAxes();
1523  auto shiftAxis = getShiftAxis().getZExtValue();
1524  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1525  return emitError() << "Invalid shift axis " << shiftAxis
1526  << ". It must be one of the grouping mesh axes.";
1527  }
1528 
1529  return success();
1530 }
1531 
1532 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1533  MLIRContext *context) {
1534  // TODO: remove op when offset is 0 or if it is a rotate with and
1535  // offset % shift_axis_mesh_dim_size == 0.
1536 }
1537 
1538 void ShiftOp::getAsmResultNames(
1539  function_ref<void(Value, StringRef)> setNameFn) {
1540  setNameFn(getResult(), "shift");
1541 }
1542 
1543 //===----------------------------------------------------------------------===//
1544 // mesh.update_halo op
1545 //===----------------------------------------------------------------------===//
1546 
1547 LogicalResult
1548 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1549  auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
1550  if (failed(mesh)) {
1551  return failure();
1552  }
1553 
1554  return success();
1555 }
1556 
1557 //===----------------------------------------------------------------------===//
1558 // TableGen'd op method definitions
1559 //===----------------------------------------------------------------------===//
1560 
1561 #define GET_OP_CLASSES
1562 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1563 
1564 #define GET_ATTRDEF_CLASSES
1565 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1566 
1567 #define GET_TYPEDEF_CLASSES
1568 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1569 
1570 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:1182
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:64
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:1055
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:179
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:127
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1144
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1072
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1099
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition: MeshOps.cpp:193
static auto product(It begin, It end)
Definition: MeshOps.cpp:1044
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:140
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:156
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1017
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
Definition: Builders.cpp:155
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:784
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
Definition: Operation.h:279
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:656
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:720
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:689
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:719
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:744
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
bool operator!=(Value rhs) const
Definition: MeshOps.cpp:766
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition: MeshOps.h:74
bool operator==(Value rhs) const
Definition: MeshOps.cpp:762
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
::llvm::StringRef getMesh() const
Definition: MeshOps.h:65
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
MeshSharding(::mlir::FlatSymbolRefAttr mesh_=nullptr)
Definition: MeshOps.cpp:776
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:796
bool equalShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:723
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:153
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:120
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:329
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:270
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
Definition: MeshOps.cpp:278
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:260
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:175
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:112
int16_t MeshAxis
Definition: MeshOps.h:26
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:265
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:419
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.