MLIR  20.0.0git
ReconcileUnrealizedCasts.cpp
Go to the documentation of this file.
1 //===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/Pass/Pass.h"
13 
14 namespace mlir {
15 #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
16 #include "mlir/Conversion/Passes.h.inc"
17 } // namespace mlir
18 
19 using namespace mlir;
20 
21 namespace {
22 
23 /// Pass to simplify and eliminate unrealized conversion casts.
24 ///
25 /// This pass processes unrealized_conversion_cast ops in a worklist-driven
26 /// fashion. For each matched cast op, if the chain of input casts eventually
27 /// reaches a cast op where the input types match the output types of the
28 /// matched op, replace the matched op with the inputs.
29 ///
30 /// Example:
31 /// %1 = unrealized_conversion_cast %0 : !A to !B
32 /// %2 = unrealized_conversion_cast %1 : !B to !C
33 /// %3 = unrealized_conversion_cast %2 : !C to !A
34 ///
35 /// In the above example, %0 can be used instead of %3 and all cast ops are
36 /// folded away.
37 struct ReconcileUnrealizedCasts
38  : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
39  ReconcileUnrealizedCasts() = default;
40 
41  void runOnOperation() override {
42  // Gather all unrealized_conversion_cast ops.
44  getOperation()->walk(
45  [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
46 
47  // Helper function that adds all operands to the worklist that are an
48  // unrealized_conversion_cast op result.
49  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
50  for (Value v : castOp.getInputs())
51  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
52  worklist.insert(inputCastOp);
53  };
54 
55  // Helper function that return the unrealized_conversion_cast op that
56  // defines all inputs of the given op (in the same order). Return "nullptr"
57  // if there is no such op.
58  auto getInputCast =
59  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
60  if (castOp.getInputs().empty())
61  return {};
62  auto inputCastOp = castOp.getInputs()
63  .front()
64  .getDefiningOp<UnrealizedConversionCastOp>();
65  if (!inputCastOp)
66  return {};
67  if (inputCastOp.getOutputs() != castOp.getInputs())
68  return {};
69  return inputCastOp;
70  };
71 
72  // Process ops in the worklist bottom-to-top.
73  while (!worklist.empty()) {
74  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
75  if (castOp->use_empty()) {
76  // DCE: If the op has no users, erase it. Add the operands to the
77  // worklist to find additional DCE opportunities.
78  enqueueOperands(castOp);
79  castOp->erase();
80  continue;
81  }
82 
83  // Traverse the chain of input cast ops to see if an op with the same
84  // input types can be found.
85  UnrealizedConversionCastOp nextCast = castOp;
86  while (nextCast) {
87  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
88  // Found a cast where the input types match the output types of the
89  // matched op. We can directly use those inputs and the matched op can
90  // be removed.
91  enqueueOperands(castOp);
92  castOp.replaceAllUsesWith(nextCast.getInputs());
93  castOp->erase();
94  break;
95  }
96  nextCast = getInputCast(nextCast);
97  }
98  }
99  }
100 };
101 
102 } // namespace
103 
105  return std::make_unique<ReconcileUnrealizedCasts>();
106 }
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
std::unique_ptr< Pass > createReconcileUnrealizedCastsPass()
Creates a pass that eliminates noop unrealized_conversion_cast operation sequences.