MLIR  19.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Support/LLVM.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include <algorithm>
32 #include <functional>
33 #include <iterator>
34 #include <numeric>
35 #include <optional>
36 #include <utility>
37 
38 #define DEBUG_TYPE "mesh-ops"
39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
40 
41 using namespace mlir;
42 using namespace mlir::mesh;
43 
44 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
45 
46 namespace {
47 
48 struct DimensionSize {
49  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
50  DimensionSize(int64_t val) : val(val) {}
51  int64_t value() const { return val; }
52  operator int64_t() const { return val; }
53  bool isDynamic() const { return ShapedType::isDynamic(val); }
54 
55 private:
56  int64_t val;
57 };
58 
59 } // namespace
60 
61 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
62  if (lhs.isDynamic() || rhs.isDynamic()) {
63  return DimensionSize::dynamic();
64  }
65  return lhs.value() / rhs.value();
66 }
67 
68 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
69  if (lhs.isDynamic() || rhs.isDynamic()) {
70  return DimensionSize::dynamic();
71  }
72  return lhs.value() * rhs.value();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Mesh dialect
77 //===----------------------------------------------------------------------===//
78 
79 void MeshDialect::initialize() {
80  addOperations<
81 #define GET_OP_LIST
82 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
83  >();
84  addAttributes<
85 #define GET_ATTRDEF_LIST
86 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
87  >();
88 }
89 
91  Type type, Location loc) {
92  return arith::ConstantOp::materialize(builder, value, type, loc);
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Mesh utilities
97 //===----------------------------------------------------------------------===//
98 
100  FlatSymbolRefAttr meshSymbol,
101  SymbolTableCollection &symbolTable) {
102  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
103  if (!mesh) {
104  return op->emitError() << "Undefined required mesh symbol \""
105  << meshSymbol.getValue() << "\".";
106  }
107 
108  return mesh;
109 }
110 
111 template <typename It>
112 bool isUnique(It begin, It end) {
113  if (begin == end) {
114  return true;
115  }
116  It next = std::next(begin);
117  if (next == end) {
118  return true;
119  }
120  for (; next != end; ++next, ++begin) {
121  if (*begin == *next) {
122  return false;
123  }
124  }
125  return true;
126 }
127 
129  MeshOp mesh) {
130  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
131  llvm::sort(sorted);
132  if (!isUnique(sorted.begin(), sorted.end())) {
133  return emitError(loc) << "Mesh axes contains duplicate elements.";
134  }
135 
136  MeshAxis rank = mesh.getRank();
137  for (auto axis : axes) {
138  if (axis >= rank || axis < 0) {
139  return emitError(loc)
140  << "0-based mesh axis index " << axis
141  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
142  << "\" is of rank " << rank << ".";
143  }
144  }
145 
146  return success();
147 }
148 
149 template <typename InShape, typename MeshShape, typename SplitAxes,
150  typename OutShape>
151 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
152  const SplitAxes &splitAxes, OutShape &outShape) {
153  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
154  llvm::adl_begin(outShape));
155  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
156  outShape[tensorAxis] = shardDimension(
157  inShape[tensorAxis],
158  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
159  }
160 }
161 
162 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
163  MeshShardingAttr sharding) {
164  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
165  SmallVector<Dim> resShapeArr(shape.getShape().size());
166  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
167  resShapeArr);
168  return shape.clone(resShapeArr);
169 }
170 
172  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
173  if (rankedTensorType) {
174  return shardShapedType(rankedTensorType, mesh, sharding);
175  }
176 
177  assert(!sharding);
178  return type;
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // mesh.mesh op
183 //===----------------------------------------------------------------------===//
184 
186  int64_t rank = getRank();
187 
188  if (rank <= 0)
189  return emitOpError("rank of mesh is expected to be a positive integer");
190 
191  for (int64_t dimSize : getShape()) {
192  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
193  return emitOpError("dimension size of a mesh is expected to be "
194  "non-negative or dynamic");
195  }
196 
197  return success();
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // mesh.mesh_shape op
202 //===----------------------------------------------------------------------===//
203 
205 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
206  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
207  if (failed(mesh)) {
208  return failure();
209  }
210  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
211  return failure();
212  }
213 
214  size_t expectedResultsCount =
215  getAxes().empty() ? mesh->getRank() : getAxes().size();
216  if (getResult().size() != expectedResultsCount) {
217  return emitError() << "Unexpected number of results " << getResult().size()
218  << ". Expected " << expectedResultsCount << ".";
219  }
220 
221  return success();
222 }
223 
224 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
225  MeshOp mesh) {
226  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
227 }
228 
229 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
230  MeshOp mesh, ArrayRef<MeshAxis> axes) {
231  build(odsBuilder, odsState,
232  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
233  odsBuilder.getIndexType()),
234  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
235 }
236 
237 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238  StringRef mesh, ArrayRef<MeshAxis> axes) {
239  assert(!axes.empty());
240  build(odsBuilder, odsState,
241  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
242  MeshAxesAttr::get(odsBuilder.getContext(), axes));
243 }
244 
245 void MeshShapeOp::getAsmResultNames(
246  function_ref<void(Value, StringRef)> setNameFn) {
247  setNameFn(getResults()[0], "mesh_shape");
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // mesh.shard attr
252 //===----------------------------------------------------------------------===//
253 
257  ArrayRef<MeshAxis> partialAxes, ReductionKind) {
258  // TODO: At present mesh symbol ref is not verified. This is due to the
259  // difficulty in fetching the corresponding symbol op based on an attribute.
260 
261  llvm::SmallSet<MeshAxis, 4> visitedAxes;
262 
263  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
264  for (MeshAxis axis : axesArray) {
265  if (axis < 0)
266  return emitError() << "mesh axis is expected to be non-negative";
267  if (!visitedAxes.insert(axis).second)
268  return emitError() << "mesh axis duplicated";
269  }
270  return success();
271  };
272 
273  for (MeshAxesAttr subAxes : splitAxes) {
274  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
275  if (failed(checkMeshAxis(subAxesArray)))
276  return failure();
277  }
278  if (failed(checkMeshAxis(partialAxes)))
279  return failure();
280  return success();
281 }
282 
283 bool MeshShardingAttr::operator==(Attribute rhs) const {
284  MeshShardingAttr rhsAsMeshShardingAttr =
285  mlir::dyn_cast<MeshShardingAttr>(rhs);
286  return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
287 }
288 
290  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
291  return false;
292  }
293 
294  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
295  return false;
296  }
297 
298  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
299  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
300  getSplitAxes().begin() + minSize),
301  llvm::make_range(rhs.getSplitAxes().begin(),
302  rhs.getSplitAxes().begin() + minSize))) {
303  return false;
304  }
305 
306  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
307  getSplitAxes().end()),
308  std::mem_fn(&MeshAxesAttr::empty)) &&
309  llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
310  rhs.getSplitAxes().end()),
311  std::mem_fn(&MeshAxesAttr::empty));
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // mesh.shard op
316 //===----------------------------------------------------------------------===//
317 
318 void ShardOp::getAsmResultNames(
319  function_ref<void(Value, StringRef)> setNameFn) {
320  setNameFn(getResult(), "sharding_annotated");
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // mesh.process_multi_index op
325 //===----------------------------------------------------------------------===//
326 
328 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
329  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
330  if (failed(mesh)) {
331  return failure();
332  }
333  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
334  return failure();
335  }
336 
337  size_t expectedResultsCount =
338  getAxes().empty() ? mesh->getRank() : getAxes().size();
339  if (getResult().size() != expectedResultsCount) {
340  return emitError() << "Unexpected number of results " << getResult().size()
341  << ". Expected " << expectedResultsCount << ".";
342  }
343 
344  return success();
345 }
346 
347 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
348  MeshOp mesh) {
349  build(odsBuilder, odsState,
350  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
351  mesh.getSymName(), ArrayRef<MeshAxis>());
352 }
353 
354 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
355  StringRef mesh, ArrayRef<MeshAxis> axes) {
356  build(odsBuilder, odsState,
357  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
358  MeshAxesAttr::get(odsBuilder.getContext(), axes));
359 }
360 
361 void ProcessMultiIndexOp::getAsmResultNames(
362  function_ref<void(Value, StringRef)> setNameFn) {
363  setNameFn(getResults()[0], "proc_linear_idx");
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // mesh.process_linear_index op
368 //===----------------------------------------------------------------------===//
369 
371 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
372  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
373  if (failed(mesh)) {
374  return failure();
375  }
376  return success();
377 }
378 
379 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
380  OperationState &odsState, MeshOp mesh) {
381  build(odsBuilder, odsState, mesh.getSymName());
382 }
383 
384 void ProcessLinearIndexOp::getAsmResultNames(
385  function_ref<void(Value, StringRef)> setNameFn) {
386  setNameFn(getResult(), "proc_linear_idx");
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // collective communication ops
391 //===----------------------------------------------------------------------===//
392 
393 namespace {
394 
395 template <typename Op>
396 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
398  LogicalResult matchAndRewrite(Op op,
399  PatternRewriter &rewriter) const override {
400  auto meshAxes = op.getMeshAxes();
401  if (!meshAxes.empty()) {
402  return failure();
403  }
404  if (op.getInput().getType() != op.getResult().getType()) {
405  return failure();
406  }
407 
408  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
409  rewriter.eraseOp(op.getOperation());
410  return success();
411  }
412 };
413 
414 } // namespace
415 
416 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
417  ArrayRef<int64_t> device,
418  Operation::operand_range deviceDynamic,
419  ArrayRef<MeshAxis> meshAxes,
420  ArrayRef<int64_t> meshShape) {
421  if (device.size() != meshAxes.size()) {
422  return emitError(loc) << "In-group device \"" << deviceName
423  << "\" has unexpected multi-index size "
424  << device.size() << ". Expected " << meshAxes.size()
425  << ".";
426  }
427 
428  for (size_t i = 0; i < device.size(); ++i) {
429  if (!ShapedType::isDynamic(device[i]) &&
430  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
431  meshShape[meshAxes[i]] <= device[i]) {
432  return emitError(loc)
433  << "Out of bounds coordinate " << i << " for in-group device \""
434  << deviceName << "\"."
435  << " Got " << device[i] << ", but expected value in the range [0, "
436  << (meshShape[meshAxes[i]] - 1) << "].";
437  }
438  }
439  return success();
440 }
441 
442 template <typename Op>
443 static FailureOr<MeshOp>
445  auto mesh =
446  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
447  if (failed(mesh)) {
448  return failure();
449  }
450  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
451  return failure();
452  }
453  return mesh;
454 }
455 
456 template <typename It>
457 static auto product(It begin, It end) {
458  using ElementType = std::decay_t<decltype(*begin)>;
459  return std::accumulate(begin, end, static_cast<ElementType>(1),
460  std::multiplies<ElementType>());
461 }
462 
463 template <typename R>
464 static auto product(R &&range) {
465  return product(adl_begin(range), adl_end(range));
466 }
467 
469  int64_t expectedDimSize,
470  int64_t resultDimSize,
471  int64_t resultAxis) {
472  if (!ShapedType::isDynamic(resultDimSize) &&
473  expectedDimSize != resultDimSize) {
474  return emitError(loc) << "Dimension size mismatch for result axis "
475  << resultAxis << ". Expected "
476  << (ShapedType::isDynamic(expectedDimSize)
477  ? Twine("dynamic")
478  : Twine(expectedDimSize))
479  << ", but got " << resultDimSize << ".";
480  }
481 
482  return success();
483 }
484 
486  Value operand, Value result, int64_t gatherAxis,
487  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
488  auto resultRank = cast<ShapedType>(result.getType()).getRank();
489  if (gatherAxis < 0 || gatherAxis >= resultRank) {
490  return emitError(result.getLoc())
491  << "Gather axis " << gatherAxis << " is out of bounds [0, "
492  << resultRank << ").";
493  }
494 
495  ShapedType operandType = cast<ShapedType>(operand.getType());
496  ShapedType resultType = cast<ShapedType>(result.getType());
497  auto deviceGroupSize =
498  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
499  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
500  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
501  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
502  auto expectedResultDimSize =
503  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
505  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
506  return failure();
507  }
508  }
509  return success();
510 }
511 
513  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
514  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
515  ShapedType operandType = cast<ShapedType>(operand.getType());
516  ShapedType resultType = cast<ShapedType>(result.getType());
517  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
518  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
520  result.getLoc(), operandType.getDimSize(axis),
521  resultType.getDimSize(axis), axis))) {
522  return failure();
523  }
524  }
525  }
526 
527  if (splitAxis == concatAxis) {
528  return success();
529  }
530 
531  auto deviceGroupSize =
532  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
533  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
534  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
535  DimensionSize expectedResultConcatDimSize =
536  operandConcatDimSize * deviceGroupSize;
537  DimensionSize expectedResultSplitDimSize =
538  operandSplitDimSize / deviceGroupSize;
539  if (!expectedResultSplitDimSize.isDynamic() &&
540  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
541  expectedResultSplitDimSize = DimensionSize::dynamic();
542  }
544  result.getLoc(), expectedResultConcatDimSize.value(),
545  resultType.getDimSize(concatAxis), concatAxis))) {
546  return failure();
547  }
549  result.getLoc(), expectedResultSplitDimSize.value(),
550  resultType.getDimSize(splitAxis), splitAxis))) {
551  return failure();
552  }
553 
554  return success();
555 }
556 
558  Value operand, Value result, int64_t tensorAxis,
559  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
560  ShapedType operandType = cast<ShapedType>(operand.getType());
561  ShapedType resultType = cast<ShapedType>(result.getType());
562  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
563  if (axis != tensorAxis) {
565  result.getLoc(), operandType.getDimSize(axis),
566  resultType.getDimSize(axis), axis))) {
567  return failure();
568  }
569  }
570  }
571 
572  auto deviceGroupSize =
573  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
574  auto operandScatterDimSize =
575  DimensionSize(operandType.getDimSize(tensorAxis));
576  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
577  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
578  return emitError(result.getLoc())
579  << "Operand dimension size " << int64_t(operandScatterDimSize)
580  << " is not divisible by collective device group size "
581  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
582  << ".";
583  }
584  DimensionSize expectedResultTensorDimSize =
585  operandScatterDimSize / deviceGroupSize;
587  result.getLoc(), expectedResultTensorDimSize.value(),
588  resultType.getDimSize(tensorAxis), tensorAxis))) {
589  return failure();
590  }
591 
592  return success();
593 }
594 
595 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
596  ArrayRef<MeshAxis> meshAxes,
597  int64_t sliceAxis) {
598  RankedTensorType operandRankedTensorType =
599  cast<RankedTensorType>(operandType);
600  DimensionSize operandSliceAxisSize =
601  operandRankedTensorType.getShape()[sliceAxis];
602  SmallVector<int64_t> resultShape =
603  llvm::to_vector(operandRankedTensorType.getShape());
604 
605  resultShape[sliceAxis] =
606  operandSliceAxisSize /
607  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
608  return operandRankedTensorType.clone(resultShape);
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // mesh.all_gather op
613 //===----------------------------------------------------------------------===//
614 
616 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
617  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
618  if (failed(mesh)) {
619  return failure();
620  }
621  auto gatherAxis = getGatherAxis().getSExtValue();
622  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
623  gatherAxis, getMeshAxes(),
624  mesh.value().getShape());
625 }
626 
627 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
628  MLIRContext *context) {
629  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
630 }
631 
632 void AllGatherOp::getAsmResultNames(
633  function_ref<void(Value, StringRef)> setNameFn) {
634  setNameFn(getResult(), "all_gather");
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // mesh.all_reduce op
639 //===----------------------------------------------------------------------===//
640 
642 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
643  return getMeshAndVerifyAxes(*this, symbolTable);
644 }
645 
646 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
647  MLIRContext *context) {
648  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
649 }
650 
651 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
652  Value input, StringRef mesh,
653  ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
654  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
655  reduction);
656 }
657 
658 void AllReduceOp::getAsmResultNames(
659  function_ref<void(Value, StringRef)> setNameFn) {
660  setNameFn(getResult(), "all_reduce");
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // mesh.all_slice op
665 //===----------------------------------------------------------------------===//
666 
667 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
668  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
669  if (failed(mesh)) {
670  return failure();
671  }
673  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
674  mesh.value().getShape());
675 }
676 
677 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
678  MLIRContext *context) {
679  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
680 }
681 
682 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
683  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
684  int64_t sliceAxis) {
685  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
686  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
687  sliceAxis);
688 }
689 
690 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
691  Type resultType, Value input, StringRef mesh,
692  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
693  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
694  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
695 }
696 
697 void AllSliceOp::getAsmResultNames(
698  function_ref<void(Value, StringRef)> setNameFn) {
699  setNameFn(getResult(), "all_slice");
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // mesh.all_to_all op
704 //===----------------------------------------------------------------------===//
705 
706 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
707  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
708  if (failed(mesh)) {
709  return failure();
710  }
711 
713  getOperand(), getResult(), getSplitAxis().getSExtValue(),
714  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
715 }
716 
717 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
718  MLIRContext *context) {
719  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
720 }
721 
722 void AllToAllOp::getAsmResultNames(
723  function_ref<void(Value, StringRef)> setNameFn) {
724  setNameFn(getResult(), "all_to_all");
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // mesh.broadcast op
729 //===----------------------------------------------------------------------===//
730 
732 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
733  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
734  if (failed(mesh)) {
735  return failure();
736  }
737  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
738  getRootDynamic(), getMeshAxes(),
739  mesh.value().getShape()))) {
740  return failure();
741  }
742 
743  return success();
744 }
745 
746 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747  MLIRContext *context) {
748  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
749 }
750 
751 void BroadcastOp::getAsmResultNames(
752  function_ref<void(Value, StringRef)> setNameFn) {
753  setNameFn(getResult(), "broadcast");
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // mesh.gather op
758 //===----------------------------------------------------------------------===//
759 
760 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
761  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
762  if (failed(mesh)) {
763  return failure();
764  }
765  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
766  getRootDynamic(), getMeshAxes(),
767  mesh.value().getShape()))) {
768  return failure();
769  }
770 
771  auto gatherAxis = getGatherAxis().getSExtValue();
772  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
773  getMeshAxes(),
774  mesh.value().getShape());
775 }
776 
777 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
778  MLIRContext *context) {
779  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
780 }
781 
782 void GatherOp::getAsmResultNames(
783  function_ref<void(Value, StringRef)> setNameFn) {
784  setNameFn(getResult(), "gather");
785 }
786 
787 //===----------------------------------------------------------------------===//
788 // mesh.recv op
789 //===----------------------------------------------------------------------===//
790 
791 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
792  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
793  if (failed(mesh)) {
794  return failure();
795  }
796  if (getSource() &&
797  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
798  getSource().value(), getSourceDynamic(),
799  getMeshAxes(), mesh.value().getShape()))) {
800  return failure();
801  }
802  return success();
803 }
804 
805 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
806  MLIRContext *context) {
807  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
808 }
809 
810 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
811  setNameFn(getResult(), "recv");
812 }
813 
814 //===----------------------------------------------------------------------===//
815 // mesh.reduce op
816 //===----------------------------------------------------------------------===//
817 
818 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
819  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
820  if (failed(mesh)) {
821  return failure();
822  }
823  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
824  getRootDynamic(), getMeshAxes(),
825  mesh.value().getShape()))) {
826  return failure();
827  }
828 
829  return success();
830 }
831 
832 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
833  MLIRContext *context) {
834  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
835 }
836 
837 void ReduceOp::getAsmResultNames(
838  function_ref<void(Value, StringRef)> setNameFn) {
839  setNameFn(getResult(), "reduce");
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // mesh.reduce_scatter op
844 //===----------------------------------------------------------------------===//
845 
847 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
848  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
849  if (failed(mesh)) {
850  return failure();
851  }
852 
854  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
855  mesh.value().getShape());
856 }
857 
858 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
859  MLIRContext *context) {
860  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
861 }
862 
863 void ReduceScatterOp::getAsmResultNames(
864  function_ref<void(Value, StringRef)> setNameFn) {
865  setNameFn(getResult(), "reduce_scatter");
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // mesh.scatter op
870 //===----------------------------------------------------------------------===//
871 
872 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
873  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
874  if (failed(mesh)) {
875  return failure();
876  }
877  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
878  getRootDynamic(), getMeshAxes(),
879  mesh.value().getShape()))) {
880  return failure();
881  }
882 
883  auto scatterAxis = getScatterAxis().getSExtValue();
884  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
885  scatterAxis, getMeshAxes(),
886  mesh.value().getShape());
887 }
888 
889 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
890  MLIRContext *context) {
891  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
892 }
893 
894 void ScatterOp::getAsmResultNames(
895  function_ref<void(Value, StringRef)> setNameFn) {
896  setNameFn(getResult(), "scatter");
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // mesh.send op
901 //===----------------------------------------------------------------------===//
902 
903 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
904  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
905  if (failed(mesh)) {
906  return failure();
907  }
908  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
909  getDestination(), getDestinationDynamic(),
910  getMeshAxes(), mesh.value().getShape()))) {
911  return failure();
912  }
913  return success();
914 }
915 
916 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
917  MLIRContext *context) {
918  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
919 }
920 
921 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
922  setNameFn(getResult(), "send");
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // mesh.shift op
927 //===----------------------------------------------------------------------===//
928 
929 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
930  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
931  if (failed(mesh)) {
932  return failure();
933  }
934 
935  auto meshAxes = getMeshAxes();
936  auto shiftAxis = getShiftAxis().getZExtValue();
937  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
938  return emitError() << "Invalid shift axis " << shiftAxis
939  << ". It must be one of the grouping mesh axes.";
940  }
941 
942  return success();
943 }
944 
945 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
946  MLIRContext *context) {
947  // TODO: remove op when offset is 0 or if it is a rotate with and
948  // offset % shift_axis_mesh_dim_size == 0.
949 }
950 
951 void ShiftOp::getAsmResultNames(
952  function_ref<void(Value, StringRef)> setNameFn) {
953  setNameFn(getResult(), "shift");
954 }
955 
956 //===----------------------------------------------------------------------===//
957 // TableGen'd op method definitions
958 //===----------------------------------------------------------------------===//
959 
960 #define GET_OP_CLASSES
961 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
962 
963 #define GET_ATTRDEF_CLASSES
964 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
965 
966 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:595
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:61
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:468
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:444
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:99
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:557
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:485
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:512
static auto product(It begin, It end)
Definition: MeshOps.cpp:457
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape)
Definition: MeshOps.cpp:151
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:112
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:128
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:416
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:79
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:101
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:162
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:171
bool operator==(const Fraction &x, const Fraction &y)
Definition: Fraction.h:88
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:266
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.