17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
27 using namespace tblgen;
41 return isa_and_nonnull<llvm::UnsetInit>(def);
46 return isSubClassOf(
"TypeConstraint");
51 return isSubClassOf(
"AttrConstraint");
55 return isSubClassOf(
"NativeCodeCall");
61 return isSubClassOf(
"EnumAttrCaseInfo");
68 "the DAG leaf must be operand or attribute");
69 return Constraint(cast<DefInit>(def)->getDef());
73 assert(
isConstantAttr() &&
"the DAG leaf must be constant attribute");
78 assert(
isEnumAttrCase() &&
"the DAG leaf must be an enum attribute case");
88 return cast<DefInit>(def)->getDef()->getValueAsString(
"expression");
93 return cast<DefInit>(def)->getDef()->getValueAsInt(
"numReturns");
97 assert(
isStringAttr() &&
"the DAG leaf must be string attribute");
98 return def->getAsUnquotedString();
100 bool DagLeaf::isSubClassOf(StringRef superclass)
const {
101 if (
auto *defInit = dyn_cast_or_null<DefInit>(def))
102 return defInit->getDef()->isSubClassOf(superclass);
116 if (
auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
117 return defInit->getDef()->isSubClassOf(
"NativeCodeCall");
129 return cast<DefInit>(node->getOperator())
131 ->getValueAsString(
"expression");
136 return cast<DefInit>(node->getOperator())
138 ->getValueAsInt(
"numReturns");
144 const Record *opDef = cast<DefInit>(node->getOperator())->
getDef();
145 auto [it, inserted] = mapper->try_emplace(opDef);
147 it->second = std::make_unique<Operator>(opDef);
155 for (
int i = 0, e =
getNumArgs(); i != e; ++i) {
157 count += child.getNumOps();
165 return isa<DagInit>(node->getArg(index));
169 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
174 return DagLeaf(node->getArg(index));
178 return node->getArgNameStr(index);
182 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
183 return dagOpDef->getName() ==
"replaceWithValue";
187 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
188 return dagOpDef->getName() ==
"location";
192 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
193 return dagOpDef->getName() ==
"returnType";
197 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
198 return dagOpDef->getName() ==
"either";
202 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
203 return dagOpDef->getName() ==
"variadic";
217 auto [name, indexStr] = symbol.rsplit(
"__");
219 if (indexStr.consumeInteger(10, idx)) {
229 SymbolInfoMap::SymbolInfo::SymbolInfo(
231 std::optional<DagAndConstant> dagAndConstant)
232 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
234 int SymbolInfoMap::SymbolInfo::getStaticValueCount()
const {
242 case Kind::MultipleValues:
245 llvm_unreachable(
"unknown kind");
249 return alternativeName ? *alternativeName : name.str();
253 LLVM_DEBUG(dbgs() <<
"getVarTypeStr for '" << name <<
"': ");
257 return op->getArg(getArgIndex())
259 ->attr.getStorageType()
262 return "::mlir::Attribute";
264 case Kind::Operand: {
267 return "::mlir::Operation::operand_range";
270 return "::mlir::Value";
272 case Kind::MultipleValues: {
273 return "::mlir::ValueRange";
277 return op->getQualCppClassName();
280 llvm_unreachable(
"unknown kind");
284 LLVM_DEBUG(dbgs() <<
"getVarDecl for '" << name <<
"': ");
285 std::string varInit = kind == Kind::Operand ?
"(op0->getOperands())" :
"";
287 formatv(
"{0} {1}{2};\n", getVarTypeStr(name),
getVarName(name), varInit));
291 LLVM_DEBUG(dbgs() <<
"getArgDecl for '" << name <<
"': ");
293 formatv(
"{0} &{1}", getVarTypeStr(name),
getVarName(name)));
296 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
297 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
298 LLVM_DEBUG(dbgs() <<
"getValueAndRangeUse for '" << name <<
"': ");
302 auto repl = formatv(fmt, name);
303 LLVM_DEBUG(dbgs() << repl <<
" (Attr)\n");
304 return std::string(repl);
306 case Kind::Operand: {
312 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
313 auto repl = formatv(fmt, name);
314 LLVM_DEBUG(dbgs() << repl <<
" (VariadicOperand)\n");
315 return std::string(repl);
317 auto repl = formatv(fmt, formatv(
"(*{0}.begin())", name));
318 LLVM_DEBUG(dbgs() << repl <<
" (SingleOperand)\n");
319 return std::string(repl);
326 std::string(formatv(
"{0}.getODSResults({1})", name, index));
327 if (!op->getResult(index).isVariadic())
328 v = std::string(formatv(
"(*{0}.begin())", v));
329 auto repl = formatv(fmt, v);
330 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
331 return std::string(repl);
336 if (op->getNumResults() == 0) {
337 LLVM_DEBUG(dbgs() << name <<
" (Op)\n");
338 return formatv(fmt, name);
344 values.reserve(op->getNumResults());
346 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
347 std::string v = std::string(formatv(
"{0}.getODSResults({1})", name, i));
348 if (!op->getResult(i).isVariadic()) {
349 v = std::string(formatv(
"(*{0}.begin())", v));
351 values.push_back(std::string(formatv(fmt, v)));
353 auto repl = llvm::join(values, separator);
354 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
359 assert(op ==
nullptr);
360 auto repl = formatv(fmt, name);
361 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
362 return std::string(repl);
364 case Kind::MultipleValues: {
365 assert(op ==
nullptr);
366 assert(index < getSize());
369 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
370 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
375 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
376 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
377 return std::string(repl);
380 llvm_unreachable(
"unknown kind");
383 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
384 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
385 LLVM_DEBUG(dbgs() <<
"getAllRangeUse for '" << name <<
"': ");
388 case Kind::Operand: {
389 assert(index < 0 &&
"only allowed for symbol bound to result");
390 auto repl = formatv(fmt, name);
391 LLVM_DEBUG(dbgs() << repl <<
" (Operand/Attr)\n");
392 return std::string(repl);
396 auto repl = formatv(fmt, formatv(
"{0}.getODSResults({1})", name, index));
397 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
398 return std::string(repl);
404 values.reserve(op->getNumResults());
406 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
407 values.push_back(std::string(
408 formatv(fmt, formatv(
"{0}.getODSResults({1})", name, i))));
410 auto repl = llvm::join(values, separator);
411 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
415 assert(index < 0 &&
"only allowed for symbol bound to result");
416 assert(op ==
nullptr);
417 auto repl = formatv(fmt, formatv(
"{{{0}}", name));
418 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
419 return std::string(repl);
421 case Kind::MultipleValues: {
422 assert(op ==
nullptr);
423 assert(index < getSize());
426 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
427 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
431 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
432 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
433 return std::string(repl);
436 llvm_unreachable(
"unknown kind");
441 std::optional<int> variadicSubIndex) {
443 if (name != symbol) {
444 auto error = formatv(
445 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
446 PrintFatalError(loc, error);
451 ? SymbolInfo::getAttr(&op, argIndex)
452 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
454 std::string key = symbol.str();
455 if (symbolInfoMap.count(key)) {
457 if (symInfo.kind != SymbolInfo::Kind::Operand) {
463 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
468 symbolInfoMap.emplace(key, symInfo);
474 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
476 return symbolInfoMap.count(inserted->first) == 1;
487 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
488 return symbolInfoMap.count(inserted->first) == 1;
494 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
495 return symbolInfoMap.count(inserted->first) == 1;
499 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
500 return symbolInfoMap.count(inserted->first) == 1;
504 return find(symbol) != symbolInfoMap.end();
510 return symbolInfoMap.find(name);
516 std::optional<int> variadicSubIndex)
const {
518 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
525 auto range = symbolInfoMap.equal_range(name);
527 for (
auto it = range.first; it != range.second; ++it)
528 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
531 return symbolInfoMap.end();
534 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
538 return symbolInfoMap.equal_range(name);
543 return symbolInfoMap.count(name);
548 if (name != symbol) {
554 return find(name)->second.getStaticValueCount();
559 const char *separator)
const {
563 auto it = symbolInfoMap.find(name.str());
564 if (it == symbolInfoMap.end()) {
565 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
566 PrintFatalError(loc, error);
569 return it->second.getValueAndRangeUse(name, index, fmt, separator);
573 const char *separator)
const {
577 auto it = symbolInfoMap.find(name.str());
578 if (it == symbolInfoMap.end()) {
579 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
580 PrintFatalError(loc, error);
583 return it->second.getAllRangeUse(name, index, fmt, separator);
589 for (
auto symbolInfoIt = symbolInfoMap.begin();
590 symbolInfoIt != symbolInfoMap.end();) {
591 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
592 auto startRange = range.first;
593 auto endRange = range.second;
595 auto operandName = symbolInfoIt->first;
596 int startSearchIndex = 0;
597 for (++startRange; startRange != endRange; ++startRange) {
600 for (
int i = startSearchIndex;; ++i) {
601 std::string alternativeName = operandName + std::to_string(i);
602 if (!usedNames.contains(alternativeName) &&
603 symbolInfoMap.count(alternativeName) == 0) {
604 usedNames.insert(alternativeName);
605 startRange->second.alternativeName = alternativeName;
606 startSearchIndex = i + 1;
613 symbolInfoIt = endRange;
622 : def(*def), recordOpMap(mapper) {}
625 return DagNode(def.getValueAsDag(
"sourcePattern"));
629 auto *results = def.getValueAsListInit(
"resultPatterns");
630 return results->size();
634 auto *results = def.getValueAsListInit(
"resultPatterns");
635 return DagNode(cast<DagInit>(results->getElement(index)));
639 LLVM_DEBUG(dbgs() <<
"start collecting source pattern bound symbols\n");
641 LLVM_DEBUG(dbgs() <<
"done collecting source pattern bound symbols\n");
643 LLVM_DEBUG(dbgs() <<
"start assigning alternative names for symbols\n");
645 LLVM_DEBUG(dbgs() <<
"done assigning alternative names for symbols\n");
649 LLVM_DEBUG(dbgs() <<
"start collecting result pattern bound symbols\n");
654 LLVM_DEBUG(dbgs() <<
"done collecting result pattern bound symbols\n");
666 auto *listInit = def.getValueAsListInit(
"constraints");
667 std::vector<AppliedConstraint> ret;
668 ret.reserve(listInit->size());
670 for (
auto *it : *listInit) {
671 auto *dagInit = dyn_cast<DagInit>(it);
673 PrintFatalError(&def,
"all elements in Pattern multi-entity "
674 "constraints should be DAG nodes");
676 std::vector<std::string> entities;
677 entities.reserve(dagInit->arg_size());
678 for (
auto *argName : dagInit->getArgNames()) {
682 "operands to additional constraints can only be symbol references");
684 entities.emplace_back(argName->getValue());
687 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
688 dagInit->getNameStr(), std::move(entities));
694 auto *results = def.getValueAsListInit(
"supplementalPatterns");
695 return results->size();
699 auto *results = def.getValueAsListInit(
"supplementalPatterns");
700 return DagNode(cast<DagInit>(results->getElement(index)));
707 const DagInit *delta = def.getValueAsDag(
"benefitDelta");
708 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
709 PrintFatalError(&def,
710 "The 'addBenefit' takes and only takes one integer value");
712 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
716 std::vector<std::pair<StringRef, unsigned>> result;
717 result.reserve(def.getLoc().size());
718 for (
auto loc : def.getLoc()) {
719 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
720 assert(buf &&
"invalid source location");
722 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
723 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
728 void Pattern::verifyBind(
bool result, StringRef symbolName) {
730 auto err = formatv(
"symbol '{0}' bound more than once", symbolName);
731 PrintFatalError(&def, err);
741 if (!treeName.empty()) {
743 LLVM_DEBUG(dbgs() <<
"found symbol bound to NativeCodeCall: "
744 << treeName <<
'\n');
749 PrintFatalError(&def,
750 formatv(
"binding symbol '{0}' to NativecodeCall in "
751 "MatchPattern is not supported",
756 for (
int i = 0; i != numTreeArgs; ++i) {
771 if (!treeArgName.empty() && treeArgName !=
"_") {
777 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
782 constraint.getKind() == Constraint::Kind::CK_Attr;
786 verifyBind(infoMap.
bindAttr(treeArgName), treeArgName);
791 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
801 auto numOpArgs = op.getNumArgs();
806 int numDirectives = 0;
807 for (
int i = numTreeArgs - 1; i >= 0; --i) {
809 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
811 else if (dagArg.isEither())
816 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
818 formatv(
"op '{0}' argument number mismatch: "
819 "{1} in pattern vs. {2} in definition",
820 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
821 PrintFatalError(&def, err);
826 if (!treeName.empty()) {
827 LLVM_DEBUG(dbgs() <<
"found symbol bound to op result: " << treeName
829 verifyBind(infoMap.
bindOpResult(treeName, op), treeName);
836 for (
int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
840 auto argName = tree.getArgName(i);
841 if (!argName.empty() && argName !=
"_") {
854 auto treeName = tree.getSymbol();
855 if (!treeName.empty()) {
862 for (
int i = 0; i < tree.getNumArgs(); ++i) {
866 auto argName = tree.getArgName(i);
867 if (!argName.empty() && argName !=
"_") {
876 for (
int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
877 if (
auto treeArg = tree.getArgAsNestedDag(i)) {
878 if (treeArg.isEither()) {
879 collectSymbolInEither(tree, treeArg, opArgIdx);
888 }
else if (treeArg.isVariadic()) {
889 collectSymbolInVariadic(tree, treeArg, opArgIdx);
900 auto treeArgName = tree.getArgName(i);
902 if (!treeArgName.empty() && treeArgName !=
"_") {
903 LLVM_DEBUG(dbgs() <<
"found symbol bound to op argument: "
904 << treeArgName <<
'\n');
905 verifyBind(infoMap.
bindOpArgument(tree, treeArgName, op, opArgIdx),
913 if (!treeName.empty()) {
915 &def, formatv(
"binding symbol '{0}' to non-operation/native code call "
916 "unsupported right now",
std::string getConditionTemplate() const
Constraint getAsConstraint() const
bool isNativeCodeCall() const
bool isEnumAttrCase() const
int getNumReturnsOfNativeCode() const
ConstantAttr getAsConstantAttr() const
void print(raw_ostream &os) const
std::string getStringAttr() const
StringRef getNativeCodeTemplate() const
std::string getConditionTemplate() const
bool isUnspecified() const
bool isAttrMatcher() const
EnumAttrCase getAsEnumAttrCase() const
bool isOperandMatcher() const
bool isConstantAttr() const
bool isStringAttr() const
bool isReturnTypeDirective() const
bool isLocationDirective() const
bool isReplaceWithValue() const
DagNode getArgAsNestedDag(unsigned index) const
DagLeaf getArgAsLeaf(unsigned index) const
int getNumReturnsOfNativeCode() const
StringRef getNativeCodeTemplate() const
void print(raw_ostream &os) const
Operator & getDialectOp(RecordOperatorMap *mapper) const
bool isNativeCodeCall() const
bool isNestedDagArg(unsigned index) const
StringRef getSymbol() const
DagNode(const llvm::DagInit *node)
StringRef getArgName(unsigned index) const
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
int getNumResultPatterns() const
std::vector< IdentifierLine > getLocation() const
DagNode getSourcePattern() const
const Operator & getSourceRootOp()
std::vector< AppliedConstraint > getConstraints() const
DagNode getResultPattern(unsigned index) const
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
DagNode getSupplementalPattern(unsigned index) const
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
int getNumSupplementalPatterns() const
std::string getArgDecl(StringRef name) const
std::string getVarName(StringRef name) const
std::string getVarTypeStr(StringRef name) const
std::string getVarDecl(StringRef name) const
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
int count(StringRef key) const
const_iterator find(StringRef key) const
void assignUniqueAlternativeNames()
bool bindMultipleValues(StringRef symbol, int numValues)
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
bool bindValues(StringRef symbol, int numValues=1)
bool bindAttr(StringRef symbol)
bool bindValue(StringRef symbol)
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
int getStaticValueCount(StringRef symbol) const
bool contains(StringRef symbol) const
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.