MLIR  22.0.0git
Predicate.cpp
Go to the documentation of this file.
1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/TableGen/Error.h"
18 #include "llvm/TableGen/Record.h"
19 
20 using namespace mlir;
21 using namespace tblgen;
22 using llvm::Init;
23 using llvm::Record;
24 using llvm::SpecificBumpPtrAllocator;
25 
26 // Construct a Predicate from a record.
27 Pred::Pred(const Record *record) : def(record) {
28  assert(def->isSubClassOf("Pred") &&
29  "must be a subclass of TableGen 'Pred' class");
30 }
31 
32 // Construct a Predicate from an initializer.
33 Pred::Pred(const Init *init) {
34  if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
35  def = defInit->getDef();
36 }
37 
38 std::string Pred::getCondition() const {
39  // Static dispatch to subclasses.
40  if (def->isSubClassOf("CombinedPred"))
41  return static_cast<const CombinedPred *>(this)->getConditionImpl();
42  if (def->isSubClassOf("CPred"))
43  return static_cast<const CPred *>(this)->getConditionImpl();
44  llvm_unreachable("Pred::getCondition must be overridden in subclasses");
45 }
46 
47 bool Pred::isCombined() const {
48  return def && def->isSubClassOf("CombinedPred");
49 }
50 
51 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
52 
53 CPred::CPred(const Record *record) : Pred(record) {
54  assert(def->isSubClassOf("CPred") &&
55  "must be a subclass of Tablegen 'CPred' class");
56 }
57 
58 CPred::CPred(const Init *init) : Pred(init) {
59  assert((!def || def->isSubClassOf("CPred")) &&
60  "must be a subclass of Tablegen 'CPred' class");
61 }
62 
63 // Get condition of the C Predicate.
64 std::string CPred::getConditionImpl() const {
65  assert(!isNull() && "null predicate does not have a condition");
66  return std::string(def->getValueAsString("predExpr"));
67 }
68 
69 CombinedPred::CombinedPred(const Record *record) : Pred(record) {
70  assert(def->isSubClassOf("CombinedPred") &&
71  "must be a subclass of Tablegen 'CombinedPred' class");
72 }
73 
74 CombinedPred::CombinedPred(const Init *init) : Pred(init) {
75  assert((!def || def->isSubClassOf("CombinedPred")) &&
76  "must be a subclass of Tablegen 'CombinedPred' class");
77 }
78 
79 const Record *CombinedPred::getCombinerDef() const {
80  assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
81  return def->getValueAsDef("kind");
82 }
83 
84 std::vector<const Record *> CombinedPred::getChildren() const {
85  assert(def->getValue("children") &&
86  "CombinedPred must have a value 'children'");
87  return def->getValueAsListOfDefs("children");
88 }
89 
90 namespace {
91 // Kinds of nodes in a logical predicate tree.
92 enum class PredCombinerKind {
93  Leaf,
94  And,
95  Or,
96  Not,
97  SubstLeaves,
98  Concat,
99  // Special kinds that are used in simplification.
100  False,
101  True
102 };
103 
104 // A node in a logical predicate tree.
105 struct PredNode {
106  PredCombinerKind kind;
107  const Pred *predicate;
109  std::string expr;
110 
111  // Prefix and suffix are used by ConcatPred.
112  std::string prefix;
113  std::string suffix;
114 };
115 } // namespace
116 
117 // Get a predicate tree node kind based on the kind used in the predicate
118 // TableGen record.
119 static PredCombinerKind getPredCombinerKind(const Pred &pred) {
120  if (!pred.isCombined())
121  return PredCombinerKind::Leaf;
122 
123  const auto &combinedPred = static_cast<const CombinedPred &>(pred);
125  combinedPred.getCombinerDef()->getName())
126  .Case("PredCombinerAnd", PredCombinerKind::And)
127  .Case("PredCombinerOr", PredCombinerKind::Or)
128  .Case("PredCombinerNot", PredCombinerKind::Not)
129  .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
130  .Case("PredCombinerConcat", PredCombinerKind::Concat);
131 }
132 
133 namespace {
134 // Substitution<pattern, replacement>.
135 using Subst = std::pair<StringRef, StringRef>;
136 } // namespace
137 
138 /// Perform the given substitutions on 'str' in-place.
139 static void performSubstitutions(std::string &str,
140  ArrayRef<Subst> substitutions) {
141  // Apply all parent substitutions from innermost to outermost.
142  for (const auto &subst : llvm::reverse(substitutions)) {
143  auto pos = str.find(subst.first);
144  while (pos != std::string::npos) {
145  str.replace(pos, subst.first.size(), std::string(subst.second));
146  // Skip the newly inserted substring, which itself may consider the
147  // pattern to match.
148  pos += subst.second.size();
149  // Find the next possible match position.
150  pos = str.find(subst.first, pos);
151  }
152  }
153 }
154 
155 // Build the predicate tree starting from the top-level predicate, which may
156 // have children, and perform leaf substitutions inplace. Note that after
157 // substitution, nodes are still pointing to the original TableGen record.
158 // All nodes are created within "allocator".
159 static PredNode *
161  SpecificBumpPtrAllocator<PredNode> &allocator,
162  ArrayRef<Subst> substitutions) {
163  auto *rootNode = allocator.Allocate();
164  new (rootNode) PredNode;
165  rootNode->kind = getPredCombinerKind(root);
166  rootNode->predicate = &root;
167  if (!root.isCombined()) {
168  rootNode->expr = root.getCondition();
169  performSubstitutions(rootNode->expr, substitutions);
170  return rootNode;
171  }
172 
173  // If the current combined predicate is a leaf substitution, append it to the
174  // list before continuing.
175  auto allSubstitutions = llvm::to_vector<4>(substitutions);
176  if (rootNode->kind == PredCombinerKind::SubstLeaves) {
177  const auto &substPred = static_cast<const SubstLeavesPred &>(root);
178  allSubstitutions.push_back(
179  {substPred.getPattern(), substPred.getReplacement()});
180 
181  // If the current predicate is a ConcatPred, record the prefix and suffix.
182  } else if (rootNode->kind == PredCombinerKind::Concat) {
183  const auto &concatPred = static_cast<const ConcatPred &>(root);
184  rootNode->prefix = std::string(concatPred.getPrefix());
185  performSubstitutions(rootNode->prefix, substitutions);
186  rootNode->suffix = std::string(concatPred.getSuffix());
187  performSubstitutions(rootNode->suffix, substitutions);
188  }
189 
190  // Build child subtrees.
191  auto combined = static_cast<const CombinedPred &>(root);
192  for (const auto *record : combined.getChildren()) {
193  auto *childTree =
194  buildPredicateTree(Pred(record), allocator, allSubstitutions);
195  rootNode->children.push_back(childTree);
196  }
197  return rootNode;
198 }
199 
200 // Simplify a predicate tree rooted at "node" using the predicates that are
201 // known to be true(false). For AND(OR) combined predicates, if any of the
202 // children is known to be false(true), the result is also false(true).
203 // Furthermore, for AND(OR) combined predicates, children that are known to be
204 // true(false) don't have to be checked dynamically.
205 static PredNode *
206 propagateGroundTruth(PredNode *node,
207  const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
208  const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
209  // If the current predicate is known to be true or false, change the kind of
210  // the node and return immediately.
211  if (knownTruePreds.count(node->predicate) != 0) {
212  node->kind = PredCombinerKind::True;
213  node->children.clear();
214  return node;
215  }
216  if (knownFalsePreds.count(node->predicate) != 0) {
217  node->kind = PredCombinerKind::False;
218  node->children.clear();
219  return node;
220  }
221 
222  // If the current node is a substitution, stop recursion now.
223  // The expressions in the leaves below this node were rewritten, but the nodes
224  // still point to the original predicate records. While the original
225  // predicate may be known to be true or false, it is not necessarily the case
226  // after rewriting.
227  // TODO: we can support ground truth for rewritten
228  // predicates by either (a) having our own unique'ing of the predicates
229  // instead of relying on TableGen record pointers or (b) taking ground truth
230  // values optionally prefixed with a list of substitutions to apply, e.g.
231  // "predX is true by itself as well as predSubY leaf substitution had been
232  // applied to it".
233  if (node->kind == PredCombinerKind::SubstLeaves) {
234  return node;
235  }
236 
237  if (node->kind == PredCombinerKind::And && node->children.empty()) {
238  node->kind = PredCombinerKind::True;
239  return node;
240  }
241 
242  if (node->kind == PredCombinerKind::Or && node->children.empty()) {
243  node->kind = PredCombinerKind::False;
244  return node;
245  }
246 
247  // Otherwise, look at child nodes.
248 
249  // Move child nodes into some local variable so that they can be optimized
250  // separately and re-added if necessary.
252  std::swap(node->children, children);
253 
254  for (auto &child : children) {
255  // First, simplify the child. This maintains the predicate as it was.
256  auto *simplifiedChild =
257  propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
258 
259  // Just add the child if we don't know how to simplify the current node.
260  if (node->kind != PredCombinerKind::And &&
261  node->kind != PredCombinerKind::Or) {
262  node->children.push_back(simplifiedChild);
263  continue;
264  }
265 
266  // Second, based on the type define which known values of child predicates
267  // immediately collapse this predicate to a known value, and which others
268  // may be safely ignored.
269  // OR(..., True, ...) = True
270  // OR(..., False, ...) = OR(..., ...)
271  // AND(..., False, ...) = False
272  // AND(..., True, ...) = AND(..., ...)
273  auto collapseKind = node->kind == PredCombinerKind::And
274  ? PredCombinerKind::False
275  : PredCombinerKind::True;
276  auto eraseKind = node->kind == PredCombinerKind::And
277  ? PredCombinerKind::True
278  : PredCombinerKind::False;
279  const auto &collapseList =
280  node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
281  const auto &eraseList =
282  node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
283  if (simplifiedChild->kind == collapseKind ||
284  collapseList.count(simplifiedChild->predicate) != 0) {
285  node->kind = collapseKind;
286  node->children.clear();
287  return node;
288  }
289  if (simplifiedChild->kind == eraseKind ||
290  eraseList.count(simplifiedChild->predicate) != 0) {
291  continue;
292  }
293  node->children.push_back(simplifiedChild);
294  }
295  return node;
296 }
297 
298 // Combine a list of predicate expressions using a binary combiner. If a list
299 // is empty, return "init".
300 static std::string combineBinary(ArrayRef<std::string> children,
301  const std::string &combiner,
302  std::string init) {
303  if (children.empty())
304  return init;
305 
306  auto size = children.size();
307  if (size == 1)
308  return children.front();
309 
310  std::string str;
311  llvm::raw_string_ostream os(str);
312  os << '(' << children.front() << ')';
313  for (unsigned i = 1; i < size; ++i) {
314  os << ' ' << combiner << " (" << children[i] << ')';
315  }
316  return str;
317 }
318 
319 // Prepend negation to the only condition in the predicate expression list.
320 static std::string combineNot(ArrayRef<std::string> children) {
321  assert(children.size() == 1 && "expected exactly one child predicate of Neg");
322  return (Twine("!(") + children.front() + Twine(')')).str();
323 }
324 
325 // Recursively traverse the predicate tree in depth-first post-order and build
326 // the final expression.
327 static std::string getCombinedCondition(const PredNode &root) {
328  // Immediately return for non-combiner predicates that don't have children.
329  if (root.kind == PredCombinerKind::Leaf)
330  return root.expr;
331  if (root.kind == PredCombinerKind::True)
332  return "true";
333  if (root.kind == PredCombinerKind::False)
334  return "false";
335 
336  // Recurse into children.
337  llvm::SmallVector<std::string, 4> childExpressions;
338  childExpressions.reserve(root.children.size());
339  for (const auto &child : root.children)
340  childExpressions.push_back(getCombinedCondition(*child));
341 
342  // Combine the expressions based on the predicate node kind.
343  if (root.kind == PredCombinerKind::And)
344  return combineBinary(childExpressions, "&&", "true");
345  if (root.kind == PredCombinerKind::Or)
346  return combineBinary(childExpressions, "||", "false");
347  if (root.kind == PredCombinerKind::Not)
348  return combineNot(childExpressions);
349  if (root.kind == PredCombinerKind::Concat) {
350  assert(childExpressions.size() == 1 &&
351  "ConcatPred should only have one child");
352  return root.prefix + childExpressions.front() + root.suffix;
353  }
354 
355  // Substitutions were applied before so just ignore them.
356  if (root.kind == PredCombinerKind::SubstLeaves) {
357  assert(childExpressions.size() == 1 &&
358  "substitution predicate must have one child");
359  return childExpressions[0];
360  }
361 
362  llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
363 }
364 
365 std::string CombinedPred::getConditionImpl() const {
366  SpecificBumpPtrAllocator<PredNode> allocator;
367  auto *predicateTree = buildPredicateTree(*this, allocator, {});
368  predicateTree =
369  propagateGroundTruth(predicateTree,
370  /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
371  /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
372 
373  return getCombinedCondition(*predicateTree);
374 }
375 
376 StringRef SubstLeavesPred::getPattern() const {
377  return def->getValueAsString("pattern");
378 }
379 
381  return def->getValueAsString("replacement");
382 }
383 
384 StringRef ConcatPred::getPrefix() const {
385  return def->getValueAsString("prefix");
386 }
387 
388 StringRef ConcatPred::getSuffix() const {
389  return def->getValueAsString("suffix");
390 }
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::string combineBinary(ArrayRef< std::string > children, const std::string &combiner, std::string init)
Definition: Predicate.cpp:300
static std::string getCombinedCondition(const PredNode &root)
Definition: Predicate.cpp:327
static PredCombinerKind getPredCombinerKind(const Pred &pred)
Definition: Predicate.cpp:119
static void performSubstitutions(std::string &str, ArrayRef< Subst > substitutions)
Perform the given substitutions on 'str' in-place.
Definition: Predicate.cpp:139
static PredNode * buildPredicateTree(const Pred &root, SpecificBumpPtrAllocator< PredNode > &allocator, ArrayRef< Subst > substitutions)
Definition: Predicate.cpp:160
static PredNode * propagateGroundTruth(PredNode *node, const llvm::SmallPtrSetImpl< Pred * > &knownTruePreds, const llvm::SmallPtrSetImpl< Pred * > &knownFalsePreds)
Definition: Predicate.cpp:206
static std::string combineNot(ArrayRef< std::string > children)
Definition: Predicate.cpp:320
std::string getConditionImpl() const
Definition: Predicate.cpp:64
CPred(const llvm::Record *record)
CombinedPred(const llvm::Record *record)
const llvm::Record * getCombinerDef() const
Definition: Predicate.cpp:79
std::vector< const llvm::Record * > getChildren() const
Definition: Predicate.cpp:84
std::string getConditionImpl() const
Definition: Predicate.cpp:365
StringRef getSuffix() const
Definition: Predicate.cpp:388
StringRef getPrefix() const
Definition: Predicate.cpp:384
bool isNull() const
Definition: Predicate.h:46
std::string getCondition() const
Definition: Predicate.cpp:38
bool isCombined() const
Definition: Predicate.cpp:47
ArrayRef< SMLoc > getLoc() const
Definition: Predicate.cpp:51
const llvm::Record * def
Definition: Predicate.h:75
StringRef getReplacement() const
Definition: Predicate.cpp:380
StringRef getPattern() const
Definition: Predicate.cpp:376
Include the generated interface declarations.