MLIR 22.0.0git
Pattern.cpp
Go to the documentation of this file.
1//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pattern wrapper class to simplify using TableGen Record defining a MLIR
10// Pattern.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/Path.h"
22#include "llvm/TableGen/Error.h"
23#include "llvm/TableGen/Record.h"
24
25#define DEBUG_TYPE "mlir-tblgen-pattern"
26
27using namespace mlir;
28using namespace tblgen;
29
30using llvm::DagInit;
31using llvm::dbgs;
32using llvm::DefInit;
33using llvm::formatv;
34using llvm::IntInit;
35using llvm::Record;
36
37//===----------------------------------------------------------------------===//
38// DagLeaf
39//===----------------------------------------------------------------------===//
40
42 return isa_and_nonnull<llvm::UnsetInit>(def);
43}
44
46 // Operand matchers specify a type constraint.
47 return isSubClassOf("TypeConstraint");
48}
49
51 // Attribute matchers specify an attribute constraint.
52 return isSubClassOf("AttrConstraint");
53}
54
56 // Property matchers specify a property constraint.
57 return isSubClassOf("PropConstraint");
58}
59
61 // Property matchers specify a property definition.
62 return isSubClassOf("Property");
63}
64
66 return isSubClassOf("NativeCodeCall");
67}
68
69bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
70
71bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
72
73bool DagLeaf::isConstantProp() const { return isSubClassOf("ConstantProp"); }
74
75bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
76
78 assert((isOperandMatcher() || isAttrMatcher() || isPropMatcher()) &&
79 "the DAG leaf must be operand, attribute, or property");
80 return Constraint(cast<DefInit>(def)->getDef());
81}
82
84 assert(isPropMatcher() && "the DAG leaf must be a property matcher");
85 return PropConstraint(cast<DefInit>(def)->getDef());
86}
87
89 assert(isPropDefinition() && "the DAG leaf must be a property definition");
90 return Property(cast<DefInit>(def)->getDef());
91}
92
94 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
95 return ConstantAttr(cast<DefInit>(def));
96}
97
99 assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
100 return EnumCase(cast<DefInit>(def));
101}
102
104 assert(isConstantProp() && "the DAG leaf must be a constant property value");
105 return ConstantProp(cast<DefInit>(def));
106}
107
108std::string DagLeaf::getConditionTemplate() const {
110}
111
113 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
114 return cast<DefInit>(def)->getDef()->getValueAsString("expression");
115}
116
118 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
119 return cast<DefInit>(def)->getDef()->getValueAsInt("numReturns");
120}
121
122std::string DagLeaf::getStringAttr() const {
123 assert(isStringAttr() && "the DAG leaf must be string attribute");
124 return def->getAsUnquotedString();
125}
126bool DagLeaf::isSubClassOf(StringRef superclass) const {
127 if (auto *defInit = dyn_cast_or_null<DefInit>(def))
128 return defInit->getDef()->isSubClassOf(superclass);
129 return false;
130}
131
133 if (def)
134 def->print(os);
135}
136
137//===----------------------------------------------------------------------===//
138// DagNode
139//===----------------------------------------------------------------------===//
140
142 if (auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
143 return defInit->getDef()->isSubClassOf("NativeCodeCall");
144 return false;
145}
146
148 return !isNativeCodeCall() && !isReplaceWithValue() &&
150 !isVariadic();
151}
152
154 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
155 return cast<DefInit>(node->getOperator())
156 ->getDef()
157 ->getValueAsString("expression");
158}
159
161 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
162 return cast<DefInit>(node->getOperator())
163 ->getDef()
164 ->getValueAsInt("numReturns");
165}
166
167StringRef DagNode::getSymbol() const { return node->getNameStr(); }
168
170 const Record *opDef = cast<DefInit>(node->getOperator())->getDef();
171 auto [it, inserted] = mapper->try_emplace(opDef);
172 if (inserted)
173 it->second = std::make_unique<Operator>(opDef);
174 return *it->second;
175}
176
178 // We want to get number of operations recursively involved in the DAG tree.
179 // All other directives should be excluded.
180 int count = isOperation() ? 1 : 0;
181 for (int i = 0, e = getNumArgs(); i != e; ++i) {
182 if (auto child = getArgAsNestedDag(i))
183 count += child.getNumOps();
184 }
185 return count;
186}
187
188int DagNode::getNumArgs() const { return node->getNumArgs(); }
189
190bool DagNode::isNestedDagArg(unsigned index) const {
191 return isa<DagInit>(node->getArg(index));
192}
193
195 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
196}
197
199 assert(!isNestedDagArg(index));
200 return DagLeaf(node->getArg(index));
201}
202
203StringRef DagNode::getArgName(unsigned index) const {
204 return node->getArgNameStr(index);
205}
206
208 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
209 return dagOpDef->getName() == "replaceWithValue";
210}
211
213 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
214 return dagOpDef->getName() == "location";
215}
216
218 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
219 return dagOpDef->getName() == "returnType";
220}
221
222bool DagNode::isEither() const {
223 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
224 return dagOpDef->getName() == "either";
225}
226
228 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
229 return dagOpDef->getName() == "variadic";
230}
231
233 if (node)
234 node->print(os);
235}
236
237//===----------------------------------------------------------------------===//
238// SymbolInfoMap
239//===----------------------------------------------------------------------===//
240
241StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
242 int idx = -1;
243 auto [name, indexStr] = symbol.rsplit("__");
244
245 if (indexStr.consumeInteger(10, idx)) {
246 // The second part is not an index; we return the whole symbol as-is.
247 return symbol;
248 }
249 if (index) {
250 *index = idx;
251 }
252 return name;
253}
254
255SymbolInfoMap::SymbolInfo::SymbolInfo(
256 const Operator *op, SymbolInfo::Kind kind,
257 std::optional<DagAndConstant> dagAndConstant)
258 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
259
260int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
261 switch (kind) {
262 case Kind::Attr:
263 case Kind::Prop:
264 case Kind::Operand:
265 case Kind::Value:
266 return 1;
267 case Kind::Result:
268 return op->getNumResults();
269 case Kind::MultipleValues:
270 return getSize();
271 }
272 llvm_unreachable("unknown kind");
273}
274
275std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
276 return alternativeName ? *alternativeName : name.str();
277}
278
279std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
280 LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
281 switch (kind) {
282 case Kind::Attr: {
283 if (op)
284 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
285 ->attr.getStorageType()
286 .str();
287 // TODO(suderman): Use a more exact type when available.
288 return "::mlir::Attribute";
289 }
290 case Kind::Prop: {
291 if (op)
292 return cast<NamedProperty *>(op->getArg(getArgIndex()))
293 ->prop.getInterfaceType()
294 .str();
295 assert(dagAndConstant && dagAndConstant->dag &&
296 "generic properties must carry their constraint");
297 return reinterpret_cast<const DagLeaf *>(dagAndConstant->dag)
298 ->getAsPropConstraint()
299 .getInterfaceType()
300 .str();
301 }
302 case Kind::Operand: {
303 // Use operand range for captured operands (to support potential variadic
304 // operands).
305 return "::mlir::Operation::operand_range";
306 }
307 case Kind::Value: {
308 return "::mlir::Value";
309 }
310 case Kind::MultipleValues: {
311 return "::mlir::ValueRange";
312 }
313 case Kind::Result: {
314 // Use the op itself for captured results.
315 return op->getQualCppClassName();
316 }
317 }
318 llvm_unreachable("unknown kind");
319}
320
321std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
322 LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
323 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
324 return std::string(
325 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
326}
327
328std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
329 LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
330 return std::string(
331 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
332}
333
334std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
335 StringRef name, int index, const char *fmt, const char *separator) const {
336 LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
337 switch (kind) {
338 case Kind::Attr: {
339 assert(index < 0);
340 auto repl = formatv(fmt, name);
341 LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
342 return std::string(repl);
343 }
344 case Kind::Prop: {
345 assert(index < 0);
346 auto repl = formatv(fmt, name);
347 LLVM_DEBUG(dbgs() << repl << " (Prop)\n");
348 return std::string(repl);
349 }
350 case Kind::Operand: {
351 assert(index < 0);
352 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
353 if (operand->isOptional()) {
354 auto repl = formatv(
355 fmt, formatv("({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
356 LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n");
357 return std::string(repl);
358 }
359 // If this operand is variadic and this SymbolInfo doesn't have a range
360 // index, then return the full variadic operand_range. Otherwise, return
361 // the value itself.
362 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
363 auto repl = formatv(fmt, name);
364 LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
365 return std::string(repl);
366 }
367 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
368 LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
369 return std::string(repl);
370 }
371 case Kind::Result: {
372 // If `index` is greater than zero, then we are referencing a specific
373 // result of a multi-result op. The result can still be variadic.
374 if (index >= 0) {
375 std::string v =
376 std::string(formatv("{0}.getODSResults({1})", name, index));
377 if (!op->getResult(index).isVariadic())
378 v = std::string(formatv("(*{0}.begin())", v));
379 auto repl = formatv(fmt, v);
380 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
381 return std::string(repl);
382 }
383
384 // If this op has no result at all but still we bind a symbol to it, it
385 // means we want to capture the op itself.
386 if (op->getNumResults() == 0) {
387 LLVM_DEBUG(dbgs() << name << " (Op)\n");
388 return formatv(fmt, name);
389 }
390
391 // We are referencing all results of the multi-result op. A specific result
392 // can either be a value or a range. Then join them with `separator`.
393 SmallVector<std::string, 4> values;
394 values.reserve(op->getNumResults());
395
396 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
397 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
398 if (!op->getResult(i).isVariadic()) {
399 v = std::string(formatv("(*{0}.begin())", v));
400 }
401 values.push_back(std::string(formatv(fmt, v)));
402 }
403 auto repl = llvm::join(values, separator);
404 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
405 return repl;
406 }
407 case Kind::Value: {
408 assert(index < 0);
409 assert(op == nullptr);
410 auto repl = formatv(fmt, name);
411 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
412 return std::string(repl);
413 }
414 case Kind::MultipleValues: {
415 assert(op == nullptr);
416 assert(index < getSize());
417 if (index >= 0) {
418 std::string repl =
419 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
420 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
421 return repl;
422 }
423 // If it doesn't specify certain element, unpack them all.
424 auto repl =
425 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
426 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
427 return std::string(repl);
428 }
429 }
430 llvm_unreachable("unknown kind");
431}
432
433std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
434 StringRef name, int index, const char *fmt, const char *separator) const {
435 LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
436 switch (kind) {
437 case Kind::Attr:
438 case Kind::Prop:
439 case Kind::Operand: {
440 assert(index < 0 && "only allowed for symbol bound to result");
441 auto repl = formatv(fmt, name);
442 LLVM_DEBUG(dbgs() << repl << " (Operand/Attr/Prop)\n");
443 return std::string(repl);
444 }
445 case Kind::Result: {
446 if (index >= 0) {
447 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
448 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
449 return std::string(repl);
450 }
451
452 // We are referencing all results of the multi-result op. Each result should
453 // have a value range, and then join them with `separator`.
454 SmallVector<std::string, 4> values;
455 values.reserve(op->getNumResults());
456
457 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
458 values.push_back(std::string(
459 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
460 }
461 auto repl = llvm::join(values, separator);
462 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
463 return repl;
464 }
465 case Kind::Value: {
466 assert(index < 0 && "only allowed for symbol bound to result");
467 assert(op == nullptr);
468 auto repl = formatv(fmt, formatv("{{{0}}", name));
469 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
470 return std::string(repl);
471 }
472 case Kind::MultipleValues: {
473 assert(op == nullptr);
474 assert(index < getSize());
475 if (index >= 0) {
476 std::string repl =
477 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
478 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
479 return repl;
480 }
481 auto repl =
482 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
483 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
484 return std::string(repl);
485 }
486 }
487 llvm_unreachable("unknown kind");
488}
489
490bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
491 const Operator &op, int argIndex,
492 std::optional<int> variadicSubIndex) {
493 StringRef name = getValuePackName(symbol);
494 if (name != symbol) {
495 auto error = formatv(
496 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
497 PrintFatalError(loc, error);
498 }
499
500 Argument arg = op.getArg(argIndex);
501 SymbolInfo symInfo =
502 isa<NamedAttribute *>(arg) ? SymbolInfo::getAttr(&op, argIndex)
503 : isa<NamedProperty *>(arg)
504 ? SymbolInfo::getProp(&op, argIndex)
505 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
506
507 std::string key = symbol.str();
508 if (symbolInfoMap.count(key)) {
509 // Only non unique name for the operand is supported.
510 if (symInfo.kind != SymbolInfo::Kind::Operand) {
511 return false;
512 }
513
514 // Cannot add new operand if there is already non operand with the same
515 // name.
516 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
517 return false;
518 }
519 }
520
521 symbolInfoMap.emplace(key, symInfo);
522 return true;
523}
524
525bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
526 std::string name = getValuePackName(symbol).str();
527 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
528
529 return symbolInfoMap.count(inserted->first) == 1;
530}
531
532bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
533 std::string name = getValuePackName(symbol).str();
534 if (numValues > 1)
535 return bindMultipleValues(name, numValues);
536 return bindValue(name);
537}
538
539bool SymbolInfoMap::bindValue(StringRef symbol) {
540 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
541 return symbolInfoMap.count(inserted->first) == 1;
542}
543
544bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
545 std::string name = getValuePackName(symbol).str();
546 auto inserted =
547 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
548 return symbolInfoMap.count(inserted->first) == 1;
549}
550
551bool SymbolInfoMap::bindAttr(StringRef symbol) {
552 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
553 return symbolInfoMap.count(inserted->first) == 1;
554}
555
556bool SymbolInfoMap::bindProp(StringRef symbol,
557 const PropConstraint &constraint) {
558 auto inserted =
559 symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint));
560 return symbolInfoMap.count(inserted->first) == 1;
561}
562
563bool SymbolInfoMap::contains(StringRef symbol) const {
564 return find(symbol) != symbolInfoMap.end();
565}
566
568 std::string name = getValuePackName(key).str();
569
570 return symbolInfoMap.find(name);
571}
572
574SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
575 int argIndex,
576 std::optional<int> variadicSubIndex) const {
577 return findBoundSymbol(
578 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
579}
580
583 const SymbolInfo &symbolInfo) const {
584 std::string name = getValuePackName(key).str();
585 auto range = symbolInfoMap.equal_range(name);
586
587 for (auto it = range.first; it != range.second; ++it)
588 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
589 return it;
590
591 return symbolInfoMap.end();
592}
593
594std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
596 std::string name = getValuePackName(key).str();
597
598 return symbolInfoMap.equal_range(name);
599}
600
601int SymbolInfoMap::count(StringRef key) const {
602 std::string name = getValuePackName(key).str();
603 return symbolInfoMap.count(name);
604}
605
606int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
607 StringRef name = getValuePackName(symbol);
608 if (name != symbol) {
609 // If there is a trailing index inside symbol, it references just one
610 // static value.
611 return 1;
612 }
613 // Otherwise, find how many it represents by querying the symbol's info.
614 return find(name)->second.getStaticValueCount();
615}
616
617std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
618 const char *fmt,
619 const char *separator) const {
620 int index = -1;
621 StringRef name = getValuePackName(symbol, &index);
622
623 auto it = symbolInfoMap.find(name.str());
624 if (it == symbolInfoMap.end()) {
625 auto error = formatv("referencing unbound symbol '{0}'", symbol);
626 PrintFatalError(loc, error);
627 }
628
629 return it->second.getValueAndRangeUse(name, index, fmt, separator);
630}
631
632std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
633 const char *separator) const {
634 int index = -1;
635 StringRef name = getValuePackName(symbol, &index);
636
637 auto it = symbolInfoMap.find(name.str());
638 if (it == symbolInfoMap.end()) {
639 auto error = formatv("referencing unbound symbol '{0}'", symbol);
640 PrintFatalError(loc, error);
641 }
642
643 return it->second.getAllRangeUse(name, index, fmt, separator);
644}
645
647 llvm::StringSet<> usedNames;
648
649 for (auto symbolInfoIt = symbolInfoMap.begin();
650 symbolInfoIt != symbolInfoMap.end();) {
651 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
652 auto startRange = range.first;
653 auto endRange = range.second;
654
655 auto operandName = symbolInfoIt->first;
656 int startSearchIndex = 0;
657 for (++startRange; startRange != endRange; ++startRange) {
658 // Current operand name is not unique, find a unique one
659 // and set the alternative name.
660 for (int i = startSearchIndex;; ++i) {
661 std::string alternativeName = operandName + std::to_string(i);
662 if (!usedNames.contains(alternativeName) &&
663 symbolInfoMap.count(alternativeName) == 0) {
664 usedNames.insert(alternativeName);
665 startRange->second.alternativeName = alternativeName;
666 startSearchIndex = i + 1;
667
668 break;
669 }
670 }
671 }
672
673 symbolInfoIt = endRange;
674 }
675}
676
677//===----------------------------------------------------------------------===//
678// Pattern
679//==----------------------------------------------------------------------===//
680
681Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
682 : def(*def), recordOpMap(mapper) {}
683
685 return DagNode(def.getValueAsDag("sourcePattern"));
686}
687
689 auto *results = def.getValueAsListInit("resultPatterns");
690 return results->size();
691}
692
694 auto *results = def.getValueAsListInit("resultPatterns");
695 return DagNode(cast<DagInit>(results->getElement(index)));
696}
697
699 LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
700 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
701 LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
702
703 LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
705 LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
706}
707
709 LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
710 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
711 auto pattern = getResultPattern(i);
712 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
713 }
714 LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
715}
716
718 return getSourcePattern().getDialectOp(recordOpMap);
719}
720
722 return node.getDialectOp(recordOpMap);
723}
724
725std::vector<AppliedConstraint> Pattern::getConstraints() const {
726 auto *listInit = def.getValueAsListInit("constraints");
727 std::vector<AppliedConstraint> ret;
728 ret.reserve(listInit->size());
729
730 for (auto *it : *listInit) {
731 auto *dagInit = dyn_cast<DagInit>(it);
732 if (!dagInit)
733 PrintFatalError(&def, "all elements in Pattern multi-entity "
734 "constraints should be DAG nodes");
735
736 std::vector<std::string> entities;
737 entities.reserve(dagInit->arg_size());
738 for (auto *argName : dagInit->getArgNames()) {
739 if (!argName) {
740 PrintFatalError(
741 &def,
742 "operands to additional constraints can only be symbol references");
743 }
744 entities.emplace_back(argName->getValue());
745 }
746
747 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
748 dagInit->getNameStr(), std::move(entities));
749 }
750 return ret;
751}
752
754 auto *results = def.getValueAsListInit("supplementalPatterns");
755 return results->size();
756}
757
759 auto *results = def.getValueAsListInit("supplementalPatterns");
760 return DagNode(cast<DagInit>(results->getElement(index)));
761}
762
763int Pattern::getBenefit() const {
764 // The initial benefit value is a heuristic with number of ops in the source
765 // pattern.
766 int initBenefit = getSourcePattern().getNumOps();
767 const DagInit *delta = def.getValueAsDag("benefitDelta");
768 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
769 PrintFatalError(&def,
770 "The 'addBenefit' takes and only takes one integer value");
771 }
772 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
773}
774
775std::vector<Pattern::IdentifierLine>
776Pattern::getLocation(bool forSourceOutput) const {
777 std::vector<std::pair<StringRef, unsigned>> result;
778 result.reserve(def.getLoc().size());
779 for (auto loc : def.getLoc()) {
780 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
781 assert(buf && "invalid source location");
782
783 StringRef bufferName =
784 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier();
785 // If we're emitting a generated file, we'd like to have some indication of
786 // where our patterns came from. However, LLVM's build rules use absolute
787 // paths as arguments to TableGen, and naively echoing such paths makes the
788 // contents of the generated source file depend on the build location,
789 // making MLIR builds substantially less reproducable. As a compromise, we
790 // trim absolute paths back to only the filename component.
791 if (forSourceOutput && llvm::sys::path::is_absolute(bufferName))
792 bufferName = llvm::sys::path::filename(bufferName);
793
794 result.emplace_back(bufferName,
795 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
796 }
797 return result;
798}
799
800void Pattern::verifyBind(bool result, StringRef symbolName) {
801 if (!result) {
802 auto err = formatv("symbol '{0}' bound more than once", symbolName);
803 PrintFatalError(&def, err);
804 }
805}
806
808 bool isSrcPattern) {
809 auto treeName = tree.getSymbol();
810 auto numTreeArgs = tree.getNumArgs();
811
812 if (tree.isNativeCodeCall()) {
813 if (!treeName.empty()) {
814 if (!isSrcPattern) {
815 LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
816 << treeName << '\n');
817 verifyBind(
818 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
819 treeName);
820 } else {
821 PrintFatalError(&def,
822 formatv("binding symbol '{0}' to NativecodeCall in "
823 "MatchPattern is not supported",
824 treeName));
825 }
826 }
827
828 for (int i = 0; i != numTreeArgs; ++i) {
829 if (auto treeArg = tree.getArgAsNestedDag(i)) {
830 // This DAG node argument is a DAG node itself. Go inside recursively.
831 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
832 continue;
833 }
834
835 if (!isSrcPattern)
836 continue;
837
838 // We can only bind symbols to arguments in source pattern. Those
839 // symbols are referenced in result patterns.
840 auto treeArgName = tree.getArgName(i);
841
842 // `$_` is a special symbol meaning ignore the current argument.
843 if (!treeArgName.empty() && treeArgName != "_") {
844 DagLeaf leaf = tree.getArgAsLeaf(i);
845
846 // In (NativeCodeCall<"Foo($_self, $0, $1, $2, $3)"> I8Attr:$a, I8:$b,
847 // $c, I8Prop:$d),
848 if (leaf.isUnspecified()) {
849 // This is case of $c, a Value without any constraints.
850 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
851 } else if (leaf.isPropMatcher()) {
852 // This is case of $d, a binding to a certain property.
853 auto propConstraint = leaf.getAsPropConstraint();
854 if (propConstraint.getInterfaceType().empty()) {
855 PrintFatalError(&def,
856 formatv("binding symbol '{0}' in NativeCodeCall to "
857 "a property constraint without specifying "
858 "that constraint's type is unsupported",
859 treeArgName));
860 }
861 verifyBind(infoMap.bindProp(treeArgName, propConstraint),
862 treeArgName);
863 } else {
864 auto constraint = leaf.getAsConstraint();
865 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
866 leaf.isConstantAttr() ||
867 constraint.getKind() == Constraint::Kind::CK_Attr;
868
869 if (isAttr) {
870 // This is case of $a, a binding to a certain attribute.
871 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
872 continue;
873 }
874
875 // This is case of $b, a binding to a certain type.
876 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
877 }
878 }
879 }
880
881 return;
882 }
883
884 if (tree.isOperation()) {
885 auto &op = getDialectOp(tree);
886 auto numOpArgs = op.getNumArgs();
887 int numEither = 0;
888
889 // We need to exclude the trailing directives and `either` directive groups
890 // two operands of the operation.
891 int numDirectives = 0;
892 for (int i = numTreeArgs - 1; i >= 0; --i) {
893 if (auto dagArg = tree.getArgAsNestedDag(i)) {
894 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
895 ++numDirectives;
896 else if (dagArg.isEither())
897 ++numEither;
898 }
899 }
900
901 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
902 auto err =
903 formatv("op '{0}' argument number mismatch: "
904 "{1} in pattern vs. {2} in definition",
905 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
906 PrintFatalError(&def, err);
907 }
908
909 // The name attached to the DAG node's operator is for representing the
910 // results generated from this op. It should be remembered as bound results.
911 if (!treeName.empty()) {
912 LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
913 << '\n');
914 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
915 }
916
917 // The operand in `either` DAG should be bound to the operation in the
918 // parent DagNode.
919 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
920 int opArgIdx) {
921 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
922 if (DagNode subTree = tree.getArgAsNestedDag(i)) {
923 collectBoundSymbols(subTree, infoMap, isSrcPattern);
924 } else {
925 auto argName = tree.getArgName(i);
926 if (!argName.empty() && argName != "_") {
927 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
928 argName);
929 }
930 }
931 }
932 };
933
934 // The operand in `variadic` DAG should be bound to the operation in the
935 // parent DagNode. The range index must be included as well to distinguish
936 // (potentially) repeating argName within the `variadic` DAG.
937 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
938 int opArgIdx) {
939 auto treeName = tree.getSymbol();
940 if (!treeName.empty()) {
941 // If treeName is specified, bind to the full variadic operand_range.
942 verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
943 std::nullopt),
944 treeName);
945 }
946
947 for (int i = 0; i < tree.getNumArgs(); ++i) {
948 if (DagNode subTree = tree.getArgAsNestedDag(i)) {
949 collectBoundSymbols(subTree, infoMap, isSrcPattern);
950 } else {
951 auto argName = tree.getArgName(i);
952 if (!argName.empty() && argName != "_") {
953 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
954 /*variadicSubIndex=*/i),
955 argName);
956 }
957 }
958 }
959 };
960
961 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
962 if (auto treeArg = tree.getArgAsNestedDag(i)) {
963 if (treeArg.isEither()) {
964 collectSymbolInEither(tree, treeArg, opArgIdx);
965 // `either` DAG is *flattened*. For example,
966 //
967 // (FooOp (either arg0, arg1), arg2)
968 //
969 // can be viewed as:
970 //
971 // (FooOp arg0, arg1, arg2)
972 ++opArgIdx;
973 } else if (treeArg.isVariadic()) {
974 collectSymbolInVariadic(tree, treeArg, opArgIdx);
975 } else {
976 // This DAG node argument is a DAG node itself. Go inside recursively.
977 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
978 }
979 continue;
980 }
981
982 if (isSrcPattern) {
983 // We can only bind symbols to op arguments in source pattern. Those
984 // symbols are referenced in result patterns.
985 auto treeArgName = tree.getArgName(i);
986 // `$_` is a special symbol meaning ignore the current argument.
987 if (!treeArgName.empty() && treeArgName != "_") {
988 LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
989 << treeArgName << '\n');
990 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
991 treeArgName);
992 }
993 }
994 }
995 return;
996 }
997
998 if (!treeName.empty()) {
999 PrintFatalError(
1000 &def, formatv("binding symbol '{0}' to non-operation/native code call "
1001 "unsupported right now",
1002 treeName));
1003 }
1004}
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
std::string getConditionTemplate() const
Constraint getAsConstraint() const
Definition Pattern.cpp:77
bool isNativeCodeCall() const
Definition Pattern.cpp:65
bool isPropMatcher() const
Definition Pattern.cpp:55
int getNumReturnsOfNativeCode() const
Definition Pattern.cpp:117
ConstantAttr getAsConstantAttr() const
Definition Pattern.cpp:93
void print(raw_ostream &os) const
Definition Pattern.cpp:132
std::string getStringAttr() const
Definition Pattern.cpp:122
Property getAsProperty() const
Definition Pattern.cpp:88
bool isEnumCase() const
Definition Pattern.cpp:71
StringRef getNativeCodeTemplate() const
Definition Pattern.cpp:112
std::string getConditionTemplate() const
Definition Pattern.cpp:108
bool isConstantProp() const
Definition Pattern.cpp:73
PropConstraint getAsPropConstraint() const
Definition Pattern.cpp:83
ConstantProp getAsConstantProp() const
Definition Pattern.cpp:103
bool isUnspecified() const
Definition Pattern.cpp:41
EnumCase getAsEnumCase() const
Definition Pattern.cpp:98
bool isAttrMatcher() const
Definition Pattern.cpp:50
bool isOperandMatcher() const
Definition Pattern.cpp:45
bool isPropDefinition() const
Definition Pattern.cpp:60
bool isConstantAttr() const
Definition Pattern.cpp:69
bool isStringAttr() const
Definition Pattern.cpp:75
bool isReturnTypeDirective() const
Definition Pattern.cpp:217
bool isLocationDirective() const
Definition Pattern.cpp:212
bool isReplaceWithValue() const
Definition Pattern.cpp:207
DagNode getArgAsNestedDag(unsigned index) const
Definition Pattern.cpp:194
bool isOperation() const
Definition Pattern.cpp:147
DagLeaf getArgAsLeaf(unsigned index) const
Definition Pattern.cpp:198
int getNumReturnsOfNativeCode() const
Definition Pattern.cpp:160
StringRef getNativeCodeTemplate() const
Definition Pattern.cpp:153
void print(raw_ostream &os) const
Definition Pattern.cpp:232
int getNumOps() const
Definition Pattern.cpp:177
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition Pattern.cpp:169
bool isVariadic() const
Definition Pattern.cpp:227
bool isNativeCodeCall() const
Definition Pattern.cpp:141
bool isEither() const
Definition Pattern.cpp:222
bool isNestedDagArg(unsigned index) const
Definition Pattern.cpp:190
StringRef getSymbol() const
Definition Pattern.cpp:167
int getNumArgs() const
Definition Pattern.cpp:188
DagNode(const llvm::DagInit *node)
Definition Pattern.h:165
StringRef getArgName(unsigned index) const
Definition Pattern.cpp:203
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition Operator.h:77
int getNumResults() const
Returns the number of results this op produces.
Definition Operator.cpp:164
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition Operator.cpp:183
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition Operator.cpp:357
int getBenefit() const
int getNumResultPatterns() const
Definition Pattern.cpp:688
DagNode getSourcePattern() const
Definition Pattern.cpp:684
const Operator & getSourceRootOp()
Definition Pattern.cpp:717
std::vector< AppliedConstraint > getConstraints() const
Definition Pattern.cpp:725
DagNode getResultPattern(unsigned index) const
Definition Pattern.cpp:693
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition Pattern.cpp:807
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition Pattern.cpp:721
DagNode getSupplementalPattern(unsigned index) const
Definition Pattern.cpp:758
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition Pattern.cpp:698
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition Pattern.cpp:708
int getNumSupplementalPatterns() const
Definition Pattern.cpp:753
std::vector< IdentifierLine > getLocation(bool forSourceOutput=false) const
Definition Pattern.cpp:776
std::string getArgDecl(StringRef name) const
Definition Pattern.cpp:328
std::string getVarName(StringRef name) const
Definition Pattern.cpp:275
std::string getVarTypeStr(StringRef name) const
Definition Pattern.cpp:279
std::string getVarDecl(StringRef name) const
Definition Pattern.cpp:321
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition Pattern.cpp:241
int count(StringRef key) const
Definition Pattern.cpp:601
const_iterator find(StringRef key) const
Definition Pattern.cpp:567
bool bindMultipleValues(StringRef symbol, int numValues)
Definition Pattern.cpp:544
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition Pattern.cpp:490
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition Pattern.cpp:632
bool bindValues(StringRef symbol, int numValues=1)
Definition Pattern.cpp:532
bool bindAttr(StringRef symbol)
Definition Pattern.cpp:551
bool bindProp(StringRef symbol, const PropConstraint &constraint)
Definition Pattern.cpp:556
bool bindValue(StringRef symbol)
Definition Pattern.cpp:539
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition Pattern.cpp:574
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition Pattern.cpp:595
int getStaticValueCount(StringRef symbol) const
Definition Pattern.cpp:606
bool contains(StringRef symbol) const
Definition Pattern.cpp:563
BaseT::const_iterator const_iterator
Definition Pattern.h:504
bool bindOpResult(StringRef symbol, const Operator &op)
Definition Pattern.cpp:525
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition Pattern.cpp:617
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition Argument.h:63
DenseMap< const llvm::Record *, std::unique_ptr< Operator > > RecordOperatorMap
Definition Pattern.h:44
Include the generated interface declarations.