MLIR  20.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Casting.h"
32 #include <algorithm>
33 #include <functional>
34 #include <iterator>
35 #include <numeric>
36 #include <optional>
37 #include <utility>
38 
39 #define DEBUG_TYPE "mesh-ops"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
41 
42 using namespace mlir;
43 using namespace mlir::mesh;
44 
45 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
46 
47 namespace {
48 
49 struct DimensionSize {
50  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
51  DimensionSize(int64_t val) : val(val) {}
52  int64_t value() const { return val; }
53  operator int64_t() const { return val; }
54  bool isDynamic() const { return ShapedType::isDynamic(val); }
55 
56 private:
57  int64_t val;
58 };
59 
60 } // namespace
61 
62 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
63  if (lhs.isDynamic() || rhs.isDynamic()) {
64  return DimensionSize::dynamic();
65  }
66  return lhs.value() / rhs.value();
67 }
68 
69 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
70  if (lhs.isDynamic() || rhs.isDynamic()) {
71  return DimensionSize::dynamic();
72  }
73  return lhs.value() * rhs.value();
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Mesh dialect
78 //===----------------------------------------------------------------------===//
79 
80 void MeshDialect::initialize() {
81  addOperations<
82 #define GET_OP_LIST
83 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
84  >();
85  addAttributes<
86 #define GET_ATTRDEF_LIST
87 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
88  >();
89 }
90 
92  Type type, Location loc) {
93  return arith::ConstantOp::materialize(builder, value, type, loc);
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // Mesh utilities
98 //===----------------------------------------------------------------------===//
99 
100 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
101  FlatSymbolRefAttr meshSymbol,
102  SymbolTableCollection &symbolTable) {
103  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
104  if (!mesh) {
105  return op->emitError() << "Undefined required mesh symbol \""
106  << meshSymbol.getValue() << "\".";
107  }
108 
109  return mesh;
110 }
111 
112 template <typename It>
113 bool isUnique(It begin, It end) {
114  if (begin == end) {
115  return true;
116  }
117  It next = std::next(begin);
118  if (next == end) {
119  return true;
120  }
121  for (; next != end; ++next, ++begin) {
122  if (*begin == *next) {
123  return false;
124  }
125  }
126  return true;
127 }
128 
129 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
130  MeshOp mesh) {
131  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
132  llvm::sort(sorted);
133  if (!isUnique(sorted.begin(), sorted.end())) {
134  return emitError(loc) << "Mesh axes contains duplicate elements.";
135  }
136 
137  MeshAxis rank = mesh.getRank();
138  for (auto axis : axes) {
139  if (axis >= rank || axis < 0) {
140  return emitError(loc)
141  << "0-based mesh axis index " << axis
142  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
143  << "\" is of rank " << rank << ".";
144  }
145  }
146 
147  return success();
148 }
149 
150 template <typename InShape, typename MeshShape, typename SplitAxes,
151  typename OutShape>
152 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
153  const SplitAxes &splitAxes, OutShape &outShape) {
154  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
155  llvm::adl_begin(outShape));
156  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
157  outShape[tensorAxis] = shardDimension(
158  inShape[tensorAxis],
159  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
160  }
161 }
162 
163 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
164  MeshShardingAttr sharding) {
165  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
166  SmallVector<Dim> resShapeArr(shape.getShape().size());
167  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
168  resShapeArr);
169  return shape.clone(resShapeArr);
170 }
171 
173  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
174  if (rankedTensorType) {
175  return shardShapedType(rankedTensorType, mesh, sharding);
176  }
177 
178  assert(!sharding);
179  return type;
180 }
181 
183  OpOperand &operand,
184  OpBuilder &builder) {
185  OpBuilder::InsertionGuard insertionGuard(builder);
186  Value operandValue = operand.get();
187  Operation *operandOp = operand.getOwner();
188  builder.setInsertionPointAfterValue(operandValue);
189  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
190  if (shardOp && shardOp.getShard() == sharding &&
191  !shardOp.getAnnotateForUsers()) {
192  // No need for anything the correct sharding is already set.
193  return;
194  }
195 
196  auto newShardOp =
197  builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
198  /*annotate_for_users*/ false);
199  IRRewriter rewriter(builder);
200  rewriter.replaceUsesWithIf(
201  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
202  return use.getOwner() == operandOp && use.get() == operandValue;
203  });
204 
205  if (!shardOp || shardOp.getAnnotateForUsers()) {
206  return;
207  }
208 
209  auto newShardOp2 = builder.create<ShardOp>(
210  operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
211  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
212 }
213 
215  OpResult result,
216  OpBuilder &builder) {
217  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
218  maybeInsertTargetShardingAnnotation(sharding, use, builder);
219  }
220 }
221 
223  OpOperand &operand,
224  OpBuilder &builder) {
225  OpBuilder::InsertionGuard insertionGuard(builder);
226  Value operandValue = operand.get();
227  Operation *operandOp = operand.getOwner();
228  Operation *operandSrcOp = operandValue.getDefiningOp();
229  bool isBlockArg = !operandSrcOp;
230  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
231 
232  if (shardOp && shardOp.getShard() == sharding &&
233  shardOp.getAnnotateForUsers()) {
234  // No need for anything the correct sharding is already set.
235  return;
236  }
237 
238  builder.setInsertionPoint(operandOp);
239  auto newShardOp =
240  builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
241  /*annotate_for_users*/ true);
242  IRRewriter rewriter(builder);
243  rewriter.replaceUsesWithIf(
244  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
245  return use.getOwner() == operandOp && use.get() == operandValue;
246  });
247 
248  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
249  // No need for resharding.
250  return;
251  }
252 
253  builder.setInsertionPoint(newShardOp);
254  auto newPreceedingShardOp =
255  builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
256  /*annotate_for_users*/ false);
257  rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
258  [&newShardOp](OpOperand &use) {
259  return use.getOwner() ==
260  newShardOp.getOperation();
261  });
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // mesh.mesh op
266 //===----------------------------------------------------------------------===//
267 
268 LogicalResult MeshOp::verify() {
269  int64_t rank = getRank();
270 
271  if (rank <= 0)
272  return emitOpError("rank of mesh is expected to be a positive integer");
273 
274  for (int64_t dimSize : getShape()) {
275  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
276  return emitOpError("dimension size of a mesh is expected to be "
277  "non-negative or dynamic");
278  }
279 
280  return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // mesh.mesh_shape op
285 //===----------------------------------------------------------------------===//
286 
287 LogicalResult
288 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
289  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
290  if (failed(mesh)) {
291  return failure();
292  }
293  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
294  return failure();
295  }
296 
297  size_t expectedResultsCount =
298  getAxes().empty() ? mesh->getRank() : getAxes().size();
299  if (getResult().size() != expectedResultsCount) {
300  return emitError() << "Unexpected number of results " << getResult().size()
301  << ". Expected " << expectedResultsCount << ".";
302  }
303 
304  return success();
305 }
306 
307 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
308  MeshOp mesh) {
309  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
310 }
311 
312 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
313  MeshOp mesh, ArrayRef<MeshAxis> axes) {
314  build(odsBuilder, odsState,
315  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
316  odsBuilder.getIndexType()),
317  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
318 }
319 
320 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
321  StringRef mesh, ArrayRef<MeshAxis> axes) {
322  assert(!axes.empty());
323  build(odsBuilder, odsState,
324  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
325  MeshAxesAttr::get(odsBuilder.getContext(), axes));
326 }
327 
328 void MeshShapeOp::getAsmResultNames(
329  function_ref<void(Value, StringRef)> setNameFn) {
330  setNameFn(getResults()[0], "mesh_shape");
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // mesh.shard attr
335 //===----------------------------------------------------------------------===//
336 
337 LogicalResult
340  ArrayRef<MeshAxis> partialAxes, ReductionKind) {
341  // TODO: At present mesh symbol ref is not verified. This is due to the
342  // difficulty in fetching the corresponding symbol op based on an attribute.
343 
344  llvm::SmallSet<MeshAxis, 4> visitedAxes;
345 
346  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
347  for (MeshAxis axis : axesArray) {
348  if (axis < 0)
349  return emitError() << "mesh axis is expected to be non-negative";
350  if (!visitedAxes.insert(axis).second)
351  return emitError() << "mesh axis duplicated";
352  }
353  return success();
354  };
355 
356  for (MeshAxesAttr subAxes : splitAxes) {
357  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
358  if (failed(checkMeshAxis(subAxesArray)))
359  return failure();
360  }
361  if (failed(checkMeshAxis(partialAxes)))
362  return failure();
363  return success();
364 }
365 
366 bool MeshShardingAttr::operator==(Attribute rhs) const {
367  MeshShardingAttr rhsAsMeshShardingAttr =
368  mlir::dyn_cast<MeshShardingAttr>(rhs);
369  return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
370 }
371 
372 bool MeshShardingAttr::operator!=(Attribute rhs) const {
373  return !(*this == rhs);
374 }
375 
377  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
378  return false;
379  }
380 
381  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
382  return false;
383  }
384 
385  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
386  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
387  getSplitAxes().begin() + minSize),
388  llvm::make_range(rhs.getSplitAxes().begin(),
389  rhs.getSplitAxes().begin() + minSize))) {
390  return false;
391  }
392 
393  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
394  getSplitAxes().end()),
395  std::mem_fn(&MeshAxesAttr::empty)) &&
396  llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
397  rhs.getSplitAxes().end()),
398  std::mem_fn(&MeshAxesAttr::empty));
399 }
400 
402  return !(*this == rhs);
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // mesh.shard op
407 //===----------------------------------------------------------------------===//
408 
409 void ShardOp::getAsmResultNames(
410  function_ref<void(Value, StringRef)> setNameFn) {
411  setNameFn(getResult(), "sharding_annotated");
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // mesh.process_multi_index op
416 //===----------------------------------------------------------------------===//
417 
418 LogicalResult
419 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
420  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
421  if (failed(mesh)) {
422  return failure();
423  }
424  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
425  return failure();
426  }
427 
428  size_t expectedResultsCount =
429  getAxes().empty() ? mesh->getRank() : getAxes().size();
430  if (getResult().size() != expectedResultsCount) {
431  return emitError() << "Unexpected number of results " << getResult().size()
432  << ". Expected " << expectedResultsCount << ".";
433  }
434 
435  return success();
436 }
437 
438 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
439  MeshOp mesh) {
440  build(odsBuilder, odsState,
441  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
442  mesh.getSymName(), ArrayRef<MeshAxis>());
443 }
444 
445 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
446  StringRef mesh, ArrayRef<MeshAxis> axes) {
447  build(odsBuilder, odsState,
448  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
449  MeshAxesAttr::get(odsBuilder.getContext(), axes));
450 }
451 
452 void ProcessMultiIndexOp::getAsmResultNames(
453  function_ref<void(Value, StringRef)> setNameFn) {
454  setNameFn(getResults()[0], "proc_linear_idx");
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // mesh.process_linear_index op
459 //===----------------------------------------------------------------------===//
460 
461 LogicalResult
462 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
463  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
464  if (failed(mesh)) {
465  return failure();
466  }
467  return success();
468 }
469 
470 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
471  OperationState &odsState, MeshOp mesh) {
472  build(odsBuilder, odsState, mesh.getSymName());
473 }
474 
475 void ProcessLinearIndexOp::getAsmResultNames(
476  function_ref<void(Value, StringRef)> setNameFn) {
477  setNameFn(getResult(), "proc_linear_idx");
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // collective communication ops
482 //===----------------------------------------------------------------------===//
483 
484 namespace {
485 
486 template <typename Op>
487 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
489  LogicalResult matchAndRewrite(Op op,
490  PatternRewriter &rewriter) const override {
491  auto meshAxes = op.getMeshAxes();
492  if (!meshAxes.empty()) {
493  return failure();
494  }
495  if (op.getInput().getType() != op.getResult().getType()) {
496  return failure();
497  }
498 
499  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
500  rewriter.eraseOp(op.getOperation());
501  return success();
502  }
503 };
504 
505 } // namespace
506 
507 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
508  ArrayRef<int64_t> device,
509  Operation::operand_range deviceDynamic,
510  ArrayRef<MeshAxis> meshAxes,
511  ArrayRef<int64_t> meshShape) {
512  if (device.size() != meshAxes.size()) {
513  return emitError(loc) << "In-group device \"" << deviceName
514  << "\" has unexpected multi-index size "
515  << device.size() << ". Expected " << meshAxes.size()
516  << ".";
517  }
518 
519  for (size_t i = 0; i < device.size(); ++i) {
520  if (!ShapedType::isDynamic(device[i]) &&
521  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
522  meshShape[meshAxes[i]] <= device[i]) {
523  return emitError(loc)
524  << "Out of bounds coordinate " << i << " for in-group device \""
525  << deviceName << "\"."
526  << " Got " << device[i] << ", but expected value in the range [0, "
527  << (meshShape[meshAxes[i]] - 1) << "].";
528  }
529  }
530  return success();
531 }
532 
533 template <typename Op>
534 static FailureOr<MeshOp>
536  auto mesh =
537  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
538  if (failed(mesh)) {
539  return failure();
540  }
541  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
542  return failure();
543  }
544  return mesh;
545 }
546 
547 template <typename It>
548 static auto product(It begin, It end) {
549  using ElementType = std::decay_t<decltype(*begin)>;
550  return std::accumulate(begin, end, static_cast<ElementType>(1),
551  std::multiplies<ElementType>());
552 }
553 
554 template <typename R>
555 static auto product(R &&range) {
556  return product(adl_begin(range), adl_end(range));
557 }
558 
559 static LogicalResult verifyDimensionCompatibility(Location loc,
560  int64_t expectedDimSize,
561  int64_t resultDimSize,
562  int64_t resultAxis) {
563  if (!ShapedType::isDynamic(resultDimSize) &&
564  expectedDimSize != resultDimSize) {
565  return emitError(loc) << "Dimension size mismatch for result axis "
566  << resultAxis << ". Expected "
567  << (ShapedType::isDynamic(expectedDimSize)
568  ? Twine("dynamic")
569  : Twine(expectedDimSize))
570  << ", but got " << resultDimSize << ".";
571  }
572 
573  return success();
574 }
575 
577  Value operand, Value result, int64_t gatherAxis,
578  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
579  auto resultRank = cast<ShapedType>(result.getType()).getRank();
580  if (gatherAxis < 0 || gatherAxis >= resultRank) {
581  return emitError(result.getLoc())
582  << "Gather axis " << gatherAxis << " is out of bounds [0, "
583  << resultRank << ").";
584  }
585 
586  ShapedType operandType = cast<ShapedType>(operand.getType());
587  ShapedType resultType = cast<ShapedType>(result.getType());
588  auto deviceGroupSize =
589  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
590  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
591  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
592  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
593  auto expectedResultDimSize =
594  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
595  if (failed(verifyDimensionCompatibility(
596  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
597  return failure();
598  }
599  }
600  return success();
601 }
602 
604  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
605  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
606  ShapedType operandType = cast<ShapedType>(operand.getType());
607  ShapedType resultType = cast<ShapedType>(result.getType());
608  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
609  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
610  if (failed(verifyDimensionCompatibility(
611  result.getLoc(), operandType.getDimSize(axis),
612  resultType.getDimSize(axis), axis))) {
613  return failure();
614  }
615  }
616  }
617 
618  if (splitAxis == concatAxis) {
619  return success();
620  }
621 
622  auto deviceGroupSize =
623  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
624  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
625  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
626  DimensionSize expectedResultConcatDimSize =
627  operandConcatDimSize * deviceGroupSize;
628  DimensionSize expectedResultSplitDimSize =
629  operandSplitDimSize / deviceGroupSize;
630  if (!expectedResultSplitDimSize.isDynamic() &&
631  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
632  expectedResultSplitDimSize = DimensionSize::dynamic();
633  }
634  if (failed(verifyDimensionCompatibility(
635  result.getLoc(), expectedResultConcatDimSize.value(),
636  resultType.getDimSize(concatAxis), concatAxis))) {
637  return failure();
638  }
639  if (failed(verifyDimensionCompatibility(
640  result.getLoc(), expectedResultSplitDimSize.value(),
641  resultType.getDimSize(splitAxis), splitAxis))) {
642  return failure();
643  }
644 
645  return success();
646 }
647 
649  Value operand, Value result, int64_t tensorAxis,
650  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
651  ShapedType operandType = cast<ShapedType>(operand.getType());
652  ShapedType resultType = cast<ShapedType>(result.getType());
653  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
654  if (axis != tensorAxis) {
655  if (failed(verifyDimensionCompatibility(
656  result.getLoc(), operandType.getDimSize(axis),
657  resultType.getDimSize(axis), axis))) {
658  return failure();
659  }
660  }
661  }
662 
663  auto deviceGroupSize =
664  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
665  auto operandScatterDimSize =
666  DimensionSize(operandType.getDimSize(tensorAxis));
667  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
668  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
669  return emitError(result.getLoc())
670  << "Operand dimension size " << int64_t(operandScatterDimSize)
671  << " is not divisible by collective device group size "
672  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
673  << ".";
674  }
675  DimensionSize expectedResultTensorDimSize =
676  operandScatterDimSize / deviceGroupSize;
677  if (failed(verifyDimensionCompatibility(
678  result.getLoc(), expectedResultTensorDimSize.value(),
679  resultType.getDimSize(tensorAxis), tensorAxis))) {
680  return failure();
681  }
682 
683  return success();
684 }
685 
686 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
687  ArrayRef<MeshAxis> meshAxes,
688  int64_t sliceAxis) {
689  RankedTensorType operandRankedTensorType =
690  cast<RankedTensorType>(operandType);
691  DimensionSize operandSliceAxisSize =
692  operandRankedTensorType.getShape()[sliceAxis];
693  SmallVector<int64_t> resultShape =
694  llvm::to_vector(operandRankedTensorType.getShape());
695 
696  resultShape[sliceAxis] =
697  operandSliceAxisSize /
698  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
699  return operandRankedTensorType.clone(resultShape);
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // mesh.all_gather op
704 //===----------------------------------------------------------------------===//
705 
706 LogicalResult
707 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
708  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
709  if (failed(mesh)) {
710  return failure();
711  }
712  auto gatherAxis = getGatherAxis().getSExtValue();
713  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
714  gatherAxis, getMeshAxes(),
715  mesh.value().getShape());
716 }
717 
718 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
719  MLIRContext *context) {
720  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
721 }
722 
723 void AllGatherOp::getAsmResultNames(
724  function_ref<void(Value, StringRef)> setNameFn) {
725  setNameFn(getResult(), "all_gather");
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // mesh.all_reduce op
730 //===----------------------------------------------------------------------===//
731 
732 LogicalResult
733 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
734  return getMeshAndVerifyAxes(*this, symbolTable);
735 }
736 
737 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
738  MLIRContext *context) {
739  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
740 }
741 
742 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
743  Value input, StringRef mesh,
744  ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
745  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
746  reduction);
747 }
748 
749 void AllReduceOp::getAsmResultNames(
750  function_ref<void(Value, StringRef)> setNameFn) {
751  setNameFn(getResult(), "all_reduce");
752 }
753 
754 //===----------------------------------------------------------------------===//
755 // mesh.all_slice op
756 //===----------------------------------------------------------------------===//
757 
758 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
759  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
760  if (failed(mesh)) {
761  return failure();
762  }
764  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
765  mesh.value().getShape());
766 }
767 
768 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
769  MLIRContext *context) {
770  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
771 }
772 
773 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
774  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
775  int64_t sliceAxis) {
776  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
777  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
778  sliceAxis);
779 }
780 
781 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
782  Type resultType, Value input, StringRef mesh,
783  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
784  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
785  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
786 }
787 
788 void AllSliceOp::getAsmResultNames(
789  function_ref<void(Value, StringRef)> setNameFn) {
790  setNameFn(getResult(), "all_slice");
791 }
792 
793 //===----------------------------------------------------------------------===//
794 // mesh.all_to_all op
795 //===----------------------------------------------------------------------===//
796 
797 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
798  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
799  if (failed(mesh)) {
800  return failure();
801  }
802 
804  getOperand(), getResult(), getSplitAxis().getSExtValue(),
805  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
806 }
807 
808 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
809  MLIRContext *context) {
810  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
811 }
812 
813 void AllToAllOp::getAsmResultNames(
814  function_ref<void(Value, StringRef)> setNameFn) {
815  setNameFn(getResult(), "all_to_all");
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // mesh.broadcast op
820 //===----------------------------------------------------------------------===//
821 
822 LogicalResult
823 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
824  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
825  if (failed(mesh)) {
826  return failure();
827  }
828  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
829  getRootDynamic(), getMeshAxes(),
830  mesh.value().getShape()))) {
831  return failure();
832  }
833 
834  return success();
835 }
836 
837 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
838  MLIRContext *context) {
839  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
840 }
841 
842 void BroadcastOp::getAsmResultNames(
843  function_ref<void(Value, StringRef)> setNameFn) {
844  setNameFn(getResult(), "broadcast");
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // mesh.gather op
849 //===----------------------------------------------------------------------===//
850 
851 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
852  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
853  if (failed(mesh)) {
854  return failure();
855  }
856  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
857  getRootDynamic(), getMeshAxes(),
858  mesh.value().getShape()))) {
859  return failure();
860  }
861 
862  auto gatherAxis = getGatherAxis().getSExtValue();
863  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
864  getMeshAxes(),
865  mesh.value().getShape());
866 }
867 
868 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
869  MLIRContext *context) {
870  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
871 }
872 
873 void GatherOp::getAsmResultNames(
874  function_ref<void(Value, StringRef)> setNameFn) {
875  setNameFn(getResult(), "gather");
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // mesh.recv op
880 //===----------------------------------------------------------------------===//
881 
882 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
883  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
884  if (failed(mesh)) {
885  return failure();
886  }
887  if (getSource() &&
888  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
889  getSource().value(), getSourceDynamic(),
890  getMeshAxes(), mesh.value().getShape()))) {
891  return failure();
892  }
893  return success();
894 }
895 
896 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
897  MLIRContext *context) {
898  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
899 }
900 
901 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
902  setNameFn(getResult(), "recv");
903 }
904 
905 //===----------------------------------------------------------------------===//
906 // mesh.reduce op
907 //===----------------------------------------------------------------------===//
908 
909 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
910  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
911  if (failed(mesh)) {
912  return failure();
913  }
914  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
915  getRootDynamic(), getMeshAxes(),
916  mesh.value().getShape()))) {
917  return failure();
918  }
919 
920  return success();
921 }
922 
923 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
924  MLIRContext *context) {
925  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
926 }
927 
928 void ReduceOp::getAsmResultNames(
929  function_ref<void(Value, StringRef)> setNameFn) {
930  setNameFn(getResult(), "reduce");
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // mesh.reduce_scatter op
935 //===----------------------------------------------------------------------===//
936 
937 LogicalResult
938 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
939  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
940  if (failed(mesh)) {
941  return failure();
942  }
943 
945  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
946  mesh.value().getShape());
947 }
948 
949 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
950  MLIRContext *context) {
951  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
952 }
953 
954 void ReduceScatterOp::getAsmResultNames(
955  function_ref<void(Value, StringRef)> setNameFn) {
956  setNameFn(getResult(), "reduce_scatter");
957 }
958 
959 //===----------------------------------------------------------------------===//
960 // mesh.scatter op
961 //===----------------------------------------------------------------------===//
962 
963 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
964  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
965  if (failed(mesh)) {
966  return failure();
967  }
968  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
969  getRootDynamic(), getMeshAxes(),
970  mesh.value().getShape()))) {
971  return failure();
972  }
973 
974  auto scatterAxis = getScatterAxis().getSExtValue();
975  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
976  scatterAxis, getMeshAxes(),
977  mesh.value().getShape());
978 }
979 
980 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
981  MLIRContext *context) {
982  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
983 }
984 
985 void ScatterOp::getAsmResultNames(
986  function_ref<void(Value, StringRef)> setNameFn) {
987  setNameFn(getResult(), "scatter");
988 }
989 
990 //===----------------------------------------------------------------------===//
991 // mesh.send op
992 //===----------------------------------------------------------------------===//
993 
994 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
995  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
996  if (failed(mesh)) {
997  return failure();
998  }
999  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1000  getDestination(), getDestinationDynamic(),
1001  getMeshAxes(), mesh.value().getShape()))) {
1002  return failure();
1003  }
1004  return success();
1005 }
1006 
1007 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1008  MLIRContext *context) {
1009  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1010 }
1011 
1012 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1013  setNameFn(getResult(), "send");
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // mesh.shift op
1018 //===----------------------------------------------------------------------===//
1019 
1020 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1021  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1022  if (failed(mesh)) {
1023  return failure();
1024  }
1025 
1026  auto meshAxes = getMeshAxes();
1027  auto shiftAxis = getShiftAxis().getZExtValue();
1028  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1029  return emitError() << "Invalid shift axis " << shiftAxis
1030  << ". It must be one of the grouping mesh axes.";
1031  }
1032 
1033  return success();
1034 }
1035 
1036 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1037  MLIRContext *context) {
1038  // TODO: remove op when offset is 0 or if it is a rotate with and
1039  // offset % shift_axis_mesh_dim_size == 0.
1040 }
1041 
1042 void ShiftOp::getAsmResultNames(
1043  function_ref<void(Value, StringRef)> setNameFn) {
1044  setNameFn(getResult(), "shift");
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // TableGen'd op method definitions
1049 //===----------------------------------------------------------------------===//
1050 
1051 #define GET_OP_CLASSES
1052 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1053 
1054 #define GET_ATTRDEF_CLASSES
1055 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1056 
1057 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:686
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:62
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:559
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:535
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:100
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:648
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:576
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:603
static auto product(It begin, It end)
Definition: MeshOps.cpp:548
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape)
Definition: MeshOps.cpp:152
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:113
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:129
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:507
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:75
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:424
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
Definition: Operation.h:279
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:90
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:182
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:61
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:67
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:112
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:163
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:172
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:222
bool operator==(const Fraction &x, const Fraction &y)
Definition: Fraction.h:90
bool operator!=(const Fraction &x, const Fraction &y)
Definition: Fraction.h:94
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:265
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.