MLIR  22.0.0git
Simplifications.cpp
Go to the documentation of this file.
1 //===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TransformsDetail.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include <numeric>
19 
20 namespace mlir {
21 namespace shard {
22 
24  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
25  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
26  patterns, ReductionKind::Sum);
27  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
28  patterns, ReductionKind::Sum);
29 
30  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
31  patterns, ReductionKind::Min);
32  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
33  patterns, ReductionKind::Min);
34  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
35  patterns, ReductionKind::Min);
36 
37  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
38  patterns, ReductionKind::Max);
39  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
40  patterns, ReductionKind::Max);
41  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
42  patterns, ReductionKind::Max);
43 
44  // TODO: add simplifications for all-gather and other collectives.
45 
46  populateFoldingPatterns(patterns, symbolTableCollection);
47 }
48 
49 namespace {
50 
51 // This folding can not be done with an operation's fold method or
52 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
53 // symbol tables.
54 // We can't use DialectFoldInterface since the cache may be invalidated by some
55 // pass changing the referenced GridOp ops.
56 struct GridShapeFolder
57  : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
60  LogicalResult matchAndRewrite(GridShapeOp op,
61  PatternRewriter &rewriter) const override {
62  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
63  GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
64  op.getOperation(), op.getGridAttr());
65  if (!grid) {
66  return failure();
67  }
68  ArrayRef<GridAxis> opGridAxes = op.getAxes();
69  SmallVector<GridAxis> opAxesIota;
70  if (opGridAxes.empty()) {
71  opAxesIota.resize(grid.getRank());
72  std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
73  opGridAxes = opAxesIota;
74  }
75  if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) {
76  return ShapedType::isDynamic(grid.getShape()[axis]);
77  })) {
78  // All grid dimensions are dynamic. Nothing to fold.
79  return failure();
80  }
81 
82  SmallVector<Value> newResults(op->getResults().size());
83  SmallVector<GridAxis> newShapeOpGridAxes;
84  SmallVector<size_t> newToOldResultsIndexMap;
85 
86  for (size_t i = 0; i < opGridAxes.size(); ++i) {
87  auto gridAxisSize = grid.getShape()[opGridAxes[i]];
88  if (ShapedType::isDynamic(gridAxisSize)) {
89  newToOldResultsIndexMap.push_back(i);
90  newShapeOpGridAxes.push_back(opGridAxes[i]);
91  } else {
92  // Fold static grid axes.
93  newResults[i] = arith::ConstantOp::create(
94  builder, builder.getIndexAttr(gridAxisSize));
95  }
96  }
97 
98  // Leave only the dynamic grid axes to be queried.
99  if (!newShapeOpGridAxes.empty()) {
100  GridShapeOp newShapeOp =
101  GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
102  for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
103  newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
104  }
105  }
106  rewriter.replaceOp(op, newResults);
107 
108  return success();
109  }
110 };
111 
112 } // namespace
113 
115  SymbolTableCollection &symbolTableCollection) {
116  patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
117 }
118 
119 } // namespace shard
120 } // namespace mlir
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
shard::GridOp GridOp
int16_t GridAxis
Definition: ShardOps.h:26
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)