MLIR  16.0.0git
NestedMatcher.cpp
Go to the documentation of this file.
1 //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Allocator.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 
21 llvm::BumpPtrAllocator *&NestedMatch::allocator() {
22  thread_local llvm::BumpPtrAllocator *allocator = nullptr;
23  return allocator;
24 }
25 
27  ArrayRef<NestedMatch> nestedMatches) {
28  auto *result = allocator()->Allocate<NestedMatch>();
29  auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
30  std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
31  new (result) NestedMatch();
32  result->matchedOperation = operation;
33  result->matchedChildren =
34  ArrayRef<NestedMatch>(children, nestedMatches.size());
35  return *result;
36 }
37 
38 llvm::BumpPtrAllocator *&NestedPattern::allocator() {
39  thread_local llvm::BumpPtrAllocator *allocator = nullptr;
40  return allocator;
41 }
42 
43 void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
44  if (nested.empty())
45  return;
46 
47  auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
48  std::uninitialized_copy(nested.begin(), nested.end(), newNested);
49  nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
50 }
51 
52 void NestedPattern::freeNested() {
53  for (const auto &p : nestedPatterns)
54  p.~NestedPattern();
55 }
56 
58  FilterFunctionType filter)
59  : filter(std::move(filter)), skip(nullptr) {
60  copyNestedToThis(nested);
61 }
62 
64  : filter(other.filter), skip(other.skip) {
65  copyNestedToThis(other.nestedPatterns);
66 }
67 
69  freeNested();
70  filter = other.filter;
71  skip = other.skip;
72  copyNestedToThis(other.nestedPatterns);
73  return *this;
74 }
75 
76 unsigned NestedPattern::getDepth() const {
77  if (nestedPatterns.empty()) {
78  return 1;
79  }
80  unsigned depth = 0;
81  for (auto &c : nestedPatterns) {
82  depth = std::max(depth, c.getDepth());
83  }
84  return depth + 1;
85 }
86 
87 /// Matches a single operation in the following way:
88 /// 1. checks the kind of operation against the matcher, if different then
89 /// there is no match;
90 /// 2. calls the customizable filter function to refine the single operation
91 /// match with extra semantic constraints;
92 /// 3. if all is good, recursively matches the nested patterns;
93 /// 4. if all nested match then the single operation matches too and is
94 /// appended to the list of matches;
95 /// 5. TODO: Optionally applies actions (lambda), in which case we will want
96 /// to traverse in post-order DFS to avoid invalidating iterators.
97 void NestedPattern::matchOne(Operation *op,
99  if (skip == op) {
100  return;
101  }
102  // Local custom filter function
103  if (!filter(*op)) {
104  return;
105  }
106 
107  if (nestedPatterns.empty()) {
108  SmallVector<NestedMatch, 8> nestedMatches;
109  matches->push_back(NestedMatch::build(op, nestedMatches));
110  return;
111  }
112  // Take a copy of each nested pattern so we can match it.
113  for (auto nestedPattern : nestedPatterns) {
114  SmallVector<NestedMatch, 8> nestedMatches;
115  // Skip elem in the walk immediately following. Without this we would
116  // essentially need to reimplement walk here.
117  nestedPattern.skip = op;
118  nestedPattern.match(op, &nestedMatches);
119  // If we could not match even one of the specified nestedPattern, early exit
120  // as this whole branch is not a match.
121  if (nestedMatches.empty()) {
122  return;
123  }
124  matches->push_back(NestedMatch::build(op, nestedMatches));
125  }
126 }
127 
128 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
129 
130 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
131 
132 namespace mlir {
133 namespace matcher {
134 
136  return NestedPattern({}, std::move(filter));
137 }
138 
140  return NestedPattern(child, isAffineIfOp);
141 }
142 NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
143  return NestedPattern(child, [filter](Operation &op) {
144  return isAffineIfOp(op) && filter(op);
145  });
146 }
148  return NestedPattern(nested, isAffineIfOp);
149 }
151  ArrayRef<NestedPattern> nested) {
152  return NestedPattern(nested, [filter](Operation &op) {
153  return isAffineIfOp(op) && filter(op);
154  });
155 }
156 
158  return NestedPattern(child, isAffineForOp);
159 }
161  const NestedPattern &child) {
162  return NestedPattern(
163  child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
164 }
166  return NestedPattern(nested, isAffineForOp);
167 }
169  ArrayRef<NestedPattern> nested) {
170  return NestedPattern(
171  nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
172 }
173 
175  return isa<AffineLoadOp, AffineStoreOp>(op);
176 }
177 
178 } // namespace matcher
179 } // namespace mlir
Include the generated interface declarations.
NestedPattern(ArrayRef< NestedPattern > nested, FilterFunctionType filter=defaultFilterFunction)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool isLoadOrStore(Operation &op)
unsigned getDepth() const
Returns the depth of the pattern.
static bool isAffineForOp(Operation &op)
static NestedMatch build(Operation *operation, ArrayRef< NestedMatch > nestedMatches)
std::function< bool(Operation &)> FilterFunctionType
A NestedPattern is a nested operation walker that:
Definition: NestedMatcher.h:90
NestedPattern If(const NestedPattern &child)
NestedPattern For(const NestedPattern &child)
static bool isAffineIfOp(Operation &op)
An NestedPattern captures nested patterns in the IR.
Definition: NestedMatcher.h:46
NestedPattern & operator=(const NestedPattern &other)
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)