MLIR 22.0.0git
Simplifications.h
Go to the documentation of this file.
1//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
10#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
11
15#include "llvm/Support/Casting.h"
16#include <algorithm>
17#include <iterator>
18#include <memory>
19#include <utility>
20
21namespace mlir {
22
24
25namespace shard {
26
27// If we have an algebraic op like "+" and a summing all-reduce,
28// `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
29// `all_reduce_sum(x + y)`.
30//
31// Another example with `min`.
32// `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
33// `all_reduce_min(min(x, y))`.
34//
35// Works only with algebraic ops that have all their operands relevant
36// to the all-reduce endomorphism.
37// Will not work with some op `f(x, y, z)` where only `x` and `y` form
38// the algebraic structure.
39template <typename AlgebraicOp>
41 RewritePatternSet &patterns, ReductionKind reduction) {
42 auto getEndomorphismOpOperand = [](Operation *op) {
43 auto allReduceOp = llvm::cast<AllReduceOp>(op);
44 return &allReduceOp.getInputMutable();
45 };
46 auto getEndomorphismOpResult = [](Operation *op) {
47 auto allReduceOp = llvm::cast<AllReduceOp>(op);
48 return allReduceOp->getResult(0);
49 };
50 auto getAlgebraicOpOperands = [](Operation *op,
51 SmallVector<OpOperand *> &operands) {
52 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53 llvm::append_range(operands,
54 llvm::make_pointer_range(algebraicOp->getOpOperands()));
55 };
56 auto getAlgebraicOpResult = [](Operation *op) {
57 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
58 return algebraicOp->getResult(0);
59 };
60 auto isEndomorphismOp = [reduction](Operation *op,
61 std::optional<Operation *> referenceOp) {
62 auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
63 if (!allReduceOp)
64 return false;
65 auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
66 auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
67 if (inType.getElementType() != outType.getElementType() ||
68 allReduceOp.getReduction() != reduction) {
69 return false;
70 }
71
72 // Dont't use simplify if the all-reduce is used other than by the
73 // algebraic op.
74 // TODO: maybe handle this by an additional pass that later reverses the
75 // simplification if there are other uses left other optimizations have
76 // been done.
77 if (!allReduceOp->hasOneUse()) {
78 return false;
79 }
80
81 if (!referenceOp) {
82 return true;
83 }
84
85 auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
86 auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
87 return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
88 inType.getElementType() == refType.getElementType();
89 };
90 auto isAlgebraicOp = [](Operation *op) { return isa<AlgebraicOp>(op); };
91
92 using ConcreteEndomorphismSimplification = EndomorphismSimplification<
93 std::decay_t<decltype(getEndomorphismOpOperand)>,
94 std::decay_t<decltype(getEndomorphismOpResult)>,
95 std::decay_t<decltype(getAlgebraicOpOperands)>,
96 std::decay_t<decltype(getAlgebraicOpResult)>,
97 std::decay_t<decltype(isEndomorphismOp)>,
98 std::decay_t<decltype(isAlgebraicOp)>>;
99 patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
100 std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
101 std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
102 std::move(isEndomorphismOp), std::move(isAlgebraicOp),
103 AlgebraicOp::getOperationName(), 1, patterns.getContext()));
104}
105
106// It is invalid to change ops that declare symbols during the application of
107// these patterns, because symbolTableCollection is used to cache them.
109 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
111 SymbolTableCollection &symbolTableCollection);
112
113} // namespace shard
114} // namespace mlir
115
116#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents a collection of SymbolTables.
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet &patterns, ReductionKind reduction)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns