MLIR  21.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <functional>
36 #include <iterator>
37 #include <numeric>
38 #include <optional>
39 #include <utility>
40 
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 using namespace mlir;
45 using namespace mlir::mesh;
46 
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
48 
49 namespace {
50 
51 struct DimensionSize {
52  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
53  DimensionSize(int64_t val) : val(val) {}
54  int64_t value() const { return val; }
55  operator int64_t() const { return val; }
56  bool isDynamic() const { return ShapedType::isDynamic(val); }
57 
58 private:
59  int64_t val;
60 };
61 
62 } // namespace
63 
64 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
65  if (lhs.isDynamic() || rhs.isDynamic()) {
66  return DimensionSize::dynamic();
67  }
68  return lhs.value() / rhs.value();
69 }
70 
71 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
72  if (lhs.isDynamic() || rhs.isDynamic()) {
73  return DimensionSize::dynamic();
74  }
75  return lhs.value() * rhs.value();
76 }
77 
79  const Location &loc,
81  ValueRange dynamics,
82  Type type) {
83  SmallVector<Value> values;
84  auto dyn = dynamics.begin();
85  Type i64 = b.getI64Type();
86  if (!type)
87  type = i64;
88  assert((i64 == type || b.getIndexType() == type) &&
89  "expected an i64 or an intex type");
90  for (auto s : statics) {
91  if (s == ShapedType::kDynamic) {
92  values.emplace_back(*(dyn++));
93  } else {
94  TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
95  values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
96  }
97  }
98  return values;
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Inliner
103 //===----------------------------------------------------------------------===//
104 
105 namespace {
106 struct MeshInlinerInterface : public DialectInlinerInterface {
108  // Currently no restrictions are encoded for inlining.
109  bool isLegalToInline(Operation *, Operation *, bool) const final {
110  return true;
111  }
112  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
113  return true;
114  }
115  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
116  return true;
117  }
118 };
119 } // namespace
120 
121 //===----------------------------------------------------------------------===//
122 // Mesh dialect
123 //===----------------------------------------------------------------------===//
124 
125 void MeshDialect::initialize() {
126  addOperations<
127 #define GET_OP_LIST
128 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
129  >();
130  addAttributes<
131 #define GET_ATTRDEF_LIST
132 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
133  >();
134  addTypes<
135 #define GET_TYPEDEF_LIST
136 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
137  >();
138  addInterface<MeshInlinerInterface>();
139 }
140 
142  Type type, Location loc) {
143  return arith::ConstantOp::materialize(builder, value, type, loc);
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // Mesh utilities
148 //===----------------------------------------------------------------------===//
149 
150 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
151  FlatSymbolRefAttr meshSymbol,
152  SymbolTableCollection &symbolTable) {
153  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
154  if (!mesh) {
155  return op->emitError() << "Undefined required mesh symbol \""
156  << meshSymbol.getValue() << "\".";
157  }
158 
159  return mesh;
160 }
161 
162 template <typename It>
163 bool isUnique(It begin, It end) {
164  if (begin == end) {
165  return true;
166  }
167  It next = std::next(begin);
168  if (next == end) {
169  return true;
170  }
171  for (; next != end; ++next, ++begin) {
172  if (*begin == *next) {
173  return false;
174  }
175  }
176  return true;
177 }
178 
179 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
180  MeshOp mesh) {
181  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
182  llvm::sort(sorted);
183  if (!isUnique(sorted.begin(), sorted.end())) {
184  return emitError(loc) << "Mesh axes contains duplicate elements.";
185  }
186 
187  MeshAxis rank = mesh.getRank();
188  for (auto axis : axes) {
189  if (axis >= rank || axis < 0) {
190  return emitError(loc)
191  << "0-based mesh axis index " << axis
192  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
193  << "\" is of rank " << rank << ".";
194  }
195  }
196 
197  return success();
198 }
199 
200 template <typename Op>
201 static FailureOr<MeshOp>
203  auto mesh =
204  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
205  if (failed(mesh)) {
206  return failure();
207  }
208  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
209  return failure();
210  }
211  return mesh;
212 }
213 
214 template <typename InShape, typename MeshShape, typename SplitAxes,
215  typename OutShape>
216 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
217  const SplitAxes &splitAxes, OutShape &outShape,
218  ArrayRef<int64_t> shardedDimsOffsets = {},
219  ArrayRef<int64_t> haloSizes = {}) {
220  // 0d tensors cannot be sharded and must get replicated
221  if (inShape.empty()) {
222  assert(outShape.empty());
223  return;
224  }
225 
226  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
227  llvm::adl_begin(outShape));
228 
229  if (!shardedDimsOffsets.empty()) {
230  auto isDynShape = ShapedType::isDynamicShape(meshShape);
231  uint64_t pos = 1;
232  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
233  if (!innerSplitAxes.empty()) {
234  auto sz = shardedDimsOffsets[pos];
235  bool same = !isDynShape;
236  if (same) {
237  // Find sharded dims in shardedDimsOffsets with same static size on
238  // all devices. Use kDynamic for dimensions with dynamic or
239  // non-uniform offs in shardedDimsOffsets.
240  uint64_t numShards = 0;
241  for (auto i : innerSplitAxes.asArrayRef()) {
242  numShards += meshShape[i];
243  }
244  for (size_t i = 1; i < numShards; ++i) {
245  if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
246  sz) {
247  same = false;
248  break;
249  }
250  }
251  pos += numShards + 1;
252  }
253  outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
254  }
255  }
256  } else {
257  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
258  outShape[tensorAxis] = shardDimension(
259  inShape[tensorAxis],
260  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
261  }
262 
263  if (!haloSizes.empty()) {
264  // add halo sizes if requested
265  int haloAxis = 0;
266  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
267  if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
268  !innerSplitAxes.empty()) {
269  if (haloSizes[haloAxis * 2] >= 0 &&
270  haloSizes[haloAxis * 2 + 1] >= 0) {
271  outShape[tensorAxis] +=
272  haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
273  ++haloAxis;
274  } else {
275  outShape[tensorAxis] = ShapedType::kDynamic;
276  }
277  }
278  }
279  }
280  }
281 }
282 
283 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
284  MeshSharding sharding) {
285  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
286  SmallVector<Dim> resShapeArr(shape.getShape().size());
287  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
288  resShapeArr, sharding.getStaticShardedDimsOffsets(),
289  sharding.getStaticHaloSizes());
290  return shape.clone(resShapeArr);
291 }
292 
293 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
294  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
295  if (rankedTensorType && !rankedTensorType.getShape().empty()) {
296  return shardShapedType(rankedTensorType, mesh, sharding);
297  }
298  return type;
299 }
300 
302  Value &operandValue,
303  Operation *operandOp,
304  OpBuilder &builder,
305  ShardOp &newShardOp) {
306  OpBuilder::InsertionGuard insertionGuard(builder);
307  builder.setInsertionPointAfterValue(operandValue);
308  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
309  if (shardOp && sharding == shardOp.getSharding() &&
310  !shardOp.getAnnotateForUsers()) {
311  // No need for anything if the correct sharding is already set.
312  if (!newShardOp) {
313  newShardOp = shardOp;
314  }
315  return;
316  }
317 
318  if (!newShardOp) {
319  auto shardingOp =
320  builder.create<ShardingOp>(operandValue.getLoc(), sharding);
321  newShardOp =
322  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
323  /*annotate_for_users*/ false);
324  }
325  operandValue.replaceUsesWithIf(
326  newShardOp, [operandOp, operandValue](OpOperand &use) {
327  return use.getOwner() == operandOp && use.get() == operandValue;
328  });
329 
330  if (!shardOp || shardOp.getAnnotateForUsers()) {
331  return;
332  }
333 
334  auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
335  newShardOp.getSharding(),
336  /*annotate_for_users*/ true);
337  newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
338 }
339 
341  OpResult result,
342  OpBuilder &builder) {
343  ShardOp newShardOp;
345  for (auto &use : result.getUses()) {
346  uses.emplace_back(use.get(), use.getOwner());
347  }
348  for (auto &[operandValue, operandOp] : uses) {
349  maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
350  builder, newShardOp);
351  }
352 }
353 
355  OpOperand &operand,
356  OpBuilder &builder) {
357  OpBuilder::InsertionGuard insertionGuard(builder);
358  Value operandValue = operand.get();
359  Operation *operandSrcOp = operandValue.getDefiningOp();
360  bool isBlockArg = !operandSrcOp;
361  {
362  [[maybe_unused]] auto opType =
363  dyn_cast<mlir::RankedTensorType>(operandValue.getType());
364  assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
365  }
366  if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
367  operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
368  return;
369  }
370 
371  Operation *operandOp = operand.getOwner();
372  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
373 
374  if (shardOp && sharding == shardOp.getSharding() &&
375  shardOp.getAnnotateForUsers()) {
376  // No need for anything the correct sharding is already set.
377  return;
378  }
379 
380  builder.setInsertionPoint(operandOp);
381  auto shardingOp =
382  builder.create<ShardingOp>(operand.get().getLoc(), sharding);
383  auto newShardOp =
384  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
385  /*annotate_for_users*/ true);
386  IRRewriter rewriter(builder);
387  rewriter.replaceUsesWithIf(
388  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
389  return use.getOwner() == operandOp && use.get() == operandValue;
390  });
391 
392  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
393  // No need for resharding.
394  return;
395  }
396 
397  builder.setInsertionPoint(newShardOp);
398  auto newPreceedingShardOp =
399  builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
400  /*annotate_for_users*/ false);
401  rewriter.replaceUsesWithIf(
402  newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
403  return use.getOwner() == newShardOp.getOperation();
404  });
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // mesh.mesh op
409 //===----------------------------------------------------------------------===//
410 
411 LogicalResult MeshOp::verify() {
412  int64_t rank = getRank();
413 
414  if (rank <= 0)
415  return emitOpError("rank of mesh is expected to be a positive integer");
416 
417  for (int64_t dimSize : getShape()) {
418  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
419  return emitOpError("dimension size of a mesh is expected to be "
420  "non-negative or dynamic");
421  }
422 
423  return success();
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // mesh.mesh_shape op
428 //===----------------------------------------------------------------------===//
429 
430 LogicalResult
431 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
432  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
433  if (failed(mesh)) {
434  return failure();
435  }
436  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
437  return failure();
438  }
439 
440  size_t expectedResultsCount =
441  getAxes().empty() ? mesh->getRank() : getAxes().size();
442  if (getResult().size() != expectedResultsCount) {
443  return emitError() << "Unexpected number of results " << getResult().size()
444  << ". Expected " << expectedResultsCount << ".";
445  }
446 
447  return success();
448 }
449 
450 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
451  MeshOp mesh) {
452  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
453 }
454 
455 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
456  MeshOp mesh, ArrayRef<MeshAxis> axes) {
457  build(odsBuilder, odsState,
458  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
459  odsBuilder.getIndexType()),
460  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
461 }
462 
463 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
464  StringRef mesh, ArrayRef<MeshAxis> axes) {
465  assert(!axes.empty());
466  build(odsBuilder, odsState,
467  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
468  MeshAxesAttr::get(odsBuilder.getContext(), axes));
469 }
470 
471 void MeshShapeOp::getAsmResultNames(
472  function_ref<void(Value, StringRef)> setNameFn) {
473  setNameFn(getResults()[0], "mesh_shape");
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // mesh.sharding
478 //===----------------------------------------------------------------------===//
479 
480 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
481  FlatSymbolRefAttr mesh,
482  ArrayRef<MeshAxesAttr> split_axes,
483  ArrayRef<MeshAxis> partial_axes,
484  mesh::ReductionKind partial_type,
485  ArrayRef<int64_t> static_halos,
486  ArrayRef<int64_t> static_offsets) {
487  return build(
488  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
489  ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
490  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
491  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
492  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
493 }
494 
495 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
496  FlatSymbolRefAttr mesh,
497  ArrayRef<MeshAxesAttr> split_axes) {
498  return build(
499  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
500  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
501  {}, {}, {}, {});
502 }
503 
504 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
505  llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
506  ArrayRef<int64_t> static_halos,
507  ArrayRef<int64_t> static_offsets) {
508  return build(
509  b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
510  MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
511  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
512  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
513  ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
514 }
515 
516 void ShardingOp::build(
517  ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
518  FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
520  ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
521  mlir::SmallVector<int64_t> staticHalos, staticDims;
522  mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
523  dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
524  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
525  return build(
526  b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
527  ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
528  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
529  ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
530 }
531 
532 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
534 
535  build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
537  from.getPartialAxes().empty()
541  from.getPartialType()),
542  from.getStaticShardedDimsOffsets().empty()
546  from.getStaticHaloSizes().empty()
549  from.getDynamicHaloSizes());
550 }
551 
552 LogicalResult ShardingOp::verify() {
553  llvm::SmallSet<MeshAxis, 4> visitedAxes;
554 
555  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
556  for (MeshAxis axis : axesArray) {
557  if (axis < 0)
558  return emitError() << "mesh axis is expected to be non-negative";
559  if (!visitedAxes.insert(axis).second)
560  return emitError() << "mesh axis duplicated";
561  }
562  return success();
563  };
564 
565  for (auto subAxes : getSplitAxes().getAxes()) {
566  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
567  if (failed(checkMeshAxis(subAxesArray)))
568  return failure();
569  }
570  if (getPartialAxes().has_value() &&
571  failed(checkMeshAxis(getPartialAxes().value())))
572  return failure();
573 
574  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
575  return emitOpError("halo sizes and shard offsets are mutually exclusive");
576  }
577 
578  if (!getStaticHaloSizes().empty()) {
579  auto numSplitAxes = getSplitAxes().getAxes().size();
580  for (auto splitAxis : getSplitAxes().getAxes()) {
581  if (splitAxis.empty()) {
582  --numSplitAxes;
583  }
584  }
585  if (getStaticHaloSizes().size() != numSplitAxes * 2) {
586  return emitError() << "halo sizes must be specified for all split axes.";
587  }
588  }
589 
590  return success();
591 }
592 
593 void ShardingOp::getAsmResultNames(
594  function_ref<void(Value, StringRef)> setNameFn) {
595  setNameFn(getResult(), "sharding");
596 }
597 
598 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
599  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
600  if (failed(mesh)) {
601  return failure();
602  }
603  if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
604  getStaticShardedDimsOffsets().size() > 0) {
605  return emitError() << "sharded dims offsets are not allowed for "
606  "devices meshes with dynamic shape.";
607  }
608 
609  auto shardedDimsOffsets = getStaticShardedDimsOffsets();
610  if (!shardedDimsOffsets.empty()) {
611  auto meshShape = mesh.value().getShape();
612  assert(!ShapedType::isDynamicShape(meshShape));
613  uint64_t pos = 0;
614  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
615  if (!innerSplitAxes.empty()) {
616  int64_t numShards = 0, off = 0;
617  for (auto i : innerSplitAxes.asArrayRef()) {
618  numShards += meshShape[i];
619  }
620  for (int64_t i = 0; i <= numShards; ++i) {
621  if (shardedDimsOffsets.size() <= pos + i) {
622  return emitError() << "sharded dims offsets has wrong size.";
623  }
624  if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
625  if (shardedDimsOffsets[pos + i] < off) {
626  return emitError()
627  << "sharded dims offsets must be non-decreasing.";
628  }
629  off = shardedDimsOffsets[pos + i];
630  }
631  }
632  pos += numShards + 1;
633  }
634  }
635  }
636  return success();
637 }
638 
639 namespace {
640 // Sharding annotations "halo sizes" and "sharded dims offsets"
641 // are a mix of attributes and dynamic values. This canonicalization moves
642 // constant values to the respective attribute lists, minimizing the number
643 // of values.
644 // It also removes sharded_dims_sizes and halos if they are effectively "empty".
645 class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
646 public:
648 
649  LogicalResult matchAndRewrite(ShardingOp op,
650  PatternRewriter &b) const override {
651  auto mixedHalos =
652  getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
653  auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
654  op.getDynamicShardedDimsOffsets(), b);
655 
656  // No constant operands were folded, just return;
657  bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
658  succeeded(foldDynamicIndexList(mixedOffs, true));
659 
660  auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
661  auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
662 
663  if (dynamicHalos.empty() && !staticHalos.empty()) {
664  if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
665  staticHalos.clear();
666  modified = true;
667  }
668  }
669 
670  // Remove sharded dims offsets if they are effectively the default values,
671  // e.g. if they define equi-distance between all neighboring shards.
672  // Requires static-only offsets. Compares the first distance as the
673  // difference between the first two offsets. Only if all consecutive
674  // distances are the same, the offsets are removed.
675  if (dynamicOffs.empty() && !staticOffs.empty()) {
676  assert(staticOffs.size() >= 2);
677  auto diff = staticOffs[1] - staticOffs[0];
678  bool all_same = staticOffs.size() > 2;
679  for (auto i = 2u; i < staticOffs.size(); ++i) {
680  if (staticOffs[i] - staticOffs[i - 1] != diff) {
681  all_same = false;
682  break;
683  }
684  }
685  if (all_same) {
686  staticOffs.clear();
687  modified = true;
688  }
689  }
690 
691  if (!modified) {
692  return failure();
693  }
694 
695  op.setStaticHaloSizes(staticHalos);
696  op.getDynamicHaloSizesMutable().assign(dynamicHalos);
697  op.setStaticShardedDimsOffsets(staticOffs);
698  op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
699 
700  return success();
701  }
702 };
703 } // namespace
704 
705 void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
706  mlir::MLIRContext *context) {
707  results.add<NormalizeSharding>(context);
708 }
709 
710 //===----------------------------------------------------------------------===//
711 // MeshSharding
712 //===----------------------------------------------------------------------===//
713 
715  if (getMesh() != rhs.getMesh()) {
716  return false;
717  }
718 
719  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
720  (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
721  !llvm::equal(getPartialAxes(), rhs.getPartialAxes())) {
722  return false;
723  }
724 
725  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
726  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
727  getSplitAxes().begin() + minSize),
728  llvm::make_range(rhs.getSplitAxes().begin(),
729  rhs.getSplitAxes().begin() + minSize))) {
730  return false;
731  }
732 
733  return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
734  std::mem_fn(&MeshAxesAttr::empty)) &&
735  llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
736  std::mem_fn(&MeshAxesAttr::empty));
737 }
738 
740  return equalShardSizes(rhs) && equalHaloSizes(rhs);
741 }
742 
744  if (rhs.getStaticShardedDimsOffsets().size() !=
745  getStaticShardedDimsOffsets().size() ||
746  !llvm::equal(getStaticShardedDimsOffsets(),
748  return false;
749  }
750  if (rhs.getDynamicShardedDimsOffsets().size() !=
751  getDynamicShardedDimsOffsets().size() ||
752  !llvm::equal(getDynamicShardedDimsOffsets(),
754  return false;
755  }
756  return true;
757 }
758 
760  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
761  !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
762  return false;
763  }
764  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
765  !llvm::equal(getDynamicHaloSizes(), rhs.getDynamicHaloSizes())) {
766  return false;
767  }
768  return true;
769 }
770 
773 }
774 
775 bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
776 
777 bool MeshSharding::operator==(const MeshSharding &rhs) const {
779 }
780 
781 bool MeshSharding::operator!=(const MeshSharding &rhs) const {
782  return !(*this == rhs);
783 }
784 
786 
788  auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
789  assert(shardingOp && "expected sharding op");
790  auto splitAxes = shardingOp.getSplitAxes().getAxes();
791  auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
792  // If splitAxes and partialAxes are empty, use "empty" constructor.
793  if (splitAxes.empty() && partialAxes.empty()) {
794  *this = MeshSharding(shardingOp.getMeshAttr());
795  return;
796  }
797  *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
798  shardingOp.getPartialType().value_or(ReductionKind::Sum),
799  shardingOp.getStaticHaloSizes(),
800  shardingOp.getStaticShardedDimsOffsets(),
801  SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
802  SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
803 }
804 
806  ArrayRef<MeshAxesAttr> split_axes_,
807  ArrayRef<MeshAxis> partial_axes_,
808  ReductionKind partial_type_,
809  ArrayRef<int64_t> static_halo_sizes_,
810  ArrayRef<int64_t> static_sharded_dims_offsets_,
811  ArrayRef<Value> dynamic_halo_sizes_,
812  ArrayRef<Value> dynamic_sharded_dims_offsets_) {
813  MeshSharding res(mesh_);
814  if (split_axes_.empty() && partial_axes_.empty()) {
815  return res;
816  }
817 
818  res.split_axes.resize(split_axes_.size());
819  for (auto [i, axis] : llvm::enumerate(split_axes_)) {
820  res.split_axes[i] =
821  MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
822  }
823 
824  auto clone = [](const auto src, auto &dst) {
825  dst.resize(src.size());
826  llvm::copy(src, dst.begin());
827  };
828 
829  clone(partial_axes_, res.partial_axes);
830  res.partial_type = partial_type_;
831  clone(static_halo_sizes_, res.static_halo_sizes);
832  clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
833  clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
834  clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
835 
836  return res;
837 }
838 
839 //===----------------------------------------------------------------------===//
840 // mesh.shard_shape
841 //===----------------------------------------------------------------------===//
842 
843 void ShardShapeOp::getAsmResultNames(
844  function_ref<void(Value, StringRef)> setNameFn) {
845  setNameFn(getResult()[0], "shard_shape");
846 }
847 
848 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
849  ::mlir::OperationState &odsState,
851  ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
852  ::mlir::ValueRange device) {
853  SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
854  build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
855  SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
856 }
857 
858 //===----------------------------------------------------------------------===//
859 // mesh.shard op
860 //===----------------------------------------------------------------------===//
861 
862 void ShardOp::getAsmResultNames(
863  function_ref<void(Value, StringRef)> setNameFn) {
864  setNameFn(getResult(), "sharding_annotated");
865 }
866 
867 namespace {
868 // Determine if the given ShardOp is a duplicate of another ShardOp
869 // on the same value. This can happen if constant values are sharded.
870 class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
871 public:
873 
874  LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
875  // Get the use-list of the value being sharded and check if it has more than
876  // one use.
877  Value value = op.getSrc();
878  if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
879  return failure();
880  }
881 
882  // Iterate through the uses of the value to find a duplicate ShardOp.
883  for (auto &use : value.getUses()) {
884  if (use.getOwner() != op.getOperation()) {
885  auto otherOp = dyn_cast<ShardOp>(use.getOwner());
886  if (!otherOp || !otherOp->isBeforeInBlock(op)) {
887  return failure();
888  }
889  // Create a MeshSharding object for the current and the other ShardOp
890  // If the two are equal replace current op with the other op.
891  MeshSharding currentSharding(op.getSharding());
892  MeshSharding otherSharding(otherOp.getSharding());
893  if (currentSharding == otherSharding) {
894  b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
895  b.eraseOp(op.getOperation());
896  } else {
897  // use the other sharding as input for op
898  op.getSrcMutable().assign(otherOp.getResult());
899  }
900  return success();
901  }
902  }
903 
904  return failure();
905  }
906 };
907 } // namespace
908 
909 void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
910  mlir::MLIRContext *context) {
911  results.add<FoldDuplicateShardOp>(context);
912 }
913 
914 //===----------------------------------------------------------------------===//
915 // mesh.process_multi_index op
916 //===----------------------------------------------------------------------===//
917 
918 LogicalResult
919 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
920  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
921  if (failed(mesh)) {
922  return failure();
923  }
924  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
925  return failure();
926  }
927 
928  size_t expectedResultsCount =
929  getAxes().empty() ? mesh->getRank() : getAxes().size();
930  if (getResult().size() != expectedResultsCount) {
931  return emitError() << "Unexpected number of results " << getResult().size()
932  << ". Expected " << expectedResultsCount << ".";
933  }
934 
935  return success();
936 }
937 
938 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
939  MeshOp mesh) {
940  build(odsBuilder, odsState,
941  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
942  mesh.getSymName(), ArrayRef<MeshAxis>());
943 }
944 
945 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
946  StringRef mesh, ArrayRef<MeshAxis> axes) {
947  build(odsBuilder, odsState,
948  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
949  MeshAxesAttr::get(odsBuilder.getContext(), axes));
950 }
951 
952 void ProcessMultiIndexOp::getAsmResultNames(
953  function_ref<void(Value, StringRef)> setNameFn) {
954  setNameFn(getResults()[0], "proc_linear_idx");
955 }
956 
957 //===----------------------------------------------------------------------===//
958 // mesh.process_linear_index op
959 //===----------------------------------------------------------------------===//
960 
961 LogicalResult
962 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
963  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
964  if (failed(mesh)) {
965  return failure();
966  }
967  return success();
968 }
969 
970 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
971  OperationState &odsState, MeshOp mesh) {
972  build(odsBuilder, odsState, mesh.getSymName());
973 }
974 
975 void ProcessLinearIndexOp::getAsmResultNames(
976  function_ref<void(Value, StringRef)> setNameFn) {
977  setNameFn(getResult(), "proc_linear_idx");
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // mesh.neighbors_linear_indices op
982 //===----------------------------------------------------------------------===//
983 
984 LogicalResult
985 NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
986  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
987  if (failed(mesh)) {
988  return failure();
989  }
990  return success();
991 }
992 
993 void NeighborsLinearIndicesOp::getAsmResultNames(
994  function_ref<void(Value, StringRef)> setNameFn) {
995  setNameFn(getNeighborDown(), "down_linear_idx");
996  setNameFn(getNeighborUp(), "up_linear_idx");
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // collective communication ops
1001 //===----------------------------------------------------------------------===//
1002 
1003 namespace {
1004 
1005 template <typename Op>
1006 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
1008  LogicalResult matchAndRewrite(Op op,
1009  PatternRewriter &rewriter) const override {
1010  auto meshAxes = op.getMeshAxes();
1011  if (!meshAxes.empty()) {
1012  return failure();
1013  }
1014  if (op.getInput().getType() != op.getResult().getType()) {
1015  return failure();
1016  }
1017 
1018  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
1019  rewriter.eraseOp(op.getOperation());
1020  return success();
1021  }
1022 };
1023 
1024 } // namespace
1025 
1026 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
1027  ArrayRef<int64_t> device,
1028  Operation::operand_range deviceDynamic,
1029  ArrayRef<MeshAxis> meshAxes,
1030  ArrayRef<int64_t> meshShape) {
1031  if (device.size() != meshAxes.size()) {
1032  return emitError(loc) << "In-group device \"" << deviceName
1033  << "\" has unexpected multi-index size "
1034  << device.size() << ". Expected " << meshAxes.size()
1035  << ".";
1036  }
1037 
1038  for (size_t i = 0; i < device.size(); ++i) {
1039  if (!ShapedType::isDynamic(device[i]) &&
1040  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
1041  meshShape[meshAxes[i]] <= device[i]) {
1042  return emitError(loc)
1043  << "Out of bounds coordinate " << i << " for in-group device \""
1044  << deviceName << "\"."
1045  << " Got " << device[i] << ", but expected value in the range [0, "
1046  << (meshShape[meshAxes[i]] - 1) << "].";
1047  }
1048  }
1049  return success();
1050 }
1051 
1052 template <typename It>
1053 static auto product(It begin, It end) {
1054  using ElementType = std::decay_t<decltype(*begin)>;
1055  return std::accumulate(begin, end, static_cast<ElementType>(1),
1056  std::multiplies<ElementType>());
1057 }
1058 
1059 template <typename R>
1060 static auto product(R &&range) {
1061  return product(adl_begin(range), adl_end(range));
1062 }
1063 
1064 static LogicalResult verifyDimensionCompatibility(Location loc,
1065  int64_t expectedDimSize,
1066  int64_t resultDimSize,
1067  int64_t resultAxis) {
1068  if (!ShapedType::isDynamic(resultDimSize) &&
1069  expectedDimSize != resultDimSize) {
1070  return emitError(loc) << "Dimension size mismatch for result axis "
1071  << resultAxis << ". Expected "
1072  << (ShapedType::isDynamic(expectedDimSize)
1073  ? Twine("dynamic")
1074  : Twine(expectedDimSize))
1075  << ", but got " << resultDimSize << ".";
1076  }
1077 
1078  return success();
1079 }
1080 
1082  Value operand, Value result, int64_t gatherAxis,
1083  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1084  auto resultRank = cast<ShapedType>(result.getType()).getRank();
1085  if (gatherAxis < 0 || gatherAxis >= resultRank) {
1086  return emitError(result.getLoc())
1087  << "Gather axis " << gatherAxis << " is out of bounds [0, "
1088  << resultRank << ").";
1089  }
1090 
1091  ShapedType operandType = cast<ShapedType>(operand.getType());
1092  ShapedType resultType = cast<ShapedType>(result.getType());
1093  auto deviceGroupSize =
1094  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1095  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1096  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1097  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1098  auto expectedResultDimSize =
1099  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1100  if (failed(verifyDimensionCompatibility(
1101  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1102  return failure();
1103  }
1104  }
1105  return success();
1106 }
1107 
1109  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1110  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1111  ShapedType operandType = cast<ShapedType>(operand.getType());
1112  ShapedType resultType = cast<ShapedType>(result.getType());
1113  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1114  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1115  if (failed(verifyDimensionCompatibility(
1116  result.getLoc(), operandType.getDimSize(axis),
1117  resultType.getDimSize(axis), axis))) {
1118  return failure();
1119  }
1120  }
1121  }
1122 
1123  if (splitAxis == concatAxis) {
1124  return success();
1125  }
1126 
1127  auto deviceGroupSize =
1128  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1129  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1130  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1131  DimensionSize expectedResultConcatDimSize =
1132  operandConcatDimSize * deviceGroupSize;
1133  DimensionSize expectedResultSplitDimSize =
1134  operandSplitDimSize / deviceGroupSize;
1135  if (!expectedResultSplitDimSize.isDynamic() &&
1136  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1137  expectedResultSplitDimSize = DimensionSize::dynamic();
1138  }
1139  if (failed(verifyDimensionCompatibility(
1140  result.getLoc(), expectedResultConcatDimSize.value(),
1141  resultType.getDimSize(concatAxis), concatAxis))) {
1142  return failure();
1143  }
1144  if (failed(verifyDimensionCompatibility(
1145  result.getLoc(), expectedResultSplitDimSize.value(),
1146  resultType.getDimSize(splitAxis), splitAxis))) {
1147  return failure();
1148  }
1149 
1150  return success();
1151 }
1152 
1154  Value operand, Value result, int64_t tensorAxis,
1155  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1156  ShapedType operandType = cast<ShapedType>(operand.getType());
1157  ShapedType resultType = cast<ShapedType>(result.getType());
1158  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1159  if (axis != tensorAxis) {
1160  if (failed(verifyDimensionCompatibility(
1161  result.getLoc(), operandType.getDimSize(axis),
1162  resultType.getDimSize(axis), axis))) {
1163  return failure();
1164  }
1165  }
1166  }
1167 
1168  auto deviceGroupSize =
1169  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1170  auto operandScatterDimSize =
1171  DimensionSize(operandType.getDimSize(tensorAxis));
1172  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1173  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1174  return emitError(result.getLoc())
1175  << "Operand dimension size " << int64_t(operandScatterDimSize)
1176  << " is not divisible by collective device group size "
1177  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1178  << ".";
1179  }
1180  DimensionSize expectedResultTensorDimSize =
1181  operandScatterDimSize / deviceGroupSize;
1182  if (failed(verifyDimensionCompatibility(
1183  result.getLoc(), expectedResultTensorDimSize.value(),
1184  resultType.getDimSize(tensorAxis), tensorAxis))) {
1185  return failure();
1186  }
1187 
1188  return success();
1189 }
1190 
1191 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
1192  ArrayRef<MeshAxis> meshAxes,
1193  int64_t sliceAxis) {
1194  RankedTensorType operandRankedTensorType =
1195  cast<RankedTensorType>(operandType);
1196  DimensionSize operandSliceAxisSize =
1197  operandRankedTensorType.getShape()[sliceAxis];
1198  SmallVector<int64_t> resultShape =
1199  llvm::to_vector(operandRankedTensorType.getShape());
1200 
1201  resultShape[sliceAxis] =
1202  operandSliceAxisSize /
1203  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1204  return operandRankedTensorType.clone(resultShape);
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // mesh.all_gather op
1209 //===----------------------------------------------------------------------===//
1210 
1211 LogicalResult
1212 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1213  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1214  if (failed(mesh)) {
1215  return failure();
1216  }
1217  auto gatherAxis = getGatherAxis().getSExtValue();
1218  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1219  gatherAxis, getMeshAxes(),
1220  mesh.value().getShape());
1221 }
1222 
1223 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1224  MLIRContext *context) {
1225  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1226 }
1227 
1228 void AllGatherOp::getAsmResultNames(
1229  function_ref<void(Value, StringRef)> setNameFn) {
1230  setNameFn(getResult(), "all_gather");
1231 }
1232 
1233 //===----------------------------------------------------------------------===//
1234 // mesh.all_reduce op
1235 //===----------------------------------------------------------------------===//
1236 
1237 LogicalResult
1238 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1239  return getMeshAndVerifyAxes(*this, symbolTable);
1240 }
1241 
1242 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1243  MLIRContext *context) {
1244  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1245 }
1246 
1247 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1248  Value input, StringRef mesh,
1249  ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
1250  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1251  reduction);
1252 }
1253 
1254 void AllReduceOp::getAsmResultNames(
1255  function_ref<void(Value, StringRef)> setNameFn) {
1256  setNameFn(getResult(), "all_reduce");
1257 }
1258 
1259 //===----------------------------------------------------------------------===//
1260 // mesh.all_slice op
1261 //===----------------------------------------------------------------------===//
1262 
1263 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1264  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1265  if (failed(mesh)) {
1266  return failure();
1267  }
1269  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1270  mesh.value().getShape());
1271 }
1272 
1273 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1274  MLIRContext *context) {
1275  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1276 }
1277 
1278 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1279  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1280  int64_t sliceAxis) {
1281  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1282  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1283  sliceAxis);
1284 }
1285 
1286 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1287  Type resultType, Value input, StringRef mesh,
1288  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1289  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1290  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1291 }
1292 
1293 void AllSliceOp::getAsmResultNames(
1294  function_ref<void(Value, StringRef)> setNameFn) {
1295  setNameFn(getResult(), "all_slice");
1296 }
1297 
1298 //===----------------------------------------------------------------------===//
1299 // mesh.all_to_all op
1300 //===----------------------------------------------------------------------===//
1301 
1302 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1303  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1304  if (failed(mesh)) {
1305  return failure();
1306  }
1307 
1309  getOperand(), getResult(), getSplitAxis().getSExtValue(),
1310  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1311 }
1312 
1313 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1314  MLIRContext *context) {
1315  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1316 }
1317 
1318 void AllToAllOp::getAsmResultNames(
1319  function_ref<void(Value, StringRef)> setNameFn) {
1320  setNameFn(getResult(), "all_to_all");
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // mesh.broadcast op
1325 //===----------------------------------------------------------------------===//
1326 
1327 LogicalResult
1328 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1329  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1330  if (failed(mesh)) {
1331  return failure();
1332  }
1333  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1334  getRootDynamic(), getMeshAxes(),
1335  mesh.value().getShape()))) {
1336  return failure();
1337  }
1338 
1339  return success();
1340 }
1341 
1342 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1343  MLIRContext *context) {
1344  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1345 }
1346 
1347 void BroadcastOp::getAsmResultNames(
1348  function_ref<void(Value, StringRef)> setNameFn) {
1349  setNameFn(getResult(), "broadcast");
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // mesh.gather op
1354 //===----------------------------------------------------------------------===//
1355 
1356 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1357  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1358  if (failed(mesh)) {
1359  return failure();
1360  }
1361  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1362  getRootDynamic(), getMeshAxes(),
1363  mesh.value().getShape()))) {
1364  return failure();
1365  }
1366 
1367  auto gatherAxis = getGatherAxis().getSExtValue();
1368  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1369  getMeshAxes(),
1370  mesh.value().getShape());
1371 }
1372 
1373 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1374  MLIRContext *context) {
1375  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1376 }
1377 
1378 void GatherOp::getAsmResultNames(
1379  function_ref<void(Value, StringRef)> setNameFn) {
1380  setNameFn(getResult(), "gather");
1381 }
1382 
1383 //===----------------------------------------------------------------------===//
1384 // mesh.recv op
1385 //===----------------------------------------------------------------------===//
1386 
1387 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1388  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1389  if (failed(mesh)) {
1390  return failure();
1391  }
1392  if (getSource() &&
1393  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1394  getSource().value(), getSourceDynamic(),
1395  getMeshAxes(), mesh.value().getShape()))) {
1396  return failure();
1397  }
1398  return success();
1399 }
1400 
1401 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1402  MLIRContext *context) {
1403  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1404 }
1405 
1406 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1407  setNameFn(getResult(), "recv");
1408 }
1409 
1410 //===----------------------------------------------------------------------===//
1411 // mesh.reduce op
1412 //===----------------------------------------------------------------------===//
1413 
1414 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1415  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1416  if (failed(mesh)) {
1417  return failure();
1418  }
1419  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1420  getRootDynamic(), getMeshAxes(),
1421  mesh.value().getShape()))) {
1422  return failure();
1423  }
1424 
1425  return success();
1426 }
1427 
1428 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1429  MLIRContext *context) {
1430  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1431 }
1432 
1433 void ReduceOp::getAsmResultNames(
1434  function_ref<void(Value, StringRef)> setNameFn) {
1435  setNameFn(getResult(), "reduce");
1436 }
1437 
1438 //===----------------------------------------------------------------------===//
1439 // mesh.reduce_scatter op
1440 //===----------------------------------------------------------------------===//
1441 
1442 LogicalResult
1443 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1444  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1445  if (failed(mesh)) {
1446  return failure();
1447  }
1448 
1450  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1451  mesh.value().getShape());
1452 }
1453 
1454 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1455  MLIRContext *context) {
1456  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1457 }
1458 
1459 void ReduceScatterOp::getAsmResultNames(
1460  function_ref<void(Value, StringRef)> setNameFn) {
1461  setNameFn(getResult(), "reduce_scatter");
1462 }
1463 
1464 //===----------------------------------------------------------------------===//
1465 // mesh.scatter op
1466 //===----------------------------------------------------------------------===//
1467 
1468 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1469  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1470  if (failed(mesh)) {
1471  return failure();
1472  }
1473  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1474  getRootDynamic(), getMeshAxes(),
1475  mesh.value().getShape()))) {
1476  return failure();
1477  }
1478 
1479  auto scatterAxis = getScatterAxis().getSExtValue();
1480  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1481  scatterAxis, getMeshAxes(),
1482  mesh.value().getShape());
1483 }
1484 
1485 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1486  MLIRContext *context) {
1487  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1488 }
1489 
1490 void ScatterOp::getAsmResultNames(
1491  function_ref<void(Value, StringRef)> setNameFn) {
1492  setNameFn(getResult(), "scatter");
1493 }
1494 
1495 //===----------------------------------------------------------------------===//
1496 // mesh.send op
1497 //===----------------------------------------------------------------------===//
1498 
1499 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1500  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1501  if (failed(mesh)) {
1502  return failure();
1503  }
1504  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1505  getDestination(), getDestinationDynamic(),
1506  getMeshAxes(), mesh.value().getShape()))) {
1507  return failure();
1508  }
1509  return success();
1510 }
1511 
1512 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1513  MLIRContext *context) {
1514  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1515 }
1516 
1517 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1518  setNameFn(getResult(), "send");
1519 }
1520 
1521 //===----------------------------------------------------------------------===//
1522 // mesh.shift op
1523 //===----------------------------------------------------------------------===//
1524 
1525 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1526  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1527  if (failed(mesh)) {
1528  return failure();
1529  }
1530 
1531  auto meshAxes = getMeshAxes();
1532  auto shiftAxis = getShiftAxis().getZExtValue();
1533  if (!llvm::is_contained(meshAxes, shiftAxis)) {
1534  return emitError() << "Invalid shift axis " << shiftAxis
1535  << ". It must be one of the grouping mesh axes.";
1536  }
1537 
1538  return success();
1539 }
1540 
1541 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1542  MLIRContext *context) {
1543  // TODO: remove op when offset is 0 or if it is a rotate with and
1544  // offset % shift_axis_mesh_dim_size == 0.
1545 }
1546 
1547 void ShiftOp::getAsmResultNames(
1548  function_ref<void(Value, StringRef)> setNameFn) {
1549  setNameFn(getResult(), "shift");
1550 }
1551 
1552 //===----------------------------------------------------------------------===//
1553 // mesh.update_halo op
1554 //===----------------------------------------------------------------------===//
1555 
1556 LogicalResult
1557 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1558  auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
1559  if (failed(mesh)) {
1560  return failure();
1561  }
1562 
1563  return success();
1564 }
1565 
1566 //===----------------------------------------------------------------------===//
1567 // TableGen'd op method definitions
1568 //===----------------------------------------------------------------------===//
1569 
1570 #define GET_OP_CLASSES
1571 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1572 
1573 #define GET_ATTRDEF_CLASSES
1574 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1575 
1576 #define GET_TYPEDEF_CLASSES
1577 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1578 
1579 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:1191
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:64
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:1064
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:202
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:150
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1153
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1081
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1108
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition: MeshOps.cpp:216
static auto product(It begin, It end)
Definition: MeshOps.cpp:1053
static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, ShardOp &newShardOp)
Definition: MeshOps.cpp:301
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:163
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:179
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:1026
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef< int16_t > values)
Definition: Builders.cpp:154
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringRef getValue() const
Returns the name of the held symbol reference.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This class provides the API for a sub-set of ops that are known to be constant-like.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
Definition: Operation.h:279
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition: Value.cpp:93
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
void replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl< Operation * > &exceptions)
Replace all uses of 'this' value with 'newValue', updating anything in the IR that uses 'this' to use...
Definition: Value.cpp:73
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:714
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:739
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:759
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
bool operator!=(Value rhs) const
Definition: MeshOps.cpp:775
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition: MeshOps.h:74
bool operator==(Value rhs) const
Definition: MeshOps.cpp:771
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
::llvm::StringRef getMesh() const
Definition: MeshOps.h:65
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
MeshSharding(::mlir::FlatSymbolRefAttr mesh_=nullptr)
Definition: MeshOps.cpp:785
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:805
bool equalShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:743
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mesh::ReductionKind ReductionKind
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:153
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder)
Definition: MeshOps.cpp:340
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:120
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:354
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:293
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:283
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:175
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:112
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition: MeshOps.cpp:78
int16_t MeshAxis
Definition: MeshOps.h:26
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:252
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.