MLIR  20.0.0git
Simplifications.cpp
Go to the documentation of this file.
1 //===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TransformsDetail.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <numeric>
20 #include <utility>
21 
22 namespace mlir {
23 namespace mesh {
24 
26  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
27  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
28  patterns, ReductionKind::Sum);
29  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30  patterns, ReductionKind::Sum);
31 
32  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
33  patterns, ReductionKind::Min);
34  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
35  patterns, ReductionKind::Min);
36  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37  patterns, ReductionKind::Min);
38 
39  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
40  patterns, ReductionKind::Max);
41  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
42  patterns, ReductionKind::Max);
43  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
44  patterns, ReductionKind::Max);
45 
46  // TODO: add simplifications for all-gather and other collectives.
47 
48  populateFoldingPatterns(patterns, symbolTableCollection);
49 }
50 
51 namespace {
52 
53 // This folding can not be done with an operation's fold method or
54 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
55 // symbol tables.
56 // We can't use DialectFoldInterface since the cache may be invalidated by some
57 // pass changing the referenced MeshOp ops.
58 struct MeshShapeFolder
59  : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
62  LogicalResult matchAndRewrite(MeshShapeOp op,
63  PatternRewriter &rewriter) const override {
64  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
65  MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
66  op.getOperation(), op.getMeshAttr());
67  if (!mesh) {
68  return failure();
69  }
70  ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71  SmallVector<MeshAxis> opAxesIota;
72  if (opMeshAxes.empty()) {
73  opAxesIota.resize(mesh.getRank());
74  std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
75  opMeshAxes = opAxesIota;
76  }
77  if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
78  return ShapedType::isDynamic(mesh.getShape()[axis]);
79  })) {
80  // All mesh dimensions are dynamic. Nothing to fold.
81  return failure();
82  }
83 
84  SmallVector<Value> newResults(op->getResults().size());
85  SmallVector<MeshAxis> newShapeOpMeshAxes;
86  SmallVector<size_t> newToOldResultsIndexMap;
87 
88  for (size_t i = 0; i < opMeshAxes.size(); ++i) {
89  auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90  if (ShapedType::isDynamic(meshAxisSize)) {
91  newToOldResultsIndexMap.push_back(i);
92  newShapeOpMeshAxes.push_back(opMeshAxes[i]);
93  } else {
94  // Fold static mesh axes.
95  newResults[i] = builder.create<arith::ConstantOp>(
96  builder.getIndexAttr(meshAxisSize));
97  }
98  }
99 
100  // Leave only the dynamic mesh axes to be queried.
101  if (!newShapeOpMeshAxes.empty()) {
102  MeshShapeOp newShapeOp =
103  builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104  for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105  newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
106  }
107  }
108  rewriter.replaceOp(op, newResults);
109 
110  return success();
111  }
112 };
113 
114 } // namespace
115 
117  SymbolTableCollection &symbolTableCollection) {
118  patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
119 }
120 
121 } // namespace mesh
122 } // namespace mlir
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
int16_t MeshAxis
Definition: MeshOps.h:25
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)