MLIR  21.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::DagInit;
30 using llvm::dbgs;
31 using llvm::DefInit;
32 using llvm::formatv;
33 using llvm::IntInit;
34 using llvm::Record;
35 
36 //===----------------------------------------------------------------------===//
37 // DagLeaf
38 //===----------------------------------------------------------------------===//
39 
40 bool DagLeaf::isUnspecified() const {
41  return isa_and_nonnull<llvm::UnsetInit>(def);
42 }
43 
45  // Operand matchers specify a type constraint.
46  return isSubClassOf("TypeConstraint");
47 }
48 
49 bool DagLeaf::isAttrMatcher() const {
50  // Attribute matchers specify an attribute constraint.
51  return isSubClassOf("AttrConstraint");
52 }
53 
55  return isSubClassOf("NativeCodeCall");
56 }
57 
58 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
59 
60 bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
61 
62 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
63 
65  assert((isOperandMatcher() || isAttrMatcher()) &&
66  "the DAG leaf must be operand or attribute");
67  return Constraint(cast<DefInit>(def)->getDef());
68 }
69 
71  assert(isConstantAttr() && "the DAG leaf must be constant attribute");
72  return ConstantAttr(cast<DefInit>(def));
73 }
74 
76  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
77  return EnumCase(cast<DefInit>(def));
78 }
79 
80 std::string DagLeaf::getConditionTemplate() const {
82 }
83 
84 StringRef DagLeaf::getNativeCodeTemplate() const {
85  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
86  return cast<DefInit>(def)->getDef()->getValueAsString("expression");
87 }
88 
90  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
91  return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
92 }
93 
94 std::string DagLeaf::getStringAttr() const {
95  assert(isStringAttr() && "the DAG leaf must be string attribute");
96  return def->getAsUnquotedString();
97 }
98 bool DagLeaf::isSubClassOf(StringRef superclass) const {
99  if (auto *defInit = dyn_cast_or_null<DefInit>(def))
100  return defInit->getDef()->isSubClassOf(superclass);
101  return false;
102 }
103 
104 void DagLeaf::print(raw_ostream &os) const {
105  if (def)
106  def->print(os);
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // DagNode
111 //===----------------------------------------------------------------------===//
112 
114  if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
115  return defInit->getDef()->isSubClassOf("NativeCodeCall");
116  return false;
117 }
118 
119 bool DagNode::isOperation() const {
120  return !isNativeCodeCall() && !isReplaceWithValue() &&
122  !isVariadic();
123 }
124 
126  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
127  return cast<DefInit>(node->getOperator())
128  ->getDef()
129  ->getValueAsString("expression");
130 }
131 
133  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
134  return cast<DefInit>(node->getOperator())
135  ->getDef()
136  ->getValueAsInt("numReturns");
137 }
138 
139 StringRef DagNode::getSymbol() const { return node->getNameStr(); }
140 
142  const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
143  auto [it, inserted] = mapper->try_emplace(opDef);
144  if (inserted)
145  it->second = std::make_unique<Operator>(opDef);
146  return *it->second;
147 }
148 
149 int DagNode::getNumOps() const {
150  // We want to get number of operations recursively involved in the DAG tree.
151  // All other directives should be excluded.
152  int count = isOperation() ? 1 : 0;
153  for (int i = 0, e = getNumArgs(); i != e; ++i) {
154  if (auto child = getArgAsNestedDag(i))
155  count += child.getNumOps();
156  }
157  return count;
158 }
159 
160 int DagNode::getNumArgs() const { return node->getNumArgs(); }
161 
162 bool DagNode::isNestedDagArg(unsigned index) const {
163  return isa<DagInit>(node->getArg(index));
164 }
165 
166 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
167  return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
168 }
169 
170 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
171  assert(!isNestedDagArg(index));
172  return DagLeaf(node->getArg(index));
173 }
174 
175 StringRef DagNode::getArgName(unsigned index) const {
176  return node->getArgNameStr(index);
177 }
178 
180  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
181  return dagOpDef->getName() == "replaceWithValue";
182 }
183 
185  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
186  return dagOpDef->getName() == "location";
187 }
188 
190  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
191  return dagOpDef->getName() == "returnType";
192 }
193 
194 bool DagNode::isEither() const {
195  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
196  return dagOpDef->getName() == "either";
197 }
198 
199 bool DagNode::isVariadic() const {
200  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
201  return dagOpDef->getName() == "variadic";
202 }
203 
204 void DagNode::print(raw_ostream &os) const {
205  if (node)
206  node->print(os);
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // SymbolInfoMap
211 //===----------------------------------------------------------------------===//
212 
213 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
214  int idx = -1;
215  auto [name, indexStr] = symbol.rsplit("__");
216 
217  if (indexStr.consumeInteger(10, idx)) {
218  // The second part is not an index; we return the whole symbol as-is.
219  return symbol;
220  }
221  if (index) {
222  *index = idx;
223  }
224  return name;
225 }
226 
227 SymbolInfoMap::SymbolInfo::SymbolInfo(
228  const Operator *op, SymbolInfo::Kind kind,
229  std::optional<DagAndConstant> dagAndConstant)
230  : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
231 
232 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
233  switch (kind) {
234  case Kind::Attr:
235  case Kind::Operand:
236  case Kind::Value:
237  return 1;
238  case Kind::Result:
239  return op->getNumResults();
240  case Kind::MultipleValues:
241  return getSize();
242  }
243  llvm_unreachable("unknown kind");
244 }
245 
246 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
247  return alternativeName ? *alternativeName : name.str();
248 }
249 
250 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
251  LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
252  switch (kind) {
253  case Kind::Attr: {
254  if (op)
255  return cast<NamedAttribute *>(op->getArg(getArgIndex()))
256  ->attr.getStorageType()
257  .str();
258  // TODO(suderman): Use a more exact type when available.
259  return "::mlir::Attribute";
260  }
261  case Kind::Operand: {
262  // Use operand range for captured operands (to support potential variadic
263  // operands).
264  return "::mlir::Operation::operand_range";
265  }
266  case Kind::Value: {
267  return "::mlir::Value";
268  }
269  case Kind::MultipleValues: {
270  return "::mlir::ValueRange";
271  }
272  case Kind::Result: {
273  // Use the op itself for captured results.
274  return op->getQualCppClassName();
275  }
276  }
277  llvm_unreachable("unknown kind");
278 }
279 
280 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
281  LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
282  std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
283  return std::string(
284  formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
285 }
286 
287 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
288  LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
289  return std::string(
290  formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
291 }
292 
293 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
294  StringRef name, int index, const char *fmt, const char *separator) const {
295  LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
296  switch (kind) {
297  case Kind::Attr: {
298  assert(index < 0);
299  auto repl = formatv(fmt, name);
300  LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
301  return std::string(repl);
302  }
303  case Kind::Operand: {
304  assert(index < 0);
305  auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
306  // If this operand is variadic and this SymbolInfo doesn't have a range
307  // index, then return the full variadic operand_range. Otherwise, return
308  // the value itself.
309  if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
310  auto repl = formatv(fmt, name);
311  LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
312  return std::string(repl);
313  }
314  auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
315  LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
316  return std::string(repl);
317  }
318  case Kind::Result: {
319  // If `index` is greater than zero, then we are referencing a specific
320  // result of a multi-result op. The result can still be variadic.
321  if (index >= 0) {
322  std::string v =
323  std::string(formatv("{0}.getODSResults({1})", name, index));
324  if (!op->getResult(index).isVariadic())
325  v = std::string(formatv("(*{0}.begin())", v));
326  auto repl = formatv(fmt, v);
327  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
328  return std::string(repl);
329  }
330 
331  // If this op has no result at all but still we bind a symbol to it, it
332  // means we want to capture the op itself.
333  if (op->getNumResults() == 0) {
334  LLVM_DEBUG(dbgs() << name << " (Op)\n");
335  return formatv(fmt, name);
336  }
337 
338  // We are referencing all results of the multi-result op. A specific result
339  // can either be a value or a range. Then join them with `separator`.
341  values.reserve(op->getNumResults());
342 
343  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
344  std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
345  if (!op->getResult(i).isVariadic()) {
346  v = std::string(formatv("(*{0}.begin())", v));
347  }
348  values.push_back(std::string(formatv(fmt, v)));
349  }
350  auto repl = llvm::join(values, separator);
351  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
352  return repl;
353  }
354  case Kind::Value: {
355  assert(index < 0);
356  assert(op == nullptr);
357  auto repl = formatv(fmt, name);
358  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
359  return std::string(repl);
360  }
361  case Kind::MultipleValues: {
362  assert(op == nullptr);
363  assert(index < getSize());
364  if (index >= 0) {
365  std::string repl =
366  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
367  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
368  return repl;
369  }
370  // If it doesn't specify certain element, unpack them all.
371  auto repl =
372  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
373  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
374  return std::string(repl);
375  }
376  }
377  llvm_unreachable("unknown kind");
378 }
379 
380 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
381  StringRef name, int index, const char *fmt, const char *separator) const {
382  LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
383  switch (kind) {
384  case Kind::Attr:
385  case Kind::Operand: {
386  assert(index < 0 && "only allowed for symbol bound to result");
387  auto repl = formatv(fmt, name);
388  LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
389  return std::string(repl);
390  }
391  case Kind::Result: {
392  if (index >= 0) {
393  auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
394  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
395  return std::string(repl);
396  }
397 
398  // We are referencing all results of the multi-result op. Each result should
399  // have a value range, and then join them with `separator`.
401  values.reserve(op->getNumResults());
402 
403  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
404  values.push_back(std::string(
405  formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
406  }
407  auto repl = llvm::join(values, separator);
408  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
409  return repl;
410  }
411  case Kind::Value: {
412  assert(index < 0 && "only allowed for symbol bound to result");
413  assert(op == nullptr);
414  auto repl = formatv(fmt, formatv("{{{0}}", name));
415  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
416  return std::string(repl);
417  }
418  case Kind::MultipleValues: {
419  assert(op == nullptr);
420  assert(index < getSize());
421  if (index >= 0) {
422  std::string repl =
423  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
424  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
425  return repl;
426  }
427  auto repl =
428  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
429  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
430  return std::string(repl);
431  }
432  }
433  llvm_unreachable("unknown kind");
434 }
435 
436 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
437  const Operator &op, int argIndex,
438  std::optional<int> variadicSubIndex) {
439  StringRef name = getValuePackName(symbol);
440  if (name != symbol) {
441  auto error = formatv(
442  "symbol '{0}' with trailing index cannot bind to op argument", symbol);
443  PrintFatalError(loc, error);
444  }
445 
446  auto symInfo =
447  isa<NamedAttribute *>(op.getArg(argIndex))
448  ? SymbolInfo::getAttr(&op, argIndex)
449  : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
450 
451  std::string key = symbol.str();
452  if (symbolInfoMap.count(key)) {
453  // Only non unique name for the operand is supported.
454  if (symInfo.kind != SymbolInfo::Kind::Operand) {
455  return false;
456  }
457 
458  // Cannot add new operand if there is already non operand with the same
459  // name.
460  if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
461  return false;
462  }
463  }
464 
465  symbolInfoMap.emplace(key, symInfo);
466  return true;
467 }
468 
469 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
470  std::string name = getValuePackName(symbol).str();
471  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
472 
473  return symbolInfoMap.count(inserted->first) == 1;
474 }
475 
476 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
477  std::string name = getValuePackName(symbol).str();
478  if (numValues > 1)
479  return bindMultipleValues(name, numValues);
480  return bindValue(name);
481 }
482 
483 bool SymbolInfoMap::bindValue(StringRef symbol) {
484  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
485  return symbolInfoMap.count(inserted->first) == 1;
486 }
487 
488 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
489  std::string name = getValuePackName(symbol).str();
490  auto inserted =
491  symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
492  return symbolInfoMap.count(inserted->first) == 1;
493 }
494 
495 bool SymbolInfoMap::bindAttr(StringRef symbol) {
496  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
497  return symbolInfoMap.count(inserted->first) == 1;
498 }
499 
500 bool SymbolInfoMap::contains(StringRef symbol) const {
501  return find(symbol) != symbolInfoMap.end();
502 }
503 
505  std::string name = getValuePackName(key).str();
506 
507  return symbolInfoMap.find(name);
508 }
509 
511 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
512  int argIndex,
513  std::optional<int> variadicSubIndex) const {
514  return findBoundSymbol(
515  key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
516 }
517 
520  const SymbolInfo &symbolInfo) const {
521  std::string name = getValuePackName(key).str();
522  auto range = symbolInfoMap.equal_range(name);
523 
524  for (auto it = range.first; it != range.second; ++it)
525  if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
526  return it;
527 
528  return symbolInfoMap.end();
529 }
530 
531 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
533  std::string name = getValuePackName(key).str();
534 
535  return symbolInfoMap.equal_range(name);
536 }
537 
538 int SymbolInfoMap::count(StringRef key) const {
539  std::string name = getValuePackName(key).str();
540  return symbolInfoMap.count(name);
541 }
542 
543 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
544  StringRef name = getValuePackName(symbol);
545  if (name != symbol) {
546  // If there is a trailing index inside symbol, it references just one
547  // static value.
548  return 1;
549  }
550  // Otherwise, find how many it represents by querying the symbol's info.
551  return find(name)->second.getStaticValueCount();
552 }
553 
554 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
555  const char *fmt,
556  const char *separator) const {
557  int index = -1;
558  StringRef name = getValuePackName(symbol, &index);
559 
560  auto it = symbolInfoMap.find(name.str());
561  if (it == symbolInfoMap.end()) {
562  auto error = formatv("referencing unbound symbol '{0}'", symbol);
563  PrintFatalError(loc, error);
564  }
565 
566  return it->second.getValueAndRangeUse(name, index, fmt, separator);
567 }
568 
569 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
570  const char *separator) const {
571  int index = -1;
572  StringRef name = getValuePackName(symbol, &index);
573 
574  auto it = symbolInfoMap.find(name.str());
575  if (it == symbolInfoMap.end()) {
576  auto error = formatv("referencing unbound symbol '{0}'", symbol);
577  PrintFatalError(loc, error);
578  }
579 
580  return it->second.getAllRangeUse(name, index, fmt, separator);
581 }
582 
584  llvm::StringSet<> usedNames;
585 
586  for (auto symbolInfoIt = symbolInfoMap.begin();
587  symbolInfoIt != symbolInfoMap.end();) {
588  auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
589  auto startRange = range.first;
590  auto endRange = range.second;
591 
592  auto operandName = symbolInfoIt->first;
593  int startSearchIndex = 0;
594  for (++startRange; startRange != endRange; ++startRange) {
595  // Current operand name is not unique, find a unique one
596  // and set the alternative name.
597  for (int i = startSearchIndex;; ++i) {
598  std::string alternativeName = operandName + std::to_string(i);
599  if (!usedNames.contains(alternativeName) &&
600  symbolInfoMap.count(alternativeName) == 0) {
601  usedNames.insert(alternativeName);
602  startRange->second.alternativeName = alternativeName;
603  startSearchIndex = i + 1;
604 
605  break;
606  }
607  }
608  }
609 
610  symbolInfoIt = endRange;
611  }
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // Pattern
616 //==----------------------------------------------------------------------===//
617 
618 Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
619  : def(*def), recordOpMap(mapper) {}
620 
622  return DagNode(def.getValueAsDag("sourcePattern"));
623 }
624 
626  auto *results = def.getValueAsListInit("resultPatterns");
627  return results->size();
628 }
629 
630 DagNode Pattern::getResultPattern(unsigned index) const {
631  auto *results = def.getValueAsListInit("resultPatterns");
632  return DagNode(cast<DagInit>(results->getElement(index)));
633 }
634 
636  LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
637  collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
638  LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
639 
640  LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
642  LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
643 }
644 
646  LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
647  for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
648  auto pattern = getResultPattern(i);
649  collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
650  }
651  LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
652 }
653 
655  return getSourcePattern().getDialectOp(recordOpMap);
656 }
657 
659  return node.getDialectOp(recordOpMap);
660 }
661 
662 std::vector<AppliedConstraint> Pattern::getConstraints() const {
663  auto *listInit = def.getValueAsListInit("constraints");
664  std::vector<AppliedConstraint> ret;
665  ret.reserve(listInit->size());
666 
667  for (auto *it : *listInit) {
668  auto *dagInit = dyn_cast<DagInit>(it);
669  if (!dagInit)
670  PrintFatalError(&def, "all elements in Pattern multi-entity "
671  "constraints should be DAG nodes");
672 
673  std::vector<std::string> entities;
674  entities.reserve(dagInit->arg_size());
675  for (auto *argName : dagInit->getArgNames()) {
676  if (!argName) {
677  PrintFatalError(
678  &def,
679  "operands to additional constraints can only be symbol references");
680  }
681  entities.emplace_back(argName->getValue());
682  }
683 
684  ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
685  dagInit->getNameStr(), std::move(entities));
686  }
687  return ret;
688 }
689 
691  auto *results = def.getValueAsListInit("supplementalPatterns");
692  return results->size();
693 }
694 
696  auto *results = def.getValueAsListInit("supplementalPatterns");
697  return DagNode(cast<DagInit>(results->getElement(index)));
698 }
699 
700 int Pattern::getBenefit() const {
701  // The initial benefit value is a heuristic with number of ops in the source
702  // pattern.
703  int initBenefit = getSourcePattern().getNumOps();
704  const DagInit *delta = def.getValueAsDag("benefitDelta");
705  if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
706  PrintFatalError(&def,
707  "The 'addBenefit' takes and only takes one integer value");
708  }
709  return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
710 }
711 
712 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
713  std::vector<std::pair<StringRef, unsigned>> result;
714  result.reserve(def.getLoc().size());
715  for (auto loc : def.getLoc()) {
716  unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
717  assert(buf && "invalid source location");
718  result.emplace_back(
719  llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
720  llvm::SrcMgr.getLineAndColumn(loc, buf).first);
721  }
722  return result;
723 }
724 
725 void Pattern::verifyBind(bool result, StringRef symbolName) {
726  if (!result) {
727  auto err = formatv("symbol '{0}' bound more than once", symbolName);
728  PrintFatalError(&def, err);
729  }
730 }
731 
733  bool isSrcPattern) {
734  auto treeName = tree.getSymbol();
735  auto numTreeArgs = tree.getNumArgs();
736 
737  if (tree.isNativeCodeCall()) {
738  if (!treeName.empty()) {
739  if (!isSrcPattern) {
740  LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
741  << treeName << '\n');
742  verifyBind(
743  infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
744  treeName);
745  } else {
746  PrintFatalError(&def,
747  formatv("binding symbol '{0}' to NativecodeCall in "
748  "MatchPattern is not supported",
749  treeName));
750  }
751  }
752 
753  for (int i = 0; i != numTreeArgs; ++i) {
754  if (auto treeArg = tree.getArgAsNestedDag(i)) {
755  // This DAG node argument is a DAG node itself. Go inside recursively.
756  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
757  continue;
758  }
759 
760  if (!isSrcPattern)
761  continue;
762 
763  // We can only bind symbols to arguments in source pattern. Those
764  // symbols are referenced in result patterns.
765  auto treeArgName = tree.getArgName(i);
766 
767  // `$_` is a special symbol meaning ignore the current argument.
768  if (!treeArgName.empty() && treeArgName != "_") {
769  DagLeaf leaf = tree.getArgAsLeaf(i);
770 
771  // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
772  if (leaf.isUnspecified()) {
773  // This is case of $c, a Value without any constraints.
774  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
775  } else {
776  auto constraint = leaf.getAsConstraint();
777  bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
778  leaf.isConstantAttr() ||
779  constraint.getKind() == Constraint::Kind::CK_Attr;
780 
781  if (isAttr) {
782  // This is case of $a, a binding to a certain attribute.
783  verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
784  continue;
785  }
786 
787  // This is case of $b, a binding to a certain type.
788  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
789  }
790  }
791  }
792 
793  return;
794  }
795 
796  if (tree.isOperation()) {
797  auto &op = getDialectOp(tree);
798  auto numOpArgs = op.getNumArgs();
799  int numEither = 0;
800 
801  // We need to exclude the trailing directives and `either` directive groups
802  // two operands of the operation.
803  int numDirectives = 0;
804  for (int i = numTreeArgs - 1; i >= 0; --i) {
805  if (auto dagArg = tree.getArgAsNestedDag(i)) {
806  if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
807  ++numDirectives;
808  else if (dagArg.isEither())
809  ++numEither;
810  }
811  }
812 
813  if (numOpArgs != numTreeArgs - numDirectives + numEither) {
814  auto err =
815  formatv("op '{0}' argument number mismatch: "
816  "{1} in pattern vs. {2} in definition",
817  op.getOperationName(), numTreeArgs + numEither, numOpArgs);
818  PrintFatalError(&def, err);
819  }
820 
821  // The name attached to the DAG node's operator is for representing the
822  // results generated from this op. It should be remembered as bound results.
823  if (!treeName.empty()) {
824  LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
825  << '\n');
826  verifyBind(infoMap.bindOpResult(treeName, op), treeName);
827  }
828 
829  // The operand in `either` DAG should be bound to the operation in the
830  // parent DagNode.
831  auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
832  int opArgIdx) {
833  for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
834  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
835  collectBoundSymbols(subTree, infoMap, isSrcPattern);
836  } else {
837  auto argName = tree.getArgName(i);
838  if (!argName.empty() && argName != "_") {
839  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
840  argName);
841  }
842  }
843  }
844  };
845 
846  // The operand in `variadic` DAG should be bound to the operation in the
847  // parent DagNode. The range index must be included as well to distinguish
848  // (potentially) repeating argName within the `variadic` DAG.
849  auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
850  int opArgIdx) {
851  auto treeName = tree.getSymbol();
852  if (!treeName.empty()) {
853  // If treeName is specified, bind to the full variadic operand_range.
854  verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
855  std::nullopt),
856  treeName);
857  }
858 
859  for (int i = 0; i < tree.getNumArgs(); ++i) {
860  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
861  collectBoundSymbols(subTree, infoMap, isSrcPattern);
862  } else {
863  auto argName = tree.getArgName(i);
864  if (!argName.empty() && argName != "_") {
865  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
866  /*variadicSubIndex=*/i),
867  argName);
868  }
869  }
870  }
871  };
872 
873  for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
874  if (auto treeArg = tree.getArgAsNestedDag(i)) {
875  if (treeArg.isEither()) {
876  collectSymbolInEither(tree, treeArg, opArgIdx);
877  // `either` DAG is *flattened*. For example,
878  //
879  // (FooOp (either arg0, arg1), arg2)
880  //
881  // can be viewed as:
882  //
883  // (FooOp arg0, arg1, arg2)
884  ++opArgIdx;
885  } else if (treeArg.isVariadic()) {
886  collectSymbolInVariadic(tree, treeArg, opArgIdx);
887  } else {
888  // This DAG node argument is a DAG node itself. Go inside recursively.
889  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
890  }
891  continue;
892  }
893 
894  if (isSrcPattern) {
895  // We can only bind symbols to op arguments in source pattern. Those
896  // symbols are referenced in result patterns.
897  auto treeArgName = tree.getArgName(i);
898  // `$_` is a special symbol meaning ignore the current argument.
899  if (!treeArgName.empty() && treeArgName != "_") {
900  LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
901  << treeArgName << '\n');
902  verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
903  treeArgName);
904  }
905  }
906  }
907  return;
908  }
909 
910  if (!treeName.empty()) {
911  PrintFatalError(
912  &def, formatv("binding symbol '{0}' to non-operation/native code call "
913  "unsupported right now",
914  treeName));
915  }
916 }
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Constraint getAsConstraint() const
Definition: Pattern.cpp:64
bool isNativeCodeCall() const
Definition: Pattern.cpp:54
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:89
ConstantAttr getAsConstantAttr() const
Definition: Pattern.cpp:70
void print(raw_ostream &os) const
Definition: Pattern.cpp:104
std::string getStringAttr() const
Definition: Pattern.cpp:94
bool isEnumCase() const
Definition: Pattern.cpp:60
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:84
std::string getConditionTemplate() const
Definition: Pattern.cpp:80
bool isUnspecified() const
Definition: Pattern.cpp:40
EnumCase getAsEnumCase() const
Definition: Pattern.cpp:75
bool isAttrMatcher() const
Definition: Pattern.cpp:49
bool isOperandMatcher() const
Definition: Pattern.cpp:44
bool isConstantAttr() const
Definition: Pattern.cpp:58
bool isStringAttr() const
Definition: Pattern.cpp:62
bool isReturnTypeDirective() const
Definition: Pattern.cpp:189
bool isLocationDirective() const
Definition: Pattern.cpp:184
bool isReplaceWithValue() const
Definition: Pattern.cpp:179
DagNode getArgAsNestedDag(unsigned index) const
Definition: Pattern.cpp:166
bool isOperation() const
Definition: Pattern.cpp:119
DagLeaf getArgAsLeaf(unsigned index) const
Definition: Pattern.cpp:170
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:132
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:125
void print(raw_ostream &os) const
Definition: Pattern.cpp:204
int getNumOps() const
Definition: Pattern.cpp:149
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition: Pattern.cpp:141
bool isVariadic() const
Definition: Pattern.cpp:199
bool isNativeCodeCall() const
Definition: Pattern.cpp:113
bool isEither() const
Definition: Pattern.cpp:194
bool isNestedDagArg(unsigned index) const
Definition: Pattern.cpp:162
StringRef getSymbol() const
Definition: Pattern.cpp:139
int getNumArgs() const
Definition: Pattern.cpp:160
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:144
StringRef getArgName(unsigned index) const
Definition: Pattern.cpp:175
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition: Operator.cpp:166
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition: Operator.cpp:185
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition: Operator.cpp:359
int getBenefit() const
int getNumResultPatterns() const
Definition: Pattern.cpp:625
std::vector< IdentifierLine > getLocation() const
Definition: Pattern.cpp:712
DagNode getSourcePattern() const
Definition: Pattern.cpp:621
const Operator & getSourceRootOp()
Definition: Pattern.cpp:654
std::vector< AppliedConstraint > getConstraints() const
Definition: Pattern.cpp:662
DagNode getResultPattern(unsigned index) const
Definition: Pattern.cpp:630
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition: Pattern.cpp:732
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition: Pattern.cpp:658
DagNode getSupplementalPattern(unsigned index) const
Definition: Pattern.cpp:695
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:635
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:645
int getNumSupplementalPatterns() const
Definition: Pattern.cpp:690
std::string getArgDecl(StringRef name) const
Definition: Pattern.cpp:287
std::string getVarName(StringRef name) const
Definition: Pattern.cpp:246
std::string getVarTypeStr(StringRef name) const
Definition: Pattern.cpp:250
std::string getVarDecl(StringRef name) const
Definition: Pattern.cpp:280
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition: Pattern.cpp:213
int count(StringRef key) const
Definition: Pattern.cpp:538
const_iterator find(StringRef key) const
Definition: Pattern.cpp:504
bool bindMultipleValues(StringRef symbol, int numValues)
Definition: Pattern.cpp:488
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition: Pattern.cpp:436
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:569
bool bindValues(StringRef symbol, int numValues=1)
Definition: Pattern.cpp:476
bool bindAttr(StringRef symbol)
Definition: Pattern.cpp:495
bool bindValue(StringRef symbol)
Definition: Pattern.cpp:483
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition: Pattern.cpp:511
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition: Pattern.cpp:532
int getStaticValueCount(StringRef symbol) const
Definition: Pattern.cpp:543
bool contains(StringRef symbol) const
Definition: Pattern.cpp:500
BaseT::const_iterator const_iterator
Definition: Pattern.h:462
bool bindOpResult(StringRef symbol, const Operator &op)
Definition: Pattern.cpp:469
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:554
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:3216
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.