MLIR  19.0.0git
MeshOps.cpp
Go to the documentation of this file.
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
25 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Casting.h"
33 #include <algorithm>
34 #include <functional>
35 #include <iterator>
36 #include <numeric>
37 #include <optional>
38 #include <utility>
39 
40 #define DEBUG_TYPE "mesh-ops"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
42 
43 using namespace mlir;
44 using namespace mlir::mesh;
45 
46 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
47 
48 namespace {
49 
50 struct DimensionSize {
51  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
52  DimensionSize(int64_t val) : val(val) {}
53  int64_t value() const { return val; }
54  operator int64_t() const { return val; }
55  bool isDynamic() const { return ShapedType::isDynamic(val); }
56 
57 private:
58  int64_t val;
59 };
60 
61 } // namespace
62 
63 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
64  if (lhs.isDynamic() || rhs.isDynamic()) {
65  return DimensionSize::dynamic();
66  }
67  return lhs.value() / rhs.value();
68 }
69 
70 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
71  if (lhs.isDynamic() || rhs.isDynamic()) {
72  return DimensionSize::dynamic();
73  }
74  return lhs.value() * rhs.value();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Mesh dialect
79 //===----------------------------------------------------------------------===//
80 
81 void MeshDialect::initialize() {
82  addOperations<
83 #define GET_OP_LIST
84 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
85  >();
86  addAttributes<
87 #define GET_ATTRDEF_LIST
88 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
89  >();
90 }
91 
93  Type type, Location loc) {
94  return arith::ConstantOp::materialize(builder, value, type, loc);
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // Mesh utilities
99 //===----------------------------------------------------------------------===//
100 
102  FlatSymbolRefAttr meshSymbol,
103  SymbolTableCollection &symbolTable) {
104  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
105  if (!mesh) {
106  return op->emitError() << "Undefined required mesh symbol \""
107  << meshSymbol.getValue() << "\".";
108  }
109 
110  return mesh;
111 }
112 
113 template <typename It>
114 bool isUnique(It begin, It end) {
115  if (begin == end) {
116  return true;
117  }
118  It next = std::next(begin);
119  if (next == end) {
120  return true;
121  }
122  for (; next != end; ++next, ++begin) {
123  if (*begin == *next) {
124  return false;
125  }
126  }
127  return true;
128 }
129 
131  MeshOp mesh) {
132  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
133  llvm::sort(sorted);
134  if (!isUnique(sorted.begin(), sorted.end())) {
135  return emitError(loc) << "Mesh axes contains duplicate elements.";
136  }
137 
138  MeshAxis rank = mesh.getRank();
139  for (auto axis : axes) {
140  if (axis >= rank || axis < 0) {
141  return emitError(loc)
142  << "0-based mesh axis index " << axis
143  << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
144  << "\" is of rank " << rank << ".";
145  }
146  }
147 
148  return success();
149 }
150 
151 template <typename InShape, typename MeshShape, typename SplitAxes,
152  typename OutShape>
153 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
154  const SplitAxes &splitAxes, OutShape &outShape) {
155  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
156  llvm::adl_begin(outShape));
157  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
158  outShape[tensorAxis] = shardDimension(
159  inShape[tensorAxis],
160  collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
161  }
162 }
163 
164 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
165  MeshShardingAttr sharding) {
166  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
167  SmallVector<Dim> resShapeArr(shape.getShape().size());
168  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
169  resShapeArr);
170  return shape.clone(resShapeArr);
171 }
172 
174  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
175  if (rankedTensorType) {
176  return shardShapedType(rankedTensorType, mesh, sharding);
177  }
178 
179  assert(!sharding);
180  return type;
181 }
182 
184  OpOperand &operand,
185  OpBuilder &builder) {
186  OpBuilder::InsertionGuard insertionGuard(builder);
187  Value operandValue = operand.get();
188  Operation *operandOp = operand.getOwner();
189  builder.setInsertionPointAfterValue(operandValue);
190  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
191  if (shardOp && shardOp.getShard() == sharding &&
192  !shardOp.getAnnotateForUsers()) {
193  // No need for anything the correct sharding is already set.
194  return;
195  }
196 
197  auto newShardOp =
198  builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
199  /*annotate_for_users*/ false);
200  IRRewriter rewriter(builder);
201  rewriter.replaceUsesWithIf(
202  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
203  return use.getOwner() == operandOp && use.get() == operandValue;
204  });
205 
206  if (!shardOp || shardOp.getAnnotateForUsers()) {
207  return;
208  }
209 
210  auto newShardOp2 = builder.create<ShardOp>(
211  operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
212  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
213 }
214 
216  OpResult result,
217  OpBuilder &builder) {
218  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
219  maybeInsertTargetShardingAnnotation(sharding, use, builder);
220  }
221 }
222 
224  OpOperand &operand,
225  OpBuilder &builder) {
226  OpBuilder::InsertionGuard insertionGuard(builder);
227  Value operandValue = operand.get();
228  Operation *operandOp = operand.getOwner();
229  Operation *operandSrcOp = operandValue.getDefiningOp();
230  bool isBlockArg = !operandSrcOp;
231  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
232 
233  if (shardOp && shardOp.getShard() == sharding &&
234  shardOp.getAnnotateForUsers()) {
235  // No need for anything the correct sharding is already set.
236  return;
237  }
238 
239  builder.setInsertionPoint(operandOp);
240  auto newShardOp =
241  builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
242  /*annotate_for_users*/ true);
243  IRRewriter rewriter(builder);
244  rewriter.replaceUsesWithIf(
245  operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
246  return use.getOwner() == operandOp && use.get() == operandValue;
247  });
248 
249  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
250  // No need for resharding.
251  return;
252  }
253 
254  builder.setInsertionPoint(newShardOp);
255  auto newPreceedingShardOp =
256  builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
257  /*annotate_for_users*/ false);
258  rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
259  [&newShardOp](OpOperand &use) {
260  return use.getOwner() ==
261  newShardOp.getOperation();
262  });
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // mesh.mesh op
267 //===----------------------------------------------------------------------===//
268 
270  int64_t rank = getRank();
271 
272  if (rank <= 0)
273  return emitOpError("rank of mesh is expected to be a positive integer");
274 
275  for (int64_t dimSize : getShape()) {
276  if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
277  return emitOpError("dimension size of a mesh is expected to be "
278  "non-negative or dynamic");
279  }
280 
281  return success();
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // mesh.mesh_shape op
286 //===----------------------------------------------------------------------===//
287 
289 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
290  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
291  if (failed(mesh)) {
292  return failure();
293  }
294  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
295  return failure();
296  }
297 
298  size_t expectedResultsCount =
299  getAxes().empty() ? mesh->getRank() : getAxes().size();
300  if (getResult().size() != expectedResultsCount) {
301  return emitError() << "Unexpected number of results " << getResult().size()
302  << ". Expected " << expectedResultsCount << ".";
303  }
304 
305  return success();
306 }
307 
308 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
309  MeshOp mesh) {
310  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
311 }
312 
313 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
314  MeshOp mesh, ArrayRef<MeshAxis> axes) {
315  build(odsBuilder, odsState,
316  SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
317  odsBuilder.getIndexType()),
318  mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
319 }
320 
321 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
322  StringRef mesh, ArrayRef<MeshAxis> axes) {
323  assert(!axes.empty());
324  build(odsBuilder, odsState,
325  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
326  MeshAxesAttr::get(odsBuilder.getContext(), axes));
327 }
328 
329 void MeshShapeOp::getAsmResultNames(
330  function_ref<void(Value, StringRef)> setNameFn) {
331  setNameFn(getResults()[0], "mesh_shape");
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // mesh.shard attr
336 //===----------------------------------------------------------------------===//
337 
341  ArrayRef<MeshAxis> partialAxes, ReductionKind) {
342  // TODO: At present mesh symbol ref is not verified. This is due to the
343  // difficulty in fetching the corresponding symbol op based on an attribute.
344 
345  llvm::SmallSet<MeshAxis, 4> visitedAxes;
346 
347  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
348  for (MeshAxis axis : axesArray) {
349  if (axis < 0)
350  return emitError() << "mesh axis is expected to be non-negative";
351  if (!visitedAxes.insert(axis).second)
352  return emitError() << "mesh axis duplicated";
353  }
354  return success();
355  };
356 
357  for (MeshAxesAttr subAxes : splitAxes) {
358  ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
359  if (failed(checkMeshAxis(subAxesArray)))
360  return failure();
361  }
362  if (failed(checkMeshAxis(partialAxes)))
363  return failure();
364  return success();
365 }
366 
367 bool MeshShardingAttr::operator==(Attribute rhs) const {
368  MeshShardingAttr rhsAsMeshShardingAttr =
369  mlir::dyn_cast<MeshShardingAttr>(rhs);
370  return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
371 }
372 
373 bool MeshShardingAttr::operator!=(Attribute rhs) const {
374  return !(*this == rhs);
375 }
376 
378  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
379  return false;
380  }
381 
382  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
383  return false;
384  }
385 
386  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
387  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
388  getSplitAxes().begin() + minSize),
389  llvm::make_range(rhs.getSplitAxes().begin(),
390  rhs.getSplitAxes().begin() + minSize))) {
391  return false;
392  }
393 
394  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
395  getSplitAxes().end()),
396  std::mem_fn(&MeshAxesAttr::empty)) &&
397  llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
398  rhs.getSplitAxes().end()),
399  std::mem_fn(&MeshAxesAttr::empty));
400 }
401 
403  return !(*this == rhs);
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // mesh.shard op
408 //===----------------------------------------------------------------------===//
409 
410 void ShardOp::getAsmResultNames(
411  function_ref<void(Value, StringRef)> setNameFn) {
412  setNameFn(getResult(), "sharding_annotated");
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // mesh.process_multi_index op
417 //===----------------------------------------------------------------------===//
418 
420 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
421  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
422  if (failed(mesh)) {
423  return failure();
424  }
425  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
426  return failure();
427  }
428 
429  size_t expectedResultsCount =
430  getAxes().empty() ? mesh->getRank() : getAxes().size();
431  if (getResult().size() != expectedResultsCount) {
432  return emitError() << "Unexpected number of results " << getResult().size()
433  << ". Expected " << expectedResultsCount << ".";
434  }
435 
436  return success();
437 }
438 
439 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
440  MeshOp mesh) {
441  build(odsBuilder, odsState,
442  SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
443  mesh.getSymName(), ArrayRef<MeshAxis>());
444 }
445 
446 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
447  StringRef mesh, ArrayRef<MeshAxis> axes) {
448  build(odsBuilder, odsState,
449  SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
450  MeshAxesAttr::get(odsBuilder.getContext(), axes));
451 }
452 
453 void ProcessMultiIndexOp::getAsmResultNames(
454  function_ref<void(Value, StringRef)> setNameFn) {
455  setNameFn(getResults()[0], "proc_linear_idx");
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // mesh.process_linear_index op
460 //===----------------------------------------------------------------------===//
461 
463 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
464  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
465  if (failed(mesh)) {
466  return failure();
467  }
468  return success();
469 }
470 
471 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
472  OperationState &odsState, MeshOp mesh) {
473  build(odsBuilder, odsState, mesh.getSymName());
474 }
475 
476 void ProcessLinearIndexOp::getAsmResultNames(
477  function_ref<void(Value, StringRef)> setNameFn) {
478  setNameFn(getResult(), "proc_linear_idx");
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // collective communication ops
483 //===----------------------------------------------------------------------===//
484 
485 namespace {
486 
487 template <typename Op>
488 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
490  LogicalResult matchAndRewrite(Op op,
491  PatternRewriter &rewriter) const override {
492  auto meshAxes = op.getMeshAxes();
493  if (!meshAxes.empty()) {
494  return failure();
495  }
496  if (op.getInput().getType() != op.getResult().getType()) {
497  return failure();
498  }
499 
500  rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
501  rewriter.eraseOp(op.getOperation());
502  return success();
503  }
504 };
505 
506 } // namespace
507 
508 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
509  ArrayRef<int64_t> device,
510  Operation::operand_range deviceDynamic,
511  ArrayRef<MeshAxis> meshAxes,
512  ArrayRef<int64_t> meshShape) {
513  if (device.size() != meshAxes.size()) {
514  return emitError(loc) << "In-group device \"" << deviceName
515  << "\" has unexpected multi-index size "
516  << device.size() << ". Expected " << meshAxes.size()
517  << ".";
518  }
519 
520  for (size_t i = 0; i < device.size(); ++i) {
521  if (!ShapedType::isDynamic(device[i]) &&
522  !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
523  meshShape[meshAxes[i]] <= device[i]) {
524  return emitError(loc)
525  << "Out of bounds coordinate " << i << " for in-group device \""
526  << deviceName << "\"."
527  << " Got " << device[i] << ", but expected value in the range [0, "
528  << (meshShape[meshAxes[i]] - 1) << "].";
529  }
530  }
531  return success();
532 }
533 
534 template <typename Op>
535 static FailureOr<MeshOp>
537  auto mesh =
538  ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
539  if (failed(mesh)) {
540  return failure();
541  }
542  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
543  return failure();
544  }
545  return mesh;
546 }
547 
548 template <typename It>
549 static auto product(It begin, It end) {
550  using ElementType = std::decay_t<decltype(*begin)>;
551  return std::accumulate(begin, end, static_cast<ElementType>(1),
552  std::multiplies<ElementType>());
553 }
554 
555 template <typename R>
556 static auto product(R &&range) {
557  return product(adl_begin(range), adl_end(range));
558 }
559 
561  int64_t expectedDimSize,
562  int64_t resultDimSize,
563  int64_t resultAxis) {
564  if (!ShapedType::isDynamic(resultDimSize) &&
565  expectedDimSize != resultDimSize) {
566  return emitError(loc) << "Dimension size mismatch for result axis "
567  << resultAxis << ". Expected "
568  << (ShapedType::isDynamic(expectedDimSize)
569  ? Twine("dynamic")
570  : Twine(expectedDimSize))
571  << ", but got " << resultDimSize << ".";
572  }
573 
574  return success();
575 }
576 
578  Value operand, Value result, int64_t gatherAxis,
579  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
580  auto resultRank = cast<ShapedType>(result.getType()).getRank();
581  if (gatherAxis < 0 || gatherAxis >= resultRank) {
582  return emitError(result.getLoc())
583  << "Gather axis " << gatherAxis << " is out of bounds [0, "
584  << resultRank << ").";
585  }
586 
587  ShapedType operandType = cast<ShapedType>(operand.getType());
588  ShapedType resultType = cast<ShapedType>(result.getType());
589  auto deviceGroupSize =
590  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
591  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
592  auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
593  auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
594  auto expectedResultDimSize =
595  axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
597  result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
598  return failure();
599  }
600  }
601  return success();
602 }
603 
605  Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
606  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
607  ShapedType operandType = cast<ShapedType>(operand.getType());
608  ShapedType resultType = cast<ShapedType>(result.getType());
609  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
610  if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
612  result.getLoc(), operandType.getDimSize(axis),
613  resultType.getDimSize(axis), axis))) {
614  return failure();
615  }
616  }
617  }
618 
619  if (splitAxis == concatAxis) {
620  return success();
621  }
622 
623  auto deviceGroupSize =
624  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
625  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
626  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
627  DimensionSize expectedResultConcatDimSize =
628  operandConcatDimSize * deviceGroupSize;
629  DimensionSize expectedResultSplitDimSize =
630  operandSplitDimSize / deviceGroupSize;
631  if (!expectedResultSplitDimSize.isDynamic() &&
632  int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
633  expectedResultSplitDimSize = DimensionSize::dynamic();
634  }
636  result.getLoc(), expectedResultConcatDimSize.value(),
637  resultType.getDimSize(concatAxis), concatAxis))) {
638  return failure();
639  }
641  result.getLoc(), expectedResultSplitDimSize.value(),
642  resultType.getDimSize(splitAxis), splitAxis))) {
643  return failure();
644  }
645 
646  return success();
647 }
648 
650  Value operand, Value result, int64_t tensorAxis,
651  ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
652  ShapedType operandType = cast<ShapedType>(operand.getType());
653  ShapedType resultType = cast<ShapedType>(result.getType());
654  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
655  if (axis != tensorAxis) {
657  result.getLoc(), operandType.getDimSize(axis),
658  resultType.getDimSize(axis), axis))) {
659  return failure();
660  }
661  }
662  }
663 
664  auto deviceGroupSize =
665  DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
666  auto operandScatterDimSize =
667  DimensionSize(operandType.getDimSize(tensorAxis));
668  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
669  int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
670  return emitError(result.getLoc())
671  << "Operand dimension size " << int64_t(operandScatterDimSize)
672  << " is not divisible by collective device group size "
673  << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
674  << ".";
675  }
676  DimensionSize expectedResultTensorDimSize =
677  operandScatterDimSize / deviceGroupSize;
679  result.getLoc(), expectedResultTensorDimSize.value(),
680  resultType.getDimSize(tensorAxis), tensorAxis))) {
681  return failure();
682  }
683 
684  return success();
685 }
686 
687 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
688  ArrayRef<MeshAxis> meshAxes,
689  int64_t sliceAxis) {
690  RankedTensorType operandRankedTensorType =
691  cast<RankedTensorType>(operandType);
692  DimensionSize operandSliceAxisSize =
693  operandRankedTensorType.getShape()[sliceAxis];
694  SmallVector<int64_t> resultShape =
695  llvm::to_vector(operandRankedTensorType.getShape());
696 
697  resultShape[sliceAxis] =
698  operandSliceAxisSize /
699  DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
700  return operandRankedTensorType.clone(resultShape);
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // mesh.all_gather op
705 //===----------------------------------------------------------------------===//
706 
708 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
709  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
710  if (failed(mesh)) {
711  return failure();
712  }
713  auto gatherAxis = getGatherAxis().getSExtValue();
714  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
715  gatherAxis, getMeshAxes(),
716  mesh.value().getShape());
717 }
718 
719 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
720  MLIRContext *context) {
721  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
722 }
723 
724 void AllGatherOp::getAsmResultNames(
725  function_ref<void(Value, StringRef)> setNameFn) {
726  setNameFn(getResult(), "all_gather");
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // mesh.all_reduce op
731 //===----------------------------------------------------------------------===//
732 
734 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
735  return getMeshAndVerifyAxes(*this, symbolTable);
736 }
737 
738 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
739  MLIRContext *context) {
740  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
741 }
742 
743 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
744  Value input, StringRef mesh,
745  ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
746  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
747  reduction);
748 }
749 
750 void AllReduceOp::getAsmResultNames(
751  function_ref<void(Value, StringRef)> setNameFn) {
752  setNameFn(getResult(), "all_reduce");
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // mesh.all_slice op
757 //===----------------------------------------------------------------------===//
758 
759 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
760  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
761  if (failed(mesh)) {
762  return failure();
763  }
765  getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
766  mesh.value().getShape());
767 }
768 
769 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
770  MLIRContext *context) {
771  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
772 }
773 
774 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
775  Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
776  int64_t sliceAxis) {
777  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
778  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
779  sliceAxis);
780 }
781 
782 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
783  Type resultType, Value input, StringRef mesh,
784  ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
785  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
786  APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
787 }
788 
789 void AllSliceOp::getAsmResultNames(
790  function_ref<void(Value, StringRef)> setNameFn) {
791  setNameFn(getResult(), "all_slice");
792 }
793 
794 //===----------------------------------------------------------------------===//
795 // mesh.all_to_all op
796 //===----------------------------------------------------------------------===//
797 
798 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
799  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
800  if (failed(mesh)) {
801  return failure();
802  }
803 
805  getOperand(), getResult(), getSplitAxis().getSExtValue(),
806  getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
807 }
808 
809 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
810  MLIRContext *context) {
811  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
812 }
813 
814 void AllToAllOp::getAsmResultNames(
815  function_ref<void(Value, StringRef)> setNameFn) {
816  setNameFn(getResult(), "all_to_all");
817 }
818 
819 //===----------------------------------------------------------------------===//
820 // mesh.broadcast op
821 //===----------------------------------------------------------------------===//
822 
824 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
825  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
826  if (failed(mesh)) {
827  return failure();
828  }
829  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
830  getRootDynamic(), getMeshAxes(),
831  mesh.value().getShape()))) {
832  return failure();
833  }
834 
835  return success();
836 }
837 
838 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
839  MLIRContext *context) {
840  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
841 }
842 
843 void BroadcastOp::getAsmResultNames(
844  function_ref<void(Value, StringRef)> setNameFn) {
845  setNameFn(getResult(), "broadcast");
846 }
847 
848 //===----------------------------------------------------------------------===//
849 // mesh.gather op
850 //===----------------------------------------------------------------------===//
851 
852 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
853  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
854  if (failed(mesh)) {
855  return failure();
856  }
857  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
858  getRootDynamic(), getMeshAxes(),
859  mesh.value().getShape()))) {
860  return failure();
861  }
862 
863  auto gatherAxis = getGatherAxis().getSExtValue();
864  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
865  getMeshAxes(),
866  mesh.value().getShape());
867 }
868 
869 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
870  MLIRContext *context) {
871  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
872 }
873 
874 void GatherOp::getAsmResultNames(
875  function_ref<void(Value, StringRef)> setNameFn) {
876  setNameFn(getResult(), "gather");
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // mesh.recv op
881 //===----------------------------------------------------------------------===//
882 
883 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
884  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
885  if (failed(mesh)) {
886  return failure();
887  }
888  if (getSource() &&
889  failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
890  getSource().value(), getSourceDynamic(),
891  getMeshAxes(), mesh.value().getShape()))) {
892  return failure();
893  }
894  return success();
895 }
896 
897 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
898  MLIRContext *context) {
899  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
900 }
901 
902 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
903  setNameFn(getResult(), "recv");
904 }
905 
906 //===----------------------------------------------------------------------===//
907 // mesh.reduce op
908 //===----------------------------------------------------------------------===//
909 
910 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
911  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
912  if (failed(mesh)) {
913  return failure();
914  }
915  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
916  getRootDynamic(), getMeshAxes(),
917  mesh.value().getShape()))) {
918  return failure();
919  }
920 
921  return success();
922 }
923 
924 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
925  MLIRContext *context) {
926  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
927 }
928 
929 void ReduceOp::getAsmResultNames(
930  function_ref<void(Value, StringRef)> setNameFn) {
931  setNameFn(getResult(), "reduce");
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // mesh.reduce_scatter op
936 //===----------------------------------------------------------------------===//
937 
939 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
940  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
941  if (failed(mesh)) {
942  return failure();
943  }
944 
946  getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
947  mesh.value().getShape());
948 }
949 
950 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
951  MLIRContext *context) {
952  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
953 }
954 
955 void ReduceScatterOp::getAsmResultNames(
956  function_ref<void(Value, StringRef)> setNameFn) {
957  setNameFn(getResult(), "reduce_scatter");
958 }
959 
960 //===----------------------------------------------------------------------===//
961 // mesh.scatter op
962 //===----------------------------------------------------------------------===//
963 
964 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
965  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
966  if (failed(mesh)) {
967  return failure();
968  }
969  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
970  getRootDynamic(), getMeshAxes(),
971  mesh.value().getShape()))) {
972  return failure();
973  }
974 
975  auto scatterAxis = getScatterAxis().getSExtValue();
976  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
977  scatterAxis, getMeshAxes(),
978  mesh.value().getShape());
979 }
980 
981 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
982  MLIRContext *context) {
983  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
984 }
985 
986 void ScatterOp::getAsmResultNames(
987  function_ref<void(Value, StringRef)> setNameFn) {
988  setNameFn(getResult(), "scatter");
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // mesh.send op
993 //===----------------------------------------------------------------------===//
994 
995 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
996  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
997  if (failed(mesh)) {
998  return failure();
999  }
1000  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1001  getDestination(), getDestinationDynamic(),
1002  getMeshAxes(), mesh.value().getShape()))) {
1003  return failure();
1004  }
1005  return success();
1006 }
1007 
1008 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1009  MLIRContext *context) {
1010  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1011 }
1012 
1013 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1014  setNameFn(getResult(), "send");
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // mesh.shift op
1019 //===----------------------------------------------------------------------===//
1020 
1021 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1022  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1023  if (failed(mesh)) {
1024  return failure();
1025  }
1026 
1027  auto meshAxes = getMeshAxes();
1028  auto shiftAxis = getShiftAxis().getZExtValue();
1029  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1030  return emitError() << "Invalid shift axis " << shiftAxis
1031  << ". It must be one of the grouping mesh axes.";
1032  }
1033 
1034  return success();
1035 }
1036 
1037 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1038  MLIRContext *context) {
1039  // TODO: remove op when offset is 0 or if it is a rotate with and
1040  // offset % shift_axis_mesh_dim_size == 0.
1041 }
1042 
1043 void ShiftOp::getAsmResultNames(
1044  function_ref<void(Value, StringRef)> setNameFn) {
1045  setNameFn(getResult(), "shift");
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // TableGen'd op method definitions
1050 //===----------------------------------------------------------------------===//
1051 
1052 #define GET_OP_CLASSES
1053 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1054 
1055 #define GET_ATTRDEF_CLASSES
1056 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1057 
1058 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:687
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
Definition: MeshOps.cpp:63
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
Definition: MeshOps.cpp:560
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:536
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
Definition: MeshOps.cpp:101
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:649
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:577
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:604
static auto product(It begin, It end)
Definition: MeshOps.cpp:549
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape)
Definition: MeshOps.cpp:153
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:114
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
Definition: MeshOps.cpp:130
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
Definition: MeshOps.cpp:508
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
Definition: Operation.h:279
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:90
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:183
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:61
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:67
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:112
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:164
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:173
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:223
bool operator==(const Fraction &x, const Fraction &y)
Definition: Fraction.h:90
bool operator!=(const Fraction &x, const Fraction &y)
Definition: Fraction.h:94
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:266
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.