9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10 #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/Casting.h"
21 #include <type_traits>
61 template <
typename GetHomomorphismOpOperandFn,
62 typename GetHomomorphismOpResultFn,
63 typename GetSourceAlgebraicOpOperandsFn,
64 typename GetSourceAlgebraicOpResultFn,
65 typename GetTargetAlgebraicOpResultFn,
typename IsHomomorphismOpFn,
66 typename IsSourceAlgebraicOpFn,
typename CreateTargetAlgebraicOpFn>
68 template <
typename GetHomomorphismOpOperandFnArg,
69 typename GetHomomorphismOpResultFnArg,
70 typename GetSourceAlgebraicOpOperandsFnArg,
71 typename GetSourceAlgebraicOpResultFnArg,
72 typename GetTargetAlgebraicOpResultFnArg,
73 typename IsHomomorphismOpFnArg,
typename IsSourceAlgebraicOpFnArg,
74 typename CreateTargetAlgebraicOpFnArg,
75 typename... RewritePatternArgs>
77 GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
78 GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
79 GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
80 GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
81 GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
82 IsHomomorphismOpFnArg &&isHomomorphismOp,
83 IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
84 CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
85 RewritePatternArgs &&...args)
87 getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
88 getHomomorphismOpOperand)),
89 getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
90 getHomomorphismOpResult)),
91 getSourceAlgebraicOpOperands(
92 std::forward<GetSourceAlgebraicOpOperandsFnArg>(
93 getSourceAlgebraicOpOperands)),
94 getSourceAlgebraicOpResult(
95 std::forward<GetSourceAlgebraicOpResultFnArg>(
96 getSourceAlgebraicOpResult)),
97 getTargetAlgebraicOpResult(
98 std::forward<GetTargetAlgebraicOpResultFnArg>(
99 getTargetAlgebraicOpResult)),
100 isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
102 std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
103 createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
104 createTargetAlgebraicOpFn)) {}
109 if (
failed(matchOp(op, algebraicOpOperands))) {
112 return rewriteOp(op, algebraicOpOperands, rewriter);
119 if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
122 sourceAlgebraicOpOperands.clear();
123 getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
124 if (sourceAlgebraicOpOperands.empty()) {
128 Operation *firstHomomorphismOp =
129 sourceAlgebraicOpOperands.front()->get().getDefiningOp();
130 if (!firstHomomorphismOp ||
131 !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
134 OpResult firstHomomorphismOpResult =
135 getHomomorphismOpResult(firstHomomorphismOp);
136 if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
140 for (
auto operand : sourceAlgebraicOpOperands) {
141 Operation *homomorphismOp = operand->get().getDefiningOp();
142 if (!homomorphismOp ||
143 !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
151 rewriteOp(Operation *sourceAlgebraicOp,
152 const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
153 PatternRewriter &rewriter)
const {
155 for (
auto operand : sourceAlgebraicOpOperands) {
156 Operation *homomorphismOp = operand->get().getDefiningOp();
157 irMapping.map(operand->get(),
158 getHomomorphismOpOperand(homomorphismOp)->
get());
160 Operation *targetAlgebraicOp =
161 createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
164 assert(!sourceAlgebraicOpOperands.empty());
165 Operation *firstHomomorphismOp =
166 sourceAlgebraicOpOperands[0]->get().getDefiningOp();
167 irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->
get(),
168 getTargetAlgebraicOpResult(targetAlgebraicOp));
169 Operation *newHomomorphismOp =
170 rewriter.clone(*firstHomomorphismOp, irMapping);
171 rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
172 getHomomorphismOpResult(newHomomorphismOp));
176 GetHomomorphismOpOperandFn getHomomorphismOpOperand;
177 GetHomomorphismOpResultFn getHomomorphismOpResult;
178 GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
179 GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
180 GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
181 IsHomomorphismOpFn isHomomorphismOp;
182 IsSourceAlgebraicOpFn isSourceAlgebraicOp;
183 CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
HomomorphismSimplification(GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, IsHomomorphismOpFnArg &&isHomomorphismOp, IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, RewritePatternArgs &&...args)
This class represents an efficient way to signal success or failure.