MLIR 22.0.0git
Predicate.cpp
Go to the documentation of this file.
1//===- Predicate.cpp - Predicate class ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Wrapper around predicates defined in TableGen.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/SmallPtrSet.h"
15#include "llvm/ADT/StringExtras.h"
16#include "llvm/ADT/StringSwitch.h"
17#include "llvm/TableGen/Error.h"
18#include "llvm/TableGen/Record.h"
19
20using namespace mlir;
21using namespace tblgen;
22using llvm::Init;
23using llvm::Record;
24using llvm::SpecificBumpPtrAllocator;
25
26// Construct a Predicate from a record.
27Pred::Pred(const Record *record) : def(record) {
28 assert(def->isSubClassOf("Pred") &&
29 "must be a subclass of TableGen 'Pred' class");
30}
31
32// Construct a Predicate from an initializer.
33Pred::Pred(const Init *init) {
34 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
35 def = defInit->getDef();
36}
37
38std::string Pred::getCondition() const {
39 // Static dispatch to subclasses.
40 if (def->isSubClassOf("CombinedPred"))
41 return static_cast<const CombinedPred *>(this)->getConditionImpl();
42 if (def->isSubClassOf("CPred"))
43 return static_cast<const CPred *>(this)->getConditionImpl();
44 llvm_unreachable("Pred::getCondition must be overridden in subclasses");
45}
46
47bool Pred::isCombined() const {
48 return def && def->isSubClassOf("CombinedPred");
49}
50
51ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
52
53CPred::CPred(const Record *record) : Pred(record) {
54 assert(def->isSubClassOf("CPred") &&
55 "must be a subclass of Tablegen 'CPred' class");
56}
57
58CPred::CPred(const Init *init) : Pred(init) {
59 assert((!def || def->isSubClassOf("CPred")) &&
60 "must be a subclass of Tablegen 'CPred' class");
61}
62
63// Get condition of the C Predicate.
64std::string CPred::getConditionImpl() const {
65 assert(!isNull() && "null predicate does not have a condition");
66 return std::string(def->getValueAsString("predExpr"));
67}
68
69CombinedPred::CombinedPred(const Record *record) : Pred(record) {
70 assert(def->isSubClassOf("CombinedPred") &&
71 "must be a subclass of Tablegen 'CombinedPred' class");
72}
73
74CombinedPred::CombinedPred(const Init *init) : Pred(init) {
75 assert((!def || def->isSubClassOf("CombinedPred")) &&
76 "must be a subclass of Tablegen 'CombinedPred' class");
77}
78
79const Record *CombinedPred::getCombinerDef() const {
80 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
81 return def->getValueAsDef("kind");
82}
83
84std::vector<const Record *> CombinedPred::getChildren() const {
85 assert(def->getValue("children") &&
86 "CombinedPred must have a value 'children'");
87 return def->getValueAsListOfDefs("children");
88}
89
90namespace {
91// Kinds of nodes in a logical predicate tree.
92enum class PredCombinerKind {
93 Leaf,
94 And,
95 Or,
96 Not,
97 SubstLeaves,
98 Concat,
99 // Special kinds that are used in simplification.
100 False,
101 True
102};
103
104// A node in a logical predicate tree.
105struct PredNode {
106 PredCombinerKind kind;
107 const Pred *predicate;
109 std::string expr;
110
111 // Prefix and suffix are used by ConcatPred.
112 std::string prefix;
113 std::string suffix;
114};
115} // namespace
116
117// Get a predicate tree node kind based on the kind used in the predicate
118// TableGen record.
119static PredCombinerKind getPredCombinerKind(const Pred &pred) {
120 if (!pred.isCombined())
121 return PredCombinerKind::Leaf;
122
123 const auto &combinedPred = static_cast<const CombinedPred &>(pred);
125 combinedPred.getCombinerDef()->getName())
126 .Case("PredCombinerAnd", PredCombinerKind::And)
127 .Case("PredCombinerOr", PredCombinerKind::Or)
128 .Case("PredCombinerNot", PredCombinerKind::Not)
129 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
130 .Case("PredCombinerConcat", PredCombinerKind::Concat);
131}
132
133namespace {
134// Substitution<pattern, replacement>.
135using Subst = std::pair<StringRef, StringRef>;
136} // namespace
137
138/// Perform the given substitutions on 'str' in-place.
139static void performSubstitutions(std::string &str,
140 ArrayRef<Subst> substitutions) {
141 // Apply all parent substitutions from innermost to outermost.
142 for (const auto &subst : llvm::reverse(substitutions)) {
143 auto pos = str.find(subst.first);
144 while (pos != std::string::npos) {
145 str.replace(pos, subst.first.size(), std::string(subst.second));
146 // Skip the newly inserted substring, which itself may consider the
147 // pattern to match.
148 pos += subst.second.size();
149 // Find the next possible match position.
150 pos = str.find(subst.first, pos);
151 }
152 }
153}
154
155// Build the predicate tree starting from the top-level predicate, which may
156// have children, and perform leaf substitutions inplace. Note that after
157// substitution, nodes are still pointing to the original TableGen record.
158// All nodes are created within "allocator".
159static PredNode *
161 SpecificBumpPtrAllocator<PredNode> &allocator,
162 ArrayRef<Subst> substitutions) {
163 auto *rootNode = allocator.Allocate();
164 new (rootNode) PredNode;
165 rootNode->kind = getPredCombinerKind(root);
166 rootNode->predicate = &root;
167 if (!root.isCombined()) {
168 rootNode->expr = root.getCondition();
169 performSubstitutions(rootNode->expr, substitutions);
170 return rootNode;
171 }
172
173 // If the current combined predicate is a leaf substitution, append it to the
174 // list before continuing.
175 auto allSubstitutions = llvm::to_vector<4>(substitutions);
176 if (rootNode->kind == PredCombinerKind::SubstLeaves) {
177 const auto &substPred = static_cast<const SubstLeavesPred &>(root);
178 allSubstitutions.push_back(
179 {substPred.getPattern(), substPred.getReplacement()});
180
181 // If the current predicate is a ConcatPred, record the prefix and suffix.
182 } else if (rootNode->kind == PredCombinerKind::Concat) {
183 const auto &concatPred = static_cast<const ConcatPred &>(root);
184 rootNode->prefix = std::string(concatPred.getPrefix());
185 performSubstitutions(rootNode->prefix, substitutions);
186 rootNode->suffix = std::string(concatPred.getSuffix());
187 performSubstitutions(rootNode->suffix, substitutions);
188 }
189
190 // Build child subtrees.
191 auto combined = static_cast<const CombinedPred &>(root);
192 for (const auto *record : combined.getChildren()) {
193 auto *childTree =
194 buildPredicateTree(Pred(record), allocator, allSubstitutions);
195 rootNode->children.push_back(childTree);
196 }
197 return rootNode;
198}
199
200// Simplify a predicate tree rooted at "node" using the predicates that are
201// known to be true(false). For AND(OR) combined predicates, if any of the
202// children is known to be false(true), the result is also false(true).
203// Furthermore, for AND(OR) combined predicates, children that are known to be
204// true(false) don't have to be checked dynamically.
205static PredNode *
206propagateGroundTruth(PredNode *node,
207 const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
208 const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
209 // If the current predicate is known to be true or false, change the kind of
210 // the node and return immediately.
211 if (knownTruePreds.count(node->predicate) != 0) {
212 node->kind = PredCombinerKind::True;
213 node->children.clear();
214 return node;
215 }
216 if (knownFalsePreds.count(node->predicate) != 0) {
217 node->kind = PredCombinerKind::False;
218 node->children.clear();
219 return node;
220 }
221
222 // If the current node is a substitution, stop recursion now.
223 // The expressions in the leaves below this node were rewritten, but the nodes
224 // still point to the original predicate records. While the original
225 // predicate may be known to be true or false, it is not necessarily the case
226 // after rewriting.
227 // TODO: we can support ground truth for rewritten
228 // predicates by either (a) having our own unique'ing of the predicates
229 // instead of relying on TableGen record pointers or (b) taking ground truth
230 // values optionally prefixed with a list of substitutions to apply, e.g.
231 // "predX is true by itself as well as predSubY leaf substitution had been
232 // applied to it".
233 if (node->kind == PredCombinerKind::SubstLeaves) {
234 return node;
235 }
236
237 if (node->kind == PredCombinerKind::And && node->children.empty()) {
238 node->kind = PredCombinerKind::True;
239 return node;
240 }
241
242 if (node->kind == PredCombinerKind::Or && node->children.empty()) {
243 node->kind = PredCombinerKind::False;
244 return node;
245 }
246
247 // Otherwise, look at child nodes.
248
249 // Move child nodes into some local variable so that they can be optimized
250 // separately and re-added if necessary.
252 std::swap(node->children, children);
253
254 for (auto &child : children) {
255 // First, simplify the child. This maintains the predicate as it was.
256 auto *simplifiedChild =
257 propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
258
259 // Just add the child if we don't know how to simplify the current node.
260 if (node->kind != PredCombinerKind::And &&
261 node->kind != PredCombinerKind::Or) {
262 node->children.push_back(simplifiedChild);
263 continue;
264 }
265
266 // Second, based on the type define which known values of child predicates
267 // immediately collapse this predicate to a known value, and which others
268 // may be safely ignored.
269 // OR(..., True, ...) = True
270 // OR(..., False, ...) = OR(..., ...)
271 // AND(..., False, ...) = False
272 // AND(..., True, ...) = AND(..., ...)
273 auto collapseKind = node->kind == PredCombinerKind::And
274 ? PredCombinerKind::False
275 : PredCombinerKind::True;
276 auto eraseKind = node->kind == PredCombinerKind::And
277 ? PredCombinerKind::True
278 : PredCombinerKind::False;
279 const auto &collapseList =
280 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
281 const auto &eraseList =
282 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
283 if (simplifiedChild->kind == collapseKind ||
284 collapseList.count(simplifiedChild->predicate) != 0) {
285 node->kind = collapseKind;
286 node->children.clear();
287 return node;
288 }
289 if (simplifiedChild->kind == eraseKind ||
290 eraseList.count(simplifiedChild->predicate) != 0) {
291 continue;
292 }
293 node->children.push_back(simplifiedChild);
294 }
295 return node;
296}
297
298// Combine a list of predicate expressions using a binary combiner. If a list
299// is empty, return "init".
300static std::string combineBinary(ArrayRef<std::string> children,
301 const std::string &combiner,
302 std::string init) {
303 if (children.empty())
304 return init;
305
306 auto size = children.size();
307 if (size == 1)
308 return children.front();
309
310 std::string str;
311 llvm::raw_string_ostream os(str);
312 os << '(' << children.front() << ')';
313 for (unsigned i = 1; i < size; ++i) {
314 os << ' ' << combiner << " (" << children[i] << ')';
315 }
316 return str;
317}
318
319// Prepend negation to the only condition in the predicate expression list.
320static std::string combineNot(ArrayRef<std::string> children) {
321 assert(children.size() == 1 && "expected exactly one child predicate of Neg");
322 return (Twine("!(") + children.front() + Twine(')')).str();
323}
324
325// Recursively traverse the predicate tree in depth-first post-order and build
326// the final expression.
327static std::string getCombinedCondition(const PredNode &root) {
328 // Immediately return for non-combiner predicates that don't have children.
329 if (root.kind == PredCombinerKind::Leaf)
330 return root.expr;
331 if (root.kind == PredCombinerKind::True)
332 return "true";
333 if (root.kind == PredCombinerKind::False)
334 return "false";
335
336 // Recurse into children.
337 llvm::SmallVector<std::string, 4> childExpressions;
338 childExpressions.reserve(root.children.size());
339 for (const auto &child : root.children)
340 childExpressions.push_back(getCombinedCondition(*child));
341
342 // Combine the expressions based on the predicate node kind.
343 if (root.kind == PredCombinerKind::And)
344 return combineBinary(childExpressions, "&&", "true");
345 if (root.kind == PredCombinerKind::Or)
346 return combineBinary(childExpressions, "||", "false");
347 if (root.kind == PredCombinerKind::Not)
348 return combineNot(childExpressions);
349 if (root.kind == PredCombinerKind::Concat) {
350 assert(childExpressions.size() == 1 &&
351 "ConcatPred should only have one child");
352 return root.prefix + childExpressions.front() + root.suffix;
353 }
354
355 // Substitutions were applied before so just ignore them.
356 if (root.kind == PredCombinerKind::SubstLeaves) {
357 assert(childExpressions.size() == 1 &&
358 "substitution predicate must have one child");
359 return childExpressions[0];
360 }
361
362 llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
363}
364
366 SpecificBumpPtrAllocator<PredNode> allocator;
367 auto *predicateTree = buildPredicateTree(*this, allocator, {});
368 predicateTree =
369 propagateGroundTruth(predicateTree,
370 /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
371 /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
372
373 return getCombinedCondition(*predicateTree);
374}
375
377 return def->getValueAsString("pattern");
378}
379
381 return def->getValueAsString("replacement");
382}
383
384StringRef ConcatPred::getPrefix() const {
385 return def->getValueAsString("prefix");
386}
387
388StringRef ConcatPred::getSuffix() const {
389 return def->getValueAsString("suffix");
390}
static std::string combineBinary(ArrayRef< std::string > children, const std::string &combiner, std::string init)
static std::string getCombinedCondition(const PredNode &root)
static PredCombinerKind getPredCombinerKind(const Pred &pred)
static void performSubstitutions(std::string &str, ArrayRef< Subst > substitutions)
Perform the given substitutions on 'str' in-place.
static PredNode * buildPredicateTree(const Pred &root, SpecificBumpPtrAllocator< PredNode > &allocator, ArrayRef< Subst > substitutions)
static std::string combineNot(ArrayRef< std::string > children)
static PredNode * propagateGroundTruth(PredNode *node, const llvm::SmallPtrSetImpl< Pred * > &knownTruePreds, const llvm::SmallPtrSetImpl< Pred * > &knownFalsePreds)
std::string getConditionImpl() const
Definition Predicate.cpp:64
CPred(const llvm::Record *record)
CombinedPred(const llvm::Record *record)
const llvm::Record * getCombinerDef() const
Definition Predicate.cpp:79
std::vector< const llvm::Record * > getChildren() const
Definition Predicate.cpp:84
std::string getConditionImpl() const
StringRef getSuffix() const
StringRef getPrefix() const
bool isNull() const
Definition Predicate.h:46
std::string getCondition() const
Definition Predicate.cpp:38
bool isCombined() const
Definition Predicate.cpp:47
ArrayRef< SMLoc > getLoc() const
Definition Predicate.cpp:51
const llvm::Record * def
Definition Predicate.h:75
StringRef getReplacement() const
StringRef getPattern() const
Include the generated interface declarations.
llvm::StringSwitch< T, R > StringSwitch
Definition LLVM.h:141