MLIR  22.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25 
26 using namespace mlir;
27 using namespace tblgen;
28 
29 using llvm::DagInit;
30 using llvm::dbgs;
31 using llvm::DefInit;
32 using llvm::formatv;
33 using llvm::IntInit;
34 using llvm::Record;
35 
36 //===----------------------------------------------------------------------===//
37 // DagLeaf
38 //===----------------------------------------------------------------------===//
39 
40 bool DagLeaf::isUnspecified() const {
41  return isa_and_nonnull<llvm::UnsetInit>(def);
42 }
43 
45  // Operand matchers specify a type constraint.
46  return isSubClassOf("TypeConstraint");
47 }
48 
49 bool DagLeaf::isAttrMatcher() const {
50  // Attribute matchers specify an attribute constraint.
51  return isSubClassOf("AttrConstraint");
52 }
53 
54 bool DagLeaf::isPropMatcher() const {
55  // Property matchers specify a property constraint.
56  return isSubClassOf("PropConstraint");
57 }
58 
60  // Property matchers specify a property definition.
61  return isSubClassOf("Property");
62 }
63 
65  return isSubClassOf("NativeCodeCall");
66 }
67 
68 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
69 
70 bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
71 
72 bool DagLeaf::isConstantProp() const { return isSubClassOf("ConstantProp"); }
73 
74 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
75 
77  assert((isOperandMatcher() || isAttrMatcher() || isPropMatcher()) &&
78  "the DAG leaf must be operand, attribute, or property");
79  return Constraint(cast<DefInit>(def)->getDef());
80 }
81 
83  assert(isPropMatcher() && "the DAG leaf must be a property matcher");
84  return PropConstraint(cast<DefInit>(def)->getDef());
85 }
86 
88  assert(isPropDefinition() && "the DAG leaf must be a property definition");
89  return Property(cast<DefInit>(def)->getDef());
90 }
91 
93  assert(isConstantAttr() && "the DAG leaf must be constant attribute");
94  return ConstantAttr(cast<DefInit>(def));
95 }
96 
98  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
99  return EnumCase(cast<DefInit>(def));
100 }
101 
103  assert(isConstantProp() && "the DAG leaf must be a constant property value");
104  return ConstantProp(cast<DefInit>(def));
105 }
106 
107 std::string DagLeaf::getConditionTemplate() const {
109 }
110 
112  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
113  return cast<DefInit>(def)->getDef()->getValueAsString("expression");
114 }
115 
117  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
118  return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
119 }
120 
121 std::string DagLeaf::getStringAttr() const {
122  assert(isStringAttr() && "the DAG leaf must be string attribute");
123  return def->getAsUnquotedString();
124 }
125 bool DagLeaf::isSubClassOf(StringRef superclass) const {
126  if (auto *defInit = dyn_cast_or_null<DefInit>(def))
127  return defInit->getDef()->isSubClassOf(superclass);
128  return false;
129 }
130 
131 void DagLeaf::print(raw_ostream &os) const {
132  if (def)
133  def->print(os);
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // DagNode
138 //===----------------------------------------------------------------------===//
139 
141  if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
142  return defInit->getDef()->isSubClassOf("NativeCodeCall");
143  return false;
144 }
145 
146 bool DagNode::isOperation() const {
147  return !isNativeCodeCall() && !isReplaceWithValue() &&
149  !isVariadic();
150 }
151 
153  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
154  return cast<DefInit>(node->getOperator())
155  ->getDef()
156  ->getValueAsString("expression");
157 }
158 
160  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
161  return cast<DefInit>(node->getOperator())
162  ->getDef()
163  ->getValueAsInt("numReturns");
164 }
165 
166 StringRef DagNode::getSymbol() const { return node->getNameStr(); }
167 
169  const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
170  auto [it, inserted] = mapper->try_emplace(opDef);
171  if (inserted)
172  it->second = std::make_unique<Operator>(opDef);
173  return *it->second;
174 }
175 
176 int DagNode::getNumOps() const {
177  // We want to get number of operations recursively involved in the DAG tree.
178  // All other directives should be excluded.
179  int count = isOperation() ? 1 : 0;
180  for (int i = 0, e = getNumArgs(); i != e; ++i) {
181  if (auto child = getArgAsNestedDag(i))
182  count += child.getNumOps();
183  }
184  return count;
185 }
186 
187 int DagNode::getNumArgs() const { return node->getNumArgs(); }
188 
189 bool DagNode::isNestedDagArg(unsigned index) const {
190  return isa<DagInit>(node->getArg(index));
191 }
192 
193 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
194  return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
195 }
196 
197 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
198  assert(!isNestedDagArg(index));
199  return DagLeaf(node->getArg(index));
200 }
201 
202 StringRef DagNode::getArgName(unsigned index) const {
203  return node->getArgNameStr(index);
204 }
205 
207  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
208  return dagOpDef->getName() == "replaceWithValue";
209 }
210 
212  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
213  return dagOpDef->getName() == "location";
214 }
215 
217  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
218  return dagOpDef->getName() == "returnType";
219 }
220 
221 bool DagNode::isEither() const {
222  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
223  return dagOpDef->getName() == "either";
224 }
225 
226 bool DagNode::isVariadic() const {
227  auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
228  return dagOpDef->getName() == "variadic";
229 }
230 
231 void DagNode::print(raw_ostream &os) const {
232  if (node)
233  node->print(os);
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // SymbolInfoMap
238 //===----------------------------------------------------------------------===//
239 
240 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
241  int idx = -1;
242  auto [name, indexStr] = symbol.rsplit("__");
243 
244  if (indexStr.consumeInteger(10, idx)) {
245  // The second part is not an index; we return the whole symbol as-is.
246  return symbol;
247  }
248  if (index) {
249  *index = idx;
250  }
251  return name;
252 }
253 
254 SymbolInfoMap::SymbolInfo::SymbolInfo(
255  const Operator *op, SymbolInfo::Kind kind,
256  std::optional<DagAndConstant> dagAndConstant)
257  : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
258 
259 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
260  switch (kind) {
261  case Kind::Attr:
262  case Kind::Prop:
263  case Kind::Operand:
264  case Kind::Value:
265  return 1;
266  case Kind::Result:
267  return op->getNumResults();
268  case Kind::MultipleValues:
269  return getSize();
270  }
271  llvm_unreachable("unknown kind");
272 }
273 
274 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
275  return alternativeName ? *alternativeName : name.str();
276 }
277 
278 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
279  LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
280  switch (kind) {
281  case Kind::Attr: {
282  if (op)
283  return cast<NamedAttribute *>(op->getArg(getArgIndex()))
284  ->attr.getStorageType()
285  .str();
286  // TODO(suderman): Use a more exact type when available.
287  return "::mlir::Attribute";
288  }
289  case Kind::Prop: {
290  if (op)
291  return cast<NamedProperty *>(op->getArg(getArgIndex()))
292  ->prop.getInterfaceType()
293  .str();
294  assert(dagAndConstant && dagAndConstant->dag &&
295  "generic properties must carry their constraint");
296  return reinterpret_cast<const DagLeaf *>(dagAndConstant->dag)
297  ->getAsPropConstraint()
298  .getInterfaceType()
299  .str();
300  }
301  case Kind::Operand: {
302  // Use operand range for captured operands (to support potential variadic
303  // operands).
304  return "::mlir::Operation::operand_range";
305  }
306  case Kind::Value: {
307  return "::mlir::Value";
308  }
309  case Kind::MultipleValues: {
310  return "::mlir::ValueRange";
311  }
312  case Kind::Result: {
313  // Use the op itself for captured results.
314  return op->getQualCppClassName();
315  }
316  }
317  llvm_unreachable("unknown kind");
318 }
319 
320 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
321  LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
322  std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
323  return std::string(
324  formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
325 }
326 
327 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
328  LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
329  return std::string(
330  formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
331 }
332 
333 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
334  StringRef name, int index, const char *fmt, const char *separator) const {
335  LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
336  switch (kind) {
337  case Kind::Attr: {
338  assert(index < 0);
339  auto repl = formatv(fmt, name);
340  LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
341  return std::string(repl);
342  }
343  case Kind::Prop: {
344  assert(index < 0);
345  auto repl = formatv(fmt, name);
346  LLVM_DEBUG(dbgs() << repl << " (Prop)\n");
347  return std::string(repl);
348  }
349  case Kind::Operand: {
350  assert(index < 0);
351  auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
352  if (operand->isOptional()) {
353  auto repl = formatv(
354  fmt, formatv("({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
355  LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n");
356  return std::string(repl);
357  }
358  // If this operand is variadic and this SymbolInfo doesn't have a range
359  // index, then return the full variadic operand_range. Otherwise, return
360  // the value itself.
361  if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
362  auto repl = formatv(fmt, name);
363  LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
364  return std::string(repl);
365  }
366  auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
367  LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
368  return std::string(repl);
369  }
370  case Kind::Result: {
371  // If `index` is greater than zero, then we are referencing a specific
372  // result of a multi-result op. The result can still be variadic.
373  if (index >= 0) {
374  std::string v =
375  std::string(formatv("{0}.getODSResults({1})", name, index));
376  if (!op->getResult(index).isVariadic())
377  v = std::string(formatv("(*{0}.begin())", v));
378  auto repl = formatv(fmt, v);
379  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
380  return std::string(repl);
381  }
382 
383  // If this op has no result at all but still we bind a symbol to it, it
384  // means we want to capture the op itself.
385  if (op->getNumResults() == 0) {
386  LLVM_DEBUG(dbgs() << name << " (Op)\n");
387  return formatv(fmt, name);
388  }
389 
390  // We are referencing all results of the multi-result op. A specific result
391  // can either be a value or a range. Then join them with `separator`.
393  values.reserve(op->getNumResults());
394 
395  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
396  std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
397  if (!op->getResult(i).isVariadic()) {
398  v = std::string(formatv("(*{0}.begin())", v));
399  }
400  values.push_back(std::string(formatv(fmt, v)));
401  }
402  auto repl = llvm::join(values, separator);
403  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
404  return repl;
405  }
406  case Kind::Value: {
407  assert(index < 0);
408  assert(op == nullptr);
409  auto repl = formatv(fmt, name);
410  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
411  return std::string(repl);
412  }
413  case Kind::MultipleValues: {
414  assert(op == nullptr);
415  assert(index < getSize());
416  if (index >= 0) {
417  std::string repl =
418  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
419  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
420  return repl;
421  }
422  // If it doesn't specify certain element, unpack them all.
423  auto repl =
424  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
425  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
426  return std::string(repl);
427  }
428  }
429  llvm_unreachable("unknown kind");
430 }
431 
432 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
433  StringRef name, int index, const char *fmt, const char *separator) const {
434  LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
435  switch (kind) {
436  case Kind::Attr:
437  case Kind::Prop:
438  case Kind::Operand: {
439  assert(index < 0 && "only allowed for symbol bound to result");
440  auto repl = formatv(fmt, name);
441  LLVM_DEBUG(dbgs() << repl << " (Operand/Attr/Prop)\n");
442  return std::string(repl);
443  }
444  case Kind::Result: {
445  if (index >= 0) {
446  auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
447  LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
448  return std::string(repl);
449  }
450 
451  // We are referencing all results of the multi-result op. Each result should
452  // have a value range, and then join them with `separator`.
454  values.reserve(op->getNumResults());
455 
456  for (int i = 0, e = op->getNumResults(); i < e; ++i) {
457  values.push_back(std::string(
458  formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
459  }
460  auto repl = llvm::join(values, separator);
461  LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
462  return repl;
463  }
464  case Kind::Value: {
465  assert(index < 0 && "only allowed for symbol bound to result");
466  assert(op == nullptr);
467  auto repl = formatv(fmt, formatv("{{{0}}", name));
468  LLVM_DEBUG(dbgs() << repl << " (Value)\n");
469  return std::string(repl);
470  }
471  case Kind::MultipleValues: {
472  assert(op == nullptr);
473  assert(index < getSize());
474  if (index >= 0) {
475  std::string repl =
476  formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
477  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
478  return repl;
479  }
480  auto repl =
481  formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
482  LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
483  return std::string(repl);
484  }
485  }
486  llvm_unreachable("unknown kind");
487 }
488 
489 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
490  const Operator &op, int argIndex,
491  std::optional<int> variadicSubIndex) {
492  StringRef name = getValuePackName(symbol);
493  if (name != symbol) {
494  auto error = formatv(
495  "symbol '{0}' with trailing index cannot bind to op argument", symbol);
496  PrintFatalError(loc, error);
497  }
498 
499  Argument arg = op.getArg(argIndex);
500  SymbolInfo symInfo =
501  isa<NamedAttribute *>(arg) ? SymbolInfo::getAttr(&op, argIndex)
502  : isa<NamedProperty *>(arg)
503  ? SymbolInfo::getProp(&op, argIndex)
504  : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
505 
506  std::string key = symbol.str();
507  if (symbolInfoMap.count(key)) {
508  // Only non unique name for the operand is supported.
509  if (symInfo.kind != SymbolInfo::Kind::Operand) {
510  return false;
511  }
512 
513  // Cannot add new operand if there is already non operand with the same
514  // name.
515  if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
516  return false;
517  }
518  }
519 
520  symbolInfoMap.emplace(key, symInfo);
521  return true;
522 }
523 
524 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
525  std::string name = getValuePackName(symbol).str();
526  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
527 
528  return symbolInfoMap.count(inserted->first) == 1;
529 }
530 
531 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
532  std::string name = getValuePackName(symbol).str();
533  if (numValues > 1)
534  return bindMultipleValues(name, numValues);
535  return bindValue(name);
536 }
537 
538 bool SymbolInfoMap::bindValue(StringRef symbol) {
539  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
540  return symbolInfoMap.count(inserted->first) == 1;
541 }
542 
543 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
544  std::string name = getValuePackName(symbol).str();
545  auto inserted =
546  symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
547  return symbolInfoMap.count(inserted->first) == 1;
548 }
549 
550 bool SymbolInfoMap::bindAttr(StringRef symbol) {
551  auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
552  return symbolInfoMap.count(inserted->first) == 1;
553 }
554 
555 bool SymbolInfoMap::bindProp(StringRef symbol,
556  const PropConstraint &constraint) {
557  auto inserted =
558  symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint));
559  return symbolInfoMap.count(inserted->first) == 1;
560 }
561 
562 bool SymbolInfoMap::contains(StringRef symbol) const {
563  return find(symbol) != symbolInfoMap.end();
564 }
565 
567  std::string name = getValuePackName(key).str();
568 
569  return symbolInfoMap.find(name);
570 }
571 
573 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
574  int argIndex,
575  std::optional<int> variadicSubIndex) const {
576  return findBoundSymbol(
577  key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
578 }
579 
582  const SymbolInfo &symbolInfo) const {
583  std::string name = getValuePackName(key).str();
584  auto range = symbolInfoMap.equal_range(name);
585 
586  for (auto it = range.first; it != range.second; ++it)
587  if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
588  return it;
589 
590  return symbolInfoMap.end();
591 }
592 
593 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
595  std::string name = getValuePackName(key).str();
596 
597  return symbolInfoMap.equal_range(name);
598 }
599 
600 int SymbolInfoMap::count(StringRef key) const {
601  std::string name = getValuePackName(key).str();
602  return symbolInfoMap.count(name);
603 }
604 
605 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
606  StringRef name = getValuePackName(symbol);
607  if (name != symbol) {
608  // If there is a trailing index inside symbol, it references just one
609  // static value.
610  return 1;
611  }
612  // Otherwise, find how many it represents by querying the symbol's info.
613  return find(name)->second.getStaticValueCount();
614 }
615 
616 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
617  const char *fmt,
618  const char *separator) const {
619  int index = -1;
620  StringRef name = getValuePackName(symbol, &index);
621 
622  auto it = symbolInfoMap.find(name.str());
623  if (it == symbolInfoMap.end()) {
624  auto error = formatv("referencing unbound symbol '{0}'", symbol);
625  PrintFatalError(loc, error);
626  }
627 
628  return it->second.getValueAndRangeUse(name, index, fmt, separator);
629 }
630 
631 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
632  const char *separator) const {
633  int index = -1;
634  StringRef name = getValuePackName(symbol, &index);
635 
636  auto it = symbolInfoMap.find(name.str());
637  if (it == symbolInfoMap.end()) {
638  auto error = formatv("referencing unbound symbol '{0}'", symbol);
639  PrintFatalError(loc, error);
640  }
641 
642  return it->second.getAllRangeUse(name, index, fmt, separator);
643 }
644 
646  llvm::StringSet<> usedNames;
647 
648  for (auto symbolInfoIt = symbolInfoMap.begin();
649  symbolInfoIt != symbolInfoMap.end();) {
650  auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
651  auto startRange = range.first;
652  auto endRange = range.second;
653 
654  auto operandName = symbolInfoIt->first;
655  int startSearchIndex = 0;
656  for (++startRange; startRange != endRange; ++startRange) {
657  // Current operand name is not unique, find a unique one
658  // and set the alternative name.
659  for (int i = startSearchIndex;; ++i) {
660  std::string alternativeName = operandName + std::to_string(i);
661  if (!usedNames.contains(alternativeName) &&
662  symbolInfoMap.count(alternativeName) == 0) {
663  usedNames.insert(alternativeName);
664  startRange->second.alternativeName = alternativeName;
665  startSearchIndex = i + 1;
666 
667  break;
668  }
669  }
670  }
671 
672  symbolInfoIt = endRange;
673  }
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // Pattern
678 //==----------------------------------------------------------------------===//
679 
680 Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
681  : def(*def), recordOpMap(mapper) {}
682 
684  return DagNode(def.getValueAsDag("sourcePattern"));
685 }
686 
688  auto *results = def.getValueAsListInit("resultPatterns");
689  return results->size();
690 }
691 
692 DagNode Pattern::getResultPattern(unsigned index) const {
693  auto *results = def.getValueAsListInit("resultPatterns");
694  return DagNode(cast<DagInit>(results->getElement(index)));
695 }
696 
698  LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
699  collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
700  LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
701 
702  LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
704  LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
705 }
706 
708  LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
709  for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
710  auto pattern = getResultPattern(i);
711  collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
712  }
713  LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
714 }
715 
717  return getSourcePattern().getDialectOp(recordOpMap);
718 }
719 
721  return node.getDialectOp(recordOpMap);
722 }
723 
724 std::vector<AppliedConstraint> Pattern::getConstraints() const {
725  auto *listInit = def.getValueAsListInit("constraints");
726  std::vector<AppliedConstraint> ret;
727  ret.reserve(listInit->size());
728 
729  for (auto *it : *listInit) {
730  auto *dagInit = dyn_cast<DagInit>(it);
731  if (!dagInit)
732  PrintFatalError(&def, "all elements in Pattern multi-entity "
733  "constraints should be DAG nodes");
734 
735  std::vector<std::string> entities;
736  entities.reserve(dagInit->arg_size());
737  for (auto *argName : dagInit->getArgNames()) {
738  if (!argName) {
739  PrintFatalError(
740  &def,
741  "operands to additional constraints can only be symbol references");
742  }
743  entities.emplace_back(argName->getValue());
744  }
745 
746  ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
747  dagInit->getNameStr(), std::move(entities));
748  }
749  return ret;
750 }
751 
753  auto *results = def.getValueAsListInit("supplementalPatterns");
754  return results->size();
755 }
756 
758  auto *results = def.getValueAsListInit("supplementalPatterns");
759  return DagNode(cast<DagInit>(results->getElement(index)));
760 }
761 
762 int Pattern::getBenefit() const {
763  // The initial benefit value is a heuristic with number of ops in the source
764  // pattern.
765  int initBenefit = getSourcePattern().getNumOps();
766  const DagInit *delta = def.getValueAsDag("benefitDelta");
767  if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
768  PrintFatalError(&def,
769  "The 'addBenefit' takes and only takes one integer value");
770  }
771  return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
772 }
773 
774 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
775  std::vector<std::pair<StringRef, unsigned>> result;
776  result.reserve(def.getLoc().size());
777  for (auto loc : def.getLoc()) {
778  unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
779  assert(buf && "invalid source location");
780  result.emplace_back(
781  llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
782  llvm::SrcMgr.getLineAndColumn(loc, buf).first);
783  }
784  return result;
785 }
786 
787 void Pattern::verifyBind(bool result, StringRef symbolName) {
788  if (!result) {
789  auto err = formatv("symbol '{0}' bound more than once", symbolName);
790  PrintFatalError(&def, err);
791  }
792 }
793 
795  bool isSrcPattern) {
796  auto treeName = tree.getSymbol();
797  auto numTreeArgs = tree.getNumArgs();
798 
799  if (tree.isNativeCodeCall()) {
800  if (!treeName.empty()) {
801  if (!isSrcPattern) {
802  LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
803  << treeName << '\n');
804  verifyBind(
805  infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
806  treeName);
807  } else {
808  PrintFatalError(&def,
809  formatv("binding symbol '{0}' to NativecodeCall in "
810  "MatchPattern is not supported",
811  treeName));
812  }
813  }
814 
815  for (int i = 0; i != numTreeArgs; ++i) {
816  if (auto treeArg = tree.getArgAsNestedDag(i)) {
817  // This DAG node argument is a DAG node itself. Go inside recursively.
818  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
819  continue;
820  }
821 
822  if (!isSrcPattern)
823  continue;
824 
825  // We can only bind symbols to arguments in source pattern. Those
826  // symbols are referenced in result patterns.
827  auto treeArgName = tree.getArgName(i);
828 
829  // `$_` is a special symbol meaning ignore the current argument.
830  if (!treeArgName.empty() && treeArgName != "_") {
831  DagLeaf leaf = tree.getArgAsLeaf(i);
832 
833  // In (NativeCodeCall<"Foo($_self, $0, $1, $2, $3)"> I8Attr:$a, I8:$b,
834  // $c, I8Prop:$d),
835  if (leaf.isUnspecified()) {
836  // This is case of $c, a Value without any constraints.
837  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
838  } else if (leaf.isPropMatcher()) {
839  // This is case of $d, a binding to a certain property.
840  auto propConstraint = leaf.getAsPropConstraint();
841  if (propConstraint.getInterfaceType().empty()) {
842  PrintFatalError(&def,
843  formatv("binding symbol '{0}' in NativeCodeCall to "
844  "a property constraint without specifying "
845  "that constraint's type is unsupported",
846  treeArgName));
847  }
848  verifyBind(infoMap.bindProp(treeArgName, propConstraint),
849  treeArgName);
850  } else {
851  auto constraint = leaf.getAsConstraint();
852  bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
853  leaf.isConstantAttr() ||
854  constraint.getKind() == Constraint::Kind::CK_Attr;
855 
856  if (isAttr) {
857  // This is case of $a, a binding to a certain attribute.
858  verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
859  continue;
860  }
861 
862  // This is case of $b, a binding to a certain type.
863  verifyBind(infoMap.bindValue(treeArgName), treeArgName);
864  }
865  }
866  }
867 
868  return;
869  }
870 
871  if (tree.isOperation()) {
872  auto &op = getDialectOp(tree);
873  auto numOpArgs = op.getNumArgs();
874  int numEither = 0;
875 
876  // We need to exclude the trailing directives and `either` directive groups
877  // two operands of the operation.
878  int numDirectives = 0;
879  for (int i = numTreeArgs - 1; i >= 0; --i) {
880  if (auto dagArg = tree.getArgAsNestedDag(i)) {
881  if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
882  ++numDirectives;
883  else if (dagArg.isEither())
884  ++numEither;
885  }
886  }
887 
888  if (numOpArgs != numTreeArgs - numDirectives + numEither) {
889  auto err =
890  formatv("op '{0}' argument number mismatch: "
891  "{1} in pattern vs. {2} in definition",
892  op.getOperationName(), numTreeArgs + numEither, numOpArgs);
893  PrintFatalError(&def, err);
894  }
895 
896  // The name attached to the DAG node's operator is for representing the
897  // results generated from this op. It should be remembered as bound results.
898  if (!treeName.empty()) {
899  LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
900  << '\n');
901  verifyBind(infoMap.bindOpResult(treeName, op), treeName);
902  }
903 
904  // The operand in `either` DAG should be bound to the operation in the
905  // parent DagNode.
906  auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
907  int opArgIdx) {
908  for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
909  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
910  collectBoundSymbols(subTree, infoMap, isSrcPattern);
911  } else {
912  auto argName = tree.getArgName(i);
913  if (!argName.empty() && argName != "_") {
914  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
915  argName);
916  }
917  }
918  }
919  };
920 
921  // The operand in `variadic` DAG should be bound to the operation in the
922  // parent DagNode. The range index must be included as well to distinguish
923  // (potentially) repeating argName within the `variadic` DAG.
924  auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
925  int opArgIdx) {
926  auto treeName = tree.getSymbol();
927  if (!treeName.empty()) {
928  // If treeName is specified, bind to the full variadic operand_range.
929  verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
930  std::nullopt),
931  treeName);
932  }
933 
934  for (int i = 0; i < tree.getNumArgs(); ++i) {
935  if (DagNode subTree = tree.getArgAsNestedDag(i)) {
936  collectBoundSymbols(subTree, infoMap, isSrcPattern);
937  } else {
938  auto argName = tree.getArgName(i);
939  if (!argName.empty() && argName != "_") {
940  verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
941  /*variadicSubIndex=*/i),
942  argName);
943  }
944  }
945  }
946  };
947 
948  for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
949  if (auto treeArg = tree.getArgAsNestedDag(i)) {
950  if (treeArg.isEither()) {
951  collectSymbolInEither(tree, treeArg, opArgIdx);
952  // `either` DAG is *flattened*. For example,
953  //
954  // (FooOp (either arg0, arg1), arg2)
955  //
956  // can be viewed as:
957  //
958  // (FooOp arg0, arg1, arg2)
959  ++opArgIdx;
960  } else if (treeArg.isVariadic()) {
961  collectSymbolInVariadic(tree, treeArg, opArgIdx);
962  } else {
963  // This DAG node argument is a DAG node itself. Go inside recursively.
964  collectBoundSymbols(treeArg, infoMap, isSrcPattern);
965  }
966  continue;
967  }
968 
969  if (isSrcPattern) {
970  // We can only bind symbols to op arguments in source pattern. Those
971  // symbols are referenced in result patterns.
972  auto treeArgName = tree.getArgName(i);
973  // `$_` is a special symbol meaning ignore the current argument.
974  if (!treeArgName.empty() && treeArgName != "_") {
975  LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
976  << treeArgName << '\n');
977  verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
978  treeArgName);
979  }
980  }
981  }
982  return;
983  }
984 
985  if (!treeName.empty()) {
986  PrintFatalError(
987  &def, formatv("binding symbol '{0}' to non-operation/native code call "
988  "unsupported right now",
989  treeName));
990  }
991 }
union mlir::linalg::@1223::ArityGroupAndKind::Kind kind
std::string getConditionTemplate() const
Definition: Constraint.cpp:53
Constraint getAsConstraint() const
Definition: Pattern.cpp:76
bool isNativeCodeCall() const
Definition: Pattern.cpp:64
bool isPropMatcher() const
Definition: Pattern.cpp:54
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:116
ConstantAttr getAsConstantAttr() const
Definition: Pattern.cpp:92
void print(raw_ostream &os) const
Definition: Pattern.cpp:131
std::string getStringAttr() const
Definition: Pattern.cpp:121
Property getAsProperty() const
Definition: Pattern.cpp:87
bool isEnumCase() const
Definition: Pattern.cpp:70
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:111
std::string getConditionTemplate() const
Definition: Pattern.cpp:107
bool isConstantProp() const
Definition: Pattern.cpp:72
PropConstraint getAsPropConstraint() const
Definition: Pattern.cpp:82
ConstantProp getAsConstantProp() const
Definition: Pattern.cpp:102
bool isUnspecified() const
Definition: Pattern.cpp:40
EnumCase getAsEnumCase() const
Definition: Pattern.cpp:97
bool isAttrMatcher() const
Definition: Pattern.cpp:49
bool isOperandMatcher() const
Definition: Pattern.cpp:44
bool isPropDefinition() const
Definition: Pattern.cpp:59
bool isConstantAttr() const
Definition: Pattern.cpp:68
bool isStringAttr() const
Definition: Pattern.cpp:74
bool isReturnTypeDirective() const
Definition: Pattern.cpp:216
bool isLocationDirective() const
Definition: Pattern.cpp:211
bool isReplaceWithValue() const
Definition: Pattern.cpp:206
DagNode getArgAsNestedDag(unsigned index) const
Definition: Pattern.cpp:193
bool isOperation() const
Definition: Pattern.cpp:146
DagLeaf getArgAsLeaf(unsigned index) const
Definition: Pattern.cpp:197
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:159
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:152
void print(raw_ostream &os) const
Definition: Pattern.cpp:231
int getNumOps() const
Definition: Pattern.cpp:176
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition: Pattern.cpp:168
bool isVariadic() const
Definition: Pattern.cpp:226
bool isNativeCodeCall() const
Definition: Pattern.cpp:140
bool isEither() const
Definition: Pattern.cpp:221
bool isNestedDagArg(unsigned index) const
Definition: Pattern.cpp:189
StringRef getSymbol() const
Definition: Pattern.cpp:166
int getNumArgs() const
Definition: Pattern.cpp:187
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:165
StringRef getArgName(unsigned index) const
Definition: Pattern.cpp:202
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition: Operator.cpp:164
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition: Operator.cpp:183
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition: Operator.cpp:357
int getBenefit() const
int getNumResultPatterns() const
Definition: Pattern.cpp:687
std::vector< IdentifierLine > getLocation() const
Definition: Pattern.cpp:774
DagNode getSourcePattern() const
Definition: Pattern.cpp:683
const Operator & getSourceRootOp()
Definition: Pattern.cpp:716
std::vector< AppliedConstraint > getConstraints() const
Definition: Pattern.cpp:724
DagNode getResultPattern(unsigned index) const
Definition: Pattern.cpp:692
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition: Pattern.cpp:794
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition: Pattern.cpp:720
DagNode getSupplementalPattern(unsigned index) const
Definition: Pattern.cpp:757
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:697
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:707
int getNumSupplementalPatterns() const
Definition: Pattern.cpp:752
std::string getArgDecl(StringRef name) const
Definition: Pattern.cpp:327
std::string getVarName(StringRef name) const
Definition: Pattern.cpp:274
std::string getVarTypeStr(StringRef name) const
Definition: Pattern.cpp:278
std::string getVarDecl(StringRef name) const
Definition: Pattern.cpp:320
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition: Pattern.cpp:240
int count(StringRef key) const
Definition: Pattern.cpp:600
const_iterator find(StringRef key) const
Definition: Pattern.cpp:566
bool bindMultipleValues(StringRef symbol, int numValues)
Definition: Pattern.cpp:543
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition: Pattern.cpp:489
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:631
bool bindValues(StringRef symbol, int numValues=1)
Definition: Pattern.cpp:531
bool bindAttr(StringRef symbol)
Definition: Pattern.cpp:550
bool bindProp(StringRef symbol, const PropConstraint &constraint)
Definition: Pattern.cpp:555
bool bindValue(StringRef symbol)
Definition: Pattern.cpp:538
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition: Pattern.cpp:573
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition: Pattern.cpp:594
int getStaticValueCount(StringRef symbol) const
Definition: Pattern.cpp:605
bool contains(StringRef symbol) const
Definition: Pattern.cpp:562
BaseT::const_iterator const_iterator
Definition: Pattern.h:504
bool bindOpResult(StringRef symbol, const Operator &op)
Definition: Pattern.cpp:524
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:616
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:4092
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.