MLIR  22.0.0git
ShardOps.h
Go to the documentation of this file.
1 //===- ShardOps.h - Shard Dialect Operations --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H
10 #define MLIR_DIALECT_SHARD_IR_SHARDOPS_H
11 
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/SymbolTable.h"
21 #include "llvm/Support/MathExtras.h"
22 
23 namespace mlir {
24 namespace shard {
25 
26 using GridAxis = int16_t;
30 
31 } // namespace shard
32 } // namespace mlir
33 
34 #include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
35 
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
38 
39 namespace mlir {
40 namespace shard {
41 
42 class Sharding {
43 private:
45  SmallVector<GridAxesAttr> split_axes;
46  SmallVector<int64_t> static_halo_sizes;
47  SmallVector<int64_t> static_sharded_dims_offsets;
48  SmallVector<Value> dynamic_halo_sizes;
49  SmallVector<Value> dynamic_sharded_dims_offsets;
50 
51 public:
52  Sharding(::mlir::FlatSymbolRefAttr grid_ = nullptr);
53  Sharding(Value rhs);
54  static Sharding get(::mlir::FlatSymbolRefAttr grid_,
55  ArrayRef<GridAxesAttr> split_axes_,
56  ArrayRef<int64_t> static_halo_sizes_ = {},
57  ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
58  ArrayRef<Value> dynamic_halo_sizes_ = {},
59  ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
60  ::mlir::FlatSymbolRefAttr getGridAttr() const { return grid; }
61  ::llvm::StringRef getGrid() const { return grid ? grid.getValue() : ""; }
62  ArrayRef<GridAxesAttr> getSplitAxes() const { return split_axes; }
63  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
65  return static_sharded_dims_offsets;
66  }
67  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
69  return dynamic_sharded_dims_offsets;
70  }
71  operator bool() const { return (!grid) == false; }
72  bool operator==(Value rhs) const;
73  bool operator!=(Value rhs) const;
74  bool operator==(const Sharding &rhs) const;
75  bool operator!=(const Sharding &rhs) const;
76  bool equalSplitAxes(const Sharding &rhs) const;
77  bool equalHaloAndShardSizes(const Sharding &rhs) const;
78  bool equalHaloSizes(const Sharding &rhs) const;
79  bool equalShardSizes(const Sharding &rhs) const;
80 };
81 
82 } // namespace shard
83 } // namespace mlir
84 
85 #define GET_TYPEDEF_CLASSES
86 #include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
87 
88 #define GET_OP_CLASSES
89 #include "mlir/Dialect/Shard/IR/ShardOps.h.inc"
90 
91 namespace mlir {
92 namespace shard {
93 
94 inline bool isReductionLoop(utils::IteratorType iType) {
95  return iType == utils::IteratorType::reduction;
96 }
97 
98 // Remove empty subarrays of `array` until a minimum lengh of one is reached.
99 template <typename T>
101  while (array.size() > 1 && array.back().empty())
102  array.pop_back();
103 }
104 
105 // Is the same tensor replicated on all processes.
106 inline bool isFullReplication(Sharding sharding) {
107  return llvm::all_of(sharding.getSplitAxes(), [](GridAxesAttr axes) {
108  return axes.asArrayRef().empty();
109  });
110 }
111 
112 inline shard::GridOp
114  SymbolTableCollection &symbolTableCollection) {
115  if (!gridSymbol)
116  return nullptr;
117  return symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
118  op, gridSymbol);
119 }
120 
122  SymbolTableCollection &symbolTableCollection) {
123  shard::GridOp gridOp = getGridOrNull(op, gridSymbol, symbolTableCollection);
124  assert(gridOp);
125  return gridOp;
126 }
127 
128 // Get the corresponding grid op using the standard attribute nomenclature.
129 template <typename Op>
130 shard::GridOp getGrid(Op op, SymbolTableCollection &symbolTableCollection) {
131  return getGrid(op.getOperation(), op.getGridAttr(), symbolTableCollection);
132 }
133 
134 template <>
136 getGrid<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
137  return getGrid(
138  op.getOperation(),
139  cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr(),
140  symbolTableCollection);
141 }
142 
143 // Get the number of processes that participate in each group
144 // induced by `gridAxes`.
145 template <typename GridAxesRange, typename GridShapeRange>
146 int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes,
147  GridShapeRange &&gridShape) {
148  int64_t res = 1;
149 
150  for (GridAxis axis : gridAxes) {
151  auto axisSize = *(std::begin(gridShape) + axis);
152  if (ShapedType::isDynamic(axisSize)) {
153  return ShapedType::kDynamic;
154  }
155  res *= axisSize;
156  }
157 
158  return res;
159 }
160 
161 template <typename GridAxesRange>
162 int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridOp grid) {
163  return collectiveProcessGroupSize(std::forward<GridAxesRange>(gridAxes),
164  grid.getShape());
165 }
166 
167 // Get the size of a sharded dimension.
168 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
169  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
170  return ShapedType::kDynamic;
171 
172  assert(dimSize % shardCount == 0);
173  return dimSize / shardCount;
174 }
175 
176 // Get the size of an unsharded dimension.
177 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
178  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
179  return ShapedType::kDynamic;
180 
181  return dimSize * shardCount;
182 }
183 
184 // Return the sharded shape `shape` according ot sharding `sharding`.
185 // The shape for the tensor on each device in the grid.
186 // Example:
187 // On a 2x4x? grid with split axes = [[0], [1], [2]] the shape ?x5x1 would
188 // result in a shape for each shard of ?x2x?.
189 ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding);
190 
191 // If ranked tensor type return its sharded counterpart.
192 //
193 // If not ranked tensor type return `type`.
194 // `sharding` in that case must be null.
195 Type shardType(Type type, GridOp grid, Sharding sharding);
196 
197 // Insert shard op if there is not one that already has the same sharding.
198 // Use newShardOp if it is not null. Otherwise create a new one.
199 // May insert resharding if required.
200 // Potentially updates newShardOp.
202  OpBuilder &builder);
204  OpBuilder &builder);
205 
206 /// Converts a vector of OpFoldResults (ints) into vector of Values of the
207 /// provided type.
209  llvm::ArrayRef<int64_t> statics,
210  ValueRange dynamics, Type type = Type());
211 } // namespace shard
212 } // namespace mlir
213 
214 #endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_H
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:112
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool operator!=(Value rhs) const
Definition: ShardOps.cpp:744
bool equalShardSizes(const Sharding &rhs) const
Definition: ShardOps.cpp:712
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition: ShardOps.cpp:752
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: ShardOps.cpp:770
bool equalSplitAxes(const Sharding &rhs) const
Definition: ShardOps.cpp:689
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition: ShardOps.h:60
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: ShardOps.h:64
::llvm::StringRef getGrid() const
Definition: ShardOps.h:61
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition: ShardOps.cpp:708
bool operator==(Value rhs) const
Definition: ShardOps.cpp:740
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: ShardOps.h:63
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition: ShardOps.h:68
ArrayRef< Value > getDynamicHaloSizes() const
Definition: ShardOps.h:67
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition: ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition: ShardOps.cpp:728
shard::Sharding Sharding
shard::GridOp GridOp
int16_t GridAxis
Definition: ShardOps.h:26
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition: ShardOps.cpp:281
shard::GridOp getGrid< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:136
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition: ShardOps.cpp:338
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:113
bool isFullReplication(Sharding sharding)
Definition: ShardOps.h:106
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: ShardOps.cpp:352
bool isReductionLoop(utils::IteratorType iType)
Definition: ShardOps.h:94
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: ShardOps.h:100
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: ShardOps.h:168
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition: ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition: ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition: ShardOps.h:146
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:121
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: ShardOps.h:177
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr