9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10 #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Support/Casting.h"
20 #include <type_traits>
60 template <
typename GetHomomorphismOpOperandFn,
61 typename GetHomomorphismOpResultFn,
62 typename GetSourceAlgebraicOpOperandsFn,
63 typename GetSourceAlgebraicOpResultFn,
64 typename GetTargetAlgebraicOpResultFn,
typename IsHomomorphismOpFn,
65 typename IsSourceAlgebraicOpFn,
typename CreateTargetAlgebraicOpFn>
67 template <
typename GetHomomorphismOpOperandFnArg,
68 typename GetHomomorphismOpResultFnArg,
69 typename GetSourceAlgebraicOpOperandsFnArg,
70 typename GetSourceAlgebraicOpResultFnArg,
71 typename GetTargetAlgebraicOpResultFnArg,
72 typename IsHomomorphismOpFnArg,
typename IsSourceAlgebraicOpFnArg,
73 typename CreateTargetAlgebraicOpFnArg,
74 typename... RewritePatternArgs>
76 GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
77 GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
78 GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
79 GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
80 GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
81 IsHomomorphismOpFnArg &&isHomomorphismOp,
82 IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
83 CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
84 RewritePatternArgs &&...args)
86 getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
87 getHomomorphismOpOperand)),
88 getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
89 getHomomorphismOpResult)),
90 getSourceAlgebraicOpOperands(
91 std::forward<GetSourceAlgebraicOpOperandsFnArg>(
92 getSourceAlgebraicOpOperands)),
93 getSourceAlgebraicOpResult(
94 std::forward<GetSourceAlgebraicOpResultFnArg>(
95 getSourceAlgebraicOpResult)),
96 getTargetAlgebraicOpResult(
97 std::forward<GetTargetAlgebraicOpResultFnArg>(
98 getTargetAlgebraicOpResult)),
99 isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
101 std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
102 createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
103 createTargetAlgebraicOpFn)) {}
108 if (failed(matchOp(op, algebraicOpOperands))) {
111 return rewriteOp(op, algebraicOpOperands, rewriter);
118 if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
121 sourceAlgebraicOpOperands.clear();
122 getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
123 if (sourceAlgebraicOpOperands.empty()) {
127 Operation *firstHomomorphismOp =
128 sourceAlgebraicOpOperands.front()->get().getDefiningOp();
129 if (!firstHomomorphismOp ||
130 !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
133 OpResult firstHomomorphismOpResult =
134 getHomomorphismOpResult(firstHomomorphismOp);
135 if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
139 for (
auto operand : sourceAlgebraicOpOperands) {
140 Operation *homomorphismOp = operand->get().getDefiningOp();
141 if (!homomorphismOp ||
142 !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
150 rewriteOp(Operation *sourceAlgebraicOp,
151 const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
152 PatternRewriter &rewriter)
const {
154 for (
auto operand : sourceAlgebraicOpOperands) {
155 Operation *homomorphismOp = operand->get().getDefiningOp();
156 irMapping.map(operand->get(),
157 getHomomorphismOpOperand(homomorphismOp)->
get());
159 Operation *targetAlgebraicOp =
160 createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
163 assert(!sourceAlgebraicOpOperands.empty());
164 Operation *firstHomomorphismOp =
165 sourceAlgebraicOpOperands[0]->get().getDefiningOp();
166 irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->
get(),
167 getTargetAlgebraicOpResult(targetAlgebraicOp));
168 Operation *newHomomorphismOp =
169 rewriter.clone(*firstHomomorphismOp, irMapping);
170 rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
171 getHomomorphismOpResult(newHomomorphismOp));
175 GetHomomorphismOpOperandFn getHomomorphismOpOperand;
176 GetHomomorphismOpResultFn getHomomorphismOpResult;
177 GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
178 GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
179 GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
180 IsHomomorphismOpFn isHomomorphismOp;
181 IsSourceAlgebraicOpFn isSourceAlgebraicOp;
182 CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
HomomorphismSimplification(GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, IsHomomorphismOpFnArg &&isHomomorphismOp, IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, RewritePatternArgs &&...args)