MLIR  22.0.0git
Simplifications.h
Go to the documentation of this file.
1 //===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
10 #define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
11 
13 #include "mlir/IR/PatternMatch.h"
15 #include "llvm/Support/Casting.h"
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20 
21 namespace mlir {
22 
23 class SymbolTableCollection;
24 
25 namespace shard {
26 
27 // If we have an algebraic op like "+" and a summing all-reduce,
28 // `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
29 // `all_reduce_sum(x + y)`.
30 //
31 // Another example with `min`.
32 // `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
33 // `all_reduce_min(min(x, y))`.
34 //
35 // Works only with algebraic ops that have all their operands relevant
36 // to the all-reduce endomorphism.
37 // Will not work with some op `f(x, y, z)` where only `x` and `y` form
38 // the algebraic structure.
39 template <typename AlgebraicOp>
42  auto getEndomorphismOpOperand = [](Operation *op) {
43  auto allReduceOp = llvm::cast<AllReduceOp>(op);
44  return &allReduceOp.getInputMutable();
45  };
46  auto getEndomorphismOpResult = [](Operation *op) {
47  auto allReduceOp = llvm::cast<AllReduceOp>(op);
48  return allReduceOp->getResult(0);
49  };
50  auto getAlgebraicOpOperands = [](Operation *op,
51  SmallVector<OpOperand *> &operands) {
52  auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53  std::transform(algebraicOp->getOpOperands().begin(),
54  algebraicOp->getOpOperands().end(),
55  std::back_inserter(operands),
56  [](OpOperand &operand) { return &operand; });
57  };
58  auto getAlgebraicOpResult = [](Operation *op) {
59  auto algebraicOp = llvm::cast<AlgebraicOp>(op);
60  return algebraicOp->getResult(0);
61  };
62  auto isEndomorphismOp = [reduction](Operation *op,
63  std::optional<Operation *> referenceOp) {
64  auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65  if (!allReduceOp)
66  return false;
67  auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
68  auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
69  if (inType.getElementType() != outType.getElementType() ||
70  allReduceOp.getReduction() != reduction) {
71  return false;
72  }
73 
74  // Dont't use simplify if the all-reduce is used other than by the
75  // algebraic op.
76  // TODO: maybe handle this by an additional pass that later reverses the
77  // simplification if there are other uses left other optimizations have
78  // been done.
79  if (!allReduceOp->hasOneUse()) {
80  return false;
81  }
82 
83  if (!referenceOp) {
84  return true;
85  }
86 
87  auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
88  auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
89  return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
90  inType.getElementType() == refType.getElementType();
91  };
92  auto isAlgebraicOp = [](Operation *op) { return isa<AlgebraicOp>(op); };
93 
94  using ConcreteEndomorphismSimplification = EndomorphismSimplification<
95  std::decay_t<decltype(getEndomorphismOpOperand)>,
96  std::decay_t<decltype(getEndomorphismOpResult)>,
97  std::decay_t<decltype(getAlgebraicOpOperands)>,
98  std::decay_t<decltype(getAlgebraicOpResult)>,
99  std::decay_t<decltype(isEndomorphismOp)>,
100  std::decay_t<decltype(isAlgebraicOp)>>;
101  patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
102  std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
103  std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
104  std::move(isEndomorphismOp), std::move(isAlgebraicOp),
105  AlgebraicOp::getOperationName(), 1, patterns.getContext()));
106 }
107 
108 // It is invalid to change ops that declare symbols during the application of
109 // these patterns, because symbolTableCollection is used to cache them.
111  RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
113  SymbolTableCollection &symbolTableCollection);
114 
115 } // namespace shard
116 } // namespace mlir
117 
118 #endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
shard::ReductionKind ReductionKind
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet &patterns, ReductionKind reduction)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns