MLIR  21.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::DagInit;
30 using llvm::dbgs;
31 using llvm::DefInit;
32 using llvm::formatv;
33 using llvm::IntInit;
34 using llvm::Record;
35 
36 //===----------------------------------------------------------------------===//
37 // DagLeaf
38 //===----------------------------------------------------------------------===//
39 
40 bool DagLeaf::isUnspecified() const {
41  return isa_and_nonnull<llvm::UnsetInit>(def);
42 }
43 
45  // Operand matchers specify a type constraint.
46  return isSubClassOf("TypeConstraint");
47 }
48 
49 bool DagLeaf::isAttrMatcher() const {
50  // Attribute matchers specify an attribute constraint.
51  return isSubClassOf("AttrConstraint");
52 }
53 
55  return isSubClassOf("NativeCodeCall");
56 }
57 
58 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
59 
60 bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
61 
62 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
63 
65  assert((isOperandMatcher() || isAttrMatcher()) &&
66  "the DAG leaf must be operand or attribute");
67  return Constraint(cast<DefInit>(def)->getDef());
68 }
69 
71  assert(isConstantAttr() && "the DAG leaf must be constant attribute");
72  return ConstantAttr(cast<DefInit>(def));
73 }
74 
76  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
77  return EnumCase(cast<DefInit>(def));
78 }
79 
80 std::string DagLeaf::getConditionTemplate() const {
82 }
83 
84 StringRef DagLeaf::getNativeCodeTemplate() const {
85  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
86  return cast<DefInit>(def)->getDef()->getValueAsString("expression");
87 }
88 
90  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
91  return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
92 }
93 
94 std::string DagLeaf::getStringAttr() const {
95  assert(isStringAttr() && "the DAG leaf must be string attribute");
96  return def->getAsUnquotedString();
97 }
98 bool DagLeaf::isSubClassOf(StringRef superclass) const {
99  if (auto *defInit = dyn_cast_or_null<DefInit>(def))
100  return defInit->getDef()->isSubClassOf(superclass);
101  return false;
102 }
103 
104 void DagLeaf::print(raw_ostream &os) const {
105  if (def)
106  def->print(os);
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // DagNode
111 //===----------------------------------------------------------------------===//
112 
114  if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
115  return defInit->getDef()->isSubClassOf("NativeCodeCall");
116  return false;
117 }
118 
119 bool DagNode::isOperation() const {
120  return !isNativeCodeCall() && !isReplaceWithValue() &&
122  !isVariadic();
123 }
124 
126  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
127  return cast<DefInit>(node->getOperator())
128  ->getDef()
129  ->getValueAsString("expression");
130 }
131 
133  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
134  return cast<DefInit>(node->getOperator())
135  ->getDef()
136  ->getValueAsInt("numReturns");
137 }
138 
139 StringRef DagNode::getSymbol() const { return node->getNameStr(); }
140 
142  const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
143  auto [it, inserted] = mapper->try_emplace(opDef);
144  if (inserted)
145  it->second = std::make_unique<Operator>(opDef);
146  return *it->second;
147 }
148 
149 int DagNode::getNumOps() const {
150  // We want to get number of operations recursively involved in the DAG tree.
151  // All other directives should be excluded.
152  int count = isOperation() ? 1 : 0;
153  for (int i = 0, e = getNumArgs(); i != e; ++i) {
154  if (auto child = getArgAsNestedDag(i))
155  count += child.getNumOps();
156  }
157  return count;
158 }
159 
160 int DagNode::getNumArgs() const { return node->getNumArgs(); }
161 
162 bool DagNode::isNestedDagArg(unsigned index) const {
163  return isa<DagInit>(node->getArg(index));
164 }
165 
166 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
167  return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
168 }
169 
170 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
171  assert(!isNestedDagArg(index));
172  return DagLeaf(node->getArg(index));
173 }
174 
175 StringRef DagNode::getArgName(unsigned index) const {
176  return node->getArgNameStr(index);
177 }
178 
180  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
181  return dagOpDef->getName() == "replaceWithValue";
182 }
183 
185  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
186  return dagOpDef->getName() == "location";
187 }
188 
190  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
191  return dagOpDef->getName() == "returnType";
192 }
193 
194 bool DagNode::isEither() const {
195  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
196  return dagOpDef->getName() == "either";
197 }
198 
199 bool DagNode::isVariadic() const {
200  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
201  return dagOpDef->getName() == "variadic";
202 }
203 
204 void DagNode::print(raw_ostream &os) const {
205  if (node)
206  node->print(os);
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // SymbolInfoMap
211 //===----------------------------------------------------------------------===//
212 
213 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
214  int idx = -1;
215  auto [name, indexStr] = symbol.rsplit("__");
216 
217  if (indexStr.consumeInteger(10, idx)) {
218  // The second part is not an index; we return the whole symbol as-is.
219  return symbol;
220  }
221  if (index) {
222  *index = idx;
223  }
224  return name;
225 }
226 
227 SymbolInfoMap::SymbolInfo::SymbolInfo(
228  const Operator *op, SymbolInfo::Kind kind,
229  std::optional<DagAndConstant> dagAndConstant)
230  : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
231 
232 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
233  switch (kind) {
234  case Kind::Attr:
235  case Kind::Operand:
236  case Kind::Value:
237  return 1;
238  case Kind::Result:
239  return op->getNumResults();
240  case Kind::MultipleValues:
241  return getSize();
242  }
243  llvm_unreachable("unknown kind");
244 }
245 
246 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
247  return alternativeName ? *alternativeName : name.str();
248 }
249 
250 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
251  LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
252  switch (kind) {
253  case Kind::Attr: {
254  if (op)
255  return cast<NamedAttribute *>(op->getArg(getArgIndex()))
256  ->attr.getStorageType()
257  .str();
258  // TODO(suderman): Use a more exact type when available.
259  return "::mlir::Attribute";
260  }
261  case Kind::Operand: {
262  // Use operand range for captured operands (to support potential variadic
263  // operands).
264  return "::mlir::Operation::operand_range";
265  }
266  case Kind::Value: {
267  return "::mlir::Value";
268  }
269  case Kind::MultipleValues: {
270  return "::mlir::ValueRange";
271  }
272  case Kind::Result: {
273  // Use the op itself for captured results.
274  return op->getQualCppClassName();
275  }
276  }
277  llvm_unreachable("unknown kind");
278 }
279 
280 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
281  LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
282  std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
283  return std::string(
284  formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
285 }
286 
287 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
288  LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
289  return std::string(
290  formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
291 }
292 
293 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
294  StringRef name, int index, const char *fmt, const char *separator) const {
295  LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
296  switch (kind) {
297  case Kind::Attr: {
298  assert(index < 0);
299  auto repl = formatv(fmt, name);
300  LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
301  return std::string(repl);
302  }
303  case Kind::Operand: {
304  assert(index < 0);
305  auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
306  if (operand->isOptional()) {
307  auto repl = formatv(
308  fmt, formatv("({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
309  LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n");
310  return std::string(repl);
311  }
312  // If this operand is variadic and this SymbolInfo doesn't have a range
313  // index, then return the full variadic operand_range. Otherwise, return
314  // the value itself.
315  if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
316  auto repl = formatv(fmt, name);
317  LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
318  return std::string(repl);
319  }
320  auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
321  LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
322  return std::string(repl);
323  }
324  case Kind::Result: {
325  // If `index` is greater than zero, then we are referencing a specific
326  // result of a multi-result op. The result can still be variadic.
327  if (index >= 0) {
328  std::string v =
329  std::string(formatv("{0}.getODSResults({1})", name, index));
330  if (!op->getResult(index).isVariadic())
331  v = std::string(formatv("(*{0}.begin())", v));
332  auto repl = formatv(fmt, v);
333  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
334  return std::string(repl);
335  }
336 
337  // If this op has no result at all but still we bind a symbol to it, it
338  // means we want to capture the op itself.
339  if (op->getNumResults() == 0) {
340  LLVM_DEBUG(dbgs() << name << " (Op)\n");
341  return formatv(fmt, name);
342  }
343 
344  // We are referencing all results of the multi-result op. A specific result
345  // can either be a value or a range. Then join them with `separator`.
347  values.reserve(op->getNumResults());
348 
349  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
350  std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
351  if (!op->getResult(i).isVariadic()) {
352  v = std::string(formatv("(*{0}.begin())", v));
353  }
354  values.push_back(std::string(formatv(fmt, v)));
355  }
356  auto repl = llvm::join(values, separator);
357  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
358  return repl;
359  }
360  case Kind::Value: {
361  assert(index < 0);
362  assert(op == nullptr);
363  auto repl = formatv(fmt, name);
364  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
365  return std::string(repl);
366  }
367  case Kind::MultipleValues: {
368  assert(op == nullptr);
369  assert(index < getSize());
370  if (index >= 0) {
371  std::string repl =
372  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
373  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
374  return repl;
375  }
376  // If it doesn't specify certain element, unpack them all.
377  auto repl =
378  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
379  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
380  return std::string(repl);
381  }
382  }
383  llvm_unreachable("unknown kind");
384 }
385 
386 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
387  StringRef name, int index, const char *fmt, const char *separator) const {
388  LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
389  switch (kind) {
390  case Kind::Attr:
391  case Kind::Operand: {
392  assert(index < 0 && "only allowed for symbol bound to result");
393  auto repl = formatv(fmt, name);
394  LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
395  return std::string(repl);
396  }
397  case Kind::Result: {
398  if (index >= 0) {
399  auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
400  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
401  return std::string(repl);
402  }
403 
404  // We are referencing all results of the multi-result op. Each result should
405  // have a value range, and then join them with `separator`.
407  values.reserve(op->getNumResults());
408 
409  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
410  values.push_back(std::string(
411  formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
412  }
413  auto repl = llvm::join(values, separator);
414  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
415  return repl;
416  }
417  case Kind::Value: {
418  assert(index < 0 && "only allowed for symbol bound to result");
419  assert(op == nullptr);
420  auto repl = formatv(fmt, formatv("{{{0}}", name));
421  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
422  return std::string(repl);
423  }
424  case Kind::MultipleValues: {
425  assert(op == nullptr);
426  assert(index < getSize());
427  if (index >= 0) {
428  std::string repl =
429  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
430  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
431  return repl;
432  }
433  auto repl =
434  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
435  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
436  return std::string(repl);
437  }
438  }
439  llvm_unreachable("unknown kind");
440 }
441 
442 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
443  const Operator &op, int argIndex,
444  std::optional<int> variadicSubIndex) {
445  StringRef name = getValuePackName(symbol);
446  if (name != symbol) {
447  auto error = formatv(
448  "symbol '{0}' with trailing index cannot bind to op argument", symbol);
449  PrintFatalError(loc, error);
450  }
451 
452  auto symInfo =
453  isa<NamedAttribute *>(op.getArg(argIndex))
454  ? SymbolInfo::getAttr(&op, argIndex)
455  : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
456 
457  std::string key = symbol.str();
458  if (symbolInfoMap.count(key)) {
459  // Only non unique name for the operand is supported.
460  if (symInfo.kind != SymbolInfo::Kind::Operand) {
461  return false;
462  }
463 
464  // Cannot add new operand if there is already non operand with the same
465  // name.
466  if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
467  return false;
468  }
469  }
470 
471  symbolInfoMap.emplace(key, symInfo);
472  return true;
473 }
474 
475 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
476  std::string name = getValuePackName(symbol).str();
477  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
478 
479  return symbolInfoMap.count(inserted->first) == 1;
480 }
481 
482 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
483  std::string name = getValuePackName(symbol).str();
484  if (numValues > 1)
485  return bindMultipleValues(name, numValues);
486  return bindValue(name);
487 }
488 
489 bool SymbolInfoMap::bindValue(StringRef symbol) {
490  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
491  return symbolInfoMap.count(inserted->first) == 1;
492 }
493 
494 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
495  std::string name = getValuePackName(symbol).str();
496  auto inserted =
497  symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
498  return symbolInfoMap.count(inserted->first) == 1;
499 }
500 
501 bool SymbolInfoMap::bindAttr(StringRef symbol) {
502  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
503  return symbolInfoMap.count(inserted->first) == 1;
504 }
505 
506 bool SymbolInfoMap::contains(StringRef symbol) const {
507  return find(symbol) != symbolInfoMap.end();
508 }
509 
511  std::string name = getValuePackName(key).str();
512 
513  return symbolInfoMap.find(name);
514 }
515 
517 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
518  int argIndex,
519  std::optional<int> variadicSubIndex) const {
520  return findBoundSymbol(
521  key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
522 }
523 
526  const SymbolInfo &symbolInfo) const {
527  std::string name = getValuePackName(key).str();
528  auto range = symbolInfoMap.equal_range(name);
529 
530  for (auto it = range.first; it != range.second; ++it)
531  if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
532  return it;
533 
534  return symbolInfoMap.end();
535 }
536 
537 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
539  std::string name = getValuePackName(key).str();
540 
541  return symbolInfoMap.equal_range(name);
542 }
543 
544 int SymbolInfoMap::count(StringRef key) const {
545  std::string name = getValuePackName(key).str();
546  return symbolInfoMap.count(name);
547 }
548 
549 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
550  StringRef name = getValuePackName(symbol);
551  if (name != symbol) {
552  // If there is a trailing index inside symbol, it references just one
553  // static value.
554  return 1;
555  }
556  // Otherwise, find how many it represents by querying the symbol's info.
557  return find(name)->second.getStaticValueCount();
558 }
559 
560 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
561  const char *fmt,
562  const char *separator) const {
563  int index = -1;
564  StringRef name = getValuePackName(symbol, &index);
565 
566  auto it = symbolInfoMap.find(name.str());
567  if (it == symbolInfoMap.end()) {
568  auto error = formatv("referencing unbound symbol '{0}'", symbol);
569  PrintFatalError(loc, error);
570  }
571 
572  return it->second.getValueAndRangeUse(name, index, fmt, separator);
573 }
574 
575 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
576  const char *separator) const {
577  int index = -1;
578  StringRef name = getValuePackName(symbol, &index);
579 
580  auto it = symbolInfoMap.find(name.str());
581  if (it == symbolInfoMap.end()) {
582  auto error = formatv("referencing unbound symbol '{0}'", symbol);
583  PrintFatalError(loc, error);
584  }
585 
586  return it->second.getAllRangeUse(name, index, fmt, separator);
587 }
588 
590  llvm::StringSet<> usedNames;
591 
592  for (auto symbolInfoIt = symbolInfoMap.begin();
593  symbolInfoIt != symbolInfoMap.end();) {
594  auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
595  auto startRange = range.first;
596  auto endRange = range.second;
597 
598  auto operandName = symbolInfoIt->first;
599  int startSearchIndex = 0;
600  for (++startRange; startRange != endRange; ++startRange) {
601  // Current operand name is not unique, find a unique one
602  // and set the alternative name.
603  for (int i = startSearchIndex;; ++i) {
604  std::string alternativeName = operandName + std::to_string(i);
605  if (!usedNames.contains(alternativeName) &&
606  symbolInfoMap.count(alternativeName) == 0) {
607  usedNames.insert(alternativeName);
608  startRange->second.alternativeName = alternativeName;
609  startSearchIndex = i + 1;
610 
611  break;
612  }
613  }
614  }
615 
616  symbolInfoIt = endRange;
617  }
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // Pattern
622 //==----------------------------------------------------------------------===//
623 
624 Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
625  : def(*def), recordOpMap(mapper) {}
626 
628  return DagNode(def.getValueAsDag("sourcePattern"));
629 }
630 
632  auto *results = def.getValueAsListInit("resultPatterns");
633  return results->size();
634 }
635 
636 DagNode Pattern::getResultPattern(unsigned index) const {
637  auto *results = def.getValueAsListInit("resultPatterns");
638  return DagNode(cast<DagInit>(results->getElement(index)));
639 }
640 
642  LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
643  collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
644  LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
645 
646  LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
648  LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
649 }
650 
652  LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
653  for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
654  auto pattern = getResultPattern(i);
655  collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
656  }
657  LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
658 }
659 
661  return getSourcePattern().getDialectOp(recordOpMap);
662 }
663 
665  return node.getDialectOp(recordOpMap);
666 }
667 
668 std::vector<AppliedConstraint> Pattern::getConstraints() const {
669  auto *listInit = def.getValueAsListInit("constraints");
670  std::vector<AppliedConstraint> ret;
671  ret.reserve(listInit->size());
672 
673  for (auto *it : *listInit) {
674  auto *dagInit = dyn_cast<DagInit>(it);
675  if (!dagInit)
676  PrintFatalError(&def, "all elements in Pattern multi-entity "
677  "constraints should be DAG nodes");
678 
679  std::vector<std::string> entities;
680  entities.reserve(dagInit->arg_size());
681  for (auto *argName : dagInit->getArgNames()) {
682  if (!argName) {
683  PrintFatalError(
684  &def,
685  "operands to additional constraints can only be symbol references");
686  }
687  entities.emplace_back(argName->getValue());
688  }
689 
690  ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
691  dagInit->getNameStr(), std::move(entities));
692  }
693  return ret;
694 }
695 
697  auto *results = def.getValueAsListInit("supplementalPatterns");
698  return results->size();
699 }
700 
702  auto *results = def.getValueAsListInit("supplementalPatterns");
703  return DagNode(cast<DagInit>(results->getElement(index)));
704 }
705 
706 int Pattern::getBenefit() const {
707  // The initial benefit value is a heuristic with number of ops in the source
708  // pattern.
709  int initBenefit = getSourcePattern().getNumOps();
710  const DagInit *delta = def.getValueAsDag("benefitDelta");
711  if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
712  PrintFatalError(&def,
713  "The 'addBenefit' takes and only takes one integer value");
714  }
715  return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
716 }
717 
718 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
719  std::vector<std::pair<StringRef, unsigned>> result;
720  result.reserve(def.getLoc().size());
721  for (auto loc : def.getLoc()) {
722  unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
723  assert(buf && "invalid source location");
724  result.emplace_back(
725  llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
726  llvm::SrcMgr.getLineAndColumn(loc, buf).first);
727  }
728  return result;
729 }
730 
731 void Pattern::verifyBind(bool result, StringRef symbolName) {
732  if (!result) {
733  auto err = formatv("symbol '{0}' bound more than once", symbolName);
734  PrintFatalError(&def, err);
735  }
736 }
737 
739  bool isSrcPattern) {
740  auto treeName = tree.getSymbol();
741  auto numTreeArgs = tree.getNumArgs();
742 
743  if (tree.isNativeCodeCall()) {
744  if (!treeName.empty()) {
745  if (!isSrcPattern) {
746  LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
747  << treeName << '\n');
748  verifyBind(
749  infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
750  treeName);
751  } else {
752  PrintFatalError(&def,
753  formatv("binding symbol '{0}' to NativecodeCall in "
754  "MatchPattern is not supported",
755  treeName));
756  }
757  }
758 
759  for (int i = 0; i != numTreeArgs; ++i) {
760  if (auto treeArg = tree.getArgAsNestedDag(i)) {
761  // This DAG node argument is a DAG node itself. Go inside recursively.
762  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
763  continue;
764  }
765 
766  if (!isSrcPattern)
767  continue;
768 
769  // We can only bind symbols to arguments in source pattern. Those
770  // symbols are referenced in result patterns.
771  auto treeArgName = tree.getArgName(i);
772 
773  // `$_` is a special symbol meaning ignore the current argument.
774  if (!treeArgName.empty() && treeArgName != "_") {
775  DagLeaf leaf = tree.getArgAsLeaf(i);
776 
777  // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
778  if (leaf.isUnspecified()) {
779  // This is case of $c, a Value without any constraints.
780  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
781  } else {
782  auto constraint = leaf.getAsConstraint();
783  bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
784  leaf.isConstantAttr() ||
785  constraint.getKind() == Constraint::Kind::CK_Attr;
786 
787  if (isAttr) {
788  // This is case of $a, a binding to a certain attribute.
789  verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
790  continue;
791  }
792 
793  // This is case of $b, a binding to a certain type.
794  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
795  }
796  }
797  }
798 
799  return;
800  }
801 
802  if (tree.isOperation()) {
803  auto &op = getDialectOp(tree);
804  auto numOpArgs = op.getNumArgs();
805  int numEither = 0;
806 
807  // We need to exclude the trailing directives and `either` directive groups
808  // two operands of the operation.
809  int numDirectives = 0;
810  for (int i = numTreeArgs - 1; i >= 0; --i) {
811  if (auto dagArg = tree.getArgAsNestedDag(i)) {
812  if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
813  ++numDirectives;
814  else if (dagArg.isEither())
815  ++numEither;
816  }
817  }
818 
819  if (numOpArgs != numTreeArgs - numDirectives + numEither) {
820  auto err =
821  formatv("op '{0}' argument number mismatch: "
822  "{1} in pattern vs. {2} in definition",
823  op.getOperationName(), numTreeArgs + numEither, numOpArgs);
824  PrintFatalError(&def, err);
825  }
826 
827  // The name attached to the DAG node's operator is for representing the
828  // results generated from this op. It should be remembered as bound results.
829  if (!treeName.empty()) {
830  LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
831  << '\n');
832  verifyBind(infoMap.bindOpResult(treeName, op), treeName);
833  }
834 
835  // The operand in `either` DAG should be bound to the operation in the
836  // parent DagNode.
837  auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
838  int opArgIdx) {
839  for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
840  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
841  collectBoundSymbols(subTree, infoMap, isSrcPattern);
842  } else {
843  auto argName = tree.getArgName(i);
844  if (!argName.empty() && argName != "_") {
845  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
846  argName);
847  }
848  }
849  }
850  };
851 
852  // The operand in `variadic` DAG should be bound to the operation in the
853  // parent DagNode. The range index must be included as well to distinguish
854  // (potentially) repeating argName within the `variadic` DAG.
855  auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
856  int opArgIdx) {
857  auto treeName = tree.getSymbol();
858  if (!treeName.empty()) {
859  // If treeName is specified, bind to the full variadic operand_range.
860  verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
861  std::nullopt),
862  treeName);
863  }
864 
865  for (int i = 0; i < tree.getNumArgs(); ++i) {
866  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
867  collectBoundSymbols(subTree, infoMap, isSrcPattern);
868  } else {
869  auto argName = tree.getArgName(i);
870  if (!argName.empty() && argName != "_") {
871  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
872  /*variadicSubIndex=*/i),
873  argName);
874  }
875  }
876  }
877  };
878 
879  for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
880  if (auto treeArg = tree.getArgAsNestedDag(i)) {
881  if (treeArg.isEither()) {
882  collectSymbolInEither(tree, treeArg, opArgIdx);
883  // `either` DAG is *flattened*. For example,
884  //
885  // (FooOp (either arg0, arg1), arg2)
886  //
887  // can be viewed as:
888  //
889  // (FooOp arg0, arg1, arg2)
890  ++opArgIdx;
891  } else if (treeArg.isVariadic()) {
892  collectSymbolInVariadic(tree, treeArg, opArgIdx);
893  } else {
894  // This DAG node argument is a DAG node itself. Go inside recursively.
895  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
896  }
897  continue;
898  }
899 
900  if (isSrcPattern) {
901  // We can only bind symbols to op arguments in source pattern. Those
902  // symbols are referenced in result patterns.
903  auto treeArgName = tree.getArgName(i);
904  // `$_` is a special symbol meaning ignore the current argument.
905  if (!treeArgName.empty() && treeArgName != "_") {
906  LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
907  << treeArgName << '\n');
908  verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
909  treeArgName);
910  }
911  }
912  }
913  return;
914  }
915 
916  if (!treeName.empty()) {
917  PrintFatalError(
918  &def, formatv("binding symbol '{0}' to non-operation/native code call "
919  "unsupported right now",
920  treeName));
921  }
922 }
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Constraint getAsConstraint() const
Definition: Pattern.cpp:64
bool isNativeCodeCall() const
Definition: Pattern.cpp:54
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:89
ConstantAttr getAsConstantAttr() const
Definition: Pattern.cpp:70
void print(raw_ostream &os) const
Definition: Pattern.cpp:104
std::string getStringAttr() const
Definition: Pattern.cpp:94
bool isEnumCase() const
Definition: Pattern.cpp:60
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:84
std::string getConditionTemplate() const
Definition: Pattern.cpp:80
bool isUnspecified() const
Definition: Pattern.cpp:40
EnumCase getAsEnumCase() const
Definition: Pattern.cpp:75
bool isAttrMatcher() const
Definition: Pattern.cpp:49
bool isOperandMatcher() const
Definition: Pattern.cpp:44
bool isConstantAttr() const
Definition: Pattern.cpp:58
bool isStringAttr() const
Definition: Pattern.cpp:62
bool isReturnTypeDirective() const
Definition: Pattern.cpp:189
bool isLocationDirective() const
Definition: Pattern.cpp:184
bool isReplaceWithValue() const
Definition: Pattern.cpp:179
DagNode getArgAsNestedDag(unsigned index) const
Definition: Pattern.cpp:166
bool isOperation() const
Definition: Pattern.cpp:119
DagLeaf getArgAsLeaf(unsigned index) const
Definition: Pattern.cpp:170
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:132
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:125
void print(raw_ostream &os) const
Definition: Pattern.cpp:204
int getNumOps() const
Definition: Pattern.cpp:149
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition: Pattern.cpp:141
bool isVariadic() const
Definition: Pattern.cpp:199
bool isNativeCodeCall() const
Definition: Pattern.cpp:113
bool isEither() const
Definition: Pattern.cpp:194
bool isNestedDagArg(unsigned index) const
Definition: Pattern.cpp:162
StringRef getSymbol() const
Definition: Pattern.cpp:139
int getNumArgs() const
Definition: Pattern.cpp:160
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:144
StringRef getArgName(unsigned index) const
Definition: Pattern.cpp:175
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition: Operator.cpp:166
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition: Operator.cpp:185
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition: Operator.cpp:359
int getBenefit() const
int getNumResultPatterns() const
Definition: Pattern.cpp:631
std::vector< IdentifierLine > getLocation() const
Definition: Pattern.cpp:718
DagNode getSourcePattern() const
Definition: Pattern.cpp:627
const Operator & getSourceRootOp()
Definition: Pattern.cpp:660
std::vector< AppliedConstraint > getConstraints() const
Definition: Pattern.cpp:668
DagNode getResultPattern(unsigned index) const
Definition: Pattern.cpp:636
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition: Pattern.cpp:738
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition: Pattern.cpp:664
DagNode getSupplementalPattern(unsigned index) const
Definition: Pattern.cpp:701
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:641
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:651
int getNumSupplementalPatterns() const
Definition: Pattern.cpp:696
std::string getArgDecl(StringRef name) const
Definition: Pattern.cpp:287
std::string getVarName(StringRef name) const
Definition: Pattern.cpp:246
std::string getVarTypeStr(StringRef name) const
Definition: Pattern.cpp:250
std::string getVarDecl(StringRef name) const
Definition: Pattern.cpp:280
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition: Pattern.cpp:213
int count(StringRef key) const
Definition: Pattern.cpp:544
const_iterator find(StringRef key) const
Definition: Pattern.cpp:510
bool bindMultipleValues(StringRef symbol, int numValues)
Definition: Pattern.cpp:494
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition: Pattern.cpp:442
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:575
bool bindValues(StringRef symbol, int numValues=1)
Definition: Pattern.cpp:482
bool bindAttr(StringRef symbol)
Definition: Pattern.cpp:501
bool bindValue(StringRef symbol)
Definition: Pattern.cpp:489
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition: Pattern.cpp:517
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition: Pattern.cpp:538
int getStaticValueCount(StringRef symbol) const
Definition: Pattern.cpp:549
bool contains(StringRef symbol) const
Definition: Pattern.cpp:506
BaseT::const_iterator const_iterator
Definition: Pattern.h:462
bool bindOpResult(StringRef symbol, const Operator &op)
Definition: Pattern.cpp:475
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:560
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:3776
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.