MLIR 23.0.0git
Simplify.cpp
Go to the documentation of this file.
1//===- Simplify.cpp - Shard Simplify ----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TransformsDetail.h"
17#include "mlir/IR/SymbolTable.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
21#include <numeric>
22#include <type_traits>
23
24namespace mlir {
25namespace shard {
26
27#define GEN_PASS_DEF_SHARDSIMPLIFY
28#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
29
30namespace {
31
32template <typename LhsOp, typename RhsOp>
33static bool haveSameGridAndGridAxes(LhsOp lhsOp, RhsOp rhsOp) {
34 return lhsOp.getGrid() == rhsOp.getGrid() &&
35 lhsOp.getGridAxes() == rhsOp.getGridAxes();
36}
37
38static bool isAllGatherAllSliceFoldable(AllGatherOp gatherOp,
39 AllSliceOp sliceOp) {
40 return haveSameGridAndGridAxes(gatherOp, sliceOp) &&
41 gatherOp.getGatherAxis() == sliceOp.getSliceAxis();
42}
43
44template <typename OuterOp, typename InnerOp>
45static LogicalResult foldAllGatherAllSlice(OuterOp outerOp, InnerOp innerOp,
46 PatternRewriter &rewriter) {
47 if (!innerOp)
48 return failure();
49
50 AllGatherOp gatherOp;
51 AllSliceOp sliceOp;
52 if constexpr (std::is_same_v<OuterOp, AllGatherOp>) {
53 gatherOp = outerOp;
54 sliceOp = innerOp;
55 } else {
56 gatherOp = innerOp;
57 sliceOp = outerOp;
58 }
59
60 if (!isAllGatherAllSliceFoldable(gatherOp, sliceOp))
61 return failure();
62
63 rewriter.replaceOp(outerOp, innerOp.getInput());
64 return success();
65}
66
67// This folding can not be done with an operation's fold method or
68// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
69// symbol tables.
70// We can't use DialectFoldInterface since the cache may be invalidated by some
71// pass changing the referenced GridOp ops.
72struct GridShapeFolder
73 : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
74 using OpRewritePatternWithSymbolTableCollection::
75 OpRewritePatternWithSymbolTableCollection;
76 LogicalResult matchAndRewrite(GridShapeOp op,
77 PatternRewriter &rewriter) const override {
78 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
79 GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
80 op.getOperation(), op.getGridAttr());
81 if (!grid) {
82 return failure();
83 }
84 ArrayRef<GridAxis> opGridAxes = op.getAxes();
85 SmallVector<GridAxis> opAxesIota;
86 if (opGridAxes.empty()) {
87 opAxesIota.resize(grid.getRank());
88 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
89 opGridAxes = opAxesIota;
90 }
91 if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) {
92 return ShapedType::isDynamic(grid.getShape()[axis]);
93 })) {
94 // All grid dimensions are dynamic. Nothing to fold.
95 return failure();
96 }
97
98 SmallVector<Value> newResults(op->getResults().size());
99 SmallVector<GridAxis> newShapeOpGridAxes;
100 SmallVector<size_t> newToOldResultsIndexMap;
101
102 for (size_t i = 0; i < opGridAxes.size(); ++i) {
103 auto gridAxisSize = grid.getShape()[opGridAxes[i]];
104 if (ShapedType::isDynamic(gridAxisSize)) {
105 newToOldResultsIndexMap.push_back(i);
106 newShapeOpGridAxes.push_back(opGridAxes[i]);
107 } else {
108 // Fold static grid axes.
109 newResults[i] = arith::ConstantOp::create(
110 builder, builder.getIndexAttr(gridAxisSize));
112 }
113
114 // Leave only the dynamic grid axes to be queried.
115 if (!newShapeOpGridAxes.empty()) {
116 GridShapeOp newShapeOp =
117 GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
118 for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
119 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
120 }
121 }
122 rewriter.replaceOp(op, newResults);
123
124 return success();
126};
127
128// Simplify AllSliceOp(AllReduceOp) -> ReduceScatterOp when both ops share the
129// same grid and grid_axes.
130//
131// AllReduceOp performs an element-wise reduction across all devices in the
132// group, and AllSliceOp then slices (scatters) the result along a tensor
133// dimension. This is exactly what ReduceScatterOp does in a single collective.
134//
135// With a ring algorithm over N ranks and M elements:
136// AllReduce: 2*(N-1) steps of M/N each => ~2M total data transferred
137// AllSlice: local slice, no communication
138// ReduceScatter: (N-1) steps of M/N each => ~M total data transferred
139// So this fusion roughly halves the communication volume.
140//
141// Memory-wise, AllReduce produces a full-sized M-element result that the
142// subsequent AllSlice must keep alive until the slice is taken. ReduceScatter
143// only materializes the M/N-element local slice, reducing peak memory by
144// a factor of N.
145struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
147
148 LogicalResult matchAndRewrite(AllSliceOp sliceOp,
149 PatternRewriter &rewriter) const override {
150 // Check if the input to AllSliceOp is produced by an AllReduceOp.
151 auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
152 if (!reduceOp || !reduceOp->hasOneUse())
153 return failure();
154
155 // Both ops must operate on the same grid and grid axes.
156 if (!haveSameGridAndGridAxes(reduceOp, sliceOp))
157 return failure();
158
159 // Replace with a single ReduceScatterOp.
160 rewriter.replaceOpWithNewOp<ReduceScatterOp>(
161 sliceOp, sliceOp.getResult().getType(), sliceOp.getGridAttr(),
162 sliceOp.getGridAxesAttr(), reduceOp.getInput(),
163 reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
164
165 return success();
166 }
167};
168
169// Simplify all_slice(all_gather(x)) and all_gather(all_slice(x)) to x when
170// both ops share grid, grid_axes, and axis.
171template <typename OuterOp, typename InnerOp>
172struct AllGatherAllSliceSimplification : OpRewritePattern<OuterOp> {
173 using OpRewritePattern<OuterOp>::OpRewritePattern;
174
175 LogicalResult matchAndRewrite(OuterOp outerOp,
176 PatternRewriter &rewriter) const override {
177 auto innerOp = outerOp.getInput().template getDefiningOp<InnerOp>();
178 return foldAllGatherAllSlice(outerOp, innerOp, rewriter);
179 }
180};
181
182} // namespace
183
185 SymbolTableCollection &symbolTableCollection) {
187 patterns, ReductionKind::Sum);
189 patterns, ReductionKind::Sum);
190
192 patterns, ReductionKind::Min);
194 patterns, ReductionKind::Min);
196 patterns, ReductionKind::Min);
197
199 patterns, ReductionKind::Max);
201 patterns, ReductionKind::Max);
203 patterns, ReductionKind::Max);
204
205 patterns.add<AllReduceAllSliceSimplification,
206 AllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
207 AllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
208 patterns.getContext());
209
210 // TODO: add simplify patterns for all-gather and other collectives.
211
212 populateFoldingPatterns(patterns, symbolTableCollection);
213}
214
216 SymbolTableCollection &symbolTableCollection) {
217 patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
218}
219
220namespace {
221
222struct ShardSimplifyPass : public impl::ShardSimplifyBase<ShardSimplifyPass> {
223
224 void runOnOperation() override {
225 RewritePatternSet patterns(&getContext());
226 SymbolTableCollection symbolTableCollection;
227 populateSimplifyPatterns(patterns, symbolTableCollection);
228 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
229 signalPassFailure();
230 }
231};
232
233} // namespace
234
235} // namespace shard
236} // namespace mlir
return success()
b getContext())
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns, ReductionKind reduction)
Definition Simplify.h:40
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition Simplify.cpp:215
int16_t GridAxis
Definition ShardOps.h:27
void populateSimplifyPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition Simplify.cpp:184
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...