MLIR  19.0.0git
Visitors.h
Go to the documentation of this file.
1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24 class Block;
25 class Region;
26 
27 /// A utility result that is used to signal how to proceed with an ongoing walk:
28 /// * Interrupt: the walk will be interrupted and no more operations, regions
29 /// or blocks will be visited.
30 /// * Advance: the walk will continue.
31 /// * Skip: the walk of the current operation, region or block and their
32 /// nested elements that haven't been visited already will be skipped and will
33 /// continue with the next operation, region or block.
34 class WalkResult {
35  enum ResultEnum { Interrupt, Advance, Skip } result;
36 
37 public:
38  WalkResult(ResultEnum result = Advance) : result(result) {}
39 
40  /// Allow LogicalResult to interrupt the walk on failure.
42  : result(failed(result) ? Interrupt : Advance) {}
43 
44  /// Allow diagnostics to interrupt the walk.
45  WalkResult(Diagnostic &&) : result(Interrupt) {}
46  WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
47 
48  bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
49  bool operator!=(const WalkResult &rhs) const { return result != rhs.result; }
50 
51  static WalkResult interrupt() { return {Interrupt}; }
52  static WalkResult advance() { return {Advance}; }
53  static WalkResult skip() { return {Skip}; }
54 
55  /// Returns true if the walk was interrupted.
56  bool wasInterrupted() const { return result == Interrupt; }
57 
58  /// Returns true if the walk was skipped.
59  bool wasSkipped() const { return result == Skip; }
60 };
61 
62 /// Traversal order for region, block and operation walk utilities.
63 enum class WalkOrder { PreOrder, PostOrder };
64 
65 /// This iterator enumerates the elements in "forward" order.
67  /// Make operations iterable: return the list of regions.
69 
70  /// Regions and block are already iterable.
71  template <typename T>
72  static constexpr T &makeIterable(T &range) {
73  return range;
74  }
75 };
76 
77 /// A utility class to encode the current walk stage for "generic" walkers.
78 /// When walking an operation, we can either choose a Pre/Post order walker
79 /// which invokes the callback on an operation before/after all its attached
80 /// regions have been visited, or choose a "generic" walker where the callback
81 /// is invoked on the operation N+1 times where N is the number of regions
82 /// attached to that operation. The `WalkStage` class below encodes the current
83 /// stage of the walk, i.e., which regions have already been visited, and the
84 /// callback accepts an additional argument for the current stage. Such
85 /// generic walkers that accept stage-aware callbacks are only applicable when
86 /// the callback operates on an operation (i.e., not applicable for callbacks
87 /// on Blocks or Regions).
88 class WalkStage {
89 public:
90  explicit WalkStage(Operation *op);
91 
92  /// Return true if parent operation is being visited before all regions.
93  bool isBeforeAllRegions() const { return nextRegion == 0; }
94  /// Returns true if parent operation is being visited just before visiting
95  /// region number `region`.
96  bool isBeforeRegion(int region) const { return nextRegion == region; }
97  /// Returns true if parent operation is being visited just after visiting
98  /// region number `region`.
99  bool isAfterRegion(int region) const { return nextRegion == region + 1; }
100  /// Return true if parent operation is being visited after all regions.
101  bool isAfterAllRegions() const { return nextRegion == numRegions; }
102  /// Advance the walk stage.
103  void advance() { nextRegion++; }
104  /// Returns the next region that will be visited.
105  int getNextRegion() const { return nextRegion; }
106 
107 private:
108  const int numRegions;
109  int nextRegion;
110 };
111 
112 namespace detail {
113 /// Helper templates to deduce the first argument of a callback parameter.
114 template <typename Ret, typename Arg, typename... Rest>
115 Arg first_argument_type(Ret (*)(Arg, Rest...));
116 template <typename Ret, typename F, typename Arg, typename... Rest>
117 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
118 template <typename Ret, typename F, typename Arg, typename... Rest>
119 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
120 template <typename F>
121 decltype(first_argument_type(&F::operator())) first_argument_type(F);
122 
123 /// Type definition of the first argument to the given callable 'T'.
124 template <typename T>
125 using first_argument = decltype(first_argument_type(std::declval<T>()));
126 
127 /// Walk all of the regions, blocks, or operations nested under (and including)
128 /// the given operation. The order in which regions, blocks and operations at
129 /// the same nesting level are visited (e.g., lexicographical or reverse
130 /// lexicographical order) is determined by 'Iterator'. The walk order for
131 /// enclosing regions, blocks and operations with respect to their nested ones
132 /// is specified by 'order'. These methods are invoked for void-returning
133 /// callbacks. A callback on a block or operation is allowed to erase that block
134 /// or operation only if the walk is in post-order. See non-void method for
135 /// pre-order erasure.
136 template <typename Iterator>
137 void walk(Operation *op, function_ref<void(Region *)> callback,
138  WalkOrder order) {
139  // We don't use early increment for regions because they can't be erased from
140  // a callback.
141  for (auto &region : Iterator::makeIterable(*op)) {
142  if (order == WalkOrder::PreOrder)
143  callback(&region);
144  for (auto &block : Iterator::makeIterable(region)) {
145  for (auto &nestedOp : Iterator::makeIterable(block))
146  walk<Iterator>(&nestedOp, callback, order);
147  }
148  if (order == WalkOrder::PostOrder)
149  callback(&region);
150  }
151 }
152 
153 template <typename Iterator>
154 void walk(Operation *op, function_ref<void(Block *)> callback,
155  WalkOrder order) {
156  for (auto &region : Iterator::makeIterable(*op)) {
157  // Early increment here in the case where the block is erased.
158  for (auto &block :
159  llvm::make_early_inc_range(Iterator::makeIterable(region))) {
160  if (order == WalkOrder::PreOrder)
161  callback(&block);
162  for (auto &nestedOp : Iterator::makeIterable(block))
163  walk<Iterator>(&nestedOp, callback, order);
164  if (order == WalkOrder::PostOrder)
165  callback(&block);
166  }
167  }
168 }
169 
170 template <typename Iterator>
171 void walk(Operation *op, function_ref<void(Operation *)> callback,
172  WalkOrder order) {
173  if (order == WalkOrder::PreOrder)
174  callback(op);
175 
176  // TODO: This walk should be iterative over the operations.
177  for (auto &region : Iterator::makeIterable(*op)) {
178  for (auto &block : Iterator::makeIterable(region)) {
179  // Early increment here in the case where the operation is erased.
180  for (auto &nestedOp :
181  llvm::make_early_inc_range(Iterator::makeIterable(block)))
182  walk<Iterator>(&nestedOp, callback, order);
183  }
184  }
185 
186  if (order == WalkOrder::PostOrder)
187  callback(op);
188 }
189 
190 /// Walk all of the regions, blocks, or operations nested under (and including)
191 /// the given operation. The order in which regions, blocks and operations at
192 /// the same nesting level are visited (e.g., lexicographical or reverse
193 /// lexicographical order) is determined by 'Iterator'. The walk order for
194 /// enclosing regions, blocks and operations with respect to their nested ones
195 /// is specified by 'order'. This method is invoked for skippable or
196 /// interruptible callbacks. A callback on a block or operation is allowed to
197 /// erase that block or operation if either:
198 /// * the walk is in post-order, or
199 /// * the walk is in pre-order and the walk is skipped after the erasure.
200 template <typename Iterator>
202  WalkOrder order) {
203  // We don't use early increment for regions because they can't be erased from
204  // a callback.
205  for (auto &region : Iterator::makeIterable(*op)) {
206  if (order == WalkOrder::PreOrder) {
207  WalkResult result = callback(&region);
208  if (result.wasSkipped())
209  continue;
210  if (result.wasInterrupted())
211  return WalkResult::interrupt();
212  }
213  for (auto &block : Iterator::makeIterable(region)) {
214  for (auto &nestedOp : Iterator::makeIterable(block))
215  if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
216  return WalkResult::interrupt();
217  }
218  if (order == WalkOrder::PostOrder) {
219  if (callback(&region).wasInterrupted())
220  return WalkResult::interrupt();
221  // We don't check if this region was skipped because its walk already
222  // finished and the walk will continue with the next region.
223  }
224  }
225  return WalkResult::advance();
226 }
227 
228 template <typename Iterator>
230  WalkOrder order) {
231  for (auto &region : Iterator::makeIterable(*op)) {
232  // Early increment here in the case where the block is erased.
233  for (auto &block :
234  llvm::make_early_inc_range(Iterator::makeIterable(region))) {
235  if (order == WalkOrder::PreOrder) {
236  WalkResult result = callback(&block);
237  if (result.wasSkipped())
238  continue;
239  if (result.wasInterrupted())
240  return WalkResult::interrupt();
241  }
242  for (auto &nestedOp : Iterator::makeIterable(block))
243  if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
244  return WalkResult::interrupt();
245  if (order == WalkOrder::PostOrder) {
246  if (callback(&block).wasInterrupted())
247  return WalkResult::interrupt();
248  // We don't check if this block was skipped because its walk already
249  // finished and the walk will continue with the next block.
250  }
251  }
252  }
253  return WalkResult::advance();
254 }
255 
256 template <typename Iterator>
258  WalkOrder order) {
259  if (order == WalkOrder::PreOrder) {
260  WalkResult result = callback(op);
261  // If skipped, caller will continue the walk on the next operation.
262  if (result.wasSkipped())
263  return WalkResult::advance();
264  if (result.wasInterrupted())
265  return WalkResult::interrupt();
266  }
267 
268  // TODO: This walk should be iterative over the operations.
269  for (auto &region : Iterator::makeIterable(*op)) {
270  for (auto &block : Iterator::makeIterable(region)) {
271  // Early increment here in the case where the operation is erased.
272  for (auto &nestedOp :
273  llvm::make_early_inc_range(Iterator::makeIterable(block))) {
274  if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
275  return WalkResult::interrupt();
276  }
277  }
278  }
279 
280  if (order == WalkOrder::PostOrder)
281  return callback(op);
282  return WalkResult::advance();
283 }
284 
285 // Below are a set of functions to walk nested operations. Users should favor
286 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
287 // methods. They are also templated to allow for statically dispatching based
288 // upon the type of the callback function.
289 
290 /// Walk all of the regions, blocks, or operations nested under (and including)
291 /// the given operation. The order in which regions, blocks and operations at
292 /// the same nesting level are visited (e.g., lexicographical or reverse
293 /// lexicographical order) is determined by 'Iterator'. The walk order for
294 /// enclosing regions, blocks and operations with respect to their nested ones
295 /// is specified by 'Order' (post-order by default). A callback on a block or
296 /// operation is allowed to erase that block or operation if either:
297 /// * the walk is in post-order, or
298 /// * the walk is in pre-order and the walk is skipped after the erasure.
299 /// This method is selected for callbacks that operate on Region*, Block*, and
300 /// Operation*.
301 ///
302 /// Example:
303 /// op->walk([](Region *r) { ... });
304 /// op->walk([](Block *b) { ... });
305 /// op->walk([](Operation *op) { ... });
306 template <
307  WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
308  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
309  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
310 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
311  RetT>
312 walk(Operation *op, FuncTy &&callback) {
313  return detail::walk<Iterator>(op, function_ref<RetT(ArgT)>(callback), Order);
314 }
315 
316 /// Walk all of the operations of type 'ArgT' nested under and including the
317 /// given operation. The order in which regions, blocks and operations at
318 /// the same nesting are visited (e.g., lexicographical or reverse
319 /// lexicographical order) is determined by 'Iterator'. The walk order for
320 /// enclosing regions, blocks and operations with respect to their nested ones
321 /// is specified by 'order' (post-order by default). This method is selected for
322 /// void-returning callbacks that operate on a specific derived operation type.
323 /// A callback on an operation is allowed to erase that operation only if the
324 /// walk is in post-order. See non-void method for pre-order erasure.
325 ///
326 /// Example:
327 /// op->walk([](ReturnOp op) { ... });
328 template <
329  WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
330  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
331  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
332 std::enable_if_t<
333  !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
334  std::is_same<RetT, void>::value,
335  RetT>
336 walk(Operation *op, FuncTy &&callback) {
337  auto wrapperFn = [&](Operation *op) {
338  if (auto derivedOp = dyn_cast<ArgT>(op))
339  callback(derivedOp);
340  };
341  return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
342  Order);
343 }
344 
345 /// Walk all of the operations of type 'ArgT' nested under and including the
346 /// given operation. The order in which regions, blocks and operations at
347 /// the same nesting are visited (e.g., lexicographical or reverse
348 /// lexicographical order) is determined by 'Iterator'. The walk order for
349 /// enclosing regions, blocks and operations with respect to their nested ones
350 /// is specified by 'Order' (post-order by default). This method is selected for
351 /// WalkReturn returning skippable or interruptible callbacks that operate on a
352 /// specific derived operation type. A callback on an operation is allowed to
353 /// erase that operation if either:
354 /// * the walk is in post-order, or
355 /// * the walk is in pre-order and the walk is skipped after the erasure.
356 ///
357 /// Example:
358 /// op->walk([](ReturnOp op) {
359 /// if (some_invariant)
360 /// return WalkResult::skip();
361 /// if (another_invariant)
362 /// return WalkResult::interrupt();
363 /// return WalkResult::advance();
364 /// });
365 template <
366  WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
367  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
368  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
369 std::enable_if_t<
370  !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
371  std::is_same<RetT, WalkResult>::value,
372  RetT>
373 walk(Operation *op, FuncTy &&callback) {
374  auto wrapperFn = [&](Operation *op) {
375  if (auto derivedOp = dyn_cast<ArgT>(op))
376  return callback(derivedOp);
377  return WalkResult::advance();
378  };
379  return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
380  Order);
381 }
382 
383 /// Generic walkers with stage aware callbacks.
384 
385 /// Walk all the operations nested under (and including) the given operation,
386 /// with the callback being invoked on each operation N+1 times, where N is the
387 /// number of regions attached to the operation. The `stage` input to the
388 /// callback indicates the current walk stage. This method is invoked for void
389 /// returning callbacks.
390 void walk(Operation *op,
391  function_ref<void(Operation *, const WalkStage &stage)> callback);
392 
393 /// Walk all the operations nested under (and including) the given operation,
394 /// with the callback being invoked on each operation N+1 times, where N is the
395 /// number of regions attached to the operation. The `stage` input to the
396 /// callback indicates the current walk stage. This method is invoked for
397 /// skippable or interruptible callbacks.
400  function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
401 
402 /// Walk all of the operations nested under and including the given operation.
403 /// This method is selected for stage-aware callbacks that operate on
404 /// Operation*.
405 ///
406 /// Example:
407 /// op->walk([](Operation *op, const WalkStage &stage) { ... });
408 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
409  typename RetT = decltype(std::declval<FuncTy>()(
410  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
411 std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>
412 walk(Operation *op, FuncTy &&callback) {
413  return detail::walk(op,
414  function_ref<RetT(ArgT, const WalkStage &)>(callback));
415 }
416 
417 /// Walk all of the operations of type 'ArgT' nested under and including the
418 /// given operation. This method is selected for void returning callbacks that
419 /// operate on a specific derived operation type.
420 ///
421 /// Example:
422 /// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
423 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
424  typename RetT = decltype(std::declval<FuncTy>()(
425  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
426 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
427  std::is_same<RetT, void>::value,
428  RetT>
429 walk(Operation *op, FuncTy &&callback) {
430  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
431  if (auto derivedOp = dyn_cast<ArgT>(op))
432  callback(derivedOp, stage);
433  };
434  return detail::walk(
435  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
436 }
437 
438 /// Walk all of the operations of type 'ArgT' nested under and including the
439 /// given operation. This method is selected for WalkReturn returning
440 /// interruptible callbacks that operate on a specific derived operation type.
441 ///
442 /// Example:
443 /// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
444 /// if (some_invariant)
445 /// return WalkResult::interrupt();
446 /// return WalkResult::advance();
447 /// });
448 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
449  typename RetT = decltype(std::declval<FuncTy>()(
450  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
451 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
452  std::is_same<RetT, WalkResult>::value,
453  RetT>
454 walk(Operation *op, FuncTy &&callback) {
455  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
456  if (auto derivedOp = dyn_cast<ArgT>(op))
457  return callback(derivedOp, stage);
458  return WalkResult::advance();
459  };
460  return detail::walk(
461  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
462 }
463 
464 /// Utility to provide the return type of a templated walk method.
465 template <typename FnT>
466 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
467 } // namespace detail
468 
469 } // namespace mlir
470 
471 #endif
Block represents an ordered list of Operations.
Definition: Block.h:30
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
WalkResult(InFlightDiagnostic &&)
Definition: Visitors.h:46
bool operator==(const WalkResult &rhs) const
Definition: Visitors.h:48
WalkResult(LogicalResult result)
Allow LogicalResult to interrupt the walk on failure.
Definition: Visitors.h:41
WalkResult(ResultEnum result=Advance)
Definition: Visitors.h:38
static WalkResult skip()
Definition: Visitors.h:53
bool wasSkipped() const
Returns true if the walk was skipped.
Definition: Visitors.h:59
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
WalkResult(Diagnostic &&)
Allow diagnostics to interrupt the walk.
Definition: Visitors.h:45
static WalkResult interrupt()
Definition: Visitors.h:51
bool operator!=(const WalkResult &rhs) const
Definition: Visitors.h:49
A utility class to encode the current walk stage for "generic" walkers.
Definition: Visitors.h:88
void advance()
Advance the walk stage.
Definition: Visitors.h:103
int getNextRegion() const
Returns the next region that will be visited.
Definition: Visitors.h:105
bool isBeforeRegion(int region) const
Returns true if parent operation is being visited just before visiting region number region.
Definition: Visitors.h:96
WalkStage(Operation *op)
Definition: Visitors.cpp:14
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
Definition: Visitors.h:101
bool isAfterRegion(int region) const
Returns true if parent operation is being visited just after visiting region number region.
Definition: Visitors.h:99
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
Definition: Visitors.h:93
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
decltype(walk(nullptr, std::declval< FnT >())) walkResultType
Utility to provide the return type of a templated walk method.
Definition: Visitors.h:466
decltype(first_argument_type(&F::operator())) first_argument_type(F)
Definition: Visitors.h:121
decltype(first_argument_type(std::declval< T >())) first_argument
Type definition of the first argument to the given callable 'T'.
Definition: Visitors.h:125
Include the generated interface declarations.
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition: Visitors.h:63
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:66
static MutableArrayRef< Region > makeIterable(Operation &range)
Make operations iterable: return the list of regions.
Definition: Visitors.cpp:17
static constexpr T & makeIterable(T &range)
Regions and block are already iterable.
Definition: Visitors.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26