MLIR 22.0.0git
ShardOps.h
Go to the documentation of this file.
1//===- ShardOps.h - Shard Dialect Operations --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H
10#define MLIR_DIALECT_SHARD_IR_SHARDOPS_H
11
17#include "mlir/IR/SymbolTable.h"
21#include "llvm/Support/MathExtras.h"
22
23namespace mlir {
24namespace shard {
25
26using GridAxis = int16_t;
30
31} // namespace shard
32} // namespace mlir
33
34#include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
35
36#define GET_ATTRDEF_CLASSES
37#include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
38
39namespace mlir {
40namespace shard {
41
42class Sharding {
43private:
46 SmallVector<int64_t> static_halo_sizes;
47 SmallVector<int64_t> static_sharded_dims_offsets;
48 SmallVector<Value> dynamic_halo_sizes;
49 SmallVector<Value> dynamic_sharded_dims_offsets;
50
51public:
52 Sharding(::mlir::FlatSymbolRefAttr grid_ = nullptr);
55 ArrayRef<GridAxesAttr> split_axes_,
56 ArrayRef<int64_t> static_halo_sizes_ = {},
57 ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
58 ArrayRef<Value> dynamic_halo_sizes_ = {},
59 ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
60 ::mlir::FlatSymbolRefAttr getGridAttr() const { return grid; }
61 ::llvm::StringRef getGrid() const { return grid ? grid.getValue() : ""; }
62 ArrayRef<GridAxesAttr> getSplitAxes() const { return split_axes; }
63 ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
65 return static_sharded_dims_offsets;
66 }
67 ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
69 return dynamic_sharded_dims_offsets;
70 }
71 operator bool() const { return (!grid) == false; }
72 bool operator==(Value rhs) const;
73 bool operator!=(Value rhs) const;
74 bool operator==(const Sharding &rhs) const;
75 bool operator!=(const Sharding &rhs) const;
76 bool equalSplitAxes(const Sharding &rhs) const;
77 bool equalHaloAndShardSizes(const Sharding &rhs) const;
78 bool equalHaloSizes(const Sharding &rhs) const;
79 bool equalShardSizes(const Sharding &rhs) const;
80};
81
82} // namespace shard
83} // namespace mlir
84
85#define GET_TYPEDEF_CLASSES
86#include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
87
88#define GET_OP_CLASSES
89#include "mlir/Dialect/Shard/IR/ShardOps.h.inc"
90
91namespace mlir {
92namespace shard {
93
94inline bool isReductionLoop(utils::IteratorType iType) {
95 return iType == utils::IteratorType::reduction;
96}
97
98// Remove empty subarrays of `array` until a minimum lengh of one is reached.
99template <typename T>
101 while (array.size() > 1 && array.back().empty())
102 array.pop_back();
103}
104
105// Is the same tensor replicated on all processes.
106inline bool isFullReplication(Sharding sharding) {
107 return llvm::all_of(sharding.getSplitAxes(), [](GridAxesAttr axes) {
108 return axes.asArrayRef().empty();
109 });
110}
111
112inline shard::GridOp
114 SymbolTableCollection &symbolTableCollection) {
115 if (!gridSymbol)
116 return nullptr;
117 return symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
118 op, gridSymbol);
119}
120
121inline shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol,
122 SymbolTableCollection &symbolTableCollection) {
123 shard::GridOp gridOp = getGridOrNull(op, gridSymbol, symbolTableCollection);
124 assert(gridOp);
125 return gridOp;
126}
127
128// Get the corresponding grid op using the standard attribute nomenclature.
129template <typename Op>
130shard::GridOp getGrid(Op op, SymbolTableCollection &symbolTableCollection) {
131 return getGrid(op.getOperation(), op.getGridAttr(), symbolTableCollection);
132}
133
134template <>
135inline shard::GridOp
136getGrid<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
137 return getGrid(
138 op.getOperation(),
139 cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr(),
140 symbolTableCollection);
141}
142
143// Get the number of processes that participate in each group
144// induced by `gridAxes`.
145template <typename GridAxesRange, typename GridShapeRange>
146int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes,
147 GridShapeRange &&gridShape) {
148 int64_t res = 1;
149
150 for (GridAxis axis : gridAxes) {
151 auto axisSize = *(std::begin(gridShape) + axis);
152 if (ShapedType::isDynamic(axisSize)) {
153 return ShapedType::kDynamic;
154 }
155 res *= axisSize;
156 }
157
158 return res;
159}
160
161template <typename GridAxesRange>
162int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridOp grid) {
163 return collectiveProcessGroupSize(std::forward<GridAxesRange>(gridAxes),
164 grid.getShape());
165}
166
167// Get the size of a sharded dimension.
168inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
169 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
170 return ShapedType::kDynamic;
171
172 assert(dimSize % shardCount == 0);
173 return dimSize / shardCount;
174}
175
176// Get the size of an unsharded dimension.
177inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
178 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
179 return ShapedType::kDynamic;
180
181 return dimSize * shardCount;
182}
183
184// Return the sharded shape `shape` according ot sharding `sharding`.
185// The shape for the tensor on each device in the grid.
186// Example:
187// On a 2x4x? grid with split axes = [[0], [1], [2]] the shape ?x5x1 would
188// result in a shape for each shard of ?x2x?.
189ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding);
190
191// If ranked tensor type return its sharded counterpart.
192//
193// If not ranked tensor type return `type`.
194// `sharding` in that case must be null.
195Type shardType(Type type, GridOp grid, Sharding sharding);
196
197// Insert shard op if there is not one that already has the same sharding.
198// Use newShardOp if it is not null. Otherwise create a new one.
199// May insert resharding if required.
200// Potentially updates newShardOp.
202 OpBuilder &builder);
203void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand,
204 OpBuilder &builder);
205
206/// Converts a vector of OpFoldResults (ints) into vector of Values of the
207/// provided type.
210 ValueRange dynamics, Type type = Type());
211} // namespace shard
212} // namespace mlir
213
214#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition ShardOps.h:68
bool operator!=(Value rhs) const
Definition ShardOps.cpp:744
bool equalShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:712
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition ShardOps.cpp:752
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:770
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:689
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:63
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:60
::llvm::StringRef getGrid() const
Definition ShardOps.h:61
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:708
bool operator==(Value rhs) const
Definition ShardOps.cpp:740
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:67
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:64
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:62
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:728
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
shard::GridOp getGrid< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:136
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition ShardOps.cpp:338
DenseI16ArrayAttr GridAxesAttr
Definition ShardOps.h:27
void removeTrailingEmptySubArray(SmallVector< SmallVector< T > > &array)
Definition ShardOps.h:100
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:113
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:106
int16_t GridAxis
Definition ShardOps.h:26
DenseI64ArrayAttr HaloSizePairAttr
Definition ShardOps.h:29
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition ShardOps.cpp:352
bool isReductionLoop(utils::IteratorType iType)
Definition ShardOps.h:94
DenseI64ArrayAttr ShardShapeAttr
Definition ShardOps.h:28
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:168
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:146
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:177
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr