MLIR  16.0.0git
ReconcileUnrealizedCasts.cpp
Go to the documentation of this file.
1 //===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
18 #include "mlir/Conversion/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
26 /// the same as the input ones.
27 /// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
28 /// represent a noop within the IR, and thus the initial input values can be
29 /// propagated.
30 /// The same does not hold for 'open' chains chains of casts, such as
31 /// `A -> B -> C`. In this last case there is no cycle among the types and thus
32 /// the conversion is incomplete. The same hold for 'closed' chains like
33 /// `A -> B -> A`, but with the result of type `B` being used by some non-cast
34 /// operations.
35 /// Bifurcations (that is when a chain starts in between of another one) are
36 /// also taken into considerations, and all the above considerations remain
37 /// valid.
38 /// Special corner cases such as dead casts or single casts with same input and
39 /// output types are also covered.
40 struct UnrealizedConversionCastPassthrough
41  : public OpRewritePattern<UnrealizedConversionCastOp> {
43 
44  LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
45  PatternRewriter &rewriter) const override {
46  // The nodes that either are not used by any operation or have at least
47  // one user that is not an unrealized cast.
49 
50  // The nodes whose users are all unrealized casts
51  DenseSet<UnrealizedConversionCastOp> intermediateNodes;
52 
53  // Stack used for the depth-first traversal of the use-def DAG.
55  visitStack.push_back(op);
56 
57  while (!visitStack.empty()) {
58  UnrealizedConversionCastOp current = visitStack.pop_back_val();
59  auto users = current->getUsers();
60  bool isLive = false;
61 
62  for (Operation *user : users) {
63  if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
64  if (other.getInputs() != current.getOutputs())
65  return rewriter.notifyMatchFailure(
66  op, "mismatching values propagation");
67  } else {
68  isLive = true;
69  }
70 
71  // Continue traversing the DAG of unrealized casts
72  if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
73  visitStack.push_back(other);
74  }
75 
76  // If the cast is live, then we need to check if the results of the last
77  // cast have the same type of the root inputs. It this is the case (e.g.
78  // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
79  // no-op and the inputs can be forwarded. If it's not (e.g.
80  // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
81 
82  bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
83 
84  if (isLive && !isCycle)
85  return rewriter.notifyMatchFailure(op,
86  "live unrealized conversion cast");
87 
88  bool isExitNode = users.empty() || isLive;
89 
90  if (isExitNode) {
91  exitNodes.insert(current);
92  } else {
93  intermediateNodes.insert(current);
94  }
95  }
96 
97  // Replace the sink nodes with the root input values
98  for (UnrealizedConversionCastOp exitNode : exitNodes)
99  rewriter.replaceOp(exitNode, op.getInputs());
100 
101  // Erase all the other casts belonging to the DAG
102  for (UnrealizedConversionCastOp castOp : intermediateNodes)
103  rewriter.eraseOp(castOp);
104 
105  return success();
106  }
107 };
108 
109 /// Pass to simplify and eliminate unrealized conversion casts.
110 struct ReconcileUnrealizedCasts
111  : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
112  ReconcileUnrealizedCasts() = default;
113 
114  void runOnOperation() override {
115  RewritePatternSet patterns(&getContext());
117  ConversionTarget target(getContext());
118  target.addIllegalOp<UnrealizedConversionCastOp>();
119  if (failed(applyPartialConversion(getOperation(), target,
120  std::move(patterns))))
121  signalPassFailure();
122  }
123 };
124 
125 } // namespace
126 
128  RewritePatternSet &patterns) {
129  patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
130 }
131 
133  return std::make_unique<ReconcileUnrealizedCasts>();
134 }
Include the generated interface declarations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
std::unique_ptr< Pass > createReconcileUnrealizedCastsPass()
Creates a pass that eliminates noop unrealized_conversion_cast operation sequences.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
This class describes a specific conversion target.
void populateReconcileUnrealizedCastsPatterns(RewritePatternSet &patterns)
Populates patterns with rewrite patterns that eliminate noop unrealized_conversion_cast operation seq...
MLIRContext * getContext() const