MLIR  20.0.0git
HomomorphismSimplification.h
Go to the documentation of this file.
1 //===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10 #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
11 
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Support/Casting.h"
18 #include <iterator>
19 #include <optional>
20 #include <type_traits>
21 #include <utility>
22 
23 namespace mlir {
24 
25 // If `h` is an homomorphism with respect to the source algebraic structure
26 // induced by function `s` and the target algebraic structure induced by
27 // function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
28 // `h(t(x1, x2, ..., xn))`.
29 //
30 // Functors:
31 // ---------
32 // `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
33 // Returns the operand relevant to the homomorphism.
34 // There may be other operands that are not relevant.
35 //
36 // `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
37 // Returns the result relevant to the homomorphism.
38 //
39 // `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
40 // void` Populates into the vector the operands relevant to the homomorphism.
41 //
42 // `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
43 // Return the result of the source algebraic operation relevant to the
44 // homomorphism.
45 //
46 // `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
47 // Return the result of the target algebraic operation relevant to the
48 // homomorphism.
49 //
50 // `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
51 // Check if the operation is an homomorphism of the required type.
52 // Additionally if the optional is present checks if the operations are
53 // compatible homomorphisms.
54 //
55 // `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
56 // Check if the operation is an operation of the algebraic structure.
57 //
58 // `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
59 // PatternRewriter &rewriter) -> Operation*`
60 template <typename GetHomomorphismOpOperandFn,
61  typename GetHomomorphismOpResultFn,
62  typename GetSourceAlgebraicOpOperandsFn,
63  typename GetSourceAlgebraicOpResultFn,
64  typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
65  typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
67  template <typename GetHomomorphismOpOperandFnArg,
68  typename GetHomomorphismOpResultFnArg,
69  typename GetSourceAlgebraicOpOperandsFnArg,
70  typename GetSourceAlgebraicOpResultFnArg,
71  typename GetTargetAlgebraicOpResultFnArg,
72  typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
73  typename CreateTargetAlgebraicOpFnArg,
74  typename... RewritePatternArgs>
76  GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
77  GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
78  GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
79  GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
80  GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
81  IsHomomorphismOpFnArg &&isHomomorphismOp,
82  IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
83  CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
84  RewritePatternArgs &&...args)
85  : RewritePattern(std::forward<RewritePatternArgs>(args)...),
86  getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
87  getHomomorphismOpOperand)),
88  getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
89  getHomomorphismOpResult)),
90  getSourceAlgebraicOpOperands(
91  std::forward<GetSourceAlgebraicOpOperandsFnArg>(
92  getSourceAlgebraicOpOperands)),
93  getSourceAlgebraicOpResult(
94  std::forward<GetSourceAlgebraicOpResultFnArg>(
95  getSourceAlgebraicOpResult)),
96  getTargetAlgebraicOpResult(
97  std::forward<GetTargetAlgebraicOpResultFnArg>(
98  getTargetAlgebraicOpResult)),
99  isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
100  isSourceAlgebraicOp(
101  std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
102  createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
103  createTargetAlgebraicOpFn)) {}
104 
105  LogicalResult matchAndRewrite(Operation *op,
106  PatternRewriter &rewriter) const override {
107  SmallVector<OpOperand *> algebraicOpOperands;
108  if (failed(matchOp(op, algebraicOpOperands))) {
109  return failure();
110  }
111  return rewriteOp(op, algebraicOpOperands, rewriter);
112  }
113 
114 private:
115  LogicalResult
116  matchOp(Operation *sourceAlgebraicOp,
117  SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
118  if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
119  return failure();
120  }
121  sourceAlgebraicOpOperands.clear();
122  getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
123  if (sourceAlgebraicOpOperands.empty()) {
124  return failure();
125  }
126 
127  Operation *firstHomomorphismOp =
128  sourceAlgebraicOpOperands.front()->get().getDefiningOp();
129  if (!firstHomomorphismOp ||
130  !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
131  return failure();
132  }
133  OpResult firstHomomorphismOpResult =
134  getHomomorphismOpResult(firstHomomorphismOp);
135  if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
136  return failure();
137  }
138 
139  for (auto operand : sourceAlgebraicOpOperands) {
140  Operation *homomorphismOp = operand->get().getDefiningOp();
141  if (!homomorphismOp ||
142  !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
143  return failure();
144  }
145  }
146  return success();
147  }
148 
149  LogicalResult
150  rewriteOp(Operation *sourceAlgebraicOp,
151  const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
152  PatternRewriter &rewriter) const {
153  IRMapping irMapping;
154  for (auto operand : sourceAlgebraicOpOperands) {
155  Operation *homomorphismOp = operand->get().getDefiningOp();
156  irMapping.map(operand->get(),
157  getHomomorphismOpOperand(homomorphismOp)->get());
158  }
159  Operation *targetAlgebraicOp =
160  createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
161 
162  irMapping.clear();
163  assert(!sourceAlgebraicOpOperands.empty());
164  Operation *firstHomomorphismOp =
165  sourceAlgebraicOpOperands[0]->get().getDefiningOp();
166  irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
167  getTargetAlgebraicOpResult(targetAlgebraicOp));
168  Operation *newHomomorphismOp =
169  rewriter.clone(*firstHomomorphismOp, irMapping);
170  rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
171  getHomomorphismOpResult(newHomomorphismOp));
172  return success();
173  }
174 
175  GetHomomorphismOpOperandFn getHomomorphismOpOperand;
176  GetHomomorphismOpResultFn getHomomorphismOpResult;
177  GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
178  GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
179  GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
180  IsHomomorphismOpFn isHomomorphismOp;
181  IsSourceAlgebraicOpFn isSourceAlgebraicOp;
182  CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
183 };
184 
185 } // namespace mlir
186 
187 #endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
HomomorphismSimplification(GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, IsHomomorphismOpFnArg &&isHomomorphismOp, IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, RewritePatternArgs &&...args)