MLIR  20.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::DagInit;
30 using llvm::dbgs;
31 using llvm::DefInit;
32 using llvm::formatv;
33 using llvm::IntInit;
34 using llvm::Record;
35 
36 //===----------------------------------------------------------------------===//
37 // DagLeaf
38 //===----------------------------------------------------------------------===//
39 
40 bool DagLeaf::isUnspecified() const {
41  return isa_and_nonnull<llvm::UnsetInit>(def);
42 }
43 
45  // Operand matchers specify a type constraint.
46  return isSubClassOf("TypeConstraint");
47 }
48 
49 bool DagLeaf::isAttrMatcher() const {
50  // Attribute matchers specify an attribute constraint.
51  return isSubClassOf("AttrConstraint");
52 }
53 
55  return isSubClassOf("NativeCodeCall");
56 }
57 
58 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
59 
61  return isSubClassOf("EnumAttrCaseInfo");
62 }
63 
64 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
65 
67  assert((isOperandMatcher() || isAttrMatcher()) &&
68  "the DAG leaf must be operand or attribute");
69  return Constraint(cast<DefInit>(def)->getDef());
70 }
71 
73  assert(isConstantAttr() && "the DAG leaf must be constant attribute");
74  return ConstantAttr(cast<DefInit>(def));
75 }
76 
78  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
79  return EnumAttrCase(cast<DefInit>(def));
80 }
81 
82 std::string DagLeaf::getConditionTemplate() const {
84 }
85 
86 StringRef DagLeaf::getNativeCodeTemplate() const {
87  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88  return cast<DefInit>(def)->getDef()->getValueAsString("expression");
89 }
90 
92  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
93  return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
94 }
95 
96 std::string DagLeaf::getStringAttr() const {
97  assert(isStringAttr() && "the DAG leaf must be string attribute");
98  return def->getAsUnquotedString();
99 }
100 bool DagLeaf::isSubClassOf(StringRef superclass) const {
101  if (auto *defInit = dyn_cast_or_null<DefInit>(def))
102  return defInit->getDef()->isSubClassOf(superclass);
103  return false;
104 }
105 
106 void DagLeaf::print(raw_ostream &os) const {
107  if (def)
108  def->print(os);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // DagNode
113 //===----------------------------------------------------------------------===//
114 
116  if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
117  return defInit->getDef()->isSubClassOf("NativeCodeCall");
118  return false;
119 }
120 
121 bool DagNode::isOperation() const {
122  return !isNativeCodeCall() && !isReplaceWithValue() &&
124  !isVariadic();
125 }
126 
128  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
129  return cast<DefInit>(node->getOperator())
130  ->getDef()
131  ->getValueAsString("expression");
132 }
133 
135  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
136  return cast<DefInit>(node->getOperator())
137  ->getDef()
138  ->getValueAsInt("numReturns");
139 }
140 
141 StringRef DagNode::getSymbol() const { return node->getNameStr(); }
142 
144  const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
145  auto [it, inserted] = mapper->try_emplace(opDef);
146  if (inserted)
147  it->second = std::make_unique<Operator>(opDef);
148  return *it->second;
149 }
150 
151 int DagNode::getNumOps() const {
152  // We want to get number of operations recursively involved in the DAG tree.
153  // All other directives should be excluded.
154  int count = isOperation() ? 1 : 0;
155  for (int i = 0, e = getNumArgs(); i != e; ++i) {
156  if (auto child = getArgAsNestedDag(i))
157  count += child.getNumOps();
158  }
159  return count;
160 }
161 
162 int DagNode::getNumArgs() const { return node->getNumArgs(); }
163 
164 bool DagNode::isNestedDagArg(unsigned index) const {
165  return isa<DagInit>(node->getArg(index));
166 }
167 
168 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
169  return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
170 }
171 
172 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
173  assert(!isNestedDagArg(index));
174  return DagLeaf(node->getArg(index));
175 }
176 
177 StringRef DagNode::getArgName(unsigned index) const {
178  return node->getArgNameStr(index);
179 }
180 
182  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
183  return dagOpDef->getName() == "replaceWithValue";
184 }
185 
187  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
188  return dagOpDef->getName() == "location";
189 }
190 
192  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
193  return dagOpDef->getName() == "returnType";
194 }
195 
196 bool DagNode::isEither() const {
197  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
198  return dagOpDef->getName() == "either";
199 }
200 
201 bool DagNode::isVariadic() const {
202  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
203  return dagOpDef->getName() == "variadic";
204 }
205 
206 void DagNode::print(raw_ostream &os) const {
207  if (node)
208  node->print(os);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // SymbolInfoMap
213 //===----------------------------------------------------------------------===//
214 
215 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
216  int idx = -1;
217  auto [name, indexStr] = symbol.rsplit("__");
218 
219  if (indexStr.consumeInteger(10, idx)) {
220  // The second part is not an index; we return the whole symbol as-is.
221  return symbol;
222  }
223  if (index) {
224  *index = idx;
225  }
226  return name;
227 }
228 
229 SymbolInfoMap::SymbolInfo::SymbolInfo(
230  const Operator *op, SymbolInfo::Kind kind,
231  std::optional<DagAndConstant> dagAndConstant)
232  : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
233 
234 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
235  switch (kind) {
236  case Kind::Attr:
237  case Kind::Operand:
238  case Kind::Value:
239  return 1;
240  case Kind::Result:
241  return op->getNumResults();
242  case Kind::MultipleValues:
243  return getSize();
244  }
245  llvm_unreachable("unknown kind");
246 }
247 
248 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
249  return alternativeName ? *alternativeName : name.str();
250 }
251 
252 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
253  LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
254  switch (kind) {
255  case Kind::Attr: {
256  if (op)
257  return op->getArg(getArgIndex())
258  .get<NamedAttribute *>()
259  ->attr.getStorageType()
260  .str();
261  // TODO(suderman): Use a more exact type when available.
262  return "::mlir::Attribute";
263  }
264  case Kind::Operand: {
265  // Use operand range for captured operands (to support potential variadic
266  // operands).
267  return "::mlir::Operation::operand_range";
268  }
269  case Kind::Value: {
270  return "::mlir::Value";
271  }
272  case Kind::MultipleValues: {
273  return "::mlir::ValueRange";
274  }
275  case Kind::Result: {
276  // Use the op itself for captured results.
277  return op->getQualCppClassName();
278  }
279  }
280  llvm_unreachable("unknown kind");
281 }
282 
283 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
284  LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
285  std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
286  return std::string(
287  formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
288 }
289 
290 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
291  LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
292  return std::string(
293  formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
294 }
295 
296 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
297  StringRef name, int index, const char *fmt, const char *separator) const {
298  LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
299  switch (kind) {
300  case Kind::Attr: {
301  assert(index < 0);
302  auto repl = formatv(fmt, name);
303  LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
304  return std::string(repl);
305  }
306  case Kind::Operand: {
307  assert(index < 0);
308  auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
309  // If this operand is variadic and this SymbolInfo doesn't have a range
310  // index, then return the full variadic operand_range. Otherwise, return
311  // the value itself.
312  if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
313  auto repl = formatv(fmt, name);
314  LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
315  return std::string(repl);
316  }
317  auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
318  LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
319  return std::string(repl);
320  }
321  case Kind::Result: {
322  // If `index` is greater than zero, then we are referencing a specific
323  // result of a multi-result op. The result can still be variadic.
324  if (index >= 0) {
325  std::string v =
326  std::string(formatv("{0}.getODSResults({1})", name, index));
327  if (!op->getResult(index).isVariadic())
328  v = std::string(formatv("(*{0}.begin())", v));
329  auto repl = formatv(fmt, v);
330  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
331  return std::string(repl);
332  }
333 
334  // If this op has no result at all but still we bind a symbol to it, it
335  // means we want to capture the op itself.
336  if (op->getNumResults() == 0) {
337  LLVM_DEBUG(dbgs() << name << " (Op)\n");
338  return formatv(fmt, name);
339  }
340 
341  // We are referencing all results of the multi-result op. A specific result
342  // can either be a value or a range. Then join them with `separator`.
344  values.reserve(op->getNumResults());
345 
346  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
347  std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
348  if (!op->getResult(i).isVariadic()) {
349  v = std::string(formatv("(*{0}.begin())", v));
350  }
351  values.push_back(std::string(formatv(fmt, v)));
352  }
353  auto repl = llvm::join(values, separator);
354  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
355  return repl;
356  }
357  case Kind::Value: {
358  assert(index < 0);
359  assert(op == nullptr);
360  auto repl = formatv(fmt, name);
361  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
362  return std::string(repl);
363  }
364  case Kind::MultipleValues: {
365  assert(op == nullptr);
366  assert(index < getSize());
367  if (index >= 0) {
368  std::string repl =
369  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
370  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
371  return repl;
372  }
373  // If it doesn't specify certain element, unpack them all.
374  auto repl =
375  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
376  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
377  return std::string(repl);
378  }
379  }
380  llvm_unreachable("unknown kind");
381 }
382 
383 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
384  StringRef name, int index, const char *fmt, const char *separator) const {
385  LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
386  switch (kind) {
387  case Kind::Attr:
388  case Kind::Operand: {
389  assert(index < 0 && "only allowed for symbol bound to result");
390  auto repl = formatv(fmt, name);
391  LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
392  return std::string(repl);
393  }
394  case Kind::Result: {
395  if (index >= 0) {
396  auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
397  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
398  return std::string(repl);
399  }
400 
401  // We are referencing all results of the multi-result op. Each result should
402  // have a value range, and then join them with `separator`.
404  values.reserve(op->getNumResults());
405 
406  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
407  values.push_back(std::string(
408  formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
409  }
410  auto repl = llvm::join(values, separator);
411  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
412  return repl;
413  }
414  case Kind::Value: {
415  assert(index < 0 && "only allowed for symbol bound to result");
416  assert(op == nullptr);
417  auto repl = formatv(fmt, formatv("{{{0}}", name));
418  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
419  return std::string(repl);
420  }
421  case Kind::MultipleValues: {
422  assert(op == nullptr);
423  assert(index < getSize());
424  if (index >= 0) {
425  std::string repl =
426  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
427  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
428  return repl;
429  }
430  auto repl =
431  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
432  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
433  return std::string(repl);
434  }
435  }
436  llvm_unreachable("unknown kind");
437 }
438 
439 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
440  const Operator &op, int argIndex,
441  std::optional<int> variadicSubIndex) {
442  StringRef name = getValuePackName(symbol);
443  if (name != symbol) {
444  auto error = formatv(
445  "symbol '{0}' with trailing index cannot bind to op argument", symbol);
446  PrintFatalError(loc, error);
447  }
448 
449  auto symInfo =
450  op.getArg(argIndex).is<NamedAttribute *>()
451  ? SymbolInfo::getAttr(&op, argIndex)
452  : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
453 
454  std::string key = symbol.str();
455  if (symbolInfoMap.count(key)) {
456  // Only non unique name for the operand is supported.
457  if (symInfo.kind != SymbolInfo::Kind::Operand) {
458  return false;
459  }
460 
461  // Cannot add new operand if there is already non operand with the same
462  // name.
463  if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
464  return false;
465  }
466  }
467 
468  symbolInfoMap.emplace(key, symInfo);
469  return true;
470 }
471 
472 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
473  std::string name = getValuePackName(symbol).str();
474  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
475 
476  return symbolInfoMap.count(inserted->first) == 1;
477 }
478 
479 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
480  std::string name = getValuePackName(symbol).str();
481  if (numValues > 1)
482  return bindMultipleValues(name, numValues);
483  return bindValue(name);
484 }
485 
486 bool SymbolInfoMap::bindValue(StringRef symbol) {
487  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
488  return symbolInfoMap.count(inserted->first) == 1;
489 }
490 
491 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
492  std::string name = getValuePackName(symbol).str();
493  auto inserted =
494  symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
495  return symbolInfoMap.count(inserted->first) == 1;
496 }
497 
498 bool SymbolInfoMap::bindAttr(StringRef symbol) {
499  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
500  return symbolInfoMap.count(inserted->first) == 1;
501 }
502 
503 bool SymbolInfoMap::contains(StringRef symbol) const {
504  return find(symbol) != symbolInfoMap.end();
505 }
506 
508  std::string name = getValuePackName(key).str();
509 
510  return symbolInfoMap.find(name);
511 }
512 
514 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
515  int argIndex,
516  std::optional<int> variadicSubIndex) const {
517  return findBoundSymbol(
518  key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
519 }
520 
523  const SymbolInfo &symbolInfo) const {
524  std::string name = getValuePackName(key).str();
525  auto range = symbolInfoMap.equal_range(name);
526 
527  for (auto it = range.first; it != range.second; ++it)
528  if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
529  return it;
530 
531  return symbolInfoMap.end();
532 }
533 
534 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
536  std::string name = getValuePackName(key).str();
537 
538  return symbolInfoMap.equal_range(name);
539 }
540 
541 int SymbolInfoMap::count(StringRef key) const {
542  std::string name = getValuePackName(key).str();
543  return symbolInfoMap.count(name);
544 }
545 
546 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
547  StringRef name = getValuePackName(symbol);
548  if (name != symbol) {
549  // If there is a trailing index inside symbol, it references just one
550  // static value.
551  return 1;
552  }
553  // Otherwise, find how many it represents by querying the symbol's info.
554  return find(name)->second.getStaticValueCount();
555 }
556 
557 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
558  const char *fmt,
559  const char *separator) const {
560  int index = -1;
561  StringRef name = getValuePackName(symbol, &index);
562 
563  auto it = symbolInfoMap.find(name.str());
564  if (it == symbolInfoMap.end()) {
565  auto error = formatv("referencing unbound symbol '{0}'", symbol);
566  PrintFatalError(loc, error);
567  }
568 
569  return it->second.getValueAndRangeUse(name, index, fmt, separator);
570 }
571 
572 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
573  const char *separator) const {
574  int index = -1;
575  StringRef name = getValuePackName(symbol, &index);
576 
577  auto it = symbolInfoMap.find(name.str());
578  if (it == symbolInfoMap.end()) {
579  auto error = formatv("referencing unbound symbol '{0}'", symbol);
580  PrintFatalError(loc, error);
581  }
582 
583  return it->second.getAllRangeUse(name, index, fmt, separator);
584 }
585 
587  llvm::StringSet<> usedNames;
588 
589  for (auto symbolInfoIt = symbolInfoMap.begin();
590  symbolInfoIt != symbolInfoMap.end();) {
591  auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
592  auto startRange = range.first;
593  auto endRange = range.second;
594 
595  auto operandName = symbolInfoIt->first;
596  int startSearchIndex = 0;
597  for (++startRange; startRange != endRange; ++startRange) {
598  // Current operand name is not unique, find a unique one
599  // and set the alternative name.
600  for (int i = startSearchIndex;; ++i) {
601  std::string alternativeName = operandName + std::to_string(i);
602  if (!usedNames.contains(alternativeName) &&
603  symbolInfoMap.count(alternativeName) == 0) {
604  usedNames.insert(alternativeName);
605  startRange->second.alternativeName = alternativeName;
606  startSearchIndex = i + 1;
607 
608  break;
609  }
610  }
611  }
612 
613  symbolInfoIt = endRange;
614  }
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // Pattern
619 //==----------------------------------------------------------------------===//
620 
621 Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
622  : def(*def), recordOpMap(mapper) {}
623 
625  return DagNode(def.getValueAsDag("sourcePattern"));
626 }
627 
629  auto *results = def.getValueAsListInit("resultPatterns");
630  return results->size();
631 }
632 
633 DagNode Pattern::getResultPattern(unsigned index) const {
634  auto *results = def.getValueAsListInit("resultPatterns");
635  return DagNode(cast<DagInit>(results->getElement(index)));
636 }
637 
639  LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
640  collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
641  LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
642 
643  LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
645  LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
646 }
647 
649  LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
650  for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
651  auto pattern = getResultPattern(i);
652  collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
653  }
654  LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
655 }
656 
658  return getSourcePattern().getDialectOp(recordOpMap);
659 }
660 
662  return node.getDialectOp(recordOpMap);
663 }
664 
665 std::vector<AppliedConstraint> Pattern::getConstraints() const {
666  auto *listInit = def.getValueAsListInit("constraints");
667  std::vector<AppliedConstraint> ret;
668  ret.reserve(listInit->size());
669 
670  for (auto *it : *listInit) {
671  auto *dagInit = dyn_cast<DagInit>(it);
672  if (!dagInit)
673  PrintFatalError(&def, "all elements in Pattern multi-entity "
674  "constraints should be DAG nodes");
675 
676  std::vector<std::string> entities;
677  entities.reserve(dagInit->arg_size());
678  for (auto *argName : dagInit->getArgNames()) {
679  if (!argName) {
680  PrintFatalError(
681  &def,
682  "operands to additional constraints can only be symbol references");
683  }
684  entities.emplace_back(argName->getValue());
685  }
686 
687  ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
688  dagInit->getNameStr(), std::move(entities));
689  }
690  return ret;
691 }
692 
694  auto *results = def.getValueAsListInit("supplementalPatterns");
695  return results->size();
696 }
697 
699  auto *results = def.getValueAsListInit("supplementalPatterns");
700  return DagNode(cast<DagInit>(results->getElement(index)));
701 }
702 
703 int Pattern::getBenefit() const {
704  // The initial benefit value is a heuristic with number of ops in the source
705  // pattern.
706  int initBenefit = getSourcePattern().getNumOps();
707  const DagInit *delta = def.getValueAsDag("benefitDelta");
708  if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
709  PrintFatalError(&def,
710  "The 'addBenefit' takes and only takes one integer value");
711  }
712  return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
713 }
714 
715 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
716  std::vector<std::pair<StringRef, unsigned>> result;
717  result.reserve(def.getLoc().size());
718  for (auto loc : def.getLoc()) {
719  unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
720  assert(buf && "invalid source location");
721  result.emplace_back(
722  llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
723  llvm::SrcMgr.getLineAndColumn(loc, buf).first);
724  }
725  return result;
726 }
727 
728 void Pattern::verifyBind(bool result, StringRef symbolName) {
729  if (!result) {
730  auto err = formatv("symbol '{0}' bound more than once", symbolName);
731  PrintFatalError(&def, err);
732  }
733 }
734 
736  bool isSrcPattern) {
737  auto treeName = tree.getSymbol();
738  auto numTreeArgs = tree.getNumArgs();
739 
740  if (tree.isNativeCodeCall()) {
741  if (!treeName.empty()) {
742  if (!isSrcPattern) {
743  LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
744  << treeName << '\n');
745  verifyBind(
746  infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
747  treeName);
748  } else {
749  PrintFatalError(&def,
750  formatv("binding symbol '{0}' to NativecodeCall in "
751  "MatchPattern is not supported",
752  treeName));
753  }
754  }
755 
756  for (int i = 0; i != numTreeArgs; ++i) {
757  if (auto treeArg = tree.getArgAsNestedDag(i)) {
758  // This DAG node argument is a DAG node itself. Go inside recursively.
759  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
760  continue;
761  }
762 
763  if (!isSrcPattern)
764  continue;
765 
766  // We can only bind symbols to arguments in source pattern. Those
767  // symbols are referenced in result patterns.
768  auto treeArgName = tree.getArgName(i);
769 
770  // `$_` is a special symbol meaning ignore the current argument.
771  if (!treeArgName.empty() && treeArgName != "_") {
772  DagLeaf leaf = tree.getArgAsLeaf(i);
773 
774  // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
775  if (leaf.isUnspecified()) {
776  // This is case of $c, a Value without any constraints.
777  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
778  } else {
779  auto constraint = leaf.getAsConstraint();
780  bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
781  leaf.isConstantAttr() ||
782  constraint.getKind() == Constraint::Kind::CK_Attr;
783 
784  if (isAttr) {
785  // This is case of $a, a binding to a certain attribute.
786  verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
787  continue;
788  }
789 
790  // This is case of $b, a binding to a certain type.
791  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
792  }
793  }
794  }
795 
796  return;
797  }
798 
799  if (tree.isOperation()) {
800  auto &op = getDialectOp(tree);
801  auto numOpArgs = op.getNumArgs();
802  int numEither = 0;
803 
804  // We need to exclude the trailing directives and `either` directive groups
805  // two operands of the operation.
806  int numDirectives = 0;
807  for (int i = numTreeArgs - 1; i >= 0; --i) {
808  if (auto dagArg = tree.getArgAsNestedDag(i)) {
809  if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
810  ++numDirectives;
811  else if (dagArg.isEither())
812  ++numEither;
813  }
814  }
815 
816  if (numOpArgs != numTreeArgs - numDirectives + numEither) {
817  auto err =
818  formatv("op '{0}' argument number mismatch: "
819  "{1} in pattern vs. {2} in definition",
820  op.getOperationName(), numTreeArgs + numEither, numOpArgs);
821  PrintFatalError(&def, err);
822  }
823 
824  // The name attached to the DAG node's operator is for representing the
825  // results generated from this op. It should be remembered as bound results.
826  if (!treeName.empty()) {
827  LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
828  << '\n');
829  verifyBind(infoMap.bindOpResult(treeName, op), treeName);
830  }
831 
832  // The operand in `either` DAG should be bound to the operation in the
833  // parent DagNode.
834  auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
835  int opArgIdx) {
836  for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
837  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
838  collectBoundSymbols(subTree, infoMap, isSrcPattern);
839  } else {
840  auto argName = tree.getArgName(i);
841  if (!argName.empty() && argName != "_") {
842  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
843  argName);
844  }
845  }
846  }
847  };
848 
849  // The operand in `variadic` DAG should be bound to the operation in the
850  // parent DagNode. The range index must be included as well to distinguish
851  // (potentially) repeating argName within the `variadic` DAG.
852  auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
853  int opArgIdx) {
854  auto treeName = tree.getSymbol();
855  if (!treeName.empty()) {
856  // If treeName is specified, bind to the full variadic operand_range.
857  verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
858  std::nullopt),
859  treeName);
860  }
861 
862  for (int i = 0; i < tree.getNumArgs(); ++i) {
863  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
864  collectBoundSymbols(subTree, infoMap, isSrcPattern);
865  } else {
866  auto argName = tree.getArgName(i);
867  if (!argName.empty() && argName != "_") {
868  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
869  /*variadicSubIndex=*/i),
870  argName);
871  }
872  }
873  }
874  };
875 
876  for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
877  if (auto treeArg = tree.getArgAsNestedDag(i)) {
878  if (treeArg.isEither()) {
879  collectSymbolInEither(tree, treeArg, opArgIdx);
880  // `either` DAG is *flattened*. For example,
881  //
882  // (FooOp (either arg0, arg1), arg2)
883  //
884  // can be viewed as:
885  //
886  // (FooOp arg0, arg1, arg2)
887  ++opArgIdx;
888  } else if (treeArg.isVariadic()) {
889  collectSymbolInVariadic(tree, treeArg, opArgIdx);
890  } else {
891  // This DAG node argument is a DAG node itself. Go inside recursively.
892  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
893  }
894  continue;
895  }
896 
897  if (isSrcPattern) {
898  // We can only bind symbols to op arguments in source pattern. Those
899  // symbols are referenced in result patterns.
900  auto treeArgName = tree.getArgName(i);
901  // `$_` is a special symbol meaning ignore the current argument.
902  if (!treeArgName.empty() && treeArgName != "_") {
903  LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
904  << treeArgName << '\n');
905  verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
906  treeArgName);
907  }
908  }
909  }
910  return;
911  }
912 
913  if (!treeName.empty()) {
914  PrintFatalError(
915  &def, formatv("binding symbol '{0}' to non-operation/native code call "
916  "unsupported right now",
917  treeName));
918  }
919 }
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Constraint getAsConstraint() const
Definition: Pattern.cpp:66
bool isNativeCodeCall() const
Definition: Pattern.cpp:54
bool isEnumAttrCase() const
Definition: Pattern.cpp:60
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:91
ConstantAttr getAsConstantAttr() const
Definition: Pattern.cpp:72
void print(raw_ostream &os) const
Definition: Pattern.cpp:106
std::string getStringAttr() const
Definition: Pattern.cpp:96
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:86
std::string getConditionTemplate() const
Definition: Pattern.cpp:82
bool isUnspecified() const
Definition: Pattern.cpp:40
bool isAttrMatcher() const
Definition: Pattern.cpp:49
EnumAttrCase getAsEnumAttrCase() const
Definition: Pattern.cpp:77
bool isOperandMatcher() const
Definition: Pattern.cpp:44
bool isConstantAttr() const
Definition: Pattern.cpp:58
bool isStringAttr() const
Definition: Pattern.cpp:64
bool isReturnTypeDirective() const
Definition: Pattern.cpp:191
bool isLocationDirective() const
Definition: Pattern.cpp:186
bool isReplaceWithValue() const
Definition: Pattern.cpp:181
DagNode getArgAsNestedDag(unsigned index) const
Definition: Pattern.cpp:168
bool isOperation() const
Definition: Pattern.cpp:121
DagLeaf getArgAsLeaf(unsigned index) const
Definition: Pattern.cpp:172
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:134
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:127
void print(raw_ostream &os) const
Definition: Pattern.cpp:206
int getNumOps() const
Definition: Pattern.cpp:151
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition: Pattern.cpp:143
bool isVariadic() const
Definition: Pattern.cpp:201
bool isNativeCodeCall() const
Definition: Pattern.cpp:115
bool isEither() const
Definition: Pattern.cpp:196
bool isNestedDagArg(unsigned index) const
Definition: Pattern.cpp:164
StringRef getSymbol() const
Definition: Pattern.cpp:141
int getNumArgs() const
Definition: Pattern.cpp:162
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:143
StringRef getArgName(unsigned index) const
Definition: Pattern.cpp:177
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition: Operator.cpp:166
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition: Operator.cpp:185
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition: Operator.cpp:359
int getBenefit() const
int getNumResultPatterns() const
Definition: Pattern.cpp:628
std::vector< IdentifierLine > getLocation() const
Definition: Pattern.cpp:715
DagNode getSourcePattern() const
Definition: Pattern.cpp:624
const Operator & getSourceRootOp()
Definition: Pattern.cpp:657
std::vector< AppliedConstraint > getConstraints() const
Definition: Pattern.cpp:665
DagNode getResultPattern(unsigned index) const
Definition: Pattern.cpp:633
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition: Pattern.cpp:735
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition: Pattern.cpp:661
DagNode getSupplementalPattern(unsigned index) const
Definition: Pattern.cpp:698
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:638
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:648
int getNumSupplementalPatterns() const
Definition: Pattern.cpp:693
std::string getArgDecl(StringRef name) const
Definition: Pattern.cpp:290
std::string getVarName(StringRef name) const
Definition: Pattern.cpp:248
std::string getVarTypeStr(StringRef name) const
Definition: Pattern.cpp:252
std::string getVarDecl(StringRef name) const
Definition: Pattern.cpp:283
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition: Pattern.cpp:215
int count(StringRef key) const
Definition: Pattern.cpp:541
const_iterator find(StringRef key) const
Definition: Pattern.cpp:507
bool bindMultipleValues(StringRef symbol, int numValues)
Definition: Pattern.cpp:491
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition: Pattern.cpp:439
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:572
bool bindValues(StringRef symbol, int numValues=1)
Definition: Pattern.cpp:479
bool bindAttr(StringRef symbol)
Definition: Pattern.cpp:498
bool bindValue(StringRef symbol)
Definition: Pattern.cpp:486
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition: Pattern.cpp:514
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition: Pattern.cpp:535
int getStaticValueCount(StringRef symbol) const
Definition: Pattern.cpp:546
bool contains(StringRef symbol) const
Definition: Pattern.cpp:503
BaseT::const_iterator const_iterator
Definition: Pattern.h:461
bool bindOpResult(StringRef symbol, const Operator &op)
Definition: Pattern.cpp:472
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:557
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:2992
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.