19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
27#define GEN_PASS_DEF_SHARDSIMPLIFY
28#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
32template <
typename LhsOp,
typename RhsOp>
33static bool haveSameGridAndGridAxes(LhsOp lhsOp, RhsOp rhsOp) {
34 return lhsOp.getGrid() == rhsOp.getGrid() &&
35 lhsOp.getGridAxes() == rhsOp.getGridAxes();
38static bool isAllGatherAllSliceFoldable(AllGatherOp gatherOp,
40 return haveSameGridAndGridAxes(gatherOp, sliceOp) &&
41 gatherOp.getGatherAxis() == sliceOp.getSliceAxis();
44template <
typename OuterOp,
typename InnerOp>
45static LogicalResult foldAllGatherAllSlice(OuterOp outerOp, InnerOp innerOp,
46 PatternRewriter &rewriter) {
52 if constexpr (std::is_same_v<OuterOp, AllGatherOp>) {
60 if (!isAllGatherAllSliceFoldable(gatherOp, sliceOp))
63 rewriter.replaceOp(outerOp, innerOp.getInput());
73 : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
74 using OpRewritePatternWithSymbolTableCollection::
75 OpRewritePatternWithSymbolTableCollection;
76 LogicalResult matchAndRewrite(GridShapeOp op,
77 PatternRewriter &rewriter)
const override {
78 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
79 GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
80 op.getOperation(), op.getGridAttr());
84 ArrayRef<GridAxis> opGridAxes = op.getAxes();
85 SmallVector<GridAxis> opAxesIota;
86 if (opGridAxes.empty()) {
87 opAxesIota.resize(grid.getRank());
88 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
89 opGridAxes = opAxesIota;
91 if (llvm::all_of(opGridAxes, [&grid](
GridAxis axis) {
92 return ShapedType::isDynamic(grid.getShape()[axis]);
98 SmallVector<Value> newResults(op->getResults().size());
99 SmallVector<GridAxis> newShapeOpGridAxes;
100 SmallVector<size_t> newToOldResultsIndexMap;
102 for (
size_t i = 0; i < opGridAxes.size(); ++i) {
103 auto gridAxisSize = grid.getShape()[opGridAxes[i]];
104 if (ShapedType::isDynamic(gridAxisSize)) {
105 newToOldResultsIndexMap.push_back(i);
106 newShapeOpGridAxes.push_back(opGridAxes[i]);
109 newResults[i] = arith::ConstantOp::create(
110 builder, builder.getIndexAttr(gridAxisSize));
115 if (!newShapeOpGridAxes.empty()) {
116 GridShapeOp newShapeOp =
117 GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
118 for (
size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
119 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
122 rewriter.replaceOp(op, newResults);
148 LogicalResult matchAndRewrite(AllSliceOp sliceOp,
151 auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
152 if (!reduceOp || !reduceOp->hasOneUse())
156 if (!haveSameGridAndGridAxes(reduceOp, sliceOp))
161 sliceOp, sliceOp.getResult().
getType(), sliceOp.getGridAttr(),
162 sliceOp.getGridAxesAttr(), reduceOp.getInput(),
163 reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
171template <
typename OuterOp,
typename InnerOp>
172struct AllGatherAllSliceSimplification : OpRewritePattern<OuterOp> {
173 using OpRewritePattern<OuterOp>::OpRewritePattern;
175 LogicalResult matchAndRewrite(OuterOp outerOp,
176 PatternRewriter &rewriter)
const override {
177 auto innerOp = outerOp.getInput().template getDefiningOp<InnerOp>();
178 return foldAllGatherAllSlice(outerOp, innerOp, rewriter);
187 patterns, ReductionKind::Sum);
189 patterns, ReductionKind::Sum);
192 patterns, ReductionKind::Min);
194 patterns, ReductionKind::Min);
196 patterns, ReductionKind::Min);
199 patterns, ReductionKind::Max);
201 patterns, ReductionKind::Max);
203 patterns, ReductionKind::Max);
205 patterns.
add<AllReduceAllSliceSimplification,
206 AllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
207 AllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
217 patterns.
add<GridShapeFolder>(symbolTableCollection, patterns.
getContext());
222struct ShardSimplifyPass :
public impl::ShardSimplifyBase<ShardSimplifyPass> {
224 void runOnOperation()
override {
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns, ReductionKind reduction)
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplifyPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...