MLIR  19.0.0git
Simplifications.h
Go to the documentation of this file.
1 //===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
10 #define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
11 
13 #include "mlir/IR/PatternMatch.h"
15 #include "llvm/Support/Casting.h"
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20 
21 namespace mlir {
22 
23 class SymbolTableCollection;
24 
25 namespace mesh {
26 
27 // If we have an algebraic op like "+" and a summing all-reduce,
28 // `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
29 // `all_reduce_sum(x + y)`.
30 //
31 // Another example with `min`.
32 // `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
33 // `all_reduce_min(min(x, y))`.
34 //
35 // Works only with algebraic ops that have all their operands relevant
36 // to the all-reduce endomorphism.
37 // Will not work with some op `f(x, y, z)` where only `x` and `y` form
38 // the algebraic structure.
39 template <typename AlgebraicOp>
41  RewritePatternSet &patterns, ReductionKind reduction) {
42  auto getEndomorphismOpOperand = [](Operation *op) {
43  auto allReduceOp = llvm::cast<AllReduceOp>(op);
44  return &allReduceOp.getInputMutable();
45  };
46  auto getEndomorphismOpResult = [](Operation *op) {
47  auto allReduceOp = llvm::cast<AllReduceOp>(op);
48  return allReduceOp->getResult(0);
49  };
50  auto getAlgebraicOpOperands = [](Operation *op,
51  SmallVector<OpOperand *> &operands) {
52  auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53  std::transform(algebraicOp->getOpOperands().begin(),
54  algebraicOp->getOpOperands().end(),
55  std::back_inserter(operands),
56  [](OpOperand &operand) { return &operand; });
57  };
58  auto getAlgebraicOpResult = [](Operation *op) {
59  auto algebraicOp = llvm::cast<AlgebraicOp>(op);
60  return algebraicOp->getResult(0);
61  };
62  auto isEndomorphismOp = [reduction](Operation *op,
63  std::optional<Operation *> referenceOp) {
64  auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65  if (!allReduceOp ||
66  allReduceOp.getInput().getType().getElementType() !=
67  allReduceOp.getResult().getType().getElementType() ||
68  allReduceOp.getReduction() != reduction) {
69  return false;
70  }
71 
72  // Dont't use simplify if the all-reduce is used other than by the
73  // algebraic op.
74  // TODO: maybe handle this by an additional pass that later reverses the
75  // simplification if there are other uses left other optimizations have
76  // been done.
77  if (!allReduceOp->hasOneUse()) {
78  return false;
79  }
80 
81  if (!referenceOp) {
82  return true;
83  }
84 
85  auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
86  return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
87  allReduceOp.getInput().getType().getElementType() ==
88  refAllReduceOp.getInput().getType().getElementType();
89  };
90  auto isAlgebraicOp = [](Operation *op) {
91  return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
92  };
93 
94  using ConcreteEndomorphismSimplification = EndomorphismSimplification<
95  std::decay_t<decltype(getEndomorphismOpOperand)>,
96  std::decay_t<decltype(getEndomorphismOpResult)>,
97  std::decay_t<decltype(getAlgebraicOpOperands)>,
98  std::decay_t<decltype(getAlgebraicOpResult)>,
99  std::decay_t<decltype(isEndomorphismOp)>,
100  std::decay_t<decltype(isAlgebraicOp)>>;
101  patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
102  std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
103  std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
104  std::move(isEndomorphismOp), std::move(isAlgebraicOp),
105  AlgebraicOp::getOperationName(), 1, patterns.getContext()));
106 }
107 
108 // It is invalid to change ops that declare symbols during the application of
109 // these patterns, because symbolTableCollection is used to cache them.
111  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
113  SymbolTableCollection &symbolTableCollection);
114 
115 } // namespace mesh
116 } // namespace mlir
117 
118 #endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
mesh::ReductionKind ReductionKind
void populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet &patterns, ReductionKind reduction)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.