13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
18 #include "llvm/ADT/STLExtras.h"
22 class InFlightDiagnostic;
35 enum ResultEnum { Interrupt, Advance, Skip } result;
38 WalkResult(ResultEnum result = Advance) : result(result) {}
42 : result(
failed(result) ? Interrupt : Advance) {}
99 bool isAfterRegion(
int region)
const {
return nextRegion == region + 1; }
108 const int numRegions;
114 template <
typename Ret,
typename Arg,
typename... Rest>
116 template <
typename Ret,
typename F,
typename Arg,
typename... Rest>
118 template <
typename Ret,
typename F,
typename Arg,
typename... Rest>
120 template <
typename F>
124 template <typename T>
136 template <typename Iterator>
141 for (
auto ®ion : Iterator::makeIterable(*op)) {
144 for (
auto &block : Iterator::makeIterable(region)) {
145 for (
auto &nestedOp : Iterator::makeIterable(block))
146 walk<Iterator>(&nestedOp, callback, order);
153 template <
typename Iterator>
156 for (
auto ®ion : Iterator::makeIterable(*op)) {
159 llvm::make_early_inc_range(Iterator::makeIterable(region))) {
162 for (
auto &nestedOp : Iterator::makeIterable(block))
163 walk<Iterator>(&nestedOp, callback, order);
170 template <
typename Iterator>
177 for (
auto ®ion : Iterator::makeIterable(*op)) {
178 for (
auto &block : Iterator::makeIterable(region)) {
180 for (
auto &nestedOp :
181 llvm::make_early_inc_range(Iterator::makeIterable(block)))
182 walk<Iterator>(&nestedOp, callback, order);
200 template <
typename Iterator>
205 for (
auto ®ion : Iterator::makeIterable(*op)) {
213 for (
auto &block : Iterator::makeIterable(region)) {
214 for (
auto &nestedOp : Iterator::makeIterable(block))
215 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
219 if (callback(®ion).wasInterrupted())
228 template <
typename Iterator>
231 for (
auto ®ion : Iterator::makeIterable(*op)) {
234 llvm::make_early_inc_range(Iterator::makeIterable(region))) {
242 for (
auto &nestedOp : Iterator::makeIterable(block))
243 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
246 if (callback(&block).wasInterrupted())
256 template <
typename Iterator>
269 for (
auto ®ion : Iterator::makeIterable(*op)) {
270 for (
auto &block : Iterator::makeIterable(region)) {
272 for (
auto &nestedOp :
273 llvm::make_early_inc_range(Iterator::makeIterable(block))) {
274 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
309 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
310 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
313 return detail::walk<Iterator>(op,
function_ref<RetT(ArgT)>(callback), Order);
331 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
333 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
334 std::is_same<RetT, void>::value,
338 if (
auto derivedOp = dyn_cast<ArgT>(op))
368 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
370 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
371 std::is_same<RetT, WalkResult>::value,
375 if (
auto derivedOp = dyn_cast<ArgT>(op))
376 return callback(derivedOp);
408 template <
typename FuncTy,
typename ArgT = detail::first_argument<FuncTy>,
409 typename RetT = decltype(std::declval<FuncTy>()(
410 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
411 std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>
423 template <
typename FuncTy,
typename ArgT = detail::first_argument<FuncTy>,
424 typename RetT = decltype(std::declval<FuncTy>()(
425 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
426 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
427 std::is_same<RetT, void>::value,
431 if (
auto derivedOp = dyn_cast<ArgT>(op))
432 callback(derivedOp, stage);
448 template <
typename FuncTy,
typename ArgT = detail::first_argument<FuncTy>,
449 typename RetT = decltype(std::declval<FuncTy>()(
450 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
451 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
452 std::is_same<RetT, WalkResult>::value,
456 if (
auto derivedOp = dyn_cast<ArgT>(op))
457 return callback(derivedOp, stage);
465 template <
typename FnT>
Block represents an ordered list of Operations.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
Operation is the basic unit of execution within MLIR.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
A utility result that is used to signal how to proceed with an ongoing walk:
WalkResult(InFlightDiagnostic &&)
bool operator==(const WalkResult &rhs) const
WalkResult(LogicalResult result)
Allow LogicalResult to interrupt the walk on failure.
WalkResult(ResultEnum result=Advance)
bool wasSkipped() const
Returns true if the walk was skipped.
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
WalkResult(Diagnostic &&)
Allow diagnostics to interrupt the walk.
static WalkResult interrupt()
bool operator!=(const WalkResult &rhs) const
A utility class to encode the current walk stage for "generic" walkers.
void advance()
Advance the walk stage.
int getNextRegion() const
Returns the next region that will be visited.
bool isBeforeRegion(int region) const
Returns true if parent operation is being visited just before visiting region number region.
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
bool isAfterRegion(int region) const
Returns true if parent operation is being visited just after visiting region number region.
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
decltype(walk(nullptr, std::declval< FnT >())) walkResultType
Utility to provide the return type of a templated walk method.
decltype(first_argument_type(&F::operator())) first_argument_type(F)
decltype(first_argument_type(std::declval< T >())) first_argument
Type definition of the first argument to the given callable 'T'.
This header declares functions that assist transformations in the MemRef dialect.
WalkOrder
Traversal order for region, block and operation walk utilities.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This iterator enumerates the elements in "forward" order.
static MutableArrayRef< Region > makeIterable(Operation &range)
Make operations iterable: return the list of regions.
static constexpr T & makeIterable(T &range)
Regions and block are already iterable.
This class represents an efficient way to signal success or failure.