MLIR 23.0.0git
ShardOps.h
Go to the documentation of this file.
1//===- ShardOps.h - Shard Dialect Operations --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H
10#define MLIR_DIALECT_SHARD_IR_SHARDOPS_H
11
15#include "mlir/IR/Diagnostics.h"
18#include "mlir/IR/SymbolTable.h"
22#include "llvm/Support/MathExtras.h"
23
24namespace mlir {
25namespace shard {
26
27using GridAxis = int16_t;
31
32} // namespace shard
33} // namespace mlir
34
35#include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
36
37#define GET_ATTRDEF_CLASSES
38#include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
39
40namespace mlir {
41namespace shard {
42
43class Sharding {
44private:
47 SmallVector<int64_t> static_halo_sizes;
48 SmallVector<int64_t> static_sharded_dims_offsets;
49 SmallVector<Value> dynamic_halo_sizes;
50 SmallVector<Value> dynamic_sharded_dims_offsets;
51
52public:
53 Sharding(::mlir::FlatSymbolRefAttr grid_ = nullptr);
56 ArrayRef<GridAxesAttr> split_axes_,
57 ArrayRef<int64_t> static_halo_sizes_ = {},
58 ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
59 ArrayRef<Value> dynamic_halo_sizes_ = {},
60 ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
61 ::mlir::FlatSymbolRefAttr getGridAttr() const { return grid; }
62 ::llvm::StringRef getGrid() const { return grid ? grid.getValue() : ""; }
63 ArrayRef<GridAxesAttr> getSplitAxes() const { return split_axes; }
64 ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
66 return static_sharded_dims_offsets;
67 }
68 ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
70 return dynamic_sharded_dims_offsets;
71 }
72 operator bool() const { return (!grid) == false; }
73 bool operator==(Value rhs) const;
74 bool operator!=(Value rhs) const;
75 bool operator==(const Sharding &rhs) const;
76 bool operator!=(const Sharding &rhs) const;
77 bool equalSplitAxes(const Sharding &rhs) const;
78 bool equalHaloAndShardSizes(const Sharding &rhs) const;
79 bool equalHaloSizes(const Sharding &rhs) const;
80 bool equalShardSizes(const Sharding &rhs) const;
81};
82
83llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Sharding &sharding);
84
85inline Diagnostic &operator<<(Diagnostic &diag, const Sharding &sharding) {
86 std::string str;
87 llvm::raw_string_ostream os(str);
88 os << sharding;
89 return diag << os.str();
90}
91
92} // namespace shard
93} // namespace mlir
94
95#define GET_TYPEDEF_CLASSES
96#include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
97
98#define GET_OP_CLASSES
99#include "mlir/Dialect/Shard/IR/ShardOps.h.inc"
100
101namespace mlir {
102namespace shard {
103
104inline bool isReductionLoop(utils::IteratorType iType) {
105 return iType == utils::IteratorType::reduction;
106}
107
108// Remove empty subarrays of `array` until a minimum lengh of one is reached.
109template <typename T>
111 while (array.size() > 1 && array.back().empty())
112 array.pop_back();
113}
114
115// Is the same tensor replicated on all processes.
116inline bool isFullReplication(Sharding sharding) {
117 return llvm::all_of(sharding.getSplitAxes(), [](GridAxesAttr axes) {
118 return axes.asArrayRef().empty();
119 });
120}
121
122inline shard::GridOp
124 SymbolTableCollection &symbolTableCollection) {
125 if (!gridSymbol)
126 return nullptr;
127 return symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
128 op, gridSymbol);
129}
130
131inline shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol,
132 SymbolTableCollection &symbolTableCollection) {
133 shard::GridOp gridOp = getGridOrNull(op, gridSymbol, symbolTableCollection);
134 assert(gridOp);
135 return gridOp;
136}
137
138// Get the corresponding grid op using the standard attribute nomenclature.
139template <typename Op>
140shard::GridOp getGrid(Op op, SymbolTableCollection &symbolTableCollection) {
141 return getGrid(op.getOperation(), op.getGridAttr(), symbolTableCollection);
142}
143
144template <>
145inline shard::GridOp
146getGrid<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
147 return getGrid(
148 op.getOperation(),
149 cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr(),
150 symbolTableCollection);
151}
152
153// Get the number of processes that participate in each group
154// induced by `gridAxes`.
155template <typename GridAxesRange, typename GridShapeRange>
156int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes,
157 GridShapeRange &&gridShape) {
158 int64_t res = 1;
159
160 for (GridAxis axis : gridAxes) {
161 auto axisSize = *(std::begin(gridShape) + axis);
162 if (ShapedType::isDynamic(axisSize)) {
163 return ShapedType::kDynamic;
164 }
165 res *= axisSize;
166 }
167
168 return res;
169}
170
171template <typename GridAxesRange>
172int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridOp grid) {
173 return collectiveProcessGroupSize(std::forward<GridAxesRange>(gridAxes),
174 grid.getShape());
175}
176
177// Get the size of a sharded dimension.
178inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
179 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
180 return ShapedType::kDynamic;
181
182 assert(dimSize % shardCount == 0);
183 return dimSize / shardCount;
184}
185
186// Get the size of an unsharded dimension.
187inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
188 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
189 return ShapedType::kDynamic;
190
191 return dimSize * shardCount;
192}
193
194// Return the per-device sharded type for `shape` based on `sharding`.
195// This is the tensor shape on each grid partition.
196// Example:
197// On a 2x4x? grid with split axes = [[0], [1], [2]] the shape ?x5x1 would
198// result in a shape for each shard of ?x2x?.
199ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding);
200
201// If ranked tensor type return its sharded counterpart.
202//
203// If not ranked tensor type return `type`.
204// `sharding` in that case must be null.
205Type shardType(Type type, GridOp grid, Sharding sharding);
206
207// Insert shard op if there is not one that already has the same sharding.
208// Use newShardOp if it is not null. Otherwise create a new one.
209// May insert resharding if required.
210// Potentially updates newShardOp.
212 OpBuilder &builder);
213void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand,
214 OpBuilder &builder);
215
216/// Converts a vector of OpFoldResults (ints) into vector of Values of the
217/// provided type.
220 ValueRange dynamics, Type type = Type());
221} // namespace shard
222} // namespace mlir
223
224#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition ShardOps.h:69
bool operator!=(Value rhs) const
Definition ShardOps.cpp:743
bool equalShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:711
Sharding(::mlir::FlatSymbolRefAttr grid_=nullptr)
Definition ShardOps.cpp:774
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition ShardOps.cpp:792
bool equalSplitAxes(const Sharding &rhs) const
Definition ShardOps.cpp:688
ArrayRef< int64_t > getStaticHaloSizes() const
Definition ShardOps.h:64
::mlir::FlatSymbolRefAttr getGridAttr() const
Definition ShardOps.h:61
::llvm::StringRef getGrid() const
Definition ShardOps.h:62
bool equalHaloAndShardSizes(const Sharding &rhs) const
Definition ShardOps.cpp:707
bool operator==(Value rhs) const
Definition ShardOps.cpp:739
ArrayRef< Value > getDynamicHaloSizes() const
Definition ShardOps.h:68
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition ShardOps.h:65
ArrayRef< GridAxesAttr > getSplitAxes() const
Definition ShardOps.h:63
bool equalHaloSizes(const Sharding &rhs) const
Definition ShardOps.cpp:727
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const Sharding &sharding)
Definition ShardOps.cpp:751
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:281
shard::GridOp getGrid< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:146
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
Definition ShardOps.cpp:338
DenseI16ArrayAttr GridAxesAttr
Definition ShardOps.h:28
void removeTrailingEmptySubArray(SmallVector< SmallVector< T > > &array)
Definition ShardOps.h:110
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:123
bool isFullReplication(Sharding sharding)
Definition ShardOps.h:116
int16_t GridAxis
Definition ShardOps.h:27
DenseI64ArrayAttr HaloSizePairAttr
Definition ShardOps.h:30
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
Definition ShardOps.cpp:352
bool isReductionLoop(utils::IteratorType iType)
Definition ShardOps.h:104
DenseI64ArrayAttr ShardShapeAttr
Definition ShardOps.h:29
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:178
Type shardType(Type type, GridOp grid, Sharding sharding)
Definition ShardOps.cpp:291
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:156
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:131
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition ShardOps.h:187
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr