MLIR 22.0.0git
Simplifications.cpp
Go to the documentation of this file.
1//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TransformsDetail.h"
15#include "mlir/IR/SymbolTable.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/SmallVector.h"
18#include <numeric>
19
20namespace mlir {
21namespace shard {
22
48
49namespace {
50
51// This folding can not be done with an operation's fold method or
52// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
53// symbol tables.
54// We can't use DialectFoldInterface since the cache may be invalidated by some
55// pass changing the referenced GridOp ops.
56struct GridShapeFolder
57 : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
58 using OpRewritePatternWithSymbolTableCollection::
59 OpRewritePatternWithSymbolTableCollection;
60 LogicalResult matchAndRewrite(GridShapeOp op,
61 PatternRewriter &rewriter) const override {
62 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
63 GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
64 op.getOperation(), op.getGridAttr());
65 if (!grid) {
66 return failure();
67 }
68 ArrayRef<GridAxis> opGridAxes = op.getAxes();
69 SmallVector<GridAxis> opAxesIota;
70 if (opGridAxes.empty()) {
71 opAxesIota.resize(grid.getRank());
72 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
73 opGridAxes = opAxesIota;
74 }
75 if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) {
76 return ShapedType::isDynamic(grid.getShape()[axis]);
77 })) {
78 // All grid dimensions are dynamic. Nothing to fold.
79 return failure();
80 }
81
82 SmallVector<Value> newResults(op->getResults().size());
83 SmallVector<GridAxis> newShapeOpGridAxes;
84 SmallVector<size_t> newToOldResultsIndexMap;
85
86 for (size_t i = 0; i < opGridAxes.size(); ++i) {
87 auto gridAxisSize = grid.getShape()[opGridAxes[i]];
88 if (ShapedType::isDynamic(gridAxisSize)) {
89 newToOldResultsIndexMap.push_back(i);
90 newShapeOpGridAxes.push_back(opGridAxes[i]);
91 } else {
92 // Fold static grid axes.
93 newResults[i] = arith::ConstantOp::create(
94 builder, builder.getIndexAttr(gridAxisSize));
95 }
96 }
97
98 // Leave only the dynamic grid axes to be queried.
99 if (!newShapeOpGridAxes.empty()) {
100 GridShapeOp newShapeOp =
101 GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
102 for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
103 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
104 }
105 }
106 rewriter.replaceOp(op, newResults);
107
108 return success();
109 }
110};
111
112} // namespace
113
115 SymbolTableCollection &symbolTableCollection) {
116 patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
117}
118
119} // namespace shard
120} // namespace mlir
return success()
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
int16_t GridAxis
Definition ShardOps.h:26
void populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet &patterns, ReductionKind reduction)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns