MLIR  15.0.0git
Visitors.h
Go to the documentation of this file.
1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24 class Block;
25 class Region;
26 
27 /// A utility result that is used to signal how to proceed with an ongoing walk:
28 /// * Interrupt: the walk will be interrupted and no more operations, regions
29 /// or blocks will be visited.
30 /// * Advance: the walk will continue.
31 /// * Skip: the walk of the current operation, region or block and their
32 /// nested elements that haven't been visited already will be skipped and will
33 /// continue with the next operation, region or block.
34 class WalkResult {
35  enum ResultEnum { Interrupt, Advance, Skip } result;
36 
37 public:
38  WalkResult(ResultEnum result) : result(result) {}
39 
40  /// Allow LogicalResult to interrupt the walk on failure.
42  : result(failed(result) ? Interrupt : Advance) {}
43 
44  /// Allow diagnostics to interrupt the walk.
45  WalkResult(Diagnostic &&) : result(Interrupt) {}
46  WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
47 
48  bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
49 
50  static WalkResult interrupt() { return {Interrupt}; }
51  static WalkResult advance() { return {Advance}; }
52  static WalkResult skip() { return {Skip}; }
53 
54  /// Returns true if the walk was interrupted.
55  bool wasInterrupted() const { return result == Interrupt; }
56 
57  /// Returns true if the walk was skipped.
58  bool wasSkipped() const { return result == Skip; }
59 };
60 
61 /// Traversal order for region, block and operation walk utilities.
62 enum class WalkOrder { PreOrder, PostOrder };
63 
64 /// A utility class to encode the current walk stage for "generic" walkers.
65 /// When walking an operation, we can either choose a Pre/Post order walker
66 /// which invokes the callback on an operation before/after all its attached
67 /// regions have been visited, or choose a "generic" walker where the callback
68 /// is invoked on the operation N+1 times where N is the number of regions
69 /// attached to that operation. The `WalkStage` class below encodes the current
70 /// stage of the walk, i.e., which regions have already been visited, and the
71 /// callback accepts an additional argument for the current stage. Such
72 /// generic walkers that accept stage-aware callbacks are only applicable when
73 /// the callback operates on an operation (i.e., not applicable for callbacks
74 /// on Blocks or Regions).
75 class WalkStage {
76 public:
77  explicit WalkStage(Operation *op);
78 
79  /// Return true if parent operation is being visited before all regions.
80  bool isBeforeAllRegions() const { return nextRegion == 0; }
81  /// Returns true if parent operation is being visited just before visiting
82  /// region number `region`.
83  bool isBeforeRegion(int region) const { return nextRegion == region; }
84  /// Returns true if parent operation is being visited just after visiting
85  /// region number `region`.
86  bool isAfterRegion(int region) const { return nextRegion == region + 1; }
87  /// Return true if parent operation is being visited after all regions.
88  bool isAfterAllRegions() const { return nextRegion == numRegions; }
89  /// Advance the walk stage.
90  void advance() { nextRegion++; }
91  /// Returns the next region that will be visited.
92  int getNextRegion() const { return nextRegion; }
93 
94 private:
95  const int numRegions;
96  int nextRegion;
97 };
98 
99 namespace detail {
100 /// Helper templates to deduce the first argument of a callback parameter.
101 template <typename Ret, typename Arg, typename... Rest>
102 Arg first_argument_type(Ret (*)(Arg, Rest...));
103 template <typename Ret, typename F, typename Arg, typename... Rest>
104 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
105 template <typename Ret, typename F, typename Arg, typename... Rest>
106 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
107 template <typename F>
108 decltype(first_argument_type(&F::operator())) first_argument_type(F);
109 
110 /// Type definition of the first argument to the given callable 'T'.
111 template <typename T>
112 using first_argument = decltype(first_argument_type(std::declval<T>()));
113 
114 /// Walk all of the regions, blocks, or operations nested under (and including)
115 /// the given operation. Regions, blocks and operations at the same nesting
116 /// level are visited in lexicographical order. The walk order for enclosing
117 /// regions, blocks and operations with respect to their nested ones is
118 /// specified by 'order'. These methods are invoked for void-returning
119 /// callbacks. A callback on a block or operation is allowed to erase that block
120 /// or operation only if the walk is in post-order. See non-void method for
121 /// pre-order erasure.
122 void walk(Operation *op, function_ref<void(Region *)> callback,
123  WalkOrder order);
124 void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
125 void walk(Operation *op, function_ref<void(Operation *)> callback,
126  WalkOrder order);
127 /// Walk all of the regions, blocks, or operations nested under (and including)
128 /// the given operation. Regions, blocks and operations at the same nesting
129 /// level are visited in lexicographical order. The walk order for enclosing
130 /// regions, blocks and operations with respect to their nested ones is
131 /// specified by 'order'. This method is invoked for skippable or interruptible
132 /// callbacks. A callback on a block or operation is allowed to erase that block
133 /// or operation if either:
134 /// * the walk is in post-order, or
135 /// * the walk is in pre-order and the walk is skipped after the erasure.
137  WalkOrder order);
139  WalkOrder order);
141  WalkOrder order);
142 
143 // Below are a set of functions to walk nested operations. Users should favor
144 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
145 // methods. They are also templated to allow for statically dispatching based
146 // upon the type of the callback function.
147 
148 /// Walk all of the regions, blocks, or operations nested under (and including)
149 /// the given operation. Regions, blocks and operations at the same nesting
150 /// level are visited in lexicographical order. The walk order for enclosing
151 /// regions, blocks and operations with respect to their nested ones is
152 /// specified by 'Order' (post-order by default). A callback on a block or
153 /// operation is allowed to erase that block or operation if either:
154 /// * the walk is in post-order, or
155 /// * the walk is in pre-order and the walk is skipped after the erasure.
156 /// This method is selected for callbacks that operate on Region*, Block*, and
157 /// Operation*.
158 ///
159 /// Example:
160 /// op->walk([](Region *r) { ... });
161 /// op->walk([](Block *b) { ... });
162 /// op->walk([](Operation *op) { ... });
163 template <
164  WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
165  typename ArgT = detail::first_argument<FuncTy>,
166  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
167 typename std::enable_if<
168  llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, RetT>::type
169 walk(Operation *op, FuncTy &&callback) {
170  return detail::walk(op, function_ref<RetT(ArgT)>(callback), Order);
171 }
172 
173 /// Walk all of the operations of type 'ArgT' nested under and including the
174 /// given operation. Regions, blocks and operations at the same nesting
175 /// level are visited in lexicographical order. The walk order for enclosing
176 /// regions, blocks and operations with respect to their nested ones is
177 /// specified by 'order' (post-order by default). This method is selected for
178 /// void-returning callbacks that operate on a specific derived operation type.
179 /// A callback on an operation is allowed to erase that operation only if the
180 /// walk is in post-order. See non-void method for pre-order erasure.
181 ///
182 /// Example:
183 /// op->walk([](ReturnOp op) { ... });
184 template <
185  WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
186  typename ArgT = detail::first_argument<FuncTy>,
187  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
188 typename std::enable_if<
191  RetT>::type
192 walk(Operation *op, FuncTy &&callback) {
193  auto wrapperFn = [&](Operation *op) {
194  if (auto derivedOp = dyn_cast<ArgT>(op))
195  callback(derivedOp);
196  };
197  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
198 }
199 
200 /// Walk all of the operations of type 'ArgT' nested under and including the
201 /// given operation. Regions, blocks and operations at the same nesting level
202 /// are visited in lexicographical order. The walk order for enclosing regions,
203 /// blocks and operations with respect to their nested ones is specified by
204 /// 'Order' (post-order by default). This method is selected for WalkReturn
205 /// returning skippable or interruptible callbacks that operate on a specific
206 /// derived operation type. A callback on an operation is allowed to erase that
207 /// operation if either:
208 /// * the walk is in post-order, or
209 /// * the walk is in pre-order and the walk is skipped after the erasure.
210 ///
211 /// Example:
212 /// op->walk([](ReturnOp op) {
213 /// if (some_invariant)
214 /// return WalkResult::skip();
215 /// if (another_invariant)
216 /// return WalkResult::interrupt();
217 /// return WalkResult::advance();
218 /// });
219 template <
220  WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
221  typename ArgT = detail::first_argument<FuncTy>,
222  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
223 typename std::enable_if<
226  RetT>::type
227 walk(Operation *op, FuncTy &&callback) {
228  auto wrapperFn = [&](Operation *op) {
229  if (auto derivedOp = dyn_cast<ArgT>(op))
230  return callback(derivedOp);
231  return WalkResult::advance();
232  };
233  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
234 }
235 
236 /// Generic walkers with stage aware callbacks.
237 
238 /// Walk all the operations nested under (and including) the given operation,
239 /// with the callback being invoked on each operation N+1 times, where N is the
240 /// number of regions attached to the operation. The `stage` input to the
241 /// callback indicates the current walk stage. This method is invoked for void
242 /// returning callbacks.
243 void walk(Operation *op,
244  function_ref<void(Operation *, const WalkStage &stage)> callback);
245 
246 /// Walk all the operations nested under (and including) the given operation,
247 /// with the callback being invoked on each operation N+1 times, where N is the
248 /// number of regions attached to the operation. The `stage` input to the
249 /// callback indicates the current walk stage. This method is invoked for
250 /// skippable or interruptible callbacks.
252 walk(Operation *op,
253  function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
254 
255 /// Walk all of the operations nested under and including the given operation.
256 /// This method is selected for stage-aware callbacks that operate on
257 /// Operation*.
258 ///
259 /// Example:
260 /// op->walk([](Operation *op, const WalkStage &stage) { ... });
261 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
262  typename RetT = decltype(std::declval<FuncTy>()(
263  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
265 walk(Operation *op, FuncTy &&callback) {
266  return detail::walk(op,
267  function_ref<RetT(ArgT, const WalkStage &)>(callback));
268 }
269 
270 /// Walk all of the operations of type 'ArgT' nested under and including the
271 /// given operation. This method is selected for void returning callbacks that
272 /// operate on a specific derived operation type.
273 ///
274 /// Example:
275 /// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
276 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
277  typename RetT = decltype(std::declval<FuncTy>()(
278  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
280  std::is_same<RetT, void>::value,
281  RetT>::type
282 walk(Operation *op, FuncTy &&callback) {
283  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
284  if (auto derivedOp = dyn_cast<ArgT>(op))
285  callback(derivedOp, stage);
286  };
287  return detail::walk(
288  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
289 }
290 
291 /// Walk all of the operations of type 'ArgT' nested under and including the
292 /// given operation. This method is selected for WalkReturn returning
293 /// interruptible callbacks that operate on a specific derived operation type.
294 ///
295 /// Example:
296 /// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
297 /// if (some_invariant)
298 /// return WalkResult::interrupt();
299 /// return WalkResult::advance();
300 /// });
301 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
302  typename RetT = decltype(std::declval<FuncTy>()(
303  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
305  std::is_same<RetT, WalkResult>::value,
306  RetT>::type
307 walk(Operation *op, FuncTy &&callback) {
308  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
309  if (auto derivedOp = dyn_cast<ArgT>(op))
310  return callback(derivedOp, stage);
311  return WalkResult::advance();
312  };
313  return detail::walk(
314  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
315 }
316 
317 /// Utility to provide the return type of a templated walk method.
318 template <typename FnT>
319 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
320 } // namespace detail
321 
322 } // namespace mlir
323 
324 #endif
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
WalkResult(InFlightDiagnostic &&)
Definition: Visitors.h:46
Explicitly register a set of "builtin" types.
Definition: CallGraph.h:221
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:311
Block represents an ordered list of Operations.
Definition: Block.h:29
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
Definition: Visitors.h:88
Arg first_argument_type(Ret(F::*)(Arg, Rest...))
static constexpr const bool value
bool isBeforeRegion(int region) const
Returns true if parent operation is being visited just before visiting region number region...
Definition: Visitors.h:83
decltype(walk(nullptr, std::declval< FnT >())) walkResultType
Utility to provide the return type of a templated walk method.
Definition: Visitors.h:319
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
void advance()
Advance the walk stage.
Definition: Visitors.h:90
WalkResult(LogicalResult result)
Allow LogicalResult to interrupt the walk on failure.
Definition: Visitors.h:41
static WalkResult advance()
Definition: Visitors.h:51
int getNextRegion() const
Returns the next region that will be visited.
Definition: Visitors.h:92
A utility class to encode the current walk stage for "generic" walkers.
Definition: Visitors.h:75
static WalkResult interrupt()
Definition: Visitors.h:50
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
decltype(first_argument_type(std::declval< T >())) first_argument
Type definition of the first argument to the given callable &#39;T&#39;.
Definition: Visitors.h:112
static WalkResult skip()
Definition: Visitors.h:52
bool isAfterRegion(int region) const
Returns true if parent operation is being visited just after visiting region number region...
Definition: Visitors.h:86
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
Definition: Visitors.h:80
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
Definition: Visitors.cpp:24
bool wasSkipped() const
Returns true if the walk was skipped.
Definition: Visitors.h:58
bool operator==(const WalkResult &rhs) const
Definition: Visitors.h:48
WalkResult(ResultEnum result)
Definition: Visitors.h:38
std::enable_if<!std::is_same< ArgT, Operation * >::value &&std::is_same< RetT, WalkResult >::value, RetT >::type walk(Operation *op, FuncTy &&callback)
Walk all of the operations of type &#39;ArgT&#39; nested under and including the given operation.
Definition: Visitors.h:307
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition: Visitors.h:62
WalkResult(Diagnostic &&)
Allow diagnostics to interrupt the walk.
Definition: Visitors.h:45