MLIR  17.0.0git
Visitors.h
Go to the documentation of this file.
1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24 class Block;
25 class Region;
26 
27 /// A utility result that is used to signal how to proceed with an ongoing walk:
28 /// * Interrupt: the walk will be interrupted and no more operations, regions
29 /// or blocks will be visited.
30 /// * Advance: the walk will continue.
31 /// * Skip: the walk of the current operation, region or block and their
32 /// nested elements that haven't been visited already will be skipped and will
33 /// continue with the next operation, region or block.
34 class WalkResult {
35  enum ResultEnum { Interrupt, Advance, Skip } result;
36 
37 public:
38  WalkResult(ResultEnum result) : result(result) {}
39 
40  /// Allow LogicalResult to interrupt the walk on failure.
42  : result(failed(result) ? Interrupt : Advance) {}
43 
44  /// Allow diagnostics to interrupt the walk.
45  WalkResult(Diagnostic &&) : result(Interrupt) {}
46  WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
47 
48  bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
49  bool operator!=(const WalkResult &rhs) const { return result != rhs.result; }
50 
51  static WalkResult interrupt() { return {Interrupt}; }
52  static WalkResult advance() { return {Advance}; }
53  static WalkResult skip() { return {Skip}; }
54 
55  /// Returns true if the walk was interrupted.
56  bool wasInterrupted() const { return result == Interrupt; }
57 
58  /// Returns true if the walk was skipped.
59  bool wasSkipped() const { return result == Skip; }
60 };
61 
62 /// Traversal order for region, block and operation walk utilities.
63 enum class WalkOrder { PreOrder, PostOrder };
64 
65 /// A utility class to encode the current walk stage for "generic" walkers.
66 /// When walking an operation, we can either choose a Pre/Post order walker
67 /// which invokes the callback on an operation before/after all its attached
68 /// regions have been visited, or choose a "generic" walker where the callback
69 /// is invoked on the operation N+1 times where N is the number of regions
70 /// attached to that operation. The `WalkStage` class below encodes the current
71 /// stage of the walk, i.e., which regions have already been visited, and the
72 /// callback accepts an additional argument for the current stage. Such
73 /// generic walkers that accept stage-aware callbacks are only applicable when
74 /// the callback operates on an operation (i.e., not applicable for callbacks
75 /// on Blocks or Regions).
76 class WalkStage {
77 public:
78  explicit WalkStage(Operation *op);
79 
80  /// Return true if parent operation is being visited before all regions.
81  bool isBeforeAllRegions() const { return nextRegion == 0; }
82  /// Returns true if parent operation is being visited just before visiting
83  /// region number `region`.
84  bool isBeforeRegion(int region) const { return nextRegion == region; }
85  /// Returns true if parent operation is being visited just after visiting
86  /// region number `region`.
87  bool isAfterRegion(int region) const { return nextRegion == region + 1; }
88  /// Return true if parent operation is being visited after all regions.
89  bool isAfterAllRegions() const { return nextRegion == numRegions; }
90  /// Advance the walk stage.
91  void advance() { nextRegion++; }
92  /// Returns the next region that will be visited.
93  int getNextRegion() const { return nextRegion; }
94 
95 private:
96  const int numRegions;
97  int nextRegion;
98 };
99 
100 namespace detail {
101 /// Helper templates to deduce the first argument of a callback parameter.
102 template <typename Ret, typename Arg, typename... Rest>
103 Arg first_argument_type(Ret (*)(Arg, Rest...));
104 template <typename Ret, typename F, typename Arg, typename... Rest>
105 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
106 template <typename Ret, typename F, typename Arg, typename... Rest>
107 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
108 template <typename F>
109 decltype(first_argument_type(&F::operator())) first_argument_type(F);
110 
111 /// Type definition of the first argument to the given callable 'T'.
112 template <typename T>
113 using first_argument = decltype(first_argument_type(std::declval<T>()));
114 
115 /// Walk all of the regions, blocks, or operations nested under (and including)
116 /// the given operation. Regions, blocks and operations at the same nesting
117 /// level are visited in lexicographical order. The walk order for enclosing
118 /// regions, blocks and operations with respect to their nested ones is
119 /// specified by 'order'. These methods are invoked for void-returning
120 /// callbacks. A callback on a block or operation is allowed to erase that block
121 /// or operation only if the walk is in post-order. See non-void method for
122 /// pre-order erasure.
123 void walk(Operation *op, function_ref<void(Region *)> callback,
124  WalkOrder order);
125 void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
126 void walk(Operation *op, function_ref<void(Operation *)> callback,
127  WalkOrder order);
128 /// Walk all of the regions, blocks, or operations nested under (and including)
129 /// the given operation. Regions, blocks and operations at the same nesting
130 /// level are visited in lexicographical order. The walk order for enclosing
131 /// regions, blocks and operations with respect to their nested ones is
132 /// specified by 'order'. This method is invoked for skippable or interruptible
133 /// callbacks. A callback on a block or operation is allowed to erase that block
134 /// or operation if either:
135 /// * the walk is in post-order, or
136 /// * the walk is in pre-order and the walk is skipped after the erasure.
138  WalkOrder order);
140  WalkOrder order);
142  WalkOrder order);
143 
144 // Below are a set of functions to walk nested operations. Users should favor
145 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
146 // methods. They are also templated to allow for statically dispatching based
147 // upon the type of the callback function.
148 
149 /// Walk all of the regions, blocks, or operations nested under (and including)
150 /// the given operation. Regions, blocks and operations at the same nesting
151 /// level are visited in lexicographical order. The walk order for enclosing
152 /// regions, blocks and operations with respect to their nested ones is
153 /// specified by 'Order' (post-order by default). A callback on a block or
154 /// operation is allowed to erase that block or operation if either:
155 /// * the walk is in post-order, or
156 /// * the walk is in pre-order and the walk is skipped after the erasure.
157 /// This method is selected for callbacks that operate on Region*, Block*, and
158 /// Operation*.
159 ///
160 /// Example:
161 /// op->walk([](Region *r) { ... });
162 /// op->walk([](Block *b) { ... });
163 /// op->walk([](Operation *op) { ... });
164 template <
165  WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
166  typename ArgT = detail::first_argument<FuncTy>,
167  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
168 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
169  RetT>
170 walk(Operation *op, FuncTy &&callback) {
171  return detail::walk(op, function_ref<RetT(ArgT)>(callback), Order);
172 }
173 
174 /// Walk all of the operations of type 'ArgT' nested under and including the
175 /// given operation. Regions, blocks and operations at the same nesting
176 /// level are visited in lexicographical order. The walk order for enclosing
177 /// regions, blocks and operations with respect to their nested ones is
178 /// specified by 'order' (post-order by default). This method is selected for
179 /// void-returning callbacks that operate on a specific derived operation type.
180 /// A callback on an operation is allowed to erase that operation only if the
181 /// walk is in post-order. See non-void method for pre-order erasure.
182 ///
183 /// Example:
184 /// op->walk([](ReturnOp op) { ... });
185 template <
186  WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
187  typename ArgT = detail::first_argument<FuncTy>,
188  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
189 std::enable_if_t<
190  !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
191  std::is_same<RetT, void>::value,
192  RetT>
193 walk(Operation *op, FuncTy &&callback) {
194  auto wrapperFn = [&](Operation *op) {
195  if (auto derivedOp = dyn_cast<ArgT>(op))
196  callback(derivedOp);
197  };
198  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
199 }
200 
201 /// Walk all of the operations of type 'ArgT' nested under and including the
202 /// given operation. Regions, blocks and operations at the same nesting level
203 /// are visited in lexicographical order. The walk order for enclosing regions,
204 /// blocks and operations with respect to their nested ones is specified by
205 /// 'Order' (post-order by default). This method is selected for WalkReturn
206 /// returning skippable or interruptible callbacks that operate on a specific
207 /// derived operation type. A callback on an operation is allowed to erase that
208 /// operation if either:
209 /// * the walk is in post-order, or
210 /// * the walk is in pre-order and the walk is skipped after the erasure.
211 ///
212 /// Example:
213 /// op->walk([](ReturnOp op) {
214 /// if (some_invariant)
215 /// return WalkResult::skip();
216 /// if (another_invariant)
217 /// return WalkResult::interrupt();
218 /// return WalkResult::advance();
219 /// });
220 template <
221  WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
222  typename ArgT = detail::first_argument<FuncTy>,
223  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
224 std::enable_if_t<
225  !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
226  std::is_same<RetT, WalkResult>::value,
227  RetT>
228 walk(Operation *op, FuncTy &&callback) {
229  auto wrapperFn = [&](Operation *op) {
230  if (auto derivedOp = dyn_cast<ArgT>(op))
231  return callback(derivedOp);
232  return WalkResult::advance();
233  };
234  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
235 }
236 
237 /// Generic walkers with stage aware callbacks.
238 
239 /// Walk all the operations nested under (and including) the given operation,
240 /// with the callback being invoked on each operation N+1 times, where N is the
241 /// number of regions attached to the operation. The `stage` input to the
242 /// callback indicates the current walk stage. This method is invoked for void
243 /// returning callbacks.
244 void walk(Operation *op,
245  function_ref<void(Operation *, const WalkStage &stage)> callback);
246 
247 /// Walk all the operations nested under (and including) the given operation,
248 /// with the callback being invoked on each operation N+1 times, where N is the
249 /// number of regions attached to the operation. The `stage` input to the
250 /// callback indicates the current walk stage. This method is invoked for
251 /// skippable or interruptible callbacks.
254  function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
255 
256 /// Walk all of the operations nested under and including the given operation.
257 /// This method is selected for stage-aware callbacks that operate on
258 /// Operation*.
259 ///
260 /// Example:
261 /// op->walk([](Operation *op, const WalkStage &stage) { ... });
262 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
263  typename RetT = decltype(std::declval<FuncTy>()(
264  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
265 std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>
266 walk(Operation *op, FuncTy &&callback) {
267  return detail::walk(op,
268  function_ref<RetT(ArgT, const WalkStage &)>(callback));
269 }
270 
271 /// Walk all of the operations of type 'ArgT' nested under and including the
272 /// given operation. This method is selected for void returning callbacks that
273 /// operate on a specific derived operation type.
274 ///
275 /// Example:
276 /// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
277 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
278  typename RetT = decltype(std::declval<FuncTy>()(
279  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
280 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
281  std::is_same<RetT, void>::value,
282  RetT>
283 walk(Operation *op, FuncTy &&callback) {
284  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
285  if (auto derivedOp = dyn_cast<ArgT>(op))
286  callback(derivedOp, stage);
287  };
288  return detail::walk(
289  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
290 }
291 
292 /// Walk all of the operations of type 'ArgT' nested under and including the
293 /// given operation. This method is selected for WalkReturn returning
294 /// interruptible callbacks that operate on a specific derived operation type.
295 ///
296 /// Example:
297 /// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
298 /// if (some_invariant)
299 /// return WalkResult::interrupt();
300 /// return WalkResult::advance();
301 /// });
302 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
303  typename RetT = decltype(std::declval<FuncTy>()(
304  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
305 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
306  std::is_same<RetT, WalkResult>::value,
307  RetT>
308 walk(Operation *op, FuncTy &&callback) {
309  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
310  if (auto derivedOp = dyn_cast<ArgT>(op))
311  return callback(derivedOp, stage);
312  return WalkResult::advance();
313  };
314  return detail::walk(
315  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
316 }
317 
318 /// Utility to provide the return type of a templated walk method.
319 template <typename FnT>
320 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
321 } // namespace detail
322 
323 } // namespace mlir
324 
325 #endif
Block represents an ordered list of Operations.
Definition: Block.h:30
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
WalkResult(InFlightDiagnostic &&)
Definition: Visitors.h:46
bool operator==(const WalkResult &rhs) const
Definition: Visitors.h:48
WalkResult(LogicalResult result)
Allow LogicalResult to interrupt the walk on failure.
Definition: Visitors.h:41
static WalkResult skip()
Definition: Visitors.h:53
bool wasSkipped() const
Returns true if the walk was skipped.
Definition: Visitors.h:59
WalkResult(ResultEnum result)
Definition: Visitors.h:38
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
WalkResult(Diagnostic &&)
Allow diagnostics to interrupt the walk.
Definition: Visitors.h:45
static WalkResult interrupt()
Definition: Visitors.h:51
bool operator!=(const WalkResult &rhs) const
Definition: Visitors.h:49
A utility class to encode the current walk stage for "generic" walkers.
Definition: Visitors.h:76
void advance()
Advance the walk stage.
Definition: Visitors.h:91
int getNextRegion() const
Returns the next region that will be visited.
Definition: Visitors.h:93
bool isBeforeRegion(int region) const
Returns true if parent operation is being visited just before visiting region number region.
Definition: Visitors.h:84
WalkStage(Operation *op)
Definition: Visitors.cpp:14
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
Definition: Visitors.h:89
bool isAfterRegion(int region) const
Returns true if parent operation is being visited just after visiting region number region.
Definition: Visitors.h:87
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
Definition: Visitors.h:81
Include the generated interface declarations.
Definition: CallGraph.h:229
decltype(walk(nullptr, std::declval< FnT >())) walkResultType
Utility to provide the return type of a templated walk method.
Definition: Visitors.h:320
decltype(first_argument_type(&F::operator())) first_argument_type(F)
Definition: Visitors.h:109
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.cpp:24
decltype(first_argument_type(std::declval< T >())) first_argument
Type definition of the first argument to the given callable 'T'.
Definition: Visitors.h:113
Include the generated interface declarations.
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition: Visitors.h:63
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26