MLIR  21.0.0git
StateStack.h
Go to the documentation of this file.
1 //===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for storing a stack of generic context.
10 // The context can be arbitrary data, possibly including file-scoped types. Data
11 // must be derived from StateStackFrameBase and implement MLIR TypeID.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_SUPPORT_STACKFRAME_H
16 #define MLIR_SUPPORT_STACKFRAME_H
17 
18 #include "mlir/Support/TypeID.h"
20 #include <memory>
21 
22 namespace mlir {
23 
24 /// Common CRTP base class for StateStack frames.
26 public:
27  virtual ~StateStackFrame() = default;
28  TypeID getTypeID() const { return typeID; }
29 
30 protected:
31  explicit StateStackFrame(TypeID typeID) : typeID(typeID) {}
32 
33 private:
34  const TypeID typeID;
35  virtual void anchor();
36 };
37 
38 /// Concrete CRTP base class for StateStack frames. This is used for keeping a
39 /// stack of common state useful for recursive IR conversions. For example, when
40 /// translating operations with regions, users of StateStack can store state on
41 /// StateStack before entering the region and inspect it when converting
42 /// operations nested within that region. Users are expected to derive this
43 /// class and put any relevant information into fields of the derived class. The
44 /// usual isa/dyn_cast functionality is available for instances of derived
45 /// classes.
46 template <typename Derived>
48 public:
49  explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {}
50 };
51 
52 class StateStack {
53 public:
54  /// Creates a stack frame of type `T` on StateStack. `T` must
55  /// be derived from `StackFrameBase<T>` and constructible from the provided
56  /// arguments. Doing this before entering the region of the op being
57  /// translated makes the frame available when translating ops within that
58  /// region.
59  template <typename T, typename... Args>
60  void stackPush(Args &&...args) {
61  static_assert(std::is_base_of<StateStackFrame, T>::value,
62  "can only push instances of StackFrame on StateStack");
63  stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
64  }
65 
66  /// Pops the last element from the StateStack.
67  void stackPop() { stack.pop_back(); }
68 
69  /// Calls `callback` for every StateStack frame of type `T`
70  /// starting from the top of the stack.
71  template <typename T>
73  static_assert(std::is_base_of<StateStackFrame, T>::value,
74  "expected T derived from StackFrame");
75  if (!callback)
76  return WalkResult::skip();
77  for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) {
78  if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
79  WalkResult result = callback(*ptr);
80  if (result.wasInterrupted())
81  return result;
82  }
83  }
84  return WalkResult::advance();
85  }
86 
87  /// Get the top instance of frame type `T` or nullptr if none are found
88  template <typename T>
89  T *getStackTop() {
90  T *top = nullptr;
91  stackWalk<T>([&](T &frame) -> mlir::WalkResult {
92  top = &frame;
94  });
95  return top;
96  }
97 
98 private:
100 };
101 
102 /// RAII object calling stackPush/stackPop on construction/destruction.
103 /// HostClass could be a StateStack or some other class which forwards calls to
104 /// one.
105 template <typename T, typename HostClass = StateStack>
107  template <typename... Args>
108  explicit SaveStateStack(HostClass &host, Args &&...args) : host(host) {
109  host.template stackPush<T>(std::forward<Args>(args)...);
110  }
111  ~SaveStateStack() { host.stackPop(); }
112 
113 private:
114  HostClass &host;
115 };
116 
117 } // namespace mlir
118 
119 namespace llvm {
120 template <typename T>
121 struct isa_impl<T, ::mlir::StateStackFrame> {
122  static inline bool doit(const ::mlir::StateStackFrame &frame) {
123  return frame.getTypeID() == ::mlir::TypeID::get<T>();
124  }
125 };
126 } // namespace llvm
127 
128 #endif // MLIR_SUPPORT_STACKFRAME_H
Concrete CRTP base class for StateStack frames.
Definition: StateStack.h:47
Common CRTP base class for StateStack frames.
Definition: StateStack.h:25
TypeID getTypeID() const
Definition: StateStack.h:28
virtual ~StateStackFrame()=default
StateStackFrame(TypeID typeID)
Definition: StateStack.h:31
T * getStackTop()
Get the top instance of frame type T or nullptr if none are found.
Definition: StateStack.h:89
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every StateStack frame of type T starting from the top of the stack.
Definition: StateStack.h:72
void stackPop()
Pops the last element from the StateStack.
Definition: StateStack.h:67
void stackPush(Args &&...args)
Creates a stack frame of type T on StateStack.
Definition: StateStack.h:60
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static bool doit(const ::mlir::StateStackFrame &frame)
Definition: StateStack.h:122
RAII object calling stackPush/stackPop on construction/destruction.
Definition: StateStack.h:106
SaveStateStack(HostClass &host, Args &&...args)
Definition: StateStack.h:108