MLIR 22.0.0git
HomomorphismSimplification.h
Go to the documentation of this file.
1//===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10#define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
11
12#include "mlir/IR/IRMapping.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/Support/Casting.h"
18#include <iterator>
19#include <optional>
20#include <type_traits>
21#include <utility>
22
23namespace mlir {
24
25// If `h` is an homomorphism with respect to the source algebraic structure
26// induced by function `s` and the target algebraic structure induced by
27// function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
28// `h(t(x1, x2, ..., xn))`.
29//
30// Functors:
31// ---------
32// `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
33// Returns the operand relevant to the homomorphism.
34// There may be other operands that are not relevant.
35//
36// `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
37// Returns the result relevant to the homomorphism.
38//
39// `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
40// void` Populates into the vector the operands relevant to the homomorphism.
41//
42// `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
43// Return the result of the source algebraic operation relevant to the
44// homomorphism.
45//
46// `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
47// Return the result of the target algebraic operation relevant to the
48// homomorphism.
49//
50// `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
51// Check if the operation is an homomorphism of the required type.
52// Additionally if the optional is present checks if the operations are
53// compatible homomorphisms.
54//
55// `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
56// Check if the operation is an operation of the algebraic structure.
57//
58// `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
59// PatternRewriter &rewriter) -> Operation*`
60template <typename GetHomomorphismOpOperandFn,
61 typename GetHomomorphismOpResultFn,
62 typename GetSourceAlgebraicOpOperandsFn,
63 typename GetSourceAlgebraicOpResultFn,
64 typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
65 typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
67 template <typename GetHomomorphismOpOperandFnArg,
68 typename GetHomomorphismOpResultFnArg,
69 typename GetSourceAlgebraicOpOperandsFnArg,
70 typename GetSourceAlgebraicOpResultFnArg,
71 typename GetTargetAlgebraicOpResultFnArg,
72 typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
73 typename CreateTargetAlgebraicOpFnArg,
74 typename... RewritePatternArgs>
76 GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
77 GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
78 GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
79 GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
80 GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
81 IsHomomorphismOpFnArg &&isHomomorphismOp,
82 IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
83 CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
84 RewritePatternArgs &&...args)
85 : RewritePattern(std::forward<RewritePatternArgs>(args)...),
86 getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
87 getHomomorphismOpOperand)),
88 getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
89 getHomomorphismOpResult)),
90 getSourceAlgebraicOpOperands(
91 std::forward<GetSourceAlgebraicOpOperandsFnArg>(
92 getSourceAlgebraicOpOperands)),
93 getSourceAlgebraicOpResult(
94 std::forward<GetSourceAlgebraicOpResultFnArg>(
95 getSourceAlgebraicOpResult)),
96 getTargetAlgebraicOpResult(
97 std::forward<GetTargetAlgebraicOpResultFnArg>(
98 getTargetAlgebraicOpResult)),
99 isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
100 isSourceAlgebraicOp(
101 std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
102 createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
103 createTargetAlgebraicOpFn)) {}
104
105 LogicalResult matchAndRewrite(Operation *op,
106 PatternRewriter &rewriter) const override {
107 SmallVector<OpOperand *> algebraicOpOperands;
108 if (failed(matchOp(op, algebraicOpOperands))) {
109 return failure();
110 }
111 return rewriteOp(op, algebraicOpOperands, rewriter);
112 }
113
114private:
115 LogicalResult
116 matchOp(Operation *sourceAlgebraicOp,
117 SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
118 if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
119 return failure();
120 }
121 sourceAlgebraicOpOperands.clear();
122 getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
123 if (sourceAlgebraicOpOperands.empty()) {
124 return failure();
125 }
126
127 Operation *firstHomomorphismOp =
128 sourceAlgebraicOpOperands.front()->get().getDefiningOp();
129 if (!firstHomomorphismOp ||
130 !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
131 return failure();
132 }
133 OpResult firstHomomorphismOpResult =
134 getHomomorphismOpResult(firstHomomorphismOp);
135 if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
136 return failure();
137 }
138
139 for (auto operand : sourceAlgebraicOpOperands) {
140 Operation *homomorphismOp = operand->get().getDefiningOp();
141 if (!homomorphismOp ||
142 !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
143 return failure();
144 }
145 }
146 return success();
147 }
148
149 LogicalResult
150 rewriteOp(Operation *sourceAlgebraicOp,
151 const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
152 PatternRewriter &rewriter) const {
153 IRMapping irMapping;
154 for (auto operand : sourceAlgebraicOpOperands) {
155 Operation *homomorphismOp = operand->get().getDefiningOp();
156 irMapping.map(operand->get(),
157 getHomomorphismOpOperand(homomorphismOp)->get());
158 }
159 Operation *targetAlgebraicOp =
160 createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
161
162 irMapping.clear();
163 assert(!sourceAlgebraicOpOperands.empty());
164 Operation *firstHomomorphismOp =
165 sourceAlgebraicOpOperands[0]->get().getDefiningOp();
166 irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
167 getTargetAlgebraicOpResult(targetAlgebraicOp));
168 Operation *newHomomorphismOp =
169 rewriter.clone(*firstHomomorphismOp, irMapping);
170 rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
171 getHomomorphismOpResult(newHomomorphismOp));
172 return success();
173 }
174
175 GetHomomorphismOpOperandFn getHomomorphismOpOperand;
176 GetHomomorphismOpResultFn getHomomorphismOpResult;
177 GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
178 GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
179 GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
180 IsHomomorphismOpFn isHomomorphismOp;
181 IsSourceAlgebraicOpFn isSourceAlgebraicOp;
182 CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
183};
184
185} // namespace mlir
186
187#endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
return success()
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
HomomorphismSimplification(GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, IsHomomorphismOpFnArg &&isHomomorphismOp, IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, RewritePatternArgs &&...args)