MLIR 22.0.0git
Pattern.cpp
Go to the documentation of this file.
1//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pattern wrapper class to simplify using TableGen Record defining a MLIR
10// Pattern.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/TableGen/Error.h"
22#include "llvm/TableGen/Record.h"
23
24#define DEBUG_TYPE "mlir-tblgen-pattern"
25
26using namespace mlir;
27using namespace tblgen;
28
29using llvm::DagInit;
30using llvm::dbgs;
31using llvm::DefInit;
32using llvm::formatv;
33using llvm::IntInit;
34using llvm::Record;
35
36//===----------------------------------------------------------------------===//
37// DagLeaf
38//===----------------------------------------------------------------------===//
39
41 return isa_and_nonnull<llvm::UnsetInit>(def);
42}
43
45 // Operand matchers specify a type constraint.
46 return isSubClassOf("TypeConstraint");
47}
48
50 // Attribute matchers specify an attribute constraint.
51 return isSubClassOf("AttrConstraint");
52}
53
55 // Property matchers specify a property constraint.
56 return isSubClassOf("PropConstraint");
57}
58
60 // Property matchers specify a property definition.
61 return isSubClassOf("Property");
62}
63
65 return isSubClassOf("NativeCodeCall");
66}
67
68bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
69
70bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
71
72bool DagLeaf::isConstantProp() const { return isSubClassOf("ConstantProp"); }
73
74bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
75
77 assert((isOperandMatcher() || isAttrMatcher() || isPropMatcher()) &&
78 "the DAG leaf must be operand, attribute, or property");
79 return Constraint(cast<DefInit>(def)->getDef());
80}
81
83 assert(isPropMatcher() && "the DAG leaf must be a property matcher");
84 return PropConstraint(cast<DefInit>(def)->getDef());
85}
86
88 assert(isPropDefinition() && "the DAG leaf must be a property definition");
89 return Property(cast<DefInit>(def)->getDef());
90}
91
93 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
94 return ConstantAttr(cast<DefInit>(def));
95}
96
98 assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
99 return EnumCase(cast<DefInit>(def));
100}
101
103 assert(isConstantProp() && "the DAG leaf must be a constant property value");
104 return ConstantProp(cast<DefInit>(def));
105}
106
107std::string DagLeaf::getConditionTemplate() const {
109}
110
112 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
113 return cast<DefInit>(def)->getDef()->getValueAsString("expression");
114}
115
117 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
118 return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
119}
120
121std::string DagLeaf::getStringAttr() const {
122 assert(isStringAttr() && "the DAG leaf must be string attribute");
123 return def->getAsUnquotedString();
124}
125bool DagLeaf::isSubClassOf(StringRef superclass) const {
126 if (auto *defInit = dyn_cast_or_null<DefInit>(def))
127 return defInit->getDef()->isSubClassOf(superclass);
128 return false;
129}
130
132 if (def)
133 def->print(os);
134}
135
136//===----------------------------------------------------------------------===//
137// DagNode
138//===----------------------------------------------------------------------===//
139
141 if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
142 return defInit->getDef()->isSubClassOf("NativeCodeCall");
143 return false;
144}
145
147 return !isNativeCodeCall() && !isReplaceWithValue() &&
149 !isVariadic();
150}
151
153 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
154 return cast<DefInit>(node->getOperator())
155 ->getDef()
156 ->getValueAsString("expression");
157}
158
160 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
161 return cast<DefInit>(node->getOperator())
162 ->getDef()
163 ->getValueAsInt("numReturns");
164}
165
166StringRef DagNode::getSymbol() const { return node->getNameStr(); }
167
169 const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
170 auto [it, inserted] = mapper->try_emplace(opDef);
171 if (inserted)
172 it->second = std::make_unique<Operator>(opDef);
173 return *it->second;
174}
175
177 // We want to get number of operations recursively involved in the DAG tree.
178 // All other directives should be excluded.
179 int count = isOperation() ? 1 : 0;
180 for (int i = 0, e = getNumArgs(); i != e; ++i) {
181 if (auto child = getArgAsNestedDag(i))
182 count += child.getNumOps();
183 }
184 return count;
185}
186
187int DagNode::getNumArgs() const { return node->getNumArgs(); }
188
189bool DagNode::isNestedDagArg(unsigned index) const {
190 return isa<DagInit>(node->getArg(index));
191}
192
194 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
195}
196
198 assert(!isNestedDagArg(index));
199 return DagLeaf(node->getArg(index));
200}
201
202StringRef DagNode::getArgName(unsigned index) const {
203 return node->getArgNameStr(index);
204}
205
207 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
208 return dagOpDef->getName() == "replaceWithValue";
209}
210
212 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
213 return dagOpDef->getName() == "location";
214}
215
217 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
218 return dagOpDef->getName() == "returnType";
219}
220
221bool DagNode::isEither() const {
222 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
223 return dagOpDef->getName() == "either";
224}
225
227 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
228 return dagOpDef->getName() == "variadic";
229}
230
232 if (node)
233 node->print(os);
234}
235
236//===----------------------------------------------------------------------===//
237// SymbolInfoMap
238//===----------------------------------------------------------------------===//
239
240StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
241 int idx = -1;
242 auto [name, indexStr] = symbol.rsplit("__");
243
244 if (indexStr.consumeInteger(10, idx)) {
245 // The second part is not an index; we return the whole symbol as-is.
246 return symbol;
247 }
248 if (index) {
249 *index = idx;
250 }
251 return name;
252}
253
254SymbolInfoMap::SymbolInfo::SymbolInfo(
255 const Operator *op, SymbolInfo::Kind kind,
256 std::optional<DagAndConstant> dagAndConstant)
257 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
258
259int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
260 switch (kind) {
261 case Kind::Attr:
262 case Kind::Prop:
263 case Kind::Operand:
264 case Kind::Value:
265 return 1;
266 case Kind::Result:
267 return op->getNumResults();
268 case Kind::MultipleValues:
269 return getSize();
270 }
271 llvm_unreachable("unknown kind");
272}
273
274std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
275 return alternativeName ? *alternativeName : name.str();
276}
277
278std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
279 LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
280 switch (kind) {
281 case Kind::Attr: {
282 if (op)
283 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
284 ->attr.getStorageType()
285 .str();
286 // TODO(suderman): Use a more exact type when available.
287 return "::mlir::Attribute";
288 }
289 case Kind::Prop: {
290 if (op)
291 return cast<NamedProperty *>(op->getArg(getArgIndex()))
292 ->prop.getInterfaceType()
293 .str();
294 assert(dagAndConstant && dagAndConstant->dag &&
295 "generic properties must carry their constraint");
296 return reinterpret_cast<const DagLeaf *>(dagAndConstant->dag)
297 ->getAsPropConstraint()
298 .getInterfaceType()
299 .str();
300 }
301 case Kind::Operand: {
302 // Use operand range for captured operands (to support potential variadic
303 // operands).
304 return "::mlir::Operation::operand_range";
305 }
306 case Kind::Value: {
307 return "::mlir::Value";
308 }
309 case Kind::MultipleValues: {
310 return "::mlir::ValueRange";
311 }
312 case Kind::Result: {
313 // Use the op itself for captured results.
314 return op->getQualCppClassName();
315 }
316 }
317 llvm_unreachable("unknown kind");
318}
319
320std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
321 LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
322 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
323 return std::string(
324 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
325}
326
327std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
328 LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
329 return std::string(
330 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
331}
332
333std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
334 StringRef name, int index, const char *fmt, const char *separator) const {
335 LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
336 switch (kind) {
337 case Kind::Attr: {
338 assert(index < 0);
339 auto repl = formatv(fmt, name);
340 LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
341 return std::string(repl);
342 }
343 case Kind::Prop: {
344 assert(index < 0);
345 auto repl = formatv(fmt, name);
346 LLVM_DEBUG(dbgs() << repl << " (Prop)\n");
347 return std::string(repl);
348 }
349 case Kind::Operand: {
350 assert(index < 0);
351 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
352 if (operand->isOptional()) {
353 auto repl = formatv(
354 fmt, formatv("({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
355 LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n");
356 return std::string(repl);
357 }
358 // If this operand is variadic and this SymbolInfo doesn't have a range
359 // index, then return the full variadic operand_range. Otherwise, return
360 // the value itself.
361 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
362 auto repl = formatv(fmt, name);
363 LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
364 return std::string(repl);
365 }
366 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
367 LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
368 return std::string(repl);
369 }
370 case Kind::Result: {
371 // If `index` is greater than zero, then we are referencing a specific
372 // result of a multi-result op. The result can still be variadic.
373 if (index >= 0) {
374 std::string v =
375 std::string(formatv("{0}.getODSResults({1})", name, index));
376 if (!op->getResult(index).isVariadic())
377 v = std::string(formatv("(*{0}.begin())", v));
378 auto repl = formatv(fmt, v);
379 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
380 return std::string(repl);
381 }
382
383 // If this op has no result at all but still we bind a symbol to it, it
384 // means we want to capture the op itself.
385 if (op->getNumResults() == 0) {
386 LLVM_DEBUG(dbgs() << name << " (Op)\n");
387 return formatv(fmt, name);
388 }
389
390 // We are referencing all results of the multi-result op. A specific result
391 // can either be a value or a range. Then join them with `separator`.
392 SmallVector<std::string, 4> values;
393 values.reserve(op->getNumResults());
394
395 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
396 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
397 if (!op->getResult(i).isVariadic()) {
398 v = std::string(formatv("(*{0}.begin())", v));
399 }
400 values.push_back(std::string(formatv(fmt, v)));
401 }
402 auto repl = llvm::join(values, separator);
403 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
404 return repl;
405 }
406 case Kind::Value: {
407 assert(index < 0);
408 assert(op == nullptr);
409 auto repl = formatv(fmt, name);
410 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
411 return std::string(repl);
412 }
413 case Kind::MultipleValues: {
414 assert(op == nullptr);
415 assert(index < getSize());
416 if (index >= 0) {
417 std::string repl =
418 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
419 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
420 return repl;
421 }
422 // If it doesn't specify certain element, unpack them all.
423 auto repl =
424 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
425 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
426 return std::string(repl);
427 }
428 }
429 llvm_unreachable("unknown kind");
430}
431
432std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
433 StringRef name, int index, const char *fmt, const char *separator) const {
434 LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
435 switch (kind) {
436 case Kind::Attr:
437 case Kind::Prop:
438 case Kind::Operand: {
439 assert(index < 0 && "only allowed for symbol bound to result");
440 auto repl = formatv(fmt, name);
441 LLVM_DEBUG(dbgs() << repl << " (Operand/Attr/Prop)\n");
442 return std::string(repl);
443 }
444 case Kind::Result: {
445 if (index >= 0) {
446 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
447 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
448 return std::string(repl);
449 }
450
451 // We are referencing all results of the multi-result op. Each result should
452 // have a value range, and then join them with `separator`.
453 SmallVector<std::string, 4> values;
454 values.reserve(op->getNumResults());
455
456 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
457 values.push_back(std::string(
458 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
459 }
460 auto repl = llvm::join(values, separator);
461 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
462 return repl;
463 }
464 case Kind::Value: {
465 assert(index < 0 && "only allowed for symbol bound to result");
466 assert(op == nullptr);
467 auto repl = formatv(fmt, formatv("{{{0}}", name));
468 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
469 return std::string(repl);
470 }
471 case Kind::MultipleValues: {
472 assert(op == nullptr);
473 assert(index < getSize());
474 if (index >= 0) {
475 std::string repl =
476 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
477 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
478 return repl;
479 }
480 auto repl =
481 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
482 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
483 return std::string(repl);
484 }
485 }
486 llvm_unreachable("unknown kind");
487}
488
489bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
490 const Operator &op, int argIndex,
491 std::optional<int> variadicSubIndex) {
492 StringRef name = getValuePackName(symbol);
493 if (name != symbol) {
494 auto error = formatv(
495 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
496 PrintFatalError(loc, error);
497 }
498
499 Argument arg = op.getArg(argIndex);
500 SymbolInfo symInfo =
501 isa<NamedAttribute *>(arg) ? SymbolInfo::getAttr(&op, argIndex)
502 : isa<NamedProperty *>(arg)
503 ? SymbolInfo::getProp(&op, argIndex)
504 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
505
506 std::string key = symbol.str();
507 if (symbolInfoMap.count(key)) {
508 // Only non unique name for the operand is supported.
509 if (symInfo.kind != SymbolInfo::Kind::Operand) {
510 return false;
511 }
512
513 // Cannot add new operand if there is already non operand with the same
514 // name.
515 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
516 return false;
517 }
518 }
519
520 symbolInfoMap.emplace(key, symInfo);
521 return true;
522}
523
524bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
525 std::string name = getValuePackName(symbol).str();
526 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
527
528 return symbolInfoMap.count(inserted->first) == 1;
529}
530
531bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
532 std::string name = getValuePackName(symbol).str();
533 if (numValues > 1)
534 return bindMultipleValues(name, numValues);
535 return bindValue(name);
536}
537
538bool SymbolInfoMap::bindValue(StringRef symbol) {
539 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
540 return symbolInfoMap.count(inserted->first) == 1;
541}
542
543bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
544 std::string name = getValuePackName(symbol).str();
545 auto inserted =
546 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
547 return symbolInfoMap.count(inserted->first) == 1;
548}
549
550bool SymbolInfoMap::bindAttr(StringRef symbol) {
551 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
552 return symbolInfoMap.count(inserted->first) == 1;
553}
554
555bool SymbolInfoMap::bindProp(StringRef symbol,
556 const PropConstraint &constraint) {
557 auto inserted =
558 symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint));
559 return symbolInfoMap.count(inserted->first) == 1;
560}
561
562bool SymbolInfoMap::contains(StringRef symbol) const {
563 return find(symbol) != symbolInfoMap.end();
564}
565
567 std::string name = getValuePackName(key).str();
568
569 return symbolInfoMap.find(name);
570}
571
573SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
574 int argIndex,
575 std::optional<int> variadicSubIndex) const {
576 return findBoundSymbol(
577 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
578}
579
582 const SymbolInfo &symbolInfo) const {
583 std::string name = getValuePackName(key).str();
584 auto range = symbolInfoMap.equal_range(name);
585
586 for (auto it = range.first; it != range.second; ++it)
587 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
588 return it;
589
590 return symbolInfoMap.end();
591}
592
593std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
595 std::string name = getValuePackName(key).str();
596
597 return symbolInfoMap.equal_range(name);
598}
599
600int SymbolInfoMap::count(StringRef key) const {
601 std::string name = getValuePackName(key).str();
602 return symbolInfoMap.count(name);
603}
604
605int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
606 StringRef name = getValuePackName(symbol);
607 if (name != symbol) {
608 // If there is a trailing index inside symbol, it references just one
609 // static value.
610 return 1;
611 }
612 // Otherwise, find how many it represents by querying the symbol's info.
613 return find(name)->second.getStaticValueCount();
614}
615
616std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
617 const char *fmt,
618 const char *separator) const {
619 int index = -1;
620 StringRef name = getValuePackName(symbol, &index);
621
622 auto it = symbolInfoMap.find(name.str());
623 if (it == symbolInfoMap.end()) {
624 auto error = formatv("referencing unbound symbol '{0}'", symbol);
625 PrintFatalError(loc, error);
626 }
627
628 return it->second.getValueAndRangeUse(name, index, fmt, separator);
629}
630
631std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
632 const char *separator) const {
633 int index = -1;
634 StringRef name = getValuePackName(symbol, &index);
635
636 auto it = symbolInfoMap.find(name.str());
637 if (it == symbolInfoMap.end()) {
638 auto error = formatv("referencing unbound symbol '{0}'", symbol);
639 PrintFatalError(loc, error);
640 }
641
642 return it->second.getAllRangeUse(name, index, fmt, separator);
643}
644
646 llvm::StringSet<> usedNames;
647
648 for (auto symbolInfoIt = symbolInfoMap.begin();
649 symbolInfoIt != symbolInfoMap.end();) {
650 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
651 auto startRange = range.first;
652 auto endRange = range.second;
653
654 auto operandName = symbolInfoIt->first;
655 int startSearchIndex = 0;
656 for (++startRange; startRange != endRange; ++startRange) {
657 // Current operand name is not unique, find a unique one
658 // and set the alternative name.
659 for (int i = startSearchIndex;; ++i) {
660 std::string alternativeName = operandName + std::to_string(i);
661 if (!usedNames.contains(alternativeName) &&
662 symbolInfoMap.count(alternativeName) == 0) {
663 usedNames.insert(alternativeName);
664 startRange->second.alternativeName = alternativeName;
665 startSearchIndex = i + 1;
666
667 break;
668 }
669 }
670 }
671
672 symbolInfoIt = endRange;
673 }
674}
675
676//===----------------------------------------------------------------------===//
677// Pattern
678//==----------------------------------------------------------------------===//
679
680Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
681 : def(*def), recordOpMap(mapper) {}
682
684 return DagNode(def.getValueAsDag("sourcePattern"));
685}
686
688 auto *results = def.getValueAsListInit("resultPatterns");
689 return results->size();
690}
691
693 auto *results = def.getValueAsListInit("resultPatterns");
694 return DagNode(cast<DagInit>(results->getElement(index)));
695}
696
698 LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
699 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
700 LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
701
702 LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
704 LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
705}
706
708 LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
709 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
710 auto pattern = getResultPattern(i);
711 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
712 }
713 LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
714}
715
717 return getSourcePattern().getDialectOp(recordOpMap);
718}
719
721 return node.getDialectOp(recordOpMap);
722}
723
724std::vector<AppliedConstraint> Pattern::getConstraints() const {
725 auto *listInit = def.getValueAsListInit("constraints");
726 std::vector<AppliedConstraint> ret;
727 ret.reserve(listInit->size());
728
729 for (auto *it : *listInit) {
730 auto *dagInit = dyn_cast<DagInit>(it);
731 if (!dagInit)
732 PrintFatalError(&def, "all elements in Pattern multi-entity "
733 "constraints should be DAG nodes");
734
735 std::vector<std::string> entities;
736 entities.reserve(dagInit->arg_size());
737 for (auto *argName : dagInit->getArgNames()) {
738 if (!argName) {
739 PrintFatalError(
740 &def,
741 "operands to additional constraints can only be symbol references");
742 }
743 entities.emplace_back(argName->getValue());
744 }
745
746 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
747 dagInit->getNameStr(), std::move(entities));
748 }
749 return ret;
750}
751
753 auto *results = def.getValueAsListInit("supplementalPatterns");
754 return results->size();
755}
756
758 auto *results = def.getValueAsListInit("supplementalPatterns");
759 return DagNode(cast<DagInit>(results->getElement(index)));
760}
761
762int Pattern::getBenefit() const {
763 // The initial benefit value is a heuristic with number of ops in the source
764 // pattern.
765 int initBenefit = getSourcePattern().getNumOps();
766 const DagInit *delta = def.getValueAsDag("benefitDelta");
767 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
768 PrintFatalError(&def,
769 "The 'addBenefit' takes and only takes one integer value");
770 }
771 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
772}
773
774std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
775 std::vector<std::pair<StringRef, unsigned>> result;
776 result.reserve(def.getLoc().size());
777 for (auto loc : def.getLoc()) {
778 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
779 assert(buf && "invalid source location");
780 result.emplace_back(
781 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
782 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
783 }
784 return result;
785}
786
787void Pattern::verifyBind(bool result, StringRef symbolName) {
788 if (!result) {
789 auto err = formatv("symbol '{0}' bound more than once", symbolName);
790 PrintFatalError(&def, err);
791 }
792}
793
795 bool isSrcPattern) {
796 auto treeName = tree.getSymbol();
797 auto numTreeArgs = tree.getNumArgs();
798
799 if (tree.isNativeCodeCall()) {
800 if (!treeName.empty()) {
801 if (!isSrcPattern) {
802 LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
803 << treeName << '\n');
804 verifyBind(
805 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
806 treeName);
807 } else {
808 PrintFatalError(&def,
809 formatv("binding symbol '{0}' to NativecodeCall in "
810 "MatchPattern is not supported",
811 treeName));
812 }
813 }
814
815 for (int i = 0; i != numTreeArgs; ++i) {
816 if (auto treeArg = tree.getArgAsNestedDag(i)) {
817 // This DAG node argument is a DAG node itself. Go inside recursively.
818 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
819 continue;
820 }
821
822 if (!isSrcPattern)
823 continue;
824
825 // We can only bind symbols to arguments in source pattern. Those
826 // symbols are referenced in result patterns.
827 auto treeArgName = tree.getArgName(i);
828
829 // `$_` is a special symbol meaning ignore the current argument.
830 if (!treeArgName.empty() && treeArgName != "_") {
831 DagLeaf leaf = tree.getArgAsLeaf(i);
832
833 // In (NativeCodeCall<"Foo($_self, $0, $1, $2, $3)"> I8Attr:$a, I8:$b,
834 // $c, I8Prop:$d),
835 if (leaf.isUnspecified()) {
836 // This is case of $c, a Value without any constraints.
837 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
838 } else if (leaf.isPropMatcher()) {
839 // This is case of $d, a binding to a certain property.
840 auto propConstraint = leaf.getAsPropConstraint();
841 if (propConstraint.getInterfaceType().empty()) {
842 PrintFatalError(&def,
843 formatv("binding symbol '{0}' in NativeCodeCall to "
844 "a property constraint without specifying "
845 "that constraint's type is unsupported",
846 treeArgName));
847 }
848 verifyBind(infoMap.bindProp(treeArgName, propConstraint),
849 treeArgName);
850 } else {
851 auto constraint = leaf.getAsConstraint();
852 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
853 leaf.isConstantAttr() ||
854 constraint.getKind() == Constraint::Kind::CK_Attr;
855
856 if (isAttr) {
857 // This is case of $a, a binding to a certain attribute.
858 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
859 continue;
860 }
861
862 // This is case of $b, a binding to a certain type.
863 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
864 }
865 }
866 }
867
868 return;
869 }
870
871 if (tree.isOperation()) {
872 auto &op = getDialectOp(tree);
873 auto numOpArgs = op.getNumArgs();
874 int numEither = 0;
875
876 // We need to exclude the trailing directives and `either` directive groups
877 // two operands of the operation.
878 int numDirectives = 0;
879 for (int i = numTreeArgs - 1; i >= 0; --i) {
880 if (auto dagArg = tree.getArgAsNestedDag(i)) {
881 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
882 ++numDirectives;
883 else if (dagArg.isEither())
884 ++numEither;
885 }
886 }
887
888 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
889 auto err =
890 formatv("op '{0}' argument number mismatch: "
891 "{1} in pattern vs. {2} in definition",
892 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
893 PrintFatalError(&def, err);
894 }
895
896 // The name attached to the DAG node's operator is for representing the
897 // results generated from this op. It should be remembered as bound results.
898 if (!treeName.empty()) {
899 LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
900 << '\n');
901 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
902 }
903
904 // The operand in `either` DAG should be bound to the operation in the
905 // parent DagNode.
906 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
907 int opArgIdx) {
908 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
909 if (DagNode subTree = tree.getArgAsNestedDag(i)) {
910 collectBoundSymbols(subTree, infoMap, isSrcPattern);
911 } else {
912 auto argName = tree.getArgName(i);
913 if (!argName.empty() && argName != "_") {
914 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
915 argName);
916 }
917 }
918 }
919 };
920
921 // The operand in `variadic` DAG should be bound to the operation in the
922 // parent DagNode. The range index must be included as well to distinguish
923 // (potentially) repeating argName within the `variadic` DAG.
924 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
925 int opArgIdx) {
926 auto treeName = tree.getSymbol();
927 if (!treeName.empty()) {
928 // If treeName is specified, bind to the full variadic operand_range.
929 verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
930 std::nullopt),
931 treeName);
932 }
933
934 for (int i = 0; i < tree.getNumArgs(); ++i) {
935 if (DagNode subTree = tree.getArgAsNestedDag(i)) {
936 collectBoundSymbols(subTree, infoMap, isSrcPattern);
937 } else {
938 auto argName = tree.getArgName(i);
939 if (!argName.empty() && argName != "_") {
940 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
941 /*variadicSubIndex=*/i),
942 argName);
943 }
944 }
945 }
946 };
947
948 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
949 if (auto treeArg = tree.getArgAsNestedDag(i)) {
950 if (treeArg.isEither()) {
951 collectSymbolInEither(tree, treeArg, opArgIdx);
952 // `either` DAG is *flattened*. For example,
953 //
954 // (FooOp (either arg0, arg1), arg2)
955 //
956 // can be viewed as:
957 //
958 // (FooOp arg0, arg1, arg2)
959 ++opArgIdx;
960 } else if (treeArg.isVariadic()) {
961 collectSymbolInVariadic(tree, treeArg, opArgIdx);
962 } else {
963 // This DAG node argument is a DAG node itself. Go inside recursively.
964 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
965 }
966 continue;
967 }
968
969 if (isSrcPattern) {
970 // We can only bind symbols to op arguments in source pattern. Those
971 // symbols are referenced in result patterns.
972 auto treeArgName = tree.getArgName(i);
973 // `$_` is a special symbol meaning ignore the current argument.
974 if (!treeArgName.empty() && treeArgName != "_") {
975 LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
976 << treeArgName << '\n');
977 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
978 treeArgName);
979 }
980 }
981 }
982 return;
983 }
984
985 if (!treeName.empty()) {
986 PrintFatalError(
987 &def, formatv("binding symbol '{0}' to non-operation/native code call "
988 "unsupported right now",
989 treeName));
990 }
991}
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
std::string getConditionTemplate() const
Constraint getAsConstraint() const
Definition Pattern.cpp:76
bool isNativeCodeCall() const
Definition Pattern.cpp:64
bool isPropMatcher() const
Definition Pattern.cpp:54
int getNumReturnsOfNativeCode() const
Definition Pattern.cpp:116
ConstantAttr getAsConstantAttr() const
Definition Pattern.cpp:92
void print(raw_ostream &os) const
Definition Pattern.cpp:131
std::string getStringAttr() const
Definition Pattern.cpp:121
Property getAsProperty() const
Definition Pattern.cpp:87
bool isEnumCase() const
Definition Pattern.cpp:70
StringRef getNativeCodeTemplate() const
Definition Pattern.cpp:111
std::string getConditionTemplate() const
Definition Pattern.cpp:107
bool isConstantProp() const
Definition Pattern.cpp:72
PropConstraint getAsPropConstraint() const
Definition Pattern.cpp:82
ConstantProp getAsConstantProp() const
Definition Pattern.cpp:102
bool isUnspecified() const
Definition Pattern.cpp:40
EnumCase getAsEnumCase() const
Definition Pattern.cpp:97
bool isAttrMatcher() const
Definition Pattern.cpp:49
bool isOperandMatcher() const
Definition Pattern.cpp:44
bool isPropDefinition() const
Definition Pattern.cpp:59
bool isConstantAttr() const
Definition Pattern.cpp:68
bool isStringAttr() const
Definition Pattern.cpp:74
bool isReturnTypeDirective() const
Definition Pattern.cpp:216
bool isLocationDirective() const
Definition Pattern.cpp:211
bool isReplaceWithValue() const
Definition Pattern.cpp:206
DagNode getArgAsNestedDag(unsigned index) const
Definition Pattern.cpp:193
bool isOperation() const
Definition Pattern.cpp:146
DagLeaf getArgAsLeaf(unsigned index) const
Definition Pattern.cpp:197
int getNumReturnsOfNativeCode() const
Definition Pattern.cpp:159
StringRef getNativeCodeTemplate() const
Definition Pattern.cpp:152
void print(raw_ostream &os) const
Definition Pattern.cpp:231
int getNumOps() const
Definition Pattern.cpp:176
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition Pattern.cpp:168
bool isVariadic() const
Definition Pattern.cpp:226
bool isNativeCodeCall() const
Definition Pattern.cpp:140
bool isEither() const
Definition Pattern.cpp:221
bool isNestedDagArg(unsigned index) const
Definition Pattern.cpp:189
StringRef getSymbol() const
Definition Pattern.cpp:166
int getNumArgs() const
Definition Pattern.cpp:187
DagNode(const llvm::DagInit *node)
Definition Pattern.h:165
StringRef getArgName(unsigned index) const
Definition Pattern.cpp:202
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition Operator.cpp:164
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition Operator.cpp:183
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition Operator.cpp:357
int getBenefit() const
int getNumResultPatterns() const
Definition Pattern.cpp:687
std::vector< IdentifierLine > getLocation() const
Definition Pattern.cpp:774
DagNode getSourcePattern() const
Definition Pattern.cpp:683
const Operator & getSourceRootOp()
Definition Pattern.cpp:716
std::vector< AppliedConstraint > getConstraints() const
Definition Pattern.cpp:724
DagNode getResultPattern(unsigned index) const
Definition Pattern.cpp:692
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition Pattern.cpp:794
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition Pattern.cpp:720
DagNode getSupplementalPattern(unsigned index) const
Definition Pattern.cpp:757
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition Pattern.cpp:697
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition Pattern.cpp:707
int getNumSupplementalPatterns() const
Definition Pattern.cpp:752
std::string getArgDecl(StringRef name) const
Definition Pattern.cpp:327
std::string getVarName(StringRef name) const
Definition Pattern.cpp:274
std::string getVarTypeStr(StringRef name) const
Definition Pattern.cpp:278
std::string getVarDecl(StringRef name) const
Definition Pattern.cpp:320
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition Pattern.cpp:240
int count(StringRef key) const
Definition Pattern.cpp:600
const_iterator find(StringRef key) const
Definition Pattern.cpp:566
bool bindMultipleValues(StringRef symbol, int numValues)
Definition Pattern.cpp:543
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition Pattern.cpp:489
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition Pattern.cpp:631
bool bindValues(StringRef symbol, int numValues=1)
Definition Pattern.cpp:531
bool bindAttr(StringRef symbol)
Definition Pattern.cpp:550
bool bindProp(StringRef symbol, const PropConstraint &constraint)
Definition Pattern.cpp:555
bool bindValue(StringRef symbol)
Definition Pattern.cpp:538
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition Pattern.cpp:573
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition Pattern.cpp:594
int getStaticValueCount(StringRef symbol) const
Definition Pattern.cpp:605
bool contains(StringRef symbol) const
Definition Pattern.cpp:562
BaseT::const_iterator const_iterator
Definition Pattern.h:504
bool bindOpResult(StringRef symbol, const Operator &op)
Definition Pattern.cpp:524
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition Pattern.cpp:616
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition Argument.h:63
DenseMap< const llvm::Record *, std::unique_ptr< Operator > > RecordOperatorMap
Definition Pattern.h:44
Include the generated interface declarations.