MLIR  19.0.0git
MeshOps.h
Go to the documentation of this file.
1 //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
11 
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/SymbolTable.h"
21 
22 namespace mlir {
23 namespace mesh {
24 
25 using MeshAxis = int16_t;
27 
28 } // namespace mesh
29 } // namespace mlir
30 
31 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
32 
33 #define GET_ATTRDEF_CLASSES
34 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
35 
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
38 
39 namespace mlir {
40 namespace mesh {
41 
42 inline bool isReductionLoop(utils::IteratorType iType) {
43  return iType == utils::IteratorType::reduction;
44 }
45 
46 template <typename T>
48  while (!array.empty() && array.back().empty())
49  array.pop_back();
50 }
51 
52 // Is the same tensor replicated on all processes.
54  return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
55 }
56 
58  SymbolTableCollection &symbolTableCollection) {
59  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
60  op, meshSymbol);
61 }
62 
63 // Get the corresponding mesh op using the standard attribute nomenclature.
64 template <typename Op>
65 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
66  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
67 }
68 
69 template <>
71 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
72  return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
73  symbolTableCollection);
74 }
75 
76 // Get the number of processes that participate in each group
77 // induced by `meshAxes`.
78 template <typename MeshAxesRange, typename MeshShapeRange>
79 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
80  MeshShapeRange &&meshShape) {
81  int64_t res = 1;
82 
83  for (MeshAxis axis : meshAxes) {
84  auto axisSize = *(std::begin(meshShape) + axis);
85  if (ShapedType::isDynamic(axisSize)) {
86  return ShapedType::kDynamic;
87  }
88  res *= axisSize;
89  }
90 
91  return res;
92 }
93 
94 template <typename MeshAxesRange>
95 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
96  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
97  mesh.getShape());
98 }
99 
100 // Get the size of a sharded dimension.
101 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
102  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
103  return ShapedType::kDynamic;
104 
105  assert(dimSize % shardCount == 0);
106  return ceilDiv(dimSize, shardCount);
107 }
108 
109 // Get the size of an unsharded dimension.
110 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
111  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
112  return ShapedType::kDynamic;
113 
114  return dimSize * shardCount;
115 }
116 
117 // Return the sharded shape `shape` according ot sharding `sharding`.
118 // The shape for the tensor on each device in the mesh.
119 // Example:
120 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
121 // result in a shape for each shard of ?x2x?.
122 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
123  MeshShardingAttr sharding);
124 
125 // If ranked tensor type return its sharded counterpart.
126 //
127 // If not ranked tensor type return `type`.
128 // `sharding` in that case must be null.
129 Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
130 
131 } // namespace mesh
132 } // namespace mlir
133 
134 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
A symbol reference with a reference path containing a single element.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:79
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:110
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:57
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:101
bool isFullReplication(MeshShardingAttr attr)
Definition: MeshOps.h:53
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:42
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:71
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:162
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:171
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:47
Include the generated interface declarations.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr