MLIR  19.0.0git
Simplifications.cpp
Go to the documentation of this file.
1 //===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TransformsDetail.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include <numeric>
21 #include <utility>
22 
23 namespace mlir {
24 namespace mesh {
25 
27  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
28  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
29  patterns, ReductionKind::Sum);
30  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
31  patterns, ReductionKind::Sum);
32 
33  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
34  patterns, ReductionKind::Min);
35  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
36  patterns, ReductionKind::Min);
37  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
38  patterns, ReductionKind::Min);
39 
40  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
41  patterns, ReductionKind::Max);
42  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
43  patterns, ReductionKind::Max);
44  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
45  patterns, ReductionKind::Max);
46 
47  // TODO: add simplifications for all-gather and other collectives.
48 
49  populateFoldingPatterns(patterns, symbolTableCollection);
50 }
51 
52 namespace {
53 
54 // This folding can not be done with an operation's fold method or
55 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
56 // symbol tables.
57 // We can't use DialectFoldInterface since the cache may be invalidated by some
58 // pass changing the referenced MeshOp ops.
59 struct MeshShapeFolder
60  : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
63  LogicalResult matchAndRewrite(MeshShapeOp op,
64  PatternRewriter &rewriter) const override {
65  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
66  MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
67  op.getOperation(), op.getMeshAttr());
68  if (!mesh) {
69  return failure();
70  }
71  ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
72  SmallVector<MeshAxis> opAxesIota;
73  if (opMeshAxes.empty()) {
74  opAxesIota.resize(mesh.getRank());
75  std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
76  opMeshAxes = opAxesIota;
77  }
78  if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
79  return ShapedType::isDynamic(mesh.getShape()[axis]);
80  })) {
81  // All mesh dimensions are dynamic. Nothing to fold.
82  return failure();
83  }
84 
85  SmallVector<Value> newResults(op->getResults().size());
86  SmallVector<MeshAxis> newShapeOpMeshAxes;
87  SmallVector<size_t> newToOldResultsIndexMap;
88 
89  for (size_t i = 0; i < opMeshAxes.size(); ++i) {
90  auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
91  if (ShapedType::isDynamic(meshAxisSize)) {
92  newToOldResultsIndexMap.push_back(i);
93  newShapeOpMeshAxes.push_back(opMeshAxes[i]);
94  } else {
95  // Fold static mesh axes.
96  newResults[i] = builder.create<arith::ConstantOp>(
97  builder.getIndexAttr(meshAxisSize));
98  }
99  }
100 
101  // Leave only the dynamic mesh axes to be queried.
102  if (!newShapeOpMeshAxes.empty()) {
103  MeshShapeOp newShapeOp =
104  builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
105  for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
106  newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
107  }
108  }
109  rewriter.replaceOp(op, newResults);
110 
111  return success();
112  }
113 };
114 
115 } // namespace
116 
118  SymbolTableCollection &symbolTableCollection) {
119  patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
120 }
121 
122 } // namespace mesh
123 } // namespace mlir
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
int16_t MeshAxis
Definition: MeshOps.h:25
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)