MLIR  19.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Support/LLVM.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include <algorithm>
32 #include <functional>
33 #include <iterator>
34 #include <numeric>
35 #include <optional>
36 #include <utility>
37 
38 #define DEBUG_TYPE "mesh-ops"
39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
40 
41 using namespace mlir;
42 using namespace mlir::mesh;
43 
44 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
45 
46 namespace {
47 
48 struct DimensionSize {
49  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
50  DimensionSize(int64_t val) : val(val) {}
51  int64_t value() const { return val; }
52  operator int64_t() const { return val; }
53  bool isDynamic() const { return ShapedType::isDynamic(val); }
54 
55 private:
56  int64_t val;
57 };
58 
59 } // namespace
60 
61 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
62  if (lhs.isDynamic() || rhs.isDynamic()) {
63  return DimensionSize::dynamic();
64  }
65  return lhs.value() / rhs.value();
66 }
67 
68 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
69  if (lhs.isDynamic() || rhs.isDynamic()) {
70  return DimensionSize::dynamic();
71  }
72  return lhs.value() * rhs.value();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Mesh dialect
77 //===----------------------------------------------------------------------===//
78 
79 void MeshDialect::initialize() {
80  addOperations<
81 #define GET_OP_LIST
82 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
83  >();
84  addAttributes<
85 #define GET_ATTRDEF_LIST
86 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
87  >();
88 }
89 
91  Type type, Location loc) {
92  return arith::ConstantOp::materialize(builder, value, type, loc);
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Mesh utilities
97 //===----------------------------------------------------------------------===//
98 
100  FlatSymbolRefAttr meshSymbol,
101  SymbolTableCollection &symbolTable) {
102  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
103  if (!mesh) {
104  return op->emitError() << "Undefined required mesh symbol \""
105  << meshSymbol.getValue() << "\".";
106  }
107 
108  return mesh;
109 }
110 
111 template <typename It>
112 bool isUnique(It begin, It end) {
113  if (begin == end) {
114  return true;
115  }
116  It next = std::next(begin);
117  if (next == end) {
118  return true;
119  }
120  for (; next != end; ++next, ++begin) {
121  if (*begin == *next) {
122  return false;
123  }
124  }
125  return true;
126 }
127 
129  MeshOp mesh) {
130  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
131  llvm::sort(sorted);
132  if (!isUnique(sorted.begin(), sorted.end())) {
133  return emitError(loc) << "Mesh axes contains duplicate elements.";
134  }
135 
136  MeshAxis rank = mesh.getRank();
137  for (auto axis : axes) {
138  if (axis >= rank || axis < 0) {
139  return emitError(loc)
140  << "0-based mesh axis index " << axis
141  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
142  << "\" is of rank " << rank << ".";
143  }
144  }
145 
146  return success();
147 }
148 
149 template <typename InShape, typename MeshShape, typename SplitAxes,
150  typename OutShape>
151 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
152  const SplitAxes &splitAxes, OutShape &outShape) {
153  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
154  llvm::adl_begin(outShape));
155  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
156  outShape[tensorAxis] = shardDimension(
157  inShape[tensorAxis],
158  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
159  }
160 }
161 
162 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
163  MeshShardingAttr sharding) {
164  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
165  SmallVector<Dim> resShapeArr(shape.getShape().size());
166  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
167  resShapeArr);
168  return shape.clone(resShapeArr);
169 }
170 
171 Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
172  RankedTensorType rankedTensorType = type.dyn_cast<RankedTensorType>();
173  if (rankedTensorType) {
174  return shardShapedType(rankedTensorType, mesh, sharding);
175  }
176 
177  assert(!sharding);
178  return type;
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // mesh.mesh op
183 //===----------------------------------------------------------------------===//
184 
186  int64_t rank = getRank();
187 
188  if (rank <= 0)
189  return emitOpError("rank of mesh is expected to be a positive integer");
190 
191  for (int64_t dimSize : getShape()) {
192  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
193  return emitOpError("dimension size of a mesh is expected to be "
194  "non-negative or dynamic");
195  }
196 
197  return success();
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // mesh.mesh_shape op
202 //===----------------------------------------------------------------------===//
203 
205 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
206  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
207  if (failed(mesh)) {
208  return failure();
209  }
210  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
211  return failure();
212  }
213 
214  size_t expectedResultsCount =
215  getAxes().empty() ? mesh->getRank() : getAxes().size();
216  if (getResult().size() != expectedResultsCount) {
217  return emitError() << "Unexpected number of results " << getResult().size()
218  << ". Expected " << expectedResultsCount << ".";
219  }
220 
221  return success();
222 }
223 
224 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
225  MeshOp mesh) {
226  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
227 }
228 
229 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
230  MeshOp mesh, ArrayRef<MeshAxis> axes) {
231  build(odsBuilder, odsState,
232  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
233  odsBuilder.getIndexType()),
234  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
235 }
236 
237 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238  StringRef mesh, ArrayRef<MeshAxis> axes) {
239  assert(!axes.empty());
240  build(odsBuilder, odsState,
241  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
242  MeshAxesAttr::get(odsBuilder.getContext(), axes));
243 }
244 
245 void MeshShapeOp::getAsmResultNames(
246  function_ref<void(Value, StringRef)> setNameFn) {
247  setNameFn(getResults()[0], "mesh_shape");
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // mesh.shard attr
252 //===----------------------------------------------------------------------===//
253 
257  ArrayRef<MeshAxis> partialAxes, ReductionKind) {
258  // TODO: At present mesh symbol ref is not verified. This is due to the
259  // difficulty in fetching the corresponding symbol op based on an attribute.
260 
261  llvm::SmallSet<MeshAxis, 4> visitedAxes;
262 
263  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
264  for (MeshAxis axis : axesArray) {
265  if (axis < 0)
266  return emitError() << "mesh axis is expected to be non-negative";
267  if (!visitedAxes.insert(axis).second)
268  return emitError() << "mesh axis duplicated";
269  }
270  return success();
271  };
272 
273  for (MeshAxesAttr subAxes : splitAxes) {
274  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
275  if (failed(checkMeshAxis(subAxesArray)))
276  return failure();
277  }
278  if (failed(checkMeshAxis(partialAxes)))
279  return failure();
280  return success();
281 }
282 
283 bool MeshShardingAttr::operator==(Attribute rhs) const {
284  MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
285  return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
286 }
287 
288 bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
289  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
290  return false;
291  }
292 
293  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
294  return false;
295  }
296 
297  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
298  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
299  getSplitAxes().begin() + minSize),
300  llvm::make_range(rhs.getSplitAxes().begin(),
301  rhs.getSplitAxes().begin() + minSize))) {
302  return false;
303  }
304 
305  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
306  getSplitAxes().end()),
307  std::mem_fn(&MeshAxesAttr::empty)) &&
308  llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
309  rhs.getSplitAxes().end()),
310  std::mem_fn(&MeshAxesAttr::empty));
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // mesh.shard op
315 //===----------------------------------------------------------------------===//
316 
317 void ShardOp::getAsmResultNames(
318  function_ref<void(Value, StringRef)> setNameFn) {
319  setNameFn(getResult(), "sharding_annotated");
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // mesh.process_multi_index op
324 //===----------------------------------------------------------------------===//
325 
327 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
328  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
329  if (failed(mesh)) {
330  return failure();
331  }
332  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
333  return failure();
334  }
335 
336  size_t expectedResultsCount =
337  getAxes().empty() ? mesh->getRank() : getAxes().size();
338  if (getResult().size() != expectedResultsCount) {
339  return emitError() << "Unexpected number of results " << getResult().size()
340  << ". Expected " << expectedResultsCount << ".";
341  }
342 
343  return success();
344 }
345 
346 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
347  MeshOp mesh) {
348  build(odsBuilder, odsState,
349  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
350  mesh.getSymName(), ArrayRef<MeshAxis>());
351 }
352 
353 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
354  StringRef mesh, ArrayRef<MeshAxis> axes) {
355  build(odsBuilder, odsState,
356  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
357  MeshAxesAttr::get(odsBuilder.getContext(), axes));
358 }
359 
360 void ProcessMultiIndexOp::getAsmResultNames(
361  function_ref<void(Value, StringRef)> setNameFn) {
362  setNameFn(getResults()[0], "proc_linear_idx");
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // mesh.process_linear_index op
367 //===----------------------------------------------------------------------===//
368 
370 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
371  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
372  if (failed(mesh)) {
373  return failure();
374  }
375  return success();
376 }
377 
378 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
379  OperationState &odsState, MeshOp mesh) {
380  build(odsBuilder, odsState, mesh.getSymName());
381 }
382 
383 void ProcessLinearIndexOp::getAsmResultNames(
384  function_ref<void(Value, StringRef)> setNameFn) {
385  setNameFn(getResult(), "proc_linear_idx");
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // collective communication ops
390 //===----------------------------------------------------------------------===//
391 
392 namespace {
393 
394 template <typename Op>
395 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
397  LogicalResult matchAndRewrite(Op op,
398  PatternRewriter &rewriter) const override {
399  auto meshAxes = op.getMeshAxes();
400  if (!meshAxes.empty()) {
401  return failure();
402  }
403  if (op.getInput().getType() != op.getResult().getType()) {
404  return failure();
405  }
406 
407  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
408  rewriter.eraseOp(op.getOperation());
409  return success();
410  }
411 };
412 
413 } // namespace
414 
415 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
416  ArrayRef<int64_t> device,
417  Operation::operand_range deviceDynamic,
418  ArrayRef<MeshAxis> meshAxes,
419  ArrayRef<int64_t> meshShape) {
420  if (device.size() != meshAxes.size()) {
421  return emitError(loc) << "In-group device \"" << deviceName
422  << "\" has unexpected multi-index size "
423  << device.size() << ". Expected " << meshAxes.size()
424  << ".";
425  }
426 
427  for (size_t i = 0; i < device.size(); ++i) {
428  if (!ShapedType::isDynamic(device[i]) &&
429  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
430  meshShape[meshAxes[i]] <= device[i]) {
431  return emitError(loc)
432  << "Out of bounds coordinate " << i << " for in-group device \""
433  << deviceName << "\"."
434  << " Got " << device[i] << ", but expected value in the range [0, "
435  << (meshShape[meshAxes[i]] - 1) << "].";
436  }
437  }
438  return success();
439 }
440 
441 template <typename Op>
442 static FailureOr<MeshOp>
444  auto mesh =
445  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
446  if (failed(mesh)) {
447  return failure();
448  }
449  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
450  return failure();
451  }
452  return mesh;
453 }
454 
455 template <typename It>
456 static auto product(It begin, It end) {
457  using ElementType = std::decay_t<decltype(*begin)>;
458  return std::accumulate(begin, end, static_cast<ElementType>(1),
459  std::multiplies<ElementType>());
460 }
461 
462 template <typename R>
463 static auto product(R &&range) {
464  return product(adl_begin(range), adl_end(range));
465 }
466 
468  int64_t expectedDimSize,
469  int64_t resultDimSize,
470  int64_t resultAxis) {
471  if (!ShapedType::isDynamic(resultDimSize) &&
472  expectedDimSize != resultDimSize) {
473  return emitError(loc) << "Dimension size mismatch for result axis "
474  << resultAxis << ". Expected "
475  << (ShapedType::isDynamic(expectedDimSize)
476  ? Twine("dynamic")
477  : Twine(expectedDimSize))
478  << ", but got " << resultDimSize << ".";
479  }
480 
481  return success();
482 }
483 
485  Value operand, Value result, int64_t gatherAxis,
486  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
487  auto resultRank = result.getType().template cast<ShapedType>().getRank();
488  if (gatherAxis < 0 || gatherAxis >= resultRank) {
489  return emitError(result.getLoc())
490  << "Gather axis " << gatherAxis << " is out of bounds [0, "
491  << resultRank << ").";
492  }
493 
494  ShapedType operandType = operand.getType().cast<ShapedType>();
495  ShapedType resultType = result.getType().cast<ShapedType>();
496  auto deviceGroupSize =
497  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
498  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
499  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
500  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
501  auto expectedResultDimSize =
502  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
504  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
505  return failure();
506  }
507  }
508  return success();
509 }
510 
512  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
513  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
514  ShapedType operandType = operand.getType().cast<ShapedType>();
515  ShapedType resultType = result.getType().cast<ShapedType>();
516  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
517  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
519  result.getLoc(), operandType.getDimSize(axis),
520  resultType.getDimSize(axis), axis))) {
521  return failure();
522  }
523  }
524  }
525 
526  if (splitAxis == concatAxis) {
527  return success();
528  }
529 
530  auto deviceGroupSize =
531  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
532  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
533  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
534  DimensionSize expectedResultConcatDimSize =
535  operandConcatDimSize * deviceGroupSize;
536  DimensionSize expectedResultSplitDimSize =
537  operandSplitDimSize / deviceGroupSize;
538  if (!expectedResultSplitDimSize.isDynamic() &&
539  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
540  expectedResultSplitDimSize = DimensionSize::dynamic();
541  }
543  result.getLoc(), expectedResultConcatDimSize.value(),
544  resultType.getDimSize(concatAxis), concatAxis))) {
545  return failure();
546  }
548  result.getLoc(), expectedResultSplitDimSize.value(),
549  resultType.getDimSize(splitAxis), splitAxis))) {
550  return failure();
551  }
552 
553  return success();
554 }
555 
557  Value operand, Value result, int64_t tensorAxis,
558  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
559  ShapedType operandType = operand.getType().cast<ShapedType>();
560  ShapedType resultType = result.getType().cast<ShapedType>();
561  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
562  if (axis != tensorAxis) {
564  result.getLoc(), operandType.getDimSize(axis),
565  resultType.getDimSize(axis), axis))) {
566  return failure();
567  }
568  }
569  }
570 
571  auto deviceGroupSize =
572  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
573  auto operandScatterDimSize =
574  DimensionSize(operandType.getDimSize(tensorAxis));
575  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
576  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
577  return emitError(result.getLoc())
578  << "Operand dimension size " << int64_t(operandScatterDimSize)
579  << " is not divisible by collective device group size "
580  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
581  << ".";
582  }
583  DimensionSize expectedResultTensorDimSize =
584  operandScatterDimSize / deviceGroupSize;
586  result.getLoc(), expectedResultTensorDimSize.value(),
587  resultType.getDimSize(tensorAxis), tensorAxis))) {
588  return failure();
589  }
590 
591  return success();
592 }
593 
594 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
595  ArrayRef<MeshAxis> meshAxes,
596  int64_t sliceAxis) {
597  RankedTensorType operandRankedTensorType =
598  cast<RankedTensorType>(operandType);
599  DimensionSize operandSliceAxisSize =
600  operandRankedTensorType.getShape()[sliceAxis];
601  SmallVector<int64_t> resultShape =
602  llvm::to_vector(operandRankedTensorType.getShape());
603 
604  resultShape[sliceAxis] =
605  operandSliceAxisSize /
606  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
607  return operandRankedTensorType.clone(resultShape);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // mesh.all_gather op
612 //===----------------------------------------------------------------------===//
613 
615 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
616  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
617  if (failed(mesh)) {
618  return failure();
619  }
620  auto gatherAxis = getGatherAxis().getSExtValue();
621  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
622  gatherAxis, getMeshAxes(),
623  mesh.value().getShape());
624 }
625 
626 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
627  MLIRContext *context) {
628  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
629 }
630 
631 void AllGatherOp::getAsmResultNames(
632  function_ref<void(Value, StringRef)> setNameFn) {
633  setNameFn(getResult(), "all_gather");
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // mesh.all_reduce op
638 //===----------------------------------------------------------------------===//
639 
641 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
642  return getMeshAndVerifyAxes(*this, symbolTable);
643 }
644 
645 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
646  MLIRContext *context) {
647  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
648 }
649 
650 void AllReduceOp::getAsmResultNames(
651  function_ref<void(Value, StringRef)> setNameFn) {
652  setNameFn(getResult(), "all_reduce");
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // mesh.all_slice op
657 //===----------------------------------------------------------------------===//
658 
659 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
660  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
661  if (failed(mesh)) {
662  return failure();
663  }
665  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
666  mesh.value().getShape());
667 }
668 
669 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
670  MLIRContext *context) {
671  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
672 }
673 
674 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
675  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
676  int64_t sliceAxis) {
677  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
678  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
679  sliceAxis);
680 }
681 
682 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
683  Type resultType, Value input, StringRef mesh,
684  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
685  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
686  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
687 }
688 
689 void AllSliceOp::getAsmResultNames(
690  function_ref<void(Value, StringRef)> setNameFn) {
691  setNameFn(getResult(), "all_slice");
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // mesh.all_to_all op
696 //===----------------------------------------------------------------------===//
697 
698 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
699  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
700  if (failed(mesh)) {
701  return failure();
702  }
703 
705  getOperand(), getResult(), getSplitAxis().getSExtValue(),
706  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
707 }
708 
709 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
710  MLIRContext *context) {
711  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
712 }
713 
714 void AllToAllOp::getAsmResultNames(
715  function_ref<void(Value, StringRef)> setNameFn) {
716  setNameFn(getResult(), "all_to_all");
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // mesh.broadcast op
721 //===----------------------------------------------------------------------===//
722 
724 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
725  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
726  if (failed(mesh)) {
727  return failure();
728  }
729  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
730  getRootDynamic(), getMeshAxes(),
731  mesh.value().getShape()))) {
732  return failure();
733  }
734 
735  return success();
736 }
737 
738 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
739  MLIRContext *context) {
740  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
741 }
742 
743 void BroadcastOp::getAsmResultNames(
744  function_ref<void(Value, StringRef)> setNameFn) {
745  setNameFn(getResult(), "broadcast");
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // mesh.gather op
750 //===----------------------------------------------------------------------===//
751 
752 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
753  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
754  if (failed(mesh)) {
755  return failure();
756  }
757  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
758  getRootDynamic(), getMeshAxes(),
759  mesh.value().getShape()))) {
760  return failure();
761  }
762 
763  auto gatherAxis = getGatherAxis().getSExtValue();
764  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
765  getMeshAxes(),
766  mesh.value().getShape());
767 }
768 
769 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
770  MLIRContext *context) {
771  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
772 }
773 
774 void GatherOp::getAsmResultNames(
775  function_ref<void(Value, StringRef)> setNameFn) {
776  setNameFn(getResult(), "gather");
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // mesh.recv op
781 //===----------------------------------------------------------------------===//
782 
783 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
784  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
785  if (failed(mesh)) {
786  return failure();
787  }
788  if (getSource() &&
789  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
790  getSource().value(), getSourceDynamic(),
791  getMeshAxes(), mesh.value().getShape()))) {
792  return failure();
793  }
794  return success();
795 }
796 
797 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
798  MLIRContext *context) {
799  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
800 }
801 
802 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
803  setNameFn(getResult(), "recv");
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // mesh.reduce op
808 //===----------------------------------------------------------------------===//
809 
810 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
811  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
812  if (failed(mesh)) {
813  return failure();
814  }
815  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
816  getRootDynamic(), getMeshAxes(),
817  mesh.value().getShape()))) {
818  return failure();
819  }
820 
821  return success();
822 }
823 
824 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
825  MLIRContext *context) {
826  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
827 }
828 
829 void ReduceOp::getAsmResultNames(
830  function_ref<void(Value, StringRef)> setNameFn) {
831  setNameFn(getResult(), "reduce");
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // mesh.reduce_scatter op
836 //===----------------------------------------------------------------------===//
837 
839 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
840  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
841  if (failed(mesh)) {
842  return failure();
843  }
844 
846  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
847  mesh.value().getShape());
848 }
849 
850 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
851  MLIRContext *context) {
852  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
853 }
854 
855 void ReduceScatterOp::getAsmResultNames(
856  function_ref<void(Value, StringRef)> setNameFn) {
857  setNameFn(getResult(), "reduce_scatter");
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // mesh.scatter op
862 //===----------------------------------------------------------------------===//
863 
864 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
865  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
866  if (failed(mesh)) {
867  return failure();
868  }
869  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
870  getRootDynamic(), getMeshAxes(),
871  mesh.value().getShape()))) {
872  return failure();
873  }
874 
875  auto scatterAxis = getScatterAxis().getSExtValue();
876  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
877  scatterAxis, getMeshAxes(),
878  mesh.value().getShape());
879 }
880 
881 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
882  MLIRContext *context) {
883  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
884 }
885 
886 void ScatterOp::getAsmResultNames(
887  function_ref<void(Value, StringRef)> setNameFn) {
888  setNameFn(getResult(), "scatter");
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // mesh.send op
893 //===----------------------------------------------------------------------===//
894 
895 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
896  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
897  if (failed(mesh)) {
898  return failure();
899  }
900  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
901  getDestination(), getDestinationDynamic(),
902  getMeshAxes(), mesh.value().getShape()))) {
903  return failure();
904  }
905  return success();
906 }
907 
908 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
909  MLIRContext *context) {
910  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
911 }
912 
913 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
914  setNameFn(getResult(), "send");
915 }
916 
917 //===----------------------------------------------------------------------===//
918 // mesh.shift op
919 //===----------------------------------------------------------------------===//
920 
921 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
922  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
923  if (failed(mesh)) {
924  return failure();
925  }
926 
927  auto meshAxes = getMeshAxes();
928  auto shiftAxis = getShiftAxis().getZExtValue();
929  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
930  return emitError() << "Invalid shift axis " << shiftAxis
931  << ". It must be one of the grouping mesh axes.";
932  }
933 
934  return success();
935 }
936 
937 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
938  MLIRContext *context) {
939  // TODO: remove op when offset is 0 or if it is a rotate with and
940  // offset % shift_axis_mesh_dim_size == 0.
941 }
942 
943 void ShiftOp::getAsmResultNames(
944  function_ref<void(Value, StringRef)> setNameFn) {
945  setNameFn(getResult(), "shift");
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // TableGen'd op method definitions
950 //===----------------------------------------------------------------------===//
951 
952 #define GET_OP_CLASSES
953 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
954 
955 #define GET_ATTRDEF_CLASSES
956 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
957 
958 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:594
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:61
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:467
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:443
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:99
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:556
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:484
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:511
static auto product(It begin, It end)
Definition: MeshOps.cpp:456
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape)
Definition: MeshOps.cpp:151
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:112
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:128
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:415
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:179
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:640
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
U dyn_cast() const
Definition: Types.h:329
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:79
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:101
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:162
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:171
bool operator==(const Fraction &x, const Fraction &y)
Definition: Fraction.h:88
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:266
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.