MLIR 23.0.0git
Simplify.cpp
Go to the documentation of this file.
1//===- Simplify.cpp - Shard Simplify ----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TransformsDetail.h"
17#include "mlir/IR/SymbolTable.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
21#include <numeric>
22
23namespace mlir {
24namespace shard {
25
26#define GEN_PASS_DEF_SHARDSIMPLIFY
27#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
28
29namespace {
30
31// This folding can not be done with an operation's fold method or
32// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
33// symbol tables.
34// We can't use DialectFoldInterface since the cache may be invalidated by some
35// pass changing the referenced GridOp ops.
36struct GridShapeFolder
37 : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
38 using OpRewritePatternWithSymbolTableCollection::
39 OpRewritePatternWithSymbolTableCollection;
40 LogicalResult matchAndRewrite(GridShapeOp op,
41 PatternRewriter &rewriter) const override {
42 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
43 GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
44 op.getOperation(), op.getGridAttr());
45 if (!grid) {
46 return failure();
47 }
48 ArrayRef<GridAxis> opGridAxes = op.getAxes();
49 SmallVector<GridAxis> opAxesIota;
50 if (opGridAxes.empty()) {
51 opAxesIota.resize(grid.getRank());
52 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
53 opGridAxes = opAxesIota;
54 }
55 if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) {
56 return ShapedType::isDynamic(grid.getShape()[axis]);
57 })) {
58 // All grid dimensions are dynamic. Nothing to fold.
59 return failure();
60 }
61
62 SmallVector<Value> newResults(op->getResults().size());
63 SmallVector<GridAxis> newShapeOpGridAxes;
64 SmallVector<size_t> newToOldResultsIndexMap;
65
66 for (size_t i = 0; i < opGridAxes.size(); ++i) {
67 auto gridAxisSize = grid.getShape()[opGridAxes[i]];
68 if (ShapedType::isDynamic(gridAxisSize)) {
69 newToOldResultsIndexMap.push_back(i);
70 newShapeOpGridAxes.push_back(opGridAxes[i]);
71 } else {
72 // Fold static grid axes.
73 newResults[i] = arith::ConstantOp::create(
74 builder, builder.getIndexAttr(gridAxisSize));
75 }
76 }
77
78 // Leave only the dynamic grid axes to be queried.
79 if (!newShapeOpGridAxes.empty()) {
80 GridShapeOp newShapeOp =
81 GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
82 for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
83 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
84 }
85 }
86 rewriter.replaceOp(op, newResults);
87
88 return success();
89 }
90};
91
92// Simplify AllSliceOp(AllReduceOp) -> ReduceScatterOp when both ops share the
93// same grid and grid_axes.
94//
95// AllReduceOp performs an element-wise reduction across all devices in the
96// group, and AllSliceOp then slices (scatters) the result along a tensor
97// dimension. This is exactly what ReduceScatterOp does in a single collective.
98//
99// With a ring algorithm over N ranks and M elements:
100// AllReduce: 2*(N-1) steps of M/N each => ~2M total data transferred
101// AllSlice: local slice, no communication
102// ReduceScatter: (N-1) steps of M/N each => ~M total data transferred
103// So this fusion roughly halves the communication volume.
105// Memory-wise, AllReduce produces a full-sized M-element result that the
106// subsequent AllSlice must keep alive until the slice is taken. ReduceScatter
107// only materializes the M/N-element local slice, reducing peak memory by
108// a factor of N.
109struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
112 LogicalResult matchAndRewrite(AllSliceOp sliceOp,
113 PatternRewriter &rewriter) const override {
114 // Check if the input to AllSliceOp is produced by an AllReduceOp.
115 auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
116 if (!reduceOp || !reduceOp->hasOneUse())
117 return failure();
118
119 // Both ops must operate on the same grid and grid axes.
120 if (reduceOp.getGrid() != sliceOp.getGrid() ||
121 reduceOp.getGridAxes() != sliceOp.getGridAxes())
122 return failure();
123
124 // Replace with a single ReduceScatterOp.
125 rewriter.replaceOpWithNewOp<ReduceScatterOp>(
126 sliceOp, sliceOp.getResult().getType(), sliceOp.getGridAttr(),
127 sliceOp.getGridAxesAttr(), reduceOp.getInput(),
128 reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
129
130 return success();
131 }
132};
134} // namespace
135
137 SymbolTableCollection &symbolTableCollection) {
142
144 patterns, ReductionKind::Min);
146 patterns, ReductionKind::Min);
148 patterns, ReductionKind::Min);
149
151 patterns, ReductionKind::Max);
153 patterns, ReductionKind::Max);
155 patterns, ReductionKind::Max);
156
157 patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
158
159 // TODO: add simplify patterns for all-gather and other collectives.
160
161 populateFoldingPatterns(patterns, symbolTableCollection);
162}
163
165 SymbolTableCollection &symbolTableCollection) {
166 patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
167}
168
169namespace {
170
171struct ShardSimplifyPass : public impl::ShardSimplifyBase<ShardSimplifyPass> {
172
173 void runOnOperation() override {
174 RewritePatternSet patterns(&getContext());
175 SymbolTableCollection symbolTableCollection;
176 populateSimplifyPatterns(patterns, symbolTableCollection);
177 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
178 signalPassFailure();
179 }
180};
181
182} // namespace
183
184} // namespace shard
185} // namespace mlir
return success()
b getContext())
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns, ReductionKind reduction)
Definition Simplify.h:40
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition Simplify.cpp:164
int16_t GridAxis
Definition ShardOps.h:27
void populateSimplifyPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition Simplify.cpp:136
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...