MLIR  19.0.0git
HomomorphismSimplification.h
Go to the documentation of this file.
1 //===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10 #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
11 
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/Casting.h"
19 #include <iterator>
20 #include <optional>
21 #include <type_traits>
22 #include <utility>
23 
24 namespace mlir {
25 
26 // If `h` is an homomorphism with respect to the source algebraic structure
27 // induced by function `s` and the target algebraic structure induced by
28 // function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
29 // `h(t(x1, x2, ..., xn))`.
30 //
31 // Functors:
32 // ---------
33 // `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
34 // Returns the operand relevant to the homomorphism.
35 // There may be other operands that are not relevant.
36 //
37 // `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
38 // Returns the result relevant to the homomorphism.
39 //
40 // `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
41 // void` Populates into the vector the operands relevant to the homomorphism.
42 //
43 // `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
44 // Return the result of the source algebraic operation relevant to the
45 // homomorphism.
46 //
47 // `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
48 // Return the result of the target algebraic operation relevant to the
49 // homomorphism.
50 //
51 // `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
52 // Check if the operation is an homomorphism of the required type.
53 // Additionally if the optional is present checks if the operations are
54 // compatible homomorphisms.
55 //
56 // `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
57 // Check if the operation is an operation of the algebraic structure.
58 //
59 // `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
60 // PatternRewriter &rewriter) -> Operation*`
61 template <typename GetHomomorphismOpOperandFn,
62  typename GetHomomorphismOpResultFn,
63  typename GetSourceAlgebraicOpOperandsFn,
64  typename GetSourceAlgebraicOpResultFn,
65  typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
66  typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
68  template <typename GetHomomorphismOpOperandFnArg,
69  typename GetHomomorphismOpResultFnArg,
70  typename GetSourceAlgebraicOpOperandsFnArg,
71  typename GetSourceAlgebraicOpResultFnArg,
72  typename GetTargetAlgebraicOpResultFnArg,
73  typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
74  typename CreateTargetAlgebraicOpFnArg,
75  typename... RewritePatternArgs>
77  GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
78  GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
79  GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
80  GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
81  GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
82  IsHomomorphismOpFnArg &&isHomomorphismOp,
83  IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
84  CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
85  RewritePatternArgs &&...args)
86  : RewritePattern(std::forward<RewritePatternArgs>(args)...),
87  getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
88  getHomomorphismOpOperand)),
89  getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
90  getHomomorphismOpResult)),
91  getSourceAlgebraicOpOperands(
92  std::forward<GetSourceAlgebraicOpOperandsFnArg>(
93  getSourceAlgebraicOpOperands)),
94  getSourceAlgebraicOpResult(
95  std::forward<GetSourceAlgebraicOpResultFnArg>(
96  getSourceAlgebraicOpResult)),
97  getTargetAlgebraicOpResult(
98  std::forward<GetTargetAlgebraicOpResultFnArg>(
99  getTargetAlgebraicOpResult)),
100  isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
101  isSourceAlgebraicOp(
102  std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
103  createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
104  createTargetAlgebraicOpFn)) {}
105 
107  PatternRewriter &rewriter) const override {
108  SmallVector<OpOperand *> algebraicOpOperands;
109  if (failed(matchOp(op, algebraicOpOperands))) {
110  return failure();
111  }
112  return rewriteOp(op, algebraicOpOperands, rewriter);
113  }
114 
115 private:
117  matchOp(Operation *sourceAlgebraicOp,
118  SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
119  if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
120  return failure();
121  }
122  sourceAlgebraicOpOperands.clear();
123  getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
124  if (sourceAlgebraicOpOperands.empty()) {
125  return failure();
126  }
127 
128  Operation *firstHomomorphismOp =
129  sourceAlgebraicOpOperands.front()->get().getDefiningOp();
130  if (!firstHomomorphismOp ||
131  !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
132  return failure();
133  }
134  OpResult firstHomomorphismOpResult =
135  getHomomorphismOpResult(firstHomomorphismOp);
136  if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
137  return failure();
138  }
139 
140  for (auto operand : sourceAlgebraicOpOperands) {
141  Operation *homomorphismOp = operand->get().getDefiningOp();
142  if (!homomorphismOp ||
143  !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
144  return failure();
145  }
146  }
147  return success();
148  }
149 
150  LogicalResult
151  rewriteOp(Operation *sourceAlgebraicOp,
152  const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
153  PatternRewriter &rewriter) const {
154  IRMapping irMapping;
155  for (auto operand : sourceAlgebraicOpOperands) {
156  Operation *homomorphismOp = operand->get().getDefiningOp();
157  irMapping.map(operand->get(),
158  getHomomorphismOpOperand(homomorphismOp)->get());
159  }
160  Operation *targetAlgebraicOp =
161  createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
162 
163  irMapping.clear();
164  assert(!sourceAlgebraicOpOperands.empty());
165  Operation *firstHomomorphismOp =
166  sourceAlgebraicOpOperands[0]->get().getDefiningOp();
167  irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
168  getTargetAlgebraicOpResult(targetAlgebraicOp));
169  Operation *newHomomorphismOp =
170  rewriter.clone(*firstHomomorphismOp, irMapping);
171  rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
172  getHomomorphismOpResult(newHomomorphismOp));
173  return success();
174  }
175 
176  GetHomomorphismOpOperandFn getHomomorphismOpOperand;
177  GetHomomorphismOpResultFn getHomomorphismOpResult;
178  GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
179  GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
180  GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
181  IsHomomorphismOpFn isHomomorphismOp;
182  IsSourceAlgebraicOpFn isSourceAlgebraicOp;
183  CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
184 };
185 
186 } // namespace mlir
187 
188 #endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
HomomorphismSimplification(GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, IsHomomorphismOpFnArg &&isHomomorphismOp, IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, RewritePatternArgs &&...args)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26