9 #ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ 
   10 #define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ 
   16 #include "llvm/ADT/SmallVector.h" 
   17 #include "llvm/Support/Casting.h" 
   20 #include <type_traits> 
   60 template <
typename GetHomomorphismOpOperandFn,
 
   61           typename GetHomomorphismOpResultFn,
 
   62           typename GetSourceAlgebraicOpOperandsFn,
 
   63           typename GetSourceAlgebraicOpResultFn,
 
   64           typename GetTargetAlgebraicOpResultFn, 
typename IsHomomorphismOpFn,
 
   65           typename IsSourceAlgebraicOpFn, 
typename CreateTargetAlgebraicOpFn>
 
   67   template <
typename GetHomomorphismOpOperandFnArg,
 
   68             typename GetHomomorphismOpResultFnArg,
 
   69             typename GetSourceAlgebraicOpOperandsFnArg,
 
   70             typename GetSourceAlgebraicOpResultFnArg,
 
   71             typename GetTargetAlgebraicOpResultFnArg,
 
   72             typename IsHomomorphismOpFnArg, 
typename IsSourceAlgebraicOpFnArg,
 
   73             typename CreateTargetAlgebraicOpFnArg,
 
   74             typename... RewritePatternArgs>
 
   76       GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
 
   77       GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
 
   78       GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
 
   79       GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
 
   80       GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
 
   81       IsHomomorphismOpFnArg &&isHomomorphismOp,
 
   82       IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
 
   83       CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
 
   84       RewritePatternArgs &&...args)
 
   86         getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
 
   87             getHomomorphismOpOperand)),
 
   88         getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
 
   89             getHomomorphismOpResult)),
 
   90         getSourceAlgebraicOpOperands(
 
   91             std::forward<GetSourceAlgebraicOpOperandsFnArg>(
 
   92                 getSourceAlgebraicOpOperands)),
 
   93         getSourceAlgebraicOpResult(
 
   94             std::forward<GetSourceAlgebraicOpResultFnArg>(
 
   95                 getSourceAlgebraicOpResult)),
 
   96         getTargetAlgebraicOpResult(
 
   97             std::forward<GetTargetAlgebraicOpResultFnArg>(
 
   98                 getTargetAlgebraicOpResult)),
 
   99         isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
 
  101             std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
 
  102         createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
 
  103             createTargetAlgebraicOpFn)) {}
 
  108     if (
failed(matchOp(op, algebraicOpOperands))) {
 
  111     return rewriteOp(op, algebraicOpOperands, rewriter);
 
  118     if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
 
  121     sourceAlgebraicOpOperands.clear();
 
  122     getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
 
  123     if (sourceAlgebraicOpOperands.empty()) {
 
  127     Operation *firstHomomorphismOp =
 
  128         sourceAlgebraicOpOperands.front()->get().getDefiningOp();
 
  129     if (!firstHomomorphismOp ||
 
  130         !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
 
  133     OpResult firstHomomorphismOpResult =
 
  134         getHomomorphismOpResult(firstHomomorphismOp);
 
  135     if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
 
  139     for (
auto operand : sourceAlgebraicOpOperands) {
 
  140       Operation *homomorphismOp = operand->get().getDefiningOp();
 
  141       if (!homomorphismOp ||
 
  142           !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
 
  150   rewriteOp(Operation *sourceAlgebraicOp,
 
  151             const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
 
  152             PatternRewriter &rewriter)
 const {
 
  154     for (
auto operand : sourceAlgebraicOpOperands) {
 
  155       Operation *homomorphismOp = operand->get().getDefiningOp();
 
  156       irMapping.map(operand->get(),
 
  157                     getHomomorphismOpOperand(homomorphismOp)->
get());
 
  159     Operation *targetAlgebraicOp =
 
  160         createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
 
  163     assert(!sourceAlgebraicOpOperands.empty());
 
  164     Operation *firstHomomorphismOp =
 
  165         sourceAlgebraicOpOperands[0]->get().getDefiningOp();
 
  166     irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->
get(),
 
  167                   getTargetAlgebraicOpResult(targetAlgebraicOp));
 
  168     Operation *newHomomorphismOp =
 
  169         rewriter.clone(*firstHomomorphismOp, irMapping);
 
  170     rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
 
  171                                 getHomomorphismOpResult(newHomomorphismOp));
 
  175   GetHomomorphismOpOperandFn getHomomorphismOpOperand;
 
  176   GetHomomorphismOpResultFn getHomomorphismOpResult;
 
  177   GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
 
  178   GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
 
  179   GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
 
  180   IsHomomorphismOpFn isHomomorphismOp;
 
  181   IsSourceAlgebraicOpFn isSourceAlgebraicOp;
 
  182   CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
 
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
HomomorphismSimplification(GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, IsHomomorphismOpFnArg &&isHomomorphismOp, IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, RewritePatternArgs &&...args)