MLIR 22.0.0git
StateStack.h
Go to the documentation of this file.
1//===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines utilities for storing a stack of generic context.
10// The context can be arbitrary data, possibly including file-scoped types. Data
11// must be derived from StateStackFrameBase and implement MLIR TypeID.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_SUPPORT_STACKFRAME_H
16#define MLIR_SUPPORT_STACKFRAME_H
17
18#include "mlir/Support/TypeID.h"
20#include <memory>
21
22namespace mlir {
23
24/// Common CRTP base class for StateStack frames.
26public:
27 virtual ~StateStackFrame() = default;
28 TypeID getTypeID() const { return typeID; }
29
30protected:
31 explicit StateStackFrame(TypeID typeID) : typeID(typeID) {}
32
33private:
34 const TypeID typeID;
35 virtual void anchor();
36};
37
38/// Concrete CRTP base class for StateStack frames. This is used for keeping a
39/// stack of common state useful for recursive IR conversions. For example, when
40/// translating operations with regions, users of StateStack can store state on
41/// StateStack before entering the region and inspect it when converting
42/// operations nested within that region. Users are expected to derive this
43/// class and put any relevant information into fields of the derived class. The
44/// usual isa/dyn_cast functionality is available for instances of derived
45/// classes.
46template <typename Derived>
48public:
49 explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {}
50};
51
53public:
54 /// Creates a stack frame of type `T` on StateStack. `T` must
55 /// be derived from `StackFrameBase<T>` and constructible from the provided
56 /// arguments. Doing this before entering the region of the op being
57 /// translated makes the frame available when translating ops within that
58 /// region.
59 template <typename T, typename... Args>
60 void stackPush(Args &&...args) {
61 static_assert(std::is_base_of<StateStackFrame, T>::value,
62 "can only push instances of StackFrame on StateStack");
63 stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
64 }
65
66 /// Pops the last element from the StateStack.
67 void stackPop() { stack.pop_back(); }
68
69 /// Calls `callback` for every StateStack frame of type `T`
70 /// starting from the top of the stack.
71 template <typename T>
73 static_assert(std::is_base_of<StateStackFrame, T>::value,
74 "expected T derived from StackFrame");
75 if (!callback)
76 return WalkResult::skip();
77 for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) {
78 if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
79 WalkResult result = callback(*ptr);
80 if (result.wasInterrupted())
81 return result;
82 }
83 }
84 return WalkResult::advance();
85 }
86
87 /// Get the top instance of frame type `T` or nullptr if none are found
88 template <typename T>
90 T *top = nullptr;
91 stackWalk<T>([&](T &frame) -> mlir::WalkResult {
92 top = &frame;
94 });
95 return top;
96 }
97
98private:
100};
101
102/// RAII object calling stackPush/stackPop on construction/destruction.
103/// HostClass could be a StateStack or some other class which forwards calls to
104/// one.
105template <typename T, typename HostClass = StateStack>
107 template <typename... Args>
108 explicit SaveStateStack(HostClass &host, Args &&...args) : host(host) {
109 host.template stackPush<T>(std::forward<Args>(args)...);
110 }
111 ~SaveStateStack() { host.stackPop(); }
112
113private:
114 HostClass &host;
115};
116
117} // namespace mlir
118
119namespace llvm {
120template <typename T>
121struct isa_impl<T, ::mlir::StateStackFrame> {
122 static inline bool doit(const ::mlir::StateStackFrame &frame) {
123 return frame.getTypeID() == ::mlir::TypeID::get<T>();
124 }
125};
126} // namespace llvm
127
128#endif // MLIR_SUPPORT_STACKFRAME_H
Common CRTP base class for StateStack frames.
Definition StateStack.h:25
TypeID getTypeID() const
Definition StateStack.h:28
virtual ~StateStackFrame()=default
StateStackFrame(TypeID typeID)
Definition StateStack.h:31
T * getStackTop()
Get the top instance of frame type T or nullptr if none are found.
Definition StateStack.h:89
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every StateStack frame of type T starting from the top of the stack.
Definition StateStack.h:72
void stackPop()
Pops the last element from the StateStack.
Definition StateStack.h:67
void stackPush(Args &&...args)
Creates a stack frame of type T on StateStack.
Definition StateStack.h:60
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static bool doit(const ::mlir::StateStackFrame &frame)
Definition StateStack.h:122
SaveStateStack(HostClass &host, Args &&...args)
Definition StateStack.h:108