17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
27 using namespace tblgen;
41 return isa_and_nonnull<llvm::UnsetInit>(def);
46 return isSubClassOf(
"TypeConstraint");
51 return isSubClassOf(
"AttrConstraint");
55 return isSubClassOf(
"NativeCodeCall");
61 return isSubClassOf(
"EnumAttrCaseInfo");
68 "the DAG leaf must be operand or attribute");
69 return Constraint(cast<DefInit>(def)->getDef());
73 assert(
isConstantAttr() &&
"the DAG leaf must be constant attribute");
78 assert(
isEnumAttrCase() &&
"the DAG leaf must be an enum attribute case");
88 return cast<DefInit>(def)->getDef()->getValueAsString(
"expression");
93 return cast<DefInit>(def)->getDef()->getValueAsInt(
"numReturns");
97 assert(
isStringAttr() &&
"the DAG leaf must be string attribute");
98 return def->getAsUnquotedString();
100 bool DagLeaf::isSubClassOf(StringRef superclass)
const {
101 if (
auto *defInit = dyn_cast_or_null<DefInit>(def))
102 return defInit->getDef()->isSubClassOf(superclass);
116 if (
auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
117 return defInit->getDef()->isSubClassOf(
"NativeCodeCall");
129 return cast<DefInit>(node->getOperator())
131 ->getValueAsString(
"expression");
136 return cast<DefInit>(node->getOperator())
138 ->getValueAsInt(
"numReturns");
144 const Record *opDef = cast<DefInit>(node->getOperator())->
getDef();
145 auto [it, inserted] = mapper->try_emplace(opDef);
147 it->second = std::make_unique<Operator>(opDef);
155 for (
int i = 0, e =
getNumArgs(); i != e; ++i) {
157 count += child.getNumOps();
165 return isa<DagInit>(node->getArg(index));
169 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
174 return DagLeaf(node->getArg(index));
178 return node->getArgNameStr(index);
182 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
183 return dagOpDef->getName() ==
"replaceWithValue";
187 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
188 return dagOpDef->getName() ==
"location";
192 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
193 return dagOpDef->getName() ==
"returnType";
197 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
198 return dagOpDef->getName() ==
"either";
202 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
203 return dagOpDef->getName() ==
"variadic";
217 auto [name, indexStr] = symbol.rsplit(
"__");
219 if (indexStr.consumeInteger(10, idx)) {
229 SymbolInfoMap::SymbolInfo::SymbolInfo(
231 std::optional<DagAndConstant> dagAndConstant)
232 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
234 int SymbolInfoMap::SymbolInfo::getStaticValueCount()
const {
242 case Kind::MultipleValues:
245 llvm_unreachable(
"unknown kind");
249 return alternativeName ? *alternativeName : name.str();
253 LLVM_DEBUG(dbgs() <<
"getVarTypeStr for '" << name <<
"': ");
257 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
258 ->attr.getStorageType()
261 return "::mlir::Attribute";
263 case Kind::Operand: {
266 return "::mlir::Operation::operand_range";
269 return "::mlir::Value";
271 case Kind::MultipleValues: {
272 return "::mlir::ValueRange";
276 return op->getQualCppClassName();
279 llvm_unreachable(
"unknown kind");
283 LLVM_DEBUG(dbgs() <<
"getVarDecl for '" << name <<
"': ");
284 std::string varInit = kind == Kind::Operand ?
"(op0->getOperands())" :
"";
286 formatv(
"{0} {1}{2};\n", getVarTypeStr(name),
getVarName(name), varInit));
290 LLVM_DEBUG(dbgs() <<
"getArgDecl for '" << name <<
"': ");
292 formatv(
"{0} &{1}", getVarTypeStr(name),
getVarName(name)));
295 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
296 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
297 LLVM_DEBUG(dbgs() <<
"getValueAndRangeUse for '" << name <<
"': ");
301 auto repl = formatv(fmt, name);
302 LLVM_DEBUG(dbgs() << repl <<
" (Attr)\n");
303 return std::string(repl);
305 case Kind::Operand: {
307 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
311 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
312 auto repl = formatv(fmt, name);
313 LLVM_DEBUG(dbgs() << repl <<
" (VariadicOperand)\n");
314 return std::string(repl);
316 auto repl = formatv(fmt, formatv(
"(*{0}.begin())", name));
317 LLVM_DEBUG(dbgs() << repl <<
" (SingleOperand)\n");
318 return std::string(repl);
325 std::string(formatv(
"{0}.getODSResults({1})", name, index));
326 if (!op->getResult(index).isVariadic())
327 v = std::string(formatv(
"(*{0}.begin())", v));
328 auto repl = formatv(fmt, v);
329 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
330 return std::string(repl);
335 if (op->getNumResults() == 0) {
336 LLVM_DEBUG(dbgs() << name <<
" (Op)\n");
337 return formatv(fmt, name);
343 values.reserve(op->getNumResults());
345 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
346 std::string v = std::string(formatv(
"{0}.getODSResults({1})", name, i));
347 if (!op->getResult(i).isVariadic()) {
348 v = std::string(formatv(
"(*{0}.begin())", v));
350 values.push_back(std::string(formatv(fmt, v)));
352 auto repl = llvm::join(values, separator);
353 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
358 assert(op ==
nullptr);
359 auto repl = formatv(fmt, name);
360 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
361 return std::string(repl);
363 case Kind::MultipleValues: {
364 assert(op ==
nullptr);
365 assert(index < getSize());
368 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
369 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
374 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
375 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
376 return std::string(repl);
379 llvm_unreachable(
"unknown kind");
382 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
383 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
384 LLVM_DEBUG(dbgs() <<
"getAllRangeUse for '" << name <<
"': ");
387 case Kind::Operand: {
388 assert(index < 0 &&
"only allowed for symbol bound to result");
389 auto repl = formatv(fmt, name);
390 LLVM_DEBUG(dbgs() << repl <<
" (Operand/Attr)\n");
391 return std::string(repl);
395 auto repl = formatv(fmt, formatv(
"{0}.getODSResults({1})", name, index));
396 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
397 return std::string(repl);
403 values.reserve(op->getNumResults());
405 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
406 values.push_back(std::string(
407 formatv(fmt, formatv(
"{0}.getODSResults({1})", name, i))));
409 auto repl = llvm::join(values, separator);
410 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
414 assert(index < 0 &&
"only allowed for symbol bound to result");
415 assert(op ==
nullptr);
416 auto repl = formatv(fmt, formatv(
"{{{0}}", name));
417 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
418 return std::string(repl);
420 case Kind::MultipleValues: {
421 assert(op ==
nullptr);
422 assert(index < getSize());
425 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
426 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
430 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
431 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
432 return std::string(repl);
435 llvm_unreachable(
"unknown kind");
440 std::optional<int> variadicSubIndex) {
442 if (name != symbol) {
443 auto error = formatv(
444 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
445 PrintFatalError(loc, error);
449 isa<NamedAttribute *>(op.
getArg(argIndex))
450 ? SymbolInfo::getAttr(&op, argIndex)
451 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
453 std::string key = symbol.str();
454 if (symbolInfoMap.count(key)) {
456 if (symInfo.kind != SymbolInfo::Kind::Operand) {
462 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
467 symbolInfoMap.emplace(key, symInfo);
473 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
475 return symbolInfoMap.count(inserted->first) == 1;
486 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
487 return symbolInfoMap.count(inserted->first) == 1;
493 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
494 return symbolInfoMap.count(inserted->first) == 1;
498 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
499 return symbolInfoMap.count(inserted->first) == 1;
503 return find(symbol) != symbolInfoMap.end();
509 return symbolInfoMap.find(name);
515 std::optional<int> variadicSubIndex)
const {
517 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
524 auto range = symbolInfoMap.equal_range(name);
526 for (
auto it = range.first; it != range.second; ++it)
527 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
530 return symbolInfoMap.end();
533 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
537 return symbolInfoMap.equal_range(name);
542 return symbolInfoMap.count(name);
547 if (name != symbol) {
553 return find(name)->second.getStaticValueCount();
558 const char *separator)
const {
562 auto it = symbolInfoMap.find(name.str());
563 if (it == symbolInfoMap.end()) {
564 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
565 PrintFatalError(loc, error);
568 return it->second.getValueAndRangeUse(name, index, fmt, separator);
572 const char *separator)
const {
576 auto it = symbolInfoMap.find(name.str());
577 if (it == symbolInfoMap.end()) {
578 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
579 PrintFatalError(loc, error);
582 return it->second.getAllRangeUse(name, index, fmt, separator);
588 for (
auto symbolInfoIt = symbolInfoMap.begin();
589 symbolInfoIt != symbolInfoMap.end();) {
590 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
591 auto startRange = range.first;
592 auto endRange = range.second;
594 auto operandName = symbolInfoIt->first;
595 int startSearchIndex = 0;
596 for (++startRange; startRange != endRange; ++startRange) {
599 for (
int i = startSearchIndex;; ++i) {
600 std::string alternativeName = operandName + std::to_string(i);
601 if (!usedNames.contains(alternativeName) &&
602 symbolInfoMap.count(alternativeName) == 0) {
603 usedNames.insert(alternativeName);
604 startRange->second.alternativeName = alternativeName;
605 startSearchIndex = i + 1;
612 symbolInfoIt = endRange;
621 : def(*def), recordOpMap(mapper) {}
624 return DagNode(def.getValueAsDag(
"sourcePattern"));
628 auto *results = def.getValueAsListInit(
"resultPatterns");
629 return results->size();
633 auto *results = def.getValueAsListInit(
"resultPatterns");
634 return DagNode(cast<DagInit>(results->getElement(index)));
638 LLVM_DEBUG(dbgs() <<
"start collecting source pattern bound symbols\n");
640 LLVM_DEBUG(dbgs() <<
"done collecting source pattern bound symbols\n");
642 LLVM_DEBUG(dbgs() <<
"start assigning alternative names for symbols\n");
644 LLVM_DEBUG(dbgs() <<
"done assigning alternative names for symbols\n");
648 LLVM_DEBUG(dbgs() <<
"start collecting result pattern bound symbols\n");
653 LLVM_DEBUG(dbgs() <<
"done collecting result pattern bound symbols\n");
665 auto *listInit = def.getValueAsListInit(
"constraints");
666 std::vector<AppliedConstraint> ret;
667 ret.reserve(listInit->size());
669 for (
auto *it : *listInit) {
670 auto *dagInit = dyn_cast<DagInit>(it);
672 PrintFatalError(&def,
"all elements in Pattern multi-entity "
673 "constraints should be DAG nodes");
675 std::vector<std::string> entities;
676 entities.reserve(dagInit->arg_size());
677 for (
auto *argName : dagInit->getArgNames()) {
681 "operands to additional constraints can only be symbol references");
683 entities.emplace_back(argName->getValue());
686 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
687 dagInit->getNameStr(), std::move(entities));
693 auto *results = def.getValueAsListInit(
"supplementalPatterns");
694 return results->size();
698 auto *results = def.getValueAsListInit(
"supplementalPatterns");
699 return DagNode(cast<DagInit>(results->getElement(index)));
706 const DagInit *delta = def.getValueAsDag(
"benefitDelta");
707 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
708 PrintFatalError(&def,
709 "The 'addBenefit' takes and only takes one integer value");
711 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
715 std::vector<std::pair<StringRef, unsigned>> result;
716 result.reserve(def.getLoc().size());
717 for (
auto loc : def.getLoc()) {
718 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
719 assert(buf &&
"invalid source location");
721 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
722 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
727 void Pattern::verifyBind(
bool result, StringRef symbolName) {
729 auto err = formatv(
"symbol '{0}' bound more than once", symbolName);
730 PrintFatalError(&def, err);
740 if (!treeName.empty()) {
742 LLVM_DEBUG(dbgs() <<
"found symbol bound to NativeCodeCall: "
743 << treeName <<
'\n');
748 PrintFatalError(&def,
749 formatv(
"binding symbol '{0}' to NativecodeCall in "
750 "MatchPattern is not supported",
755 for (
int i = 0; i != numTreeArgs; ++i) {
770 if (!treeArgName.empty() && treeArgName !=
"_") {
776 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
781 constraint.getKind() == Constraint::Kind::CK_Attr;
785 verifyBind(infoMap.
bindAttr(treeArgName), treeArgName);
790 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
800 auto numOpArgs = op.getNumArgs();
805 int numDirectives = 0;
806 for (
int i = numTreeArgs - 1; i >= 0; --i) {
808 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
810 else if (dagArg.isEither())
815 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
817 formatv(
"op '{0}' argument number mismatch: "
818 "{1} in pattern vs. {2} in definition",
819 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
820 PrintFatalError(&def, err);
825 if (!treeName.empty()) {
826 LLVM_DEBUG(dbgs() <<
"found symbol bound to op result: " << treeName
828 verifyBind(infoMap.
bindOpResult(treeName, op), treeName);
835 for (
int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
839 auto argName = tree.getArgName(i);
840 if (!argName.empty() && argName !=
"_") {
853 auto treeName = tree.getSymbol();
854 if (!treeName.empty()) {
861 for (
int i = 0; i < tree.getNumArgs(); ++i) {
865 auto argName = tree.getArgName(i);
866 if (!argName.empty() && argName !=
"_") {
875 for (
int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
876 if (
auto treeArg = tree.getArgAsNestedDag(i)) {
877 if (treeArg.isEither()) {
878 collectSymbolInEither(tree, treeArg, opArgIdx);
887 }
else if (treeArg.isVariadic()) {
888 collectSymbolInVariadic(tree, treeArg, opArgIdx);
899 auto treeArgName = tree.getArgName(i);
901 if (!treeArgName.empty() && treeArgName !=
"_") {
902 LLVM_DEBUG(dbgs() <<
"found symbol bound to op argument: "
903 << treeArgName <<
'\n');
904 verifyBind(infoMap.
bindOpArgument(tree, treeArgName, op, opArgIdx),
912 if (!treeName.empty()) {
914 &def, formatv(
"binding symbol '{0}' to non-operation/native code call "
915 "unsupported right now",
std::string getConditionTemplate() const
Constraint getAsConstraint() const
bool isNativeCodeCall() const
bool isEnumAttrCase() const
int getNumReturnsOfNativeCode() const
ConstantAttr getAsConstantAttr() const
void print(raw_ostream &os) const
std::string getStringAttr() const
StringRef getNativeCodeTemplate() const
std::string getConditionTemplate() const
bool isUnspecified() const
bool isAttrMatcher() const
EnumAttrCase getAsEnumAttrCase() const
bool isOperandMatcher() const
bool isConstantAttr() const
bool isStringAttr() const
bool isReturnTypeDirective() const
bool isLocationDirective() const
bool isReplaceWithValue() const
DagNode getArgAsNestedDag(unsigned index) const
DagLeaf getArgAsLeaf(unsigned index) const
int getNumReturnsOfNativeCode() const
StringRef getNativeCodeTemplate() const
void print(raw_ostream &os) const
Operator & getDialectOp(RecordOperatorMap *mapper) const
bool isNativeCodeCall() const
bool isNestedDagArg(unsigned index) const
StringRef getSymbol() const
DagNode(const llvm::DagInit *node)
StringRef getArgName(unsigned index) const
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
int getNumResultPatterns() const
std::vector< IdentifierLine > getLocation() const
DagNode getSourcePattern() const
const Operator & getSourceRootOp()
std::vector< AppliedConstraint > getConstraints() const
DagNode getResultPattern(unsigned index) const
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
DagNode getSupplementalPattern(unsigned index) const
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
int getNumSupplementalPatterns() const
std::string getArgDecl(StringRef name) const
std::string getVarName(StringRef name) const
std::string getVarTypeStr(StringRef name) const
std::string getVarDecl(StringRef name) const
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
int count(StringRef key) const
const_iterator find(StringRef key) const
void assignUniqueAlternativeNames()
bool bindMultipleValues(StringRef symbol, int numValues)
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
bool bindValues(StringRef symbol, int numValues=1)
bool bindAttr(StringRef symbol)
bool bindValue(StringRef symbol)
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
int getStaticValueCount(StringRef symbol) const
bool contains(StringRef symbol) const
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.