19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
26#define GEN_PASS_DEF_SHARDSIMPLIFY
27#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
37 : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
38 using OpRewritePatternWithSymbolTableCollection::
39 OpRewritePatternWithSymbolTableCollection;
40 LogicalResult matchAndRewrite(GridShapeOp op,
41 PatternRewriter &rewriter)
const override {
42 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
43 GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
44 op.getOperation(), op.getGridAttr());
48 ArrayRef<GridAxis> opGridAxes = op.getAxes();
49 SmallVector<GridAxis> opAxesIota;
50 if (opGridAxes.empty()) {
51 opAxesIota.resize(grid.getRank());
52 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
53 opGridAxes = opAxesIota;
55 if (llvm::all_of(opGridAxes, [&grid](
GridAxis axis) {
56 return ShapedType::isDynamic(grid.getShape()[axis]);
62 SmallVector<Value> newResults(op->getResults().size());
63 SmallVector<GridAxis> newShapeOpGridAxes;
64 SmallVector<size_t> newToOldResultsIndexMap;
66 for (
size_t i = 0; i < opGridAxes.size(); ++i) {
67 auto gridAxisSize = grid.getShape()[opGridAxes[i]];
68 if (ShapedType::isDynamic(gridAxisSize)) {
69 newToOldResultsIndexMap.push_back(i);
70 newShapeOpGridAxes.push_back(opGridAxes[i]);
73 newResults[i] = arith::ConstantOp::create(
74 builder, builder.getIndexAttr(gridAxisSize));
79 if (!newShapeOpGridAxes.empty()) {
80 GridShapeOp newShapeOp =
81 GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
82 for (
size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
83 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
86 rewriter.replaceOp(op, newResults);
112 LogicalResult matchAndRewrite(AllSliceOp sliceOp,
115 auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
116 if (!reduceOp || !reduceOp->hasOneUse())
120 if (reduceOp.getGrid() != sliceOp.getGrid() ||
121 reduceOp.getGridAxes() != sliceOp.getGridAxes())
126 sliceOp, sliceOp.getResult().
getType(), sliceOp.getGridAttr(),
127 sliceOp.getGridAxesAttr(), reduceOp.getInput(),
128 reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
139 patterns, ReductionKind::Sum);
141 patterns, ReductionKind::Sum);
144 patterns, ReductionKind::Min);
146 patterns, ReductionKind::Min);
148 patterns, ReductionKind::Min);
151 patterns, ReductionKind::Max);
153 patterns, ReductionKind::Max);
155 patterns, ReductionKind::Max);
166 patterns.
add<GridShapeFolder>(symbolTableCollection, patterns.
getContext());
171struct ShardSimplifyPass :
public impl::ShardSimplifyBase<ShardSimplifyPass> {
173 void runOnOperation()
override {
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns, ReductionKind reduction)
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplifyPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...