MLIR  19.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::formatv;
30 
31 //===----------------------------------------------------------------------===//
32 // DagLeaf
33 //===----------------------------------------------------------------------===//
34 
35 bool DagLeaf::isUnspecified() const {
36  return isa_and_nonnull<llvm::UnsetInit>(def);
37 }
38 
40  // Operand matchers specify a type constraint.
41  return isSubClassOf("TypeConstraint");
42 }
43 
44 bool DagLeaf::isAttrMatcher() const {
45  // Attribute matchers specify an attribute constraint.
46  return isSubClassOf("AttrConstraint");
47 }
48 
50  return isSubClassOf("NativeCodeCall");
51 }
52 
53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
54 
56  return isSubClassOf("EnumAttrCaseInfo");
57 }
58 
59 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
60 
62  assert((isOperandMatcher() || isAttrMatcher()) &&
63  "the DAG leaf must be operand or attribute");
64  return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66 
68  assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69  return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71 
73  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74  return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76 
77 std::string DagLeaf::getConditionTemplate() const {
79 }
80 
81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83  return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85 
87  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88  return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
89 }
90 
91 std::string DagLeaf::getStringAttr() const {
92  assert(isStringAttr() && "the DAG leaf must be string attribute");
93  return def->getAsUnquotedString();
94 }
95 bool DagLeaf::isSubClassOf(StringRef superclass) const {
96  if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
97  return defInit->getDef()->isSubClassOf(superclass);
98  return false;
99 }
100 
101 void DagLeaf::print(raw_ostream &os) const {
102  if (def)
103  def->print(os);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // DagNode
108 //===----------------------------------------------------------------------===//
109 
111  if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
112  return defInit->getDef()->isSubClassOf("NativeCodeCall");
113  return false;
114 }
115 
116 bool DagNode::isOperation() const {
117  return !isNativeCodeCall() && !isReplaceWithValue() &&
119  !isVariadic();
120 }
121 
122 llvm::StringRef DagNode::getNativeCodeTemplate() const {
123  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
124  return cast<llvm::DefInit>(node->getOperator())
125  ->getDef()
126  ->getValueAsString("expression");
127 }
128 
130  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
131  return cast<llvm::DefInit>(node->getOperator())
132  ->getDef()
133  ->getValueAsInt("numReturns");
134 }
135 
136 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
137 
139  llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
140  auto it = mapper->find(opDef);
141  if (it != mapper->end())
142  return *it->second;
143  return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
144  .first->second;
145 }
146 
147 int DagNode::getNumOps() const {
148  // We want to get number of operations recursively involved in the DAG tree.
149  // All other directives should be excluded.
150  int count = isOperation() ? 1 : 0;
151  for (int i = 0, e = getNumArgs(); i != e; ++i) {
152  if (auto child = getArgAsNestedDag(i))
153  count += child.getNumOps();
154  }
155  return count;
156 }
157 
158 int DagNode::getNumArgs() const { return node->getNumArgs(); }
159 
160 bool DagNode::isNestedDagArg(unsigned index) const {
161  return isa<llvm::DagInit>(node->getArg(index));
162 }
163 
164 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
165  return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
166 }
167 
168 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
169  assert(!isNestedDagArg(index));
170  return DagLeaf(node->getArg(index));
171 }
172 
173 StringRef DagNode::getArgName(unsigned index) const {
174  return node->getArgNameStr(index);
175 }
176 
178  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
179  return dagOpDef->getName() == "replaceWithValue";
180 }
181 
183  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
184  return dagOpDef->getName() == "location";
185 }
186 
188  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
189  return dagOpDef->getName() == "returnType";
190 }
191 
192 bool DagNode::isEither() const {
193  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
194  return dagOpDef->getName() == "either";
195 }
196 
197 bool DagNode::isVariadic() const {
198  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
199  return dagOpDef->getName() == "variadic";
200 }
201 
202 void DagNode::print(raw_ostream &os) const {
203  if (node)
204  node->print(os);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // SymbolInfoMap
209 //===----------------------------------------------------------------------===//
210 
211 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
212  int idx = -1;
213  auto [name, indexStr] = symbol.rsplit("__");
214 
215  if (indexStr.consumeInteger(10, idx)) {
216  // The second part is not an index; we return the whole symbol as-is.
217  return symbol;
218  }
219  if (index) {
220  *index = idx;
221  }
222  return name;
223 }
224 
225 SymbolInfoMap::SymbolInfo::SymbolInfo(
226  const Operator *op, SymbolInfo::Kind kind,
227  std::optional<DagAndConstant> dagAndConstant)
228  : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
229 
230 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
231  switch (kind) {
232  case Kind::Attr:
233  case Kind::Operand:
234  case Kind::Value:
235  return 1;
236  case Kind::Result:
237  return op->getNumResults();
238  case Kind::MultipleValues:
239  return getSize();
240  }
241  llvm_unreachable("unknown kind");
242 }
243 
244 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
245  return alternativeName ? *alternativeName : name.str();
246 }
247 
248 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
249  LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
250  switch (kind) {
251  case Kind::Attr: {
252  if (op)
253  return op->getArg(getArgIndex())
254  .get<NamedAttribute *>()
255  ->attr.getStorageType()
256  .str();
257  // TODO(suderman): Use a more exact type when available.
258  return "::mlir::Attribute";
259  }
260  case Kind::Operand: {
261  // Use operand range for captured operands (to support potential variadic
262  // operands).
263  return "::mlir::Operation::operand_range";
264  }
265  case Kind::Value: {
266  return "::mlir::Value";
267  }
268  case Kind::MultipleValues: {
269  return "::mlir::ValueRange";
270  }
271  case Kind::Result: {
272  // Use the op itself for captured results.
273  return op->getQualCppClassName();
274  }
275  }
276  llvm_unreachable("unknown kind");
277 }
278 
279 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
280  LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
281  std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
282  return std::string(
283  formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
284 }
285 
286 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
287  LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
288  return std::string(
289  formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
290 }
291 
292 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
293  StringRef name, int index, const char *fmt, const char *separator) const {
294  LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
295  switch (kind) {
296  case Kind::Attr: {
297  assert(index < 0);
298  auto repl = formatv(fmt, name);
299  LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
300  return std::string(repl);
301  }
302  case Kind::Operand: {
303  assert(index < 0);
304  auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
305  // If this operand is variadic and this SymbolInfo doesn't have a range
306  // index, then return the full variadic operand_range. Otherwise, return
307  // the value itself.
308  if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
309  auto repl = formatv(fmt, name);
310  LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
311  return std::string(repl);
312  }
313  auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
314  LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
315  return std::string(repl);
316  }
317  case Kind::Result: {
318  // If `index` is greater than zero, then we are referencing a specific
319  // result of a multi-result op. The result can still be variadic.
320  if (index >= 0) {
321  std::string v =
322  std::string(formatv("{0}.getODSResults({1})", name, index));
323  if (!op->getResult(index).isVariadic())
324  v = std::string(formatv("(*{0}.begin())", v));
325  auto repl = formatv(fmt, v);
326  LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
327  return std::string(repl);
328  }
329 
330  // If this op has no result at all but still we bind a symbol to it, it
331  // means we want to capture the op itself.
332  if (op->getNumResults() == 0) {
333  LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
334  return formatv(fmt, name);
335  }
336 
337  // We are referencing all results of the multi-result op. A specific result
338  // can either be a value or a range. Then join them with `separator`.
340  values.reserve(op->getNumResults());
341 
342  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
343  std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
344  if (!op->getResult(i).isVariadic()) {
345  v = std::string(formatv("(*{0}.begin())", v));
346  }
347  values.push_back(std::string(formatv(fmt, v)));
348  }
349  auto repl = llvm::join(values, separator);
350  LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
351  return repl;
352  }
353  case Kind::Value: {
354  assert(index < 0);
355  assert(op == nullptr);
356  auto repl = formatv(fmt, name);
357  LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
358  return std::string(repl);
359  }
360  case Kind::MultipleValues: {
361  assert(op == nullptr);
362  assert(index < getSize());
363  if (index >= 0) {
364  std::string repl =
365  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
366  LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
367  return repl;
368  }
369  // If it doesn't specify certain element, unpack them all.
370  auto repl =
371  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
372  LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
373  return std::string(repl);
374  }
375  }
376  llvm_unreachable("unknown kind");
377 }
378 
379 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
380  StringRef name, int index, const char *fmt, const char *separator) const {
381  LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
382  switch (kind) {
383  case Kind::Attr:
384  case Kind::Operand: {
385  assert(index < 0 && "only allowed for symbol bound to result");
386  auto repl = formatv(fmt, name);
387  LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
388  return std::string(repl);
389  }
390  case Kind::Result: {
391  if (index >= 0) {
392  auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
393  LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
394  return std::string(repl);
395  }
396 
397  // We are referencing all results of the multi-result op. Each result should
398  // have a value range, and then join them with `separator`.
400  values.reserve(op->getNumResults());
401 
402  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
403  values.push_back(std::string(
404  formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
405  }
406  auto repl = llvm::join(values, separator);
407  LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
408  return repl;
409  }
410  case Kind::Value: {
411  assert(index < 0 && "only allowed for symbol bound to result");
412  assert(op == nullptr);
413  auto repl = formatv(fmt, formatv("{{{0}}", name));
414  LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
415  return std::string(repl);
416  }
417  case Kind::MultipleValues: {
418  assert(op == nullptr);
419  assert(index < getSize());
420  if (index >= 0) {
421  std::string repl =
422  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
423  LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
424  return repl;
425  }
426  auto repl =
427  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
428  LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
429  return std::string(repl);
430  }
431  }
432  llvm_unreachable("unknown kind");
433 }
434 
435 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
436  const Operator &op, int argIndex,
437  std::optional<int> variadicSubIndex) {
438  StringRef name = getValuePackName(symbol);
439  if (name != symbol) {
440  auto error = formatv(
441  "symbol '{0}' with trailing index cannot bind to op argument", symbol);
442  PrintFatalError(loc, error);
443  }
444 
445  auto symInfo =
446  op.getArg(argIndex).is<NamedAttribute *>()
447  ? SymbolInfo::getAttr(&op, argIndex)
448  : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
449 
450  std::string key = symbol.str();
451  if (symbolInfoMap.count(key)) {
452  // Only non unique name for the operand is supported.
453  if (symInfo.kind != SymbolInfo::Kind::Operand) {
454  return false;
455  }
456 
457  // Cannot add new operand if there is already non operand with the same
458  // name.
459  if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
460  return false;
461  }
462  }
463 
464  symbolInfoMap.emplace(key, symInfo);
465  return true;
466 }
467 
468 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
469  std::string name = getValuePackName(symbol).str();
470  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
471 
472  return symbolInfoMap.count(inserted->first) == 1;
473 }
474 
475 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
476  std::string name = getValuePackName(symbol).str();
477  if (numValues > 1)
478  return bindMultipleValues(name, numValues);
479  return bindValue(name);
480 }
481 
482 bool SymbolInfoMap::bindValue(StringRef symbol) {
483  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
484  return symbolInfoMap.count(inserted->first) == 1;
485 }
486 
487 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
488  std::string name = getValuePackName(symbol).str();
489  auto inserted =
490  symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
491  return symbolInfoMap.count(inserted->first) == 1;
492 }
493 
494 bool SymbolInfoMap::bindAttr(StringRef symbol) {
495  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
496  return symbolInfoMap.count(inserted->first) == 1;
497 }
498 
499 bool SymbolInfoMap::contains(StringRef symbol) const {
500  return find(symbol) != symbolInfoMap.end();
501 }
502 
504  std::string name = getValuePackName(key).str();
505 
506  return symbolInfoMap.find(name);
507 }
508 
510 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
511  int argIndex,
512  std::optional<int> variadicSubIndex) const {
513  return findBoundSymbol(
514  key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
515 }
516 
519  const SymbolInfo &symbolInfo) const {
520  std::string name = getValuePackName(key).str();
521  auto range = symbolInfoMap.equal_range(name);
522 
523  for (auto it = range.first; it != range.second; ++it)
524  if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
525  return it;
526 
527  return symbolInfoMap.end();
528 }
529 
530 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
532  std::string name = getValuePackName(key).str();
533 
534  return symbolInfoMap.equal_range(name);
535 }
536 
537 int SymbolInfoMap::count(StringRef key) const {
538  std::string name = getValuePackName(key).str();
539  return symbolInfoMap.count(name);
540 }
541 
542 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
543  StringRef name = getValuePackName(symbol);
544  if (name != symbol) {
545  // If there is a trailing index inside symbol, it references just one
546  // static value.
547  return 1;
548  }
549  // Otherwise, find how many it represents by querying the symbol's info.
550  return find(name)->second.getStaticValueCount();
551 }
552 
553 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
554  const char *fmt,
555  const char *separator) const {
556  int index = -1;
557  StringRef name = getValuePackName(symbol, &index);
558 
559  auto it = symbolInfoMap.find(name.str());
560  if (it == symbolInfoMap.end()) {
561  auto error = formatv("referencing unbound symbol '{0}'", symbol);
562  PrintFatalError(loc, error);
563  }
564 
565  return it->second.getValueAndRangeUse(name, index, fmt, separator);
566 }
567 
568 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
569  const char *separator) const {
570  int index = -1;
571  StringRef name = getValuePackName(symbol, &index);
572 
573  auto it = symbolInfoMap.find(name.str());
574  if (it == symbolInfoMap.end()) {
575  auto error = formatv("referencing unbound symbol '{0}'", symbol);
576  PrintFatalError(loc, error);
577  }
578 
579  return it->second.getAllRangeUse(name, index, fmt, separator);
580 }
581 
583  llvm::StringSet<> usedNames;
584 
585  for (auto symbolInfoIt = symbolInfoMap.begin();
586  symbolInfoIt != symbolInfoMap.end();) {
587  auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
588  auto startRange = range.first;
589  auto endRange = range.second;
590 
591  auto operandName = symbolInfoIt->first;
592  int startSearchIndex = 0;
593  for (++startRange; startRange != endRange; ++startRange) {
594  // Current operand name is not unique, find a unique one
595  // and set the alternative name.
596  for (int i = startSearchIndex;; ++i) {
597  std::string alternativeName = operandName + std::to_string(i);
598  if (!usedNames.contains(alternativeName) &&
599  symbolInfoMap.count(alternativeName) == 0) {
600  usedNames.insert(alternativeName);
601  startRange->second.alternativeName = alternativeName;
602  startSearchIndex = i + 1;
603 
604  break;
605  }
606  }
607  }
608 
609  symbolInfoIt = endRange;
610  }
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // Pattern
615 //==----------------------------------------------------------------------===//
616 
617 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
618  : def(*def), recordOpMap(mapper) {}
619 
621  return DagNode(def.getValueAsDag("sourcePattern"));
622 }
623 
625  auto *results = def.getValueAsListInit("resultPatterns");
626  return results->size();
627 }
628 
629 DagNode Pattern::getResultPattern(unsigned index) const {
630  auto *results = def.getValueAsListInit("resultPatterns");
631  return DagNode(cast<llvm::DagInit>(results->getElement(index)));
632 }
633 
635  LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
636  collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
637  LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
638 
639  LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
641  LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
642 }
643 
645  LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
646  for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
647  auto pattern = getResultPattern(i);
648  collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
649  }
650  LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
651 }
652 
654  return getSourcePattern().getDialectOp(recordOpMap);
655 }
656 
658  return node.getDialectOp(recordOpMap);
659 }
660 
661 std::vector<AppliedConstraint> Pattern::getConstraints() const {
662  auto *listInit = def.getValueAsListInit("constraints");
663  std::vector<AppliedConstraint> ret;
664  ret.reserve(listInit->size());
665 
666  for (auto *it : *listInit) {
667  auto *dagInit = dyn_cast<llvm::DagInit>(it);
668  if (!dagInit)
669  PrintFatalError(&def, "all elements in Pattern multi-entity "
670  "constraints should be DAG nodes");
671 
672  std::vector<std::string> entities;
673  entities.reserve(dagInit->arg_size());
674  for (auto *argName : dagInit->getArgNames()) {
675  if (!argName) {
676  PrintFatalError(
677  &def,
678  "operands to additional constraints can only be symbol references");
679  }
680  entities.emplace_back(argName->getValue());
681  }
682 
683  ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
684  dagInit->getNameStr(), std::move(entities));
685  }
686  return ret;
687 }
688 
690  auto *results = def.getValueAsListInit("supplementalPatterns");
691  return results->size();
692 }
693 
695  auto *results = def.getValueAsListInit("supplementalPatterns");
696  return DagNode(cast<llvm::DagInit>(results->getElement(index)));
697 }
698 
699 int Pattern::getBenefit() const {
700  // The initial benefit value is a heuristic with number of ops in the source
701  // pattern.
702  int initBenefit = getSourcePattern().getNumOps();
703  llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
704  if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
705  PrintFatalError(&def,
706  "The 'addBenefit' takes and only takes one integer value");
707  }
708  return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
709 }
710 
711 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
712  std::vector<std::pair<StringRef, unsigned>> result;
713  result.reserve(def.getLoc().size());
714  for (auto loc : def.getLoc()) {
715  unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
716  assert(buf && "invalid source location");
717  result.emplace_back(
718  llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
719  llvm::SrcMgr.getLineAndColumn(loc, buf).first);
720  }
721  return result;
722 }
723 
724 void Pattern::verifyBind(bool result, StringRef symbolName) {
725  if (!result) {
726  auto err = formatv("symbol '{0}' bound more than once", symbolName);
727  PrintFatalError(&def, err);
728  }
729 }
730 
732  bool isSrcPattern) {
733  auto treeName = tree.getSymbol();
734  auto numTreeArgs = tree.getNumArgs();
735 
736  if (tree.isNativeCodeCall()) {
737  if (!treeName.empty()) {
738  if (!isSrcPattern) {
739  LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
740  << treeName << '\n');
741  verifyBind(
742  infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
743  treeName);
744  } else {
745  PrintFatalError(&def,
746  formatv("binding symbol '{0}' to NativecodeCall in "
747  "MatchPattern is not supported",
748  treeName));
749  }
750  }
751 
752  for (int i = 0; i != numTreeArgs; ++i) {
753  if (auto treeArg = tree.getArgAsNestedDag(i)) {
754  // This DAG node argument is a DAG node itself. Go inside recursively.
755  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
756  continue;
757  }
758 
759  if (!isSrcPattern)
760  continue;
761 
762  // We can only bind symbols to arguments in source pattern. Those
763  // symbols are referenced in result patterns.
764  auto treeArgName = tree.getArgName(i);
765 
766  // `$_` is a special symbol meaning ignore the current argument.
767  if (!treeArgName.empty() && treeArgName != "_") {
768  DagLeaf leaf = tree.getArgAsLeaf(i);
769 
770  // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
771  if (leaf.isUnspecified()) {
772  // This is case of $c, a Value without any constraints.
773  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
774  } else {
775  auto constraint = leaf.getAsConstraint();
776  bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
777  leaf.isConstantAttr() ||
778  constraint.getKind() == Constraint::Kind::CK_Attr;
779 
780  if (isAttr) {
781  // This is case of $a, a binding to a certain attribute.
782  verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
783  continue;
784  }
785 
786  // This is case of $b, a binding to a certain type.
787  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
788  }
789  }
790  }
791 
792  return;
793  }
794 
795  if (tree.isOperation()) {
796  auto &op = getDialectOp(tree);
797  auto numOpArgs = op.getNumArgs();
798  int numEither = 0;
799 
800  // We need to exclude the trailing directives and `either` directive groups
801  // two operands of the operation.
802  int numDirectives = 0;
803  for (int i = numTreeArgs - 1; i >= 0; --i) {
804  if (auto dagArg = tree.getArgAsNestedDag(i)) {
805  if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
806  ++numDirectives;
807  else if (dagArg.isEither())
808  ++numEither;
809  }
810  }
811 
812  if (numOpArgs != numTreeArgs - numDirectives + numEither) {
813  auto err =
814  formatv("op '{0}' argument number mismatch: "
815  "{1} in pattern vs. {2} in definition",
816  op.getOperationName(), numTreeArgs + numEither, numOpArgs);
817  PrintFatalError(&def, err);
818  }
819 
820  // The name attached to the DAG node's operator is for representing the
821  // results generated from this op. It should be remembered as bound results.
822  if (!treeName.empty()) {
823  LLVM_DEBUG(llvm::dbgs()
824  << "found symbol bound to op result: " << treeName << '\n');
825  verifyBind(infoMap.bindOpResult(treeName, op), treeName);
826  }
827 
828  // The operand in `either` DAG should be bound to the operation in the
829  // parent DagNode.
830  auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
831  int opArgIdx) {
832  for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
833  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
834  collectBoundSymbols(subTree, infoMap, isSrcPattern);
835  } else {
836  auto argName = tree.getArgName(i);
837  if (!argName.empty() && argName != "_") {
838  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
839  argName);
840  }
841  }
842  }
843  };
844 
845  // The operand in `variadic` DAG should be bound to the operation in the
846  // parent DagNode. The range index must be included as well to distinguish
847  // (potentially) repeating argName within the `variadic` DAG.
848  auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
849  int opArgIdx) {
850  auto treeName = tree.getSymbol();
851  if (!treeName.empty()) {
852  // If treeName is specified, bind to the full variadic operand_range.
853  verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
854  std::nullopt),
855  treeName);
856  }
857 
858  for (int i = 0; i < tree.getNumArgs(); ++i) {
859  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
860  collectBoundSymbols(subTree, infoMap, isSrcPattern);
861  } else {
862  auto argName = tree.getArgName(i);
863  if (!argName.empty() && argName != "_") {
864  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
865  /*variadicSubIndex=*/i),
866  argName);
867  }
868  }
869  }
870  };
871 
872  for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
873  if (auto treeArg = tree.getArgAsNestedDag(i)) {
874  if (treeArg.isEither()) {
875  collectSymbolInEither(tree, treeArg, opArgIdx);
876  // `either` DAG is *flattened*. For example,
877  //
878  // (FooOp (either arg0, arg1), arg2)
879  //
880  // can be viewed as:
881  //
882  // (FooOp arg0, arg1, arg2)
883  ++opArgIdx;
884  } else if (treeArg.isVariadic()) {
885  collectSymbolInVariadic(tree, treeArg, opArgIdx);
886  } else {
887  // This DAG node argument is a DAG node itself. Go inside recursively.
888  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
889  }
890  continue;
891  }
892 
893  if (isSrcPattern) {
894  // We can only bind symbols to op arguments in source pattern. Those
895  // symbols are referenced in result patterns.
896  auto treeArgName = tree.getArgName(i);
897  // `$_` is a special symbol meaning ignore the current argument.
898  if (!treeArgName.empty() && treeArgName != "_") {
899  LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
900  << treeArgName << '\n');
901  verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
902  treeArgName);
903  }
904  }
905  }
906  return;
907  }
908 
909  if (!treeName.empty()) {
910  PrintFatalError(
911  &def, formatv("binding symbol '{0}' to non-operation/native code call "
912  "unsupported right now",
913  treeName));
914  }
915 }
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Constraint getAsConstraint() const
Definition: Pattern.cpp:61
bool isNativeCodeCall() const
Definition: Pattern.cpp:49
bool isEnumAttrCase() const
Definition: Pattern.cpp:55
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:86
ConstantAttr getAsConstantAttr() const
Definition: Pattern.cpp:67
void print(raw_ostream &os) const
Definition: Pattern.cpp:101
std::string getStringAttr() const
Definition: Pattern.cpp:91
std::string getConditionTemplate() const
Definition: Pattern.cpp:77
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:81
bool isUnspecified() const
Definition: Pattern.cpp:35
bool isAttrMatcher() const
Definition: Pattern.cpp:44
EnumAttrCase getAsEnumAttrCase() const
Definition: Pattern.cpp:72
bool isOperandMatcher() const
Definition: Pattern.cpp:39
bool isConstantAttr() const
Definition: Pattern.cpp:53
bool isStringAttr() const
Definition: Pattern.cpp:59
bool isReturnTypeDirective() const
Definition: Pattern.cpp:187
bool isLocationDirective() const
Definition: Pattern.cpp:182
bool isReplaceWithValue() const
Definition: Pattern.cpp:177
DagNode getArgAsNestedDag(unsigned index) const
Definition: Pattern.cpp:164
bool isOperation() const
Definition: Pattern.cpp:116
DagLeaf getArgAsLeaf(unsigned index) const
Definition: Pattern.cpp:168
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:122
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:129
void print(raw_ostream &os) const
Definition: Pattern.cpp:202
int getNumOps() const
Definition: Pattern.cpp:147
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition: Pattern.cpp:138
bool isVariadic() const
Definition: Pattern.cpp:197
bool isNativeCodeCall() const
Definition: Pattern.cpp:110
bool isEither() const
Definition: Pattern.cpp:192
bool isNestedDagArg(unsigned index) const
Definition: Pattern.cpp:160
int getNumArgs() const
Definition: Pattern.cpp:158
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:143
StringRef getSymbol() const
Definition: Pattern.cpp:136
StringRef getArgName(unsigned index) const
Definition: Pattern.cpp:173
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition: Operator.cpp:163
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition: Operator.cpp:182
int getBenefit() const
int getNumResultPatterns() const
Definition: Pattern.cpp:624
std::vector< IdentifierLine > getLocation() const
Definition: Pattern.cpp:711
DagNode getSourcePattern() const
Definition: Pattern.cpp:620
const Operator & getSourceRootOp()
Definition: Pattern.cpp:653
std::vector< AppliedConstraint > getConstraints() const
Definition: Pattern.cpp:661
DagNode getResultPattern(unsigned index) const
Definition: Pattern.cpp:629
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition: Pattern.cpp:731
Operator & getDialectOp(DagNode node)
Definition: Pattern.cpp:657
DagNode getSupplementalPattern(unsigned index) const
Definition: Pattern.cpp:694
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:634
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:644
int getNumSupplementalPatterns() const
Definition: Pattern.cpp:689
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Definition: Pattern.cpp:617
std::string getArgDecl(StringRef name) const
Definition: Pattern.cpp:286
std::string getVarName(StringRef name) const
Definition: Pattern.cpp:244
std::string getVarTypeStr(StringRef name) const
Definition: Pattern.cpp:248
std::string getVarDecl(StringRef name) const
Definition: Pattern.cpp:279
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition: Pattern.cpp:211
int count(StringRef key) const
Definition: Pattern.cpp:537
const_iterator find(StringRef key) const
Definition: Pattern.cpp:503
bool bindMultipleValues(StringRef symbol, int numValues)
Definition: Pattern.cpp:487
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition: Pattern.cpp:435
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:568
bool bindValues(StringRef symbol, int numValues=1)
Definition: Pattern.cpp:475
bool bindAttr(StringRef symbol)
Definition: Pattern.cpp:494
bool bindValue(StringRef symbol)
Definition: Pattern.cpp:482
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition: Pattern.cpp:510
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition: Pattern.cpp:531
int getStaticValueCount(StringRef symbol) const
Definition: Pattern.cpp:542
bool contains(StringRef symbol) const
Definition: Pattern.cpp:499
BaseT::const_iterator const_iterator
Definition: Pattern.h:461
bool bindOpResult(StringRef symbol, const Operator &op)
Definition: Pattern.cpp:468
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:553
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:2883
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.