MLIR  20.0.0git
Visitors.h
Go to the documentation of this file.
1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 namespace mlir {
20 class Diagnostic;
21 class InFlightDiagnostic;
22 class Operation;
23 class Block;
24 class Region;
25 
26 /// A utility result that is used to signal how to proceed with an ongoing walk:
27 /// * Interrupt: the walk will be interrupted and no more operations, regions
28 /// or blocks will be visited.
29 /// * Advance: the walk will continue.
30 /// * Skip: the walk of the current operation, region or block and their
31 /// nested elements that haven't been visited already will be skipped and will
32 /// continue with the next operation, region or block.
33 class WalkResult {
34  enum ResultEnum { Interrupt, Advance, Skip } result;
35 
36 public:
37  WalkResult(ResultEnum result = Advance) : result(result) {}
38 
39  /// Allow LogicalResult to interrupt the walk on failure.
40  WalkResult(LogicalResult result)
41  : result(failed(result) ? Interrupt : Advance) {}
42 
43  /// Allow diagnostics to interrupt the walk.
44  WalkResult(Diagnostic &&) : result(Interrupt) {}
45  WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
46 
47  bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
48  bool operator!=(const WalkResult &rhs) const { return result != rhs.result; }
49 
50  static WalkResult interrupt() { return {Interrupt}; }
51  static WalkResult advance() { return {Advance}; }
52  static WalkResult skip() { return {Skip}; }
53 
54  /// Returns true if the walk was interrupted.
55  bool wasInterrupted() const { return result == Interrupt; }
56 
57  /// Returns true if the walk was skipped.
58  bool wasSkipped() const { return result == Skip; }
59 };
60 
61 /// Traversal order for region, block and operation walk utilities.
62 enum class WalkOrder { PreOrder, PostOrder };
63 
64 /// This iterator enumerates the elements in "forward" order.
66  /// Make operations iterable: return the list of regions.
68 
69  /// Regions and block are already iterable.
70  template <typename T>
71  static constexpr T &makeIterable(T &range) {
72  return range;
73  }
74 };
75 
76 /// A utility class to encode the current walk stage for "generic" walkers.
77 /// When walking an operation, we can either choose a Pre/Post order walker
78 /// which invokes the callback on an operation before/after all its attached
79 /// regions have been visited, or choose a "generic" walker where the callback
80 /// is invoked on the operation N+1 times where N is the number of regions
81 /// attached to that operation. The `WalkStage` class below encodes the current
82 /// stage of the walk, i.e., which regions have already been visited, and the
83 /// callback accepts an additional argument for the current stage. Such
84 /// generic walkers that accept stage-aware callbacks are only applicable when
85 /// the callback operates on an operation (i.e., not applicable for callbacks
86 /// on Blocks or Regions).
87 class WalkStage {
88 public:
89  explicit WalkStage(Operation *op);
90 
91  /// Return true if parent operation is being visited before all regions.
92  bool isBeforeAllRegions() const { return nextRegion == 0; }
93  /// Returns true if parent operation is being visited just before visiting
94  /// region number `region`.
95  bool isBeforeRegion(int region) const { return nextRegion == region; }
96  /// Returns true if parent operation is being visited just after visiting
97  /// region number `region`.
98  bool isAfterRegion(int region) const { return nextRegion == region + 1; }
99  /// Return true if parent operation is being visited after all regions.
100  bool isAfterAllRegions() const { return nextRegion == numRegions; }
101  /// Advance the walk stage.
102  void advance() { nextRegion++; }
103  /// Returns the next region that will be visited.
104  int getNextRegion() const { return nextRegion; }
105 
106 private:
107  const int numRegions;
108  int nextRegion;
109 };
110 
111 namespace detail {
112 /// Helper templates to deduce the first argument of a callback parameter.
113 template <typename Ret, typename Arg, typename... Rest>
114 Arg first_argument_type(Ret (*)(Arg, Rest...));
115 template <typename Ret, typename F, typename Arg, typename... Rest>
116 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
117 template <typename Ret, typename F, typename Arg, typename... Rest>
118 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
119 template <typename F>
120 decltype(first_argument_type(&F::operator())) first_argument_type(F);
121 
122 /// Type definition of the first argument to the given callable 'T'.
123 template <typename T>
124 using first_argument = decltype(first_argument_type(std::declval<T>()));
125 
126 /// Walk all of the regions, blocks, or operations nested under (and including)
127 /// the given operation. The order in which regions, blocks and operations at
128 /// the same nesting level are visited (e.g., lexicographical or reverse
129 /// lexicographical order) is determined by 'Iterator'. The walk order for
130 /// enclosing regions, blocks and operations with respect to their nested ones
131 /// is specified by 'order'. These methods are invoked for void-returning
132 /// callbacks. A callback on a block or operation is allowed to erase that block
133 /// or operation only if the walk is in post-order. See non-void method for
134 /// pre-order erasure.
135 template <typename Iterator>
136 void walk(Operation *op, function_ref<void(Region *)> callback,
137  WalkOrder order) {
138  // We don't use early increment for regions because they can't be erased from
139  // a callback.
140  for (auto &region : Iterator::makeIterable(*op)) {
141  if (order == WalkOrder::PreOrder)
142  callback(&region);
143  for (auto &block : Iterator::makeIterable(region)) {
144  for (auto &nestedOp : Iterator::makeIterable(block))
145  walk<Iterator>(&nestedOp, callback, order);
146  }
147  if (order == WalkOrder::PostOrder)
148  callback(&region);
149  }
150 }
151 
152 template <typename Iterator>
153 void walk(Operation *op, function_ref<void(Block *)> callback,
154  WalkOrder order) {
155  for (auto &region : Iterator::makeIterable(*op)) {
156  // Early increment here in the case where the block is erased.
157  for (auto &block :
158  llvm::make_early_inc_range(Iterator::makeIterable(region))) {
159  if (order == WalkOrder::PreOrder)
160  callback(&block);
161  for (auto &nestedOp : Iterator::makeIterable(block))
162  walk<Iterator>(&nestedOp, callback, order);
163  if (order == WalkOrder::PostOrder)
164  callback(&block);
165  }
166  }
167 }
168 
169 template <typename Iterator>
170 void walk(Operation *op, function_ref<void(Operation *)> callback,
171  WalkOrder order) {
172  if (order == WalkOrder::PreOrder)
173  callback(op);
174 
175  // TODO: This walk should be iterative over the operations.
176  for (auto &region : Iterator::makeIterable(*op)) {
177  for (auto &block : Iterator::makeIterable(region)) {
178  // Early increment here in the case where the operation is erased.
179  for (auto &nestedOp :
180  llvm::make_early_inc_range(Iterator::makeIterable(block)))
181  walk<Iterator>(&nestedOp, callback, order);
182  }
183  }
184 
185  if (order == WalkOrder::PostOrder)
186  callback(op);
187 }
188 
189 /// Walk all of the regions, blocks, or operations nested under (and including)
190 /// the given operation. The order in which regions, blocks and operations at
191 /// the same nesting level are visited (e.g., lexicographical or reverse
192 /// lexicographical order) is determined by 'Iterator'. The walk order for
193 /// enclosing regions, blocks and operations with respect to their nested ones
194 /// is specified by 'order'. This method is invoked for skippable or
195 /// interruptible callbacks. A callback on a block or operation is allowed to
196 /// erase that block or operation if either:
197 /// * the walk is in post-order, or
198 /// * the walk is in pre-order and the walk is skipped after the erasure.
199 template <typename Iterator>
201  WalkOrder order) {
202  // We don't use early increment for regions because they can't be erased from
203  // a callback.
204  for (auto &region : Iterator::makeIterable(*op)) {
205  if (order == WalkOrder::PreOrder) {
206  WalkResult result = callback(&region);
207  if (result.wasSkipped())
208  continue;
209  if (result.wasInterrupted())
210  return WalkResult::interrupt();
211  }
212  for (auto &block : Iterator::makeIterable(region)) {
213  for (auto &nestedOp : Iterator::makeIterable(block))
214  if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
215  return WalkResult::interrupt();
216  }
217  if (order == WalkOrder::PostOrder) {
218  if (callback(&region).wasInterrupted())
219  return WalkResult::interrupt();
220  // We don't check if this region was skipped because its walk already
221  // finished and the walk will continue with the next region.
222  }
223  }
224  return WalkResult::advance();
225 }
226 
227 template <typename Iterator>
229  WalkOrder order) {
230  for (auto &region : Iterator::makeIterable(*op)) {
231  // Early increment here in the case where the block is erased.
232  for (auto &block :
233  llvm::make_early_inc_range(Iterator::makeIterable(region))) {
234  if (order == WalkOrder::PreOrder) {
235  WalkResult result = callback(&block);
236  if (result.wasSkipped())
237  continue;
238  if (result.wasInterrupted())
239  return WalkResult::interrupt();
240  }
241  for (auto &nestedOp : Iterator::makeIterable(block))
242  if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
243  return WalkResult::interrupt();
244  if (order == WalkOrder::PostOrder) {
245  if (callback(&block).wasInterrupted())
246  return WalkResult::interrupt();
247  // We don't check if this block was skipped because its walk already
248  // finished and the walk will continue with the next block.
249  }
250  }
251  }
252  return WalkResult::advance();
253 }
254 
255 template <typename Iterator>
257  WalkOrder order) {
258  if (order == WalkOrder::PreOrder) {
259  WalkResult result = callback(op);
260  // If skipped, caller will continue the walk on the next operation.
261  if (result.wasSkipped())
262  return WalkResult::advance();
263  if (result.wasInterrupted())
264  return WalkResult::interrupt();
265  }
266 
267  // TODO: This walk should be iterative over the operations.
268  for (auto &region : Iterator::makeIterable(*op)) {
269  for (auto &block : Iterator::makeIterable(region)) {
270  // Early increment here in the case where the operation is erased.
271  for (auto &nestedOp :
272  llvm::make_early_inc_range(Iterator::makeIterable(block))) {
273  if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
274  return WalkResult::interrupt();
275  }
276  }
277  }
278 
279  if (order == WalkOrder::PostOrder)
280  return callback(op);
281  return WalkResult::advance();
282 }
283 
284 // Below are a set of functions to walk nested operations. Users should favor
285 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
286 // methods. They are also templated to allow for statically dispatching based
287 // upon the type of the callback function.
288 
289 /// Walk all of the regions, blocks, or operations nested under (and including)
290 /// the given operation. The order in which regions, blocks and operations at
291 /// the same nesting level are visited (e.g., lexicographical or reverse
292 /// lexicographical order) is determined by 'Iterator'. The walk order for
293 /// enclosing regions, blocks and operations with respect to their nested ones
294 /// is specified by 'Order' (post-order by default). A callback on a block or
295 /// operation is allowed to erase that block or operation if either:
296 /// * the walk is in post-order, or
297 /// * the walk is in pre-order and the walk is skipped after the erasure.
298 /// This method is selected for callbacks that operate on Region*, Block*, and
299 /// Operation*.
300 ///
301 /// Example:
302 /// op->walk([](Region *r) { ... });
303 /// op->walk([](Block *b) { ... });
304 /// op->walk([](Operation *op) { ... });
305 template <
306  WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
307  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
308  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
309 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
310  RetT>
311 walk(Operation *op, FuncTy &&callback) {
312  return detail::walk<Iterator>(op, function_ref<RetT(ArgT)>(callback), Order);
313 }
314 
315 /// Walk all of the operations of type 'ArgT' nested under and including the
316 /// given operation. The order in which regions, blocks and operations at
317 /// the same nesting are visited (e.g., lexicographical or reverse
318 /// lexicographical order) is determined by 'Iterator'. The walk order for
319 /// enclosing regions, blocks and operations with respect to their nested ones
320 /// is specified by 'order' (post-order by default). This method is selected for
321 /// void-returning callbacks that operate on a specific derived operation type.
322 /// A callback on an operation is allowed to erase that operation only if the
323 /// walk is in post-order. See non-void method for pre-order erasure.
324 ///
325 /// Example:
326 /// op->walk([](ReturnOp op) { ... });
327 template <
328  WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
329  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
330  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
331 std::enable_if_t<
332  !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
333  std::is_same<RetT, void>::value,
334  RetT>
335 walk(Operation *op, FuncTy &&callback) {
336  auto wrapperFn = [&](Operation *op) {
337  if (auto derivedOp = dyn_cast<ArgT>(op))
338  callback(derivedOp);
339  };
340  return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
341  Order);
342 }
343 
344 /// Walk all of the operations of type 'ArgT' nested under and including the
345 /// given operation. The order in which regions, blocks and operations at
346 /// the same nesting are visited (e.g., lexicographical or reverse
347 /// lexicographical order) is determined by 'Iterator'. The walk order for
348 /// enclosing regions, blocks and operations with respect to their nested ones
349 /// is specified by 'Order' (post-order by default). This method is selected for
350 /// WalkReturn returning skippable or interruptible callbacks that operate on a
351 /// specific derived operation type. A callback on an operation is allowed to
352 /// erase that operation if either:
353 /// * the walk is in post-order, or
354 /// * the walk is in pre-order and the walk is skipped after the erasure.
355 ///
356 /// Example:
357 /// op->walk([](ReturnOp op) {
358 /// if (some_invariant)
359 /// return WalkResult::skip();
360 /// if (another_invariant)
361 /// return WalkResult::interrupt();
362 /// return WalkResult::advance();
363 /// });
364 template <
365  WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
366  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
367  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
368 std::enable_if_t<
369  !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
370  std::is_same<RetT, WalkResult>::value,
371  RetT>
372 walk(Operation *op, FuncTy &&callback) {
373  auto wrapperFn = [&](Operation *op) {
374  if (auto derivedOp = dyn_cast<ArgT>(op))
375  return callback(derivedOp);
376  return WalkResult::advance();
377  };
378  return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
379  Order);
380 }
381 
382 /// Generic walkers with stage aware callbacks.
383 
384 /// Walk all the operations nested under (and including) the given operation,
385 /// with the callback being invoked on each operation N+1 times, where N is the
386 /// number of regions attached to the operation. The `stage` input to the
387 /// callback indicates the current walk stage. This method is invoked for void
388 /// returning callbacks.
389 void walk(Operation *op,
390  function_ref<void(Operation *, const WalkStage &stage)> callback);
391 
392 /// Walk all the operations nested under (and including) the given operation,
393 /// with the callback being invoked on each operation N+1 times, where N is the
394 /// number of regions attached to the operation. The `stage` input to the
395 /// callback indicates the current walk stage. This method is invoked for
396 /// skippable or interruptible callbacks.
399  function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
400 
401 /// Walk all of the operations nested under and including the given operation.
402 /// This method is selected for stage-aware callbacks that operate on
403 /// Operation*.
404 ///
405 /// Example:
406 /// op->walk([](Operation *op, const WalkStage &stage) { ... });
407 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
408  typename RetT = decltype(std::declval<FuncTy>()(
409  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
410 std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>
411 walk(Operation *op, FuncTy &&callback) {
412  return detail::walk(op,
413  function_ref<RetT(ArgT, const WalkStage &)>(callback));
414 }
415 
416 /// Walk all of the operations of type 'ArgT' nested under and including the
417 /// given operation. This method is selected for void returning callbacks that
418 /// operate on a specific derived operation type.
419 ///
420 /// Example:
421 /// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
422 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
423  typename RetT = decltype(std::declval<FuncTy>()(
424  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
425 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
426  std::is_same<RetT, void>::value,
427  RetT>
428 walk(Operation *op, FuncTy &&callback) {
429  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
430  if (auto derivedOp = dyn_cast<ArgT>(op))
431  callback(derivedOp, stage);
432  };
433  return detail::walk(
434  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
435 }
436 
437 /// Walk all of the operations of type 'ArgT' nested under and including the
438 /// given operation. This method is selected for WalkReturn returning
439 /// interruptible callbacks that operate on a specific derived operation type.
440 ///
441 /// Example:
442 /// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
443 /// if (some_invariant)
444 /// return WalkResult::interrupt();
445 /// return WalkResult::advance();
446 /// });
447 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
448  typename RetT = decltype(std::declval<FuncTy>()(
449  std::declval<ArgT>(), std::declval<const WalkStage &>()))>
450 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
451  std::is_same<RetT, WalkResult>::value,
452  RetT>
453 walk(Operation *op, FuncTy &&callback) {
454  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
455  if (auto derivedOp = dyn_cast<ArgT>(op))
456  return callback(derivedOp, stage);
457  return WalkResult::advance();
458  };
459  return detail::walk(
460  op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
461 }
462 
463 /// Utility to provide the return type of a templated walk method.
464 template <typename FnT>
465 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
466 } // namespace detail
467 
468 } // namespace mlir
469 
470 #endif
Block represents an ordered list of Operations.
Definition: Block.h:33
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
WalkResult(InFlightDiagnostic &&)
Definition: Visitors.h:45
bool operator==(const WalkResult &rhs) const
Definition: Visitors.h:47
WalkResult(LogicalResult result)
Allow LogicalResult to interrupt the walk on failure.
Definition: Visitors.h:40
WalkResult(ResultEnum result=Advance)
Definition: Visitors.h:37
static WalkResult skip()
Definition: Visitors.h:52
bool wasSkipped() const
Returns true if the walk was skipped.
Definition: Visitors.h:58
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
WalkResult(Diagnostic &&)
Allow diagnostics to interrupt the walk.
Definition: Visitors.h:44
static WalkResult interrupt()
Definition: Visitors.h:50
bool operator!=(const WalkResult &rhs) const
Definition: Visitors.h:48
A utility class to encode the current walk stage for "generic" walkers.
Definition: Visitors.h:87
void advance()
Advance the walk stage.
Definition: Visitors.h:102
int getNextRegion() const
Returns the next region that will be visited.
Definition: Visitors.h:104
bool isBeforeRegion(int region) const
Returns true if parent operation is being visited just before visiting region number region.
Definition: Visitors.h:95
WalkStage(Operation *op)
Definition: Visitors.cpp:14
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
Definition: Visitors.h:100
bool isAfterRegion(int region) const
Returns true if parent operation is being visited just after visiting region number region.
Definition: Visitors.h:98
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
Definition: Visitors.h:92
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
decltype(walk(nullptr, std::declval< FnT >())) walkResultType
Utility to provide the return type of a templated walk method.
Definition: Visitors.h:465
decltype(first_argument_type(&F::operator())) first_argument_type(F)
Definition: Visitors.h:120
decltype(first_argument_type(std::declval< T >())) first_argument
Type definition of the first argument to the given callable 'T'.
Definition: Visitors.h:124
Include the generated interface declarations.
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition: Visitors.h:62
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:65
static MutableArrayRef< Region > makeIterable(Operation &range)
Make operations iterable: return the list of regions.
Definition: Visitors.cpp:17
static constexpr T & makeIterable(T &range)
Regions and block are already iterable.
Definition: Visitors.h:71