MLIR  20.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::DagInit;
30 using llvm::dbgs;
31 using llvm::DefInit;
32 using llvm::formatv;
33 using llvm::IntInit;
34 using llvm::Record;
35 
36 //===----------------------------------------------------------------------===//
37 // DagLeaf
38 //===----------------------------------------------------------------------===//
39 
40 bool DagLeaf::isUnspecified() const {
41  return isa_and_nonnull<llvm::UnsetInit>(def);
42 }
43 
45  // Operand matchers specify a type constraint.
46  return isSubClassOf("TypeConstraint");
47 }
48 
49 bool DagLeaf::isAttrMatcher() const {
50  // Attribute matchers specify an attribute constraint.
51  return isSubClassOf("AttrConstraint");
52 }
53 
55  return isSubClassOf("NativeCodeCall");
56 }
57 
58 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
59 
61  return isSubClassOf("EnumAttrCaseInfo");
62 }
63 
64 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
65 
67  assert((isOperandMatcher() || isAttrMatcher()) &&
68  "the DAG leaf must be operand or attribute");
69  return Constraint(cast<DefInit>(def)->getDef());
70 }
71 
73  assert(isConstantAttr() && "the DAG leaf must be constant attribute");
74  return ConstantAttr(cast<DefInit>(def));
75 }
76 
78  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
79  return EnumAttrCase(cast<DefInit>(def));
80 }
81 
82 std::string DagLeaf::getConditionTemplate() const {
84 }
85 
86 StringRef DagLeaf::getNativeCodeTemplate() const {
87  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88  return cast<DefInit>(def)->getDef()->getValueAsString("expression");
89 }
90 
92  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
93  return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
94 }
95 
96 std::string DagLeaf::getStringAttr() const {
97  assert(isStringAttr() && "the DAG leaf must be string attribute");
98  return def->getAsUnquotedString();
99 }
100 bool DagLeaf::isSubClassOf(StringRef superclass) const {
101  if (auto *defInit = dyn_cast_or_null<DefInit>(def))
102  return defInit->getDef()->isSubClassOf(superclass);
103  return false;
104 }
105 
106 void DagLeaf::print(raw_ostream &os) const {
107  if (def)
108  def->print(os);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // DagNode
113 //===----------------------------------------------------------------------===//
114 
116  if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
117  return defInit->getDef()->isSubClassOf("NativeCodeCall");
118  return false;
119 }
120 
121 bool DagNode::isOperation() const {
122  return !isNativeCodeCall() && !isReplaceWithValue() &&
124  !isVariadic();
125 }
126 
128  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
129  return cast<DefInit>(node->getOperator())
130  ->getDef()
131  ->getValueAsString("expression");
132 }
133 
135  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
136  return cast<DefInit>(node->getOperator())
137  ->getDef()
138  ->getValueAsInt("numReturns");
139 }
140 
141 StringRef DagNode::getSymbol() const { return node->getNameStr(); }
142 
144  const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
145  auto [it, inserted] = mapper->try_emplace(opDef);
146  if (inserted)
147  it->second = std::make_unique<Operator>(opDef);
148  return *it->second;
149 }
150 
151 int DagNode::getNumOps() const {
152  // We want to get number of operations recursively involved in the DAG tree.
153  // All other directives should be excluded.
154  int count = isOperation() ? 1 : 0;
155  for (int i = 0, e = getNumArgs(); i != e; ++i) {
156  if (auto child = getArgAsNestedDag(i))
157  count += child.getNumOps();
158  }
159  return count;
160 }
161 
162 int DagNode::getNumArgs() const { return node->getNumArgs(); }
163 
164 bool DagNode::isNestedDagArg(unsigned index) const {
165  return isa<DagInit>(node->getArg(index));
166 }
167 
168 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
169  return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
170 }
171 
172 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
173  assert(!isNestedDagArg(index));
174  return DagLeaf(node->getArg(index));
175 }
176 
177 StringRef DagNode::getArgName(unsigned index) const {
178  return node->getArgNameStr(index);
179 }
180 
182  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
183  return dagOpDef->getName() == "replaceWithValue";
184 }
185 
187  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
188  return dagOpDef->getName() == "location";
189 }
190 
192  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
193  return dagOpDef->getName() == "returnType";
194 }
195 
196 bool DagNode::isEither() const {
197  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
198  return dagOpDef->getName() == "either";
199 }
200 
201 bool DagNode::isVariadic() const {
202  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
203  return dagOpDef->getName() == "variadic";
204 }
205 
206 void DagNode::print(raw_ostream &os) const {
207  if (node)
208  node->print(os);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // SymbolInfoMap
213 //===----------------------------------------------------------------------===//
214 
215 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
216  int idx = -1;
217  auto [name, indexStr] = symbol.rsplit("__");
218 
219  if (indexStr.consumeInteger(10, idx)) {
220  // The second part is not an index; we return the whole symbol as-is.
221  return symbol;
222  }
223  if (index) {
224  *index = idx;
225  }
226  return name;
227 }
228 
229 SymbolInfoMap::SymbolInfo::SymbolInfo(
230  const Operator *op, SymbolInfo::Kind kind,
231  std::optional<DagAndConstant> dagAndConstant)
232  : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
233 
234 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
235  switch (kind) {
236  case Kind::Attr:
237  case Kind::Operand:
238  case Kind::Value:
239  return 1;
240  case Kind::Result:
241  return op->getNumResults();
242  case Kind::MultipleValues:
243  return getSize();
244  }
245  llvm_unreachable("unknown kind");
246 }
247 
248 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
249  return alternativeName ? *alternativeName : name.str();
250 }
251 
252 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
253  LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
254  switch (kind) {
255  case Kind::Attr: {
256  if (op)
257  return cast<NamedAttribute *>(op->getArg(getArgIndex()))
258  ->attr.getStorageType()
259  .str();
260  // TODO(suderman): Use a more exact type when available.
261  return "::mlir::Attribute";
262  }
263  case Kind::Operand: {
264  // Use operand range for captured operands (to support potential variadic
265  // operands).
266  return "::mlir::Operation::operand_range";
267  }
268  case Kind::Value: {
269  return "::mlir::Value";
270  }
271  case Kind::MultipleValues: {
272  return "::mlir::ValueRange";
273  }
274  case Kind::Result: {
275  // Use the op itself for captured results.
276  return op->getQualCppClassName();
277  }
278  }
279  llvm_unreachable("unknown kind");
280 }
281 
282 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
283  LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
284  std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
285  return std::string(
286  formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
287 }
288 
289 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
290  LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
291  return std::string(
292  formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
293 }
294 
295 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
296  StringRef name, int index, const char *fmt, const char *separator) const {
297  LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
298  switch (kind) {
299  case Kind::Attr: {
300  assert(index < 0);
301  auto repl = formatv(fmt, name);
302  LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
303  return std::string(repl);
304  }
305  case Kind::Operand: {
306  assert(index < 0);
307  auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
308  // If this operand is variadic and this SymbolInfo doesn't have a range
309  // index, then return the full variadic operand_range. Otherwise, return
310  // the value itself.
311  if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
312  auto repl = formatv(fmt, name);
313  LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
314  return std::string(repl);
315  }
316  auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
317  LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
318  return std::string(repl);
319  }
320  case Kind::Result: {
321  // If `index` is greater than zero, then we are referencing a specific
322  // result of a multi-result op. The result can still be variadic.
323  if (index >= 0) {
324  std::string v =
325  std::string(formatv("{0}.getODSResults({1})", name, index));
326  if (!op->getResult(index).isVariadic())
327  v = std::string(formatv("(*{0}.begin())", v));
328  auto repl = formatv(fmt, v);
329  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
330  return std::string(repl);
331  }
332 
333  // If this op has no result at all but still we bind a symbol to it, it
334  // means we want to capture the op itself.
335  if (op->getNumResults() == 0) {
336  LLVM_DEBUG(dbgs() << name << " (Op)\n");
337  return formatv(fmt, name);
338  }
339 
340  // We are referencing all results of the multi-result op. A specific result
341  // can either be a value or a range. Then join them with `separator`.
343  values.reserve(op->getNumResults());
344 
345  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
346  std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
347  if (!op->getResult(i).isVariadic()) {
348  v = std::string(formatv("(*{0}.begin())", v));
349  }
350  values.push_back(std::string(formatv(fmt, v)));
351  }
352  auto repl = llvm::join(values, separator);
353  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
354  return repl;
355  }
356  case Kind::Value: {
357  assert(index < 0);
358  assert(op == nullptr);
359  auto repl = formatv(fmt, name);
360  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
361  return std::string(repl);
362  }
363  case Kind::MultipleValues: {
364  assert(op == nullptr);
365  assert(index < getSize());
366  if (index >= 0) {
367  std::string repl =
368  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
369  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
370  return repl;
371  }
372  // If it doesn't specify certain element, unpack them all.
373  auto repl =
374  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
375  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
376  return std::string(repl);
377  }
378  }
379  llvm_unreachable("unknown kind");
380 }
381 
382 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
383  StringRef name, int index, const char *fmt, const char *separator) const {
384  LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
385  switch (kind) {
386  case Kind::Attr:
387  case Kind::Operand: {
388  assert(index < 0 && "only allowed for symbol bound to result");
389  auto repl = formatv(fmt, name);
390  LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
391  return std::string(repl);
392  }
393  case Kind::Result: {
394  if (index >= 0) {
395  auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
396  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
397  return std::string(repl);
398  }
399 
400  // We are referencing all results of the multi-result op. Each result should
401  // have a value range, and then join them with `separator`.
403  values.reserve(op->getNumResults());
404 
405  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
406  values.push_back(std::string(
407  formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
408  }
409  auto repl = llvm::join(values, separator);
410  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
411  return repl;
412  }
413  case Kind::Value: {
414  assert(index < 0 && "only allowed for symbol bound to result");
415  assert(op == nullptr);
416  auto repl = formatv(fmt, formatv("{{{0}}", name));
417  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
418  return std::string(repl);
419  }
420  case Kind::MultipleValues: {
421  assert(op == nullptr);
422  assert(index < getSize());
423  if (index >= 0) {
424  std::string repl =
425  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
426  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
427  return repl;
428  }
429  auto repl =
430  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
431  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
432  return std::string(repl);
433  }
434  }
435  llvm_unreachable("unknown kind");
436 }
437 
438 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
439  const Operator &op, int argIndex,
440  std::optional<int> variadicSubIndex) {
441  StringRef name = getValuePackName(symbol);
442  if (name != symbol) {
443  auto error = formatv(
444  "symbol '{0}' with trailing index cannot bind to op argument", symbol);
445  PrintFatalError(loc, error);
446  }
447 
448  auto symInfo =
449  isa<NamedAttribute *>(op.getArg(argIndex))
450  ? SymbolInfo::getAttr(&op, argIndex)
451  : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
452 
453  std::string key = symbol.str();
454  if (symbolInfoMap.count(key)) {
455  // Only non unique name for the operand is supported.
456  if (symInfo.kind != SymbolInfo::Kind::Operand) {
457  return false;
458  }
459 
460  // Cannot add new operand if there is already non operand with the same
461  // name.
462  if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
463  return false;
464  }
465  }
466 
467  symbolInfoMap.emplace(key, symInfo);
468  return true;
469 }
470 
471 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
472  std::string name = getValuePackName(symbol).str();
473  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
474 
475  return symbolInfoMap.count(inserted->first) == 1;
476 }
477 
478 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
479  std::string name = getValuePackName(symbol).str();
480  if (numValues > 1)
481  return bindMultipleValues(name, numValues);
482  return bindValue(name);
483 }
484 
485 bool SymbolInfoMap::bindValue(StringRef symbol) {
486  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
487  return symbolInfoMap.count(inserted->first) == 1;
488 }
489 
490 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
491  std::string name = getValuePackName(symbol).str();
492  auto inserted =
493  symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
494  return symbolInfoMap.count(inserted->first) == 1;
495 }
496 
497 bool SymbolInfoMap::bindAttr(StringRef symbol) {
498  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
499  return symbolInfoMap.count(inserted->first) == 1;
500 }
501 
502 bool SymbolInfoMap::contains(StringRef symbol) const {
503  return find(symbol) != symbolInfoMap.end();
504 }
505 
507  std::string name = getValuePackName(key).str();
508 
509  return symbolInfoMap.find(name);
510 }
511 
513 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
514  int argIndex,
515  std::optional<int> variadicSubIndex) const {
516  return findBoundSymbol(
517  key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
518 }
519 
522  const SymbolInfo &symbolInfo) const {
523  std::string name = getValuePackName(key).str();
524  auto range = symbolInfoMap.equal_range(name);
525 
526  for (auto it = range.first; it != range.second; ++it)
527  if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
528  return it;
529 
530  return symbolInfoMap.end();
531 }
532 
533 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
535  std::string name = getValuePackName(key).str();
536 
537  return symbolInfoMap.equal_range(name);
538 }
539 
540 int SymbolInfoMap::count(StringRef key) const {
541  std::string name = getValuePackName(key).str();
542  return symbolInfoMap.count(name);
543 }
544 
545 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
546  StringRef name = getValuePackName(symbol);
547  if (name != symbol) {
548  // If there is a trailing index inside symbol, it references just one
549  // static value.
550  return 1;
551  }
552  // Otherwise, find how many it represents by querying the symbol's info.
553  return find(name)->second.getStaticValueCount();
554 }
555 
556 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
557  const char *fmt,
558  const char *separator) const {
559  int index = -1;
560  StringRef name = getValuePackName(symbol, &index);
561 
562  auto it = symbolInfoMap.find(name.str());
563  if (it == symbolInfoMap.end()) {
564  auto error = formatv("referencing unbound symbol '{0}'", symbol);
565  PrintFatalError(loc, error);
566  }
567 
568  return it->second.getValueAndRangeUse(name, index, fmt, separator);
569 }
570 
571 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
572  const char *separator) const {
573  int index = -1;
574  StringRef name = getValuePackName(symbol, &index);
575 
576  auto it = symbolInfoMap.find(name.str());
577  if (it == symbolInfoMap.end()) {
578  auto error = formatv("referencing unbound symbol '{0}'", symbol);
579  PrintFatalError(loc, error);
580  }
581 
582  return it->second.getAllRangeUse(name, index, fmt, separator);
583 }
584 
586  llvm::StringSet<> usedNames;
587 
588  for (auto symbolInfoIt = symbolInfoMap.begin();
589  symbolInfoIt != symbolInfoMap.end();) {
590  auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
591  auto startRange = range.first;
592  auto endRange = range.second;
593 
594  auto operandName = symbolInfoIt->first;
595  int startSearchIndex = 0;
596  for (++startRange; startRange != endRange; ++startRange) {
597  // Current operand name is not unique, find a unique one
598  // and set the alternative name.
599  for (int i = startSearchIndex;; ++i) {
600  std::string alternativeName = operandName + std::to_string(i);
601  if (!usedNames.contains(alternativeName) &&
602  symbolInfoMap.count(alternativeName) == 0) {
603  usedNames.insert(alternativeName);
604  startRange->second.alternativeName = alternativeName;
605  startSearchIndex = i + 1;
606 
607  break;
608  }
609  }
610  }
611 
612  symbolInfoIt = endRange;
613  }
614 }
615 
616 //===----------------------------------------------------------------------===//
617 // Pattern
618 //==----------------------------------------------------------------------===//
619 
620 Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
621  : def(*def), recordOpMap(mapper) {}
622 
624  return DagNode(def.getValueAsDag("sourcePattern"));
625 }
626 
628  auto *results = def.getValueAsListInit("resultPatterns");
629  return results->size();
630 }
631 
632 DagNode Pattern::getResultPattern(unsigned index) const {
633  auto *results = def.getValueAsListInit("resultPatterns");
634  return DagNode(cast<DagInit>(results->getElement(index)));
635 }
636 
638  LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
639  collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
640  LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
641 
642  LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
644  LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
645 }
646 
648  LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
649  for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
650  auto pattern = getResultPattern(i);
651  collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
652  }
653  LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
654 }
655 
657  return getSourcePattern().getDialectOp(recordOpMap);
658 }
659 
661  return node.getDialectOp(recordOpMap);
662 }
663 
664 std::vector<AppliedConstraint> Pattern::getConstraints() const {
665  auto *listInit = def.getValueAsListInit("constraints");
666  std::vector<AppliedConstraint> ret;
667  ret.reserve(listInit->size());
668 
669  for (auto *it : *listInit) {
670  auto *dagInit = dyn_cast<DagInit>(it);
671  if (!dagInit)
672  PrintFatalError(&def, "all elements in Pattern multi-entity "
673  "constraints should be DAG nodes");
674 
675  std::vector<std::string> entities;
676  entities.reserve(dagInit->arg_size());
677  for (auto *argName : dagInit->getArgNames()) {
678  if (!argName) {
679  PrintFatalError(
680  &def,
681  "operands to additional constraints can only be symbol references");
682  }
683  entities.emplace_back(argName->getValue());
684  }
685 
686  ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
687  dagInit->getNameStr(), std::move(entities));
688  }
689  return ret;
690 }
691 
693  auto *results = def.getValueAsListInit("supplementalPatterns");
694  return results->size();
695 }
696 
698  auto *results = def.getValueAsListInit("supplementalPatterns");
699  return DagNode(cast<DagInit>(results->getElement(index)));
700 }
701 
702 int Pattern::getBenefit() const {
703  // The initial benefit value is a heuristic with number of ops in the source
704  // pattern.
705  int initBenefit = getSourcePattern().getNumOps();
706  const DagInit *delta = def.getValueAsDag("benefitDelta");
707  if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
708  PrintFatalError(&def,
709  "The 'addBenefit' takes and only takes one integer value");
710  }
711  return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
712 }
713 
714 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
715  std::vector<std::pair<StringRef, unsigned>> result;
716  result.reserve(def.getLoc().size());
717  for (auto loc : def.getLoc()) {
718  unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
719  assert(buf && "invalid source location");
720  result.emplace_back(
721  llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
722  llvm::SrcMgr.getLineAndColumn(loc, buf).first);
723  }
724  return result;
725 }
726 
727 void Pattern::verifyBind(bool result, StringRef symbolName) {
728  if (!result) {
729  auto err = formatv("symbol '{0}' bound more than once", symbolName);
730  PrintFatalError(&def, err);
731  }
732 }
733 
735  bool isSrcPattern) {
736  auto treeName = tree.getSymbol();
737  auto numTreeArgs = tree.getNumArgs();
738 
739  if (tree.isNativeCodeCall()) {
740  if (!treeName.empty()) {
741  if (!isSrcPattern) {
742  LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
743  << treeName << '\n');
744  verifyBind(
745  infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
746  treeName);
747  } else {
748  PrintFatalError(&def,
749  formatv("binding symbol '{0}' to NativecodeCall in "
750  "MatchPattern is not supported",
751  treeName));
752  }
753  }
754 
755  for (int i = 0; i != numTreeArgs; ++i) {
756  if (auto treeArg = tree.getArgAsNestedDag(i)) {
757  // This DAG node argument is a DAG node itself. Go inside recursively.
758  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
759  continue;
760  }
761 
762  if (!isSrcPattern)
763  continue;
764 
765  // We can only bind symbols to arguments in source pattern. Those
766  // symbols are referenced in result patterns.
767  auto treeArgName = tree.getArgName(i);
768 
769  // `$_` is a special symbol meaning ignore the current argument.
770  if (!treeArgName.empty() && treeArgName != "_") {
771  DagLeaf leaf = tree.getArgAsLeaf(i);
772 
773  // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
774  if (leaf.isUnspecified()) {
775  // This is case of $c, a Value without any constraints.
776  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
777  } else {
778  auto constraint = leaf.getAsConstraint();
779  bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
780  leaf.isConstantAttr() ||
781  constraint.getKind() == Constraint::Kind::CK_Attr;
782 
783  if (isAttr) {
784  // This is case of $a, a binding to a certain attribute.
785  verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
786  continue;
787  }
788 
789  // This is case of $b, a binding to a certain type.
790  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
791  }
792  }
793  }
794 
795  return;
796  }
797 
798  if (tree.isOperation()) {
799  auto &op = getDialectOp(tree);
800  auto numOpArgs = op.getNumArgs();
801  int numEither = 0;
802 
803  // We need to exclude the trailing directives and `either` directive groups
804  // two operands of the operation.
805  int numDirectives = 0;
806  for (int i = numTreeArgs - 1; i >= 0; --i) {
807  if (auto dagArg = tree.getArgAsNestedDag(i)) {
808  if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
809  ++numDirectives;
810  else if (dagArg.isEither())
811  ++numEither;
812  }
813  }
814 
815  if (numOpArgs != numTreeArgs - numDirectives + numEither) {
816  auto err =
817  formatv("op '{0}' argument number mismatch: "
818  "{1} in pattern vs. {2} in definition",
819  op.getOperationName(), numTreeArgs + numEither, numOpArgs);
820  PrintFatalError(&def, err);
821  }
822 
823  // The name attached to the DAG node's operator is for representing the
824  // results generated from this op. It should be remembered as bound results.
825  if (!treeName.empty()) {
826  LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
827  << '\n');
828  verifyBind(infoMap.bindOpResult(treeName, op), treeName);
829  }
830 
831  // The operand in `either` DAG should be bound to the operation in the
832  // parent DagNode.
833  auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
834  int opArgIdx) {
835  for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
836  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
837  collectBoundSymbols(subTree, infoMap, isSrcPattern);
838  } else {
839  auto argName = tree.getArgName(i);
840  if (!argName.empty() && argName != "_") {
841  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
842  argName);
843  }
844  }
845  }
846  };
847 
848  // The operand in `variadic` DAG should be bound to the operation in the
849  // parent DagNode. The range index must be included as well to distinguish
850  // (potentially) repeating argName within the `variadic` DAG.
851  auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
852  int opArgIdx) {
853  auto treeName = tree.getSymbol();
854  if (!treeName.empty()) {
855  // If treeName is specified, bind to the full variadic operand_range.
856  verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
857  std::nullopt),
858  treeName);
859  }
860 
861  for (int i = 0; i < tree.getNumArgs(); ++i) {
862  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
863  collectBoundSymbols(subTree, infoMap, isSrcPattern);
864  } else {
865  auto argName = tree.getArgName(i);
866  if (!argName.empty() && argName != "_") {
867  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
868  /*variadicSubIndex=*/i),
869  argName);
870  }
871  }
872  }
873  };
874 
875  for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
876  if (auto treeArg = tree.getArgAsNestedDag(i)) {
877  if (treeArg.isEither()) {
878  collectSymbolInEither(tree, treeArg, opArgIdx);
879  // `either` DAG is *flattened*. For example,
880  //
881  // (FooOp (either arg0, arg1), arg2)
882  //
883  // can be viewed as:
884  //
885  // (FooOp arg0, arg1, arg2)
886  ++opArgIdx;
887  } else if (treeArg.isVariadic()) {
888  collectSymbolInVariadic(tree, treeArg, opArgIdx);
889  } else {
890  // This DAG node argument is a DAG node itself. Go inside recursively.
891  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
892  }
893  continue;
894  }
895 
896  if (isSrcPattern) {
897  // We can only bind symbols to op arguments in source pattern. Those
898  // symbols are referenced in result patterns.
899  auto treeArgName = tree.getArgName(i);
900  // `$_` is a special symbol meaning ignore the current argument.
901  if (!treeArgName.empty() && treeArgName != "_") {
902  LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
903  << treeArgName << '\n');
904  verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
905  treeArgName);
906  }
907  }
908  }
909  return;
910  }
911 
912  if (!treeName.empty()) {
913  PrintFatalError(
914  &def, formatv("binding symbol '{0}' to non-operation/native code call "
915  "unsupported right now",
916  treeName));
917  }
918 }
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Constraint getAsConstraint() const
Definition: Pattern.cpp:66
bool isNativeCodeCall() const
Definition: Pattern.cpp:54
bool isEnumAttrCase() const
Definition: Pattern.cpp:60
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:91
ConstantAttr getAsConstantAttr() const
Definition: Pattern.cpp:72
void print(raw_ostream &os) const
Definition: Pattern.cpp:106
std::string getStringAttr() const
Definition: Pattern.cpp:96
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:86
std::string getConditionTemplate() const
Definition: Pattern.cpp:82
bool isUnspecified() const
Definition: Pattern.cpp:40
bool isAttrMatcher() const
Definition: Pattern.cpp:49
EnumAttrCase getAsEnumAttrCase() const
Definition: Pattern.cpp:77
bool isOperandMatcher() const
Definition: Pattern.cpp:44
bool isConstantAttr() const
Definition: Pattern.cpp:58
bool isStringAttr() const
Definition: Pattern.cpp:64
bool isReturnTypeDirective() const
Definition: Pattern.cpp:191
bool isLocationDirective() const
Definition: Pattern.cpp:186
bool isReplaceWithValue() const
Definition: Pattern.cpp:181
DagNode getArgAsNestedDag(unsigned index) const
Definition: Pattern.cpp:168
bool isOperation() const
Definition: Pattern.cpp:121
DagLeaf getArgAsLeaf(unsigned index) const
Definition: Pattern.cpp:172
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:134
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:127
void print(raw_ostream &os) const
Definition: Pattern.cpp:206
int getNumOps() const
Definition: Pattern.cpp:151
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition: Pattern.cpp:143
bool isVariadic() const
Definition: Pattern.cpp:201
bool isNativeCodeCall() const
Definition: Pattern.cpp:115
bool isEither() const
Definition: Pattern.cpp:196
bool isNestedDagArg(unsigned index) const
Definition: Pattern.cpp:164
StringRef getSymbol() const
Definition: Pattern.cpp:141
int getNumArgs() const
Definition: Pattern.cpp:162
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:143
StringRef getArgName(unsigned index) const
Definition: Pattern.cpp:177
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition: Operator.cpp:166
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition: Operator.cpp:185
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition: Operator.cpp:359
int getBenefit() const
int getNumResultPatterns() const
Definition: Pattern.cpp:627
std::vector< IdentifierLine > getLocation() const
Definition: Pattern.cpp:714
DagNode getSourcePattern() const
Definition: Pattern.cpp:623
const Operator & getSourceRootOp()
Definition: Pattern.cpp:656
std::vector< AppliedConstraint > getConstraints() const
Definition: Pattern.cpp:664
DagNode getResultPattern(unsigned index) const
Definition: Pattern.cpp:632
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition: Pattern.cpp:734
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition: Pattern.cpp:660
DagNode getSupplementalPattern(unsigned index) const
Definition: Pattern.cpp:697
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:637
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:647
int getNumSupplementalPatterns() const
Definition: Pattern.cpp:692
std::string getArgDecl(StringRef name) const
Definition: Pattern.cpp:289
std::string getVarName(StringRef name) const
Definition: Pattern.cpp:248
std::string getVarTypeStr(StringRef name) const
Definition: Pattern.cpp:252
std::string getVarDecl(StringRef name) const
Definition: Pattern.cpp:282
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition: Pattern.cpp:215
int count(StringRef key) const
Definition: Pattern.cpp:540
const_iterator find(StringRef key) const
Definition: Pattern.cpp:506
bool bindMultipleValues(StringRef symbol, int numValues)
Definition: Pattern.cpp:490
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition: Pattern.cpp:438
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:571
bool bindValues(StringRef symbol, int numValues=1)
Definition: Pattern.cpp:478
bool bindAttr(StringRef symbol)
Definition: Pattern.cpp:497
bool bindValue(StringRef symbol)
Definition: Pattern.cpp:485
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition: Pattern.cpp:513
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition: Pattern.cpp:534
int getStaticValueCount(StringRef symbol) const
Definition: Pattern.cpp:545
bool contains(StringRef symbol) const
Definition: Pattern.cpp:502
BaseT::const_iterator const_iterator
Definition: Pattern.h:461
bool bindOpResult(StringRef symbol, const Operator &op)
Definition: Pattern.cpp:471
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:556
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:3184
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.