17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
27 using namespace tblgen;
36 return isa_and_nonnull<llvm::UnsetInit>(def);
41 return isSubClassOf(
"TypeConstraint");
46 return isSubClassOf(
"AttrConstraint");
50 return isSubClassOf(
"NativeCodeCall");
56 return isSubClassOf(
"EnumAttrCaseInfo");
63 "the DAG leaf must be operand or attribute");
64 return Constraint(cast<llvm::DefInit>(def)->getDef());
68 assert(
isConstantAttr() &&
"the DAG leaf must be constant attribute");
73 assert(
isEnumAttrCase() &&
"the DAG leaf must be an enum attribute case");
83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString(
"expression");
88 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt(
"numReturns");
92 assert(
isStringAttr() &&
"the DAG leaf must be string attribute");
93 return def->getAsUnquotedString();
95 bool DagLeaf::isSubClassOf(StringRef superclass)
const {
96 if (
auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
97 return defInit->getDef()->isSubClassOf(superclass);
111 if (
auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
112 return defInit->getDef()->isSubClassOf(
"NativeCodeCall");
124 return cast<llvm::DefInit>(node->getOperator())
126 ->getValueAsString(
"expression");
131 return cast<llvm::DefInit>(node->getOperator())
133 ->getValueAsInt(
"numReturns");
139 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->
getDef();
140 auto it = mapper->find(opDef);
141 if (it != mapper->end())
143 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
151 for (
int i = 0, e =
getNumArgs(); i != e; ++i) {
153 count += child.getNumOps();
161 return isa<llvm::DagInit>(node->getArg(index));
165 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
170 return DagLeaf(node->getArg(index));
174 return node->getArgNameStr(index);
178 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
179 return dagOpDef->getName() ==
"replaceWithValue";
183 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
184 return dagOpDef->getName() ==
"location";
188 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
189 return dagOpDef->getName() ==
"returnType";
193 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
194 return dagOpDef->getName() ==
"either";
198 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
199 return dagOpDef->getName() ==
"variadic";
213 auto [name, indexStr] = symbol.rsplit(
"__");
215 if (indexStr.consumeInteger(10, idx)) {
225 SymbolInfoMap::SymbolInfo::SymbolInfo(
227 std::optional<DagAndConstant> dagAndConstant)
228 : op(op), kind(kind), dagAndConstant(std::move(dagAndConstant)) {}
230 int SymbolInfoMap::SymbolInfo::getStaticValueCount()
const {
238 case Kind::MultipleValues:
241 llvm_unreachable(
"unknown kind");
245 return alternativeName ? *alternativeName : name.str();
249 LLVM_DEBUG(llvm::dbgs() <<
"getVarTypeStr for '" << name <<
"': ");
253 return op->getArg(getArgIndex())
255 ->attr.getStorageType()
258 return "::mlir::Attribute";
260 case Kind::Operand: {
263 return "::mlir::Operation::operand_range";
266 return "::mlir::Value";
268 case Kind::MultipleValues: {
269 return "::mlir::ValueRange";
273 return op->getQualCppClassName();
276 llvm_unreachable(
"unknown kind");
280 LLVM_DEBUG(llvm::dbgs() <<
"getVarDecl for '" << name <<
"': ");
281 std::string varInit = kind == Kind::Operand ?
"(op0->getOperands())" :
"";
283 formatv(
"{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
287 LLVM_DEBUG(llvm::dbgs() <<
"getArgDecl for '" << name <<
"': ");
289 formatv(
"{0} &{1}", getVarTypeStr(name), getVarName(name)));
292 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
293 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
294 LLVM_DEBUG(llvm::dbgs() <<
"getValueAndRangeUse for '" << name <<
"': ");
298 auto repl = formatv(fmt, name);
299 LLVM_DEBUG(llvm::dbgs() << repl <<
" (Attr)\n");
300 return std::string(repl);
302 case Kind::Operand: {
308 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
309 auto repl = formatv(fmt, name);
310 LLVM_DEBUG(llvm::dbgs() << repl <<
" (VariadicOperand)\n");
311 return std::string(repl);
313 auto repl = formatv(fmt, formatv(
"(*{0}.begin())", name));
314 LLVM_DEBUG(llvm::dbgs() << repl <<
" (SingleOperand)\n");
315 return std::string(repl);
322 std::string(formatv(
"{0}.getODSResults({1})", name, index));
324 v = std::string(formatv(
"(*{0}.begin())", v));
325 auto repl = formatv(fmt, v);
326 LLVM_DEBUG(llvm::dbgs() << repl <<
" (SingleResult)\n");
327 return std::string(repl);
333 LLVM_DEBUG(llvm::dbgs() << name <<
" (Op)\n");
334 return formatv(fmt, name);
343 std::string v = std::string(formatv(
"{0}.getODSResults({1})", name, i));
345 v = std::string(formatv(
"(*{0}.begin())", v));
347 values.push_back(std::string(formatv(fmt, v)));
349 auto repl = llvm::join(values, separator);
350 LLVM_DEBUG(llvm::dbgs() << repl <<
" (VariadicResult)\n");
355 assert(op ==
nullptr);
356 auto repl = formatv(fmt, name);
357 LLVM_DEBUG(llvm::dbgs() << repl <<
" (Value)\n");
358 return std::string(repl);
360 case Kind::MultipleValues: {
361 assert(op ==
nullptr);
362 assert(index < getSize());
365 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
366 LLVM_DEBUG(llvm::dbgs() << repl <<
" (MultipleValues)\n");
371 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
372 LLVM_DEBUG(llvm::dbgs() << repl <<
" (MultipleValues)\n");
373 return std::string(repl);
376 llvm_unreachable(
"unknown kind");
379 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
380 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
381 LLVM_DEBUG(llvm::dbgs() <<
"getAllRangeUse for '" << name <<
"': ");
384 case Kind::Operand: {
385 assert(index < 0 &&
"only allowed for symbol bound to result");
386 auto repl = formatv(fmt, name);
387 LLVM_DEBUG(llvm::dbgs() << repl <<
" (Operand/Attr)\n");
388 return std::string(repl);
392 auto repl = formatv(fmt, formatv(
"{0}.getODSResults({1})", name, index));
393 LLVM_DEBUG(llvm::dbgs() << repl <<
" (SingleResult)\n");
394 return std::string(repl);
403 values.push_back(std::string(
404 formatv(fmt, formatv(
"{0}.getODSResults({1})", name, i))));
406 auto repl = llvm::join(values, separator);
407 LLVM_DEBUG(llvm::dbgs() << repl <<
" (VariadicResult)\n");
411 assert(index < 0 &&
"only allowed for symbol bound to result");
412 assert(op ==
nullptr);
413 auto repl = formatv(fmt, formatv(
"{{{0}}", name));
414 LLVM_DEBUG(llvm::dbgs() << repl <<
" (Value)\n");
415 return std::string(repl);
417 case Kind::MultipleValues: {
418 assert(op ==
nullptr);
419 assert(index < getSize());
422 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
423 LLVM_DEBUG(llvm::dbgs() << repl <<
" (MultipleValues)\n");
427 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
428 LLVM_DEBUG(llvm::dbgs() << repl <<
" (MultipleValues)\n");
429 return std::string(repl);
432 llvm_unreachable(
"unknown kind");
437 std::optional<int> variadicSubIndex) {
439 if (name != symbol) {
440 auto error = formatv(
441 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
442 PrintFatalError(loc, error);
447 ? SymbolInfo::getAttr(&op, argIndex)
448 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
450 std::string key = symbol.str();
451 if (symbolInfoMap.count(key)) {
453 if (symInfo.kind != SymbolInfo::Kind::Operand) {
459 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
464 symbolInfoMap.emplace(key, symInfo);
470 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
472 return symbolInfoMap.count(inserted->first) == 1;
483 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
484 return symbolInfoMap.count(inserted->first) == 1;
490 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
491 return symbolInfoMap.count(inserted->first) == 1;
495 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
496 return symbolInfoMap.count(inserted->first) == 1;
500 return find(symbol) != symbolInfoMap.end();
506 return symbolInfoMap.find(name);
512 std::optional<int> variadicSubIndex)
const {
514 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
521 auto range = symbolInfoMap.equal_range(name);
523 for (
auto it = range.first; it != range.second; ++it)
524 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
527 return symbolInfoMap.end();
530 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
534 return symbolInfoMap.equal_range(name);
539 return symbolInfoMap.count(name);
544 if (name != symbol) {
550 return find(name)->second.getStaticValueCount();
555 const char *separator)
const {
559 auto it = symbolInfoMap.find(name.str());
560 if (it == symbolInfoMap.end()) {
561 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
562 PrintFatalError(loc, error);
565 return it->second.getValueAndRangeUse(name, index, fmt, separator);
569 const char *separator)
const {
573 auto it = symbolInfoMap.find(name.str());
574 if (it == symbolInfoMap.end()) {
575 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
576 PrintFatalError(loc, error);
579 return it->second.getAllRangeUse(name, index, fmt, separator);
585 for (
auto symbolInfoIt = symbolInfoMap.begin();
586 symbolInfoIt != symbolInfoMap.end();) {
587 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
588 auto startRange = range.first;
589 auto endRange = range.second;
591 auto operandName = symbolInfoIt->first;
592 int startSearchIndex = 0;
593 for (++startRange; startRange != endRange; ++startRange) {
596 for (
int i = startSearchIndex;; ++i) {
597 std::string alternativeName = operandName + std::to_string(i);
598 if (!usedNames.contains(alternativeName) &&
599 symbolInfoMap.count(alternativeName) == 0) {
600 usedNames.insert(alternativeName);
601 startRange->second.alternativeName = alternativeName;
602 startSearchIndex = i + 1;
609 symbolInfoIt = endRange;
618 : def(*def), recordOpMap(mapper) {}
621 return DagNode(def.getValueAsDag(
"sourcePattern"));
625 auto *results = def.getValueAsListInit(
"resultPatterns");
626 return results->size();
630 auto *results = def.getValueAsListInit(
"resultPatterns");
631 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
635 LLVM_DEBUG(llvm::dbgs() <<
"start collecting source pattern bound symbols\n");
637 LLVM_DEBUG(llvm::dbgs() <<
"done collecting source pattern bound symbols\n");
639 LLVM_DEBUG(llvm::dbgs() <<
"start assigning alternative names for symbols\n");
641 LLVM_DEBUG(llvm::dbgs() <<
"done assigning alternative names for symbols\n");
645 LLVM_DEBUG(llvm::dbgs() <<
"start collecting result pattern bound symbols\n");
650 LLVM_DEBUG(llvm::dbgs() <<
"done collecting result pattern bound symbols\n");
662 auto *listInit = def.getValueAsListInit(
"constraints");
663 std::vector<AppliedConstraint> ret;
664 ret.reserve(listInit->size());
666 for (
auto *it : *listInit) {
667 auto *dagInit = dyn_cast<llvm::DagInit>(it);
669 PrintFatalError(&def,
"all elements in Pattern multi-entity "
670 "constraints should be DAG nodes");
672 std::vector<std::string> entities;
673 entities.reserve(dagInit->arg_size());
674 for (
auto *argName : dagInit->getArgNames()) {
678 "operands to additional constraints can only be symbol references");
680 entities.emplace_back(argName->getValue());
683 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
684 dagInit->getNameStr(), std::move(entities));
690 auto *results = def.getValueAsListInit(
"supplementalPatterns");
691 return results->size();
695 auto *results = def.getValueAsListInit(
"supplementalPatterns");
696 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
703 llvm::DagInit *delta = def.getValueAsDag(
"benefitDelta");
704 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
705 PrintFatalError(&def,
706 "The 'addBenefit' takes and only takes one integer value");
708 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
712 std::vector<std::pair<StringRef, unsigned>> result;
713 result.reserve(def.getLoc().size());
714 for (
auto loc : def.getLoc()) {
715 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
716 assert(buf &&
"invalid source location");
718 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
719 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
724 void Pattern::verifyBind(
bool result, StringRef symbolName) {
726 auto err = formatv(
"symbol '{0}' bound more than once", symbolName);
727 PrintFatalError(&def, err);
737 if (!treeName.empty()) {
739 LLVM_DEBUG(llvm::dbgs() <<
"found symbol bound to NativeCodeCall: "
740 << treeName <<
'\n');
745 PrintFatalError(&def,
746 formatv(
"binding symbol '{0}' to NativecodeCall in "
747 "MatchPattern is not supported",
752 for (
int i = 0; i != numTreeArgs; ++i) {
767 if (!treeArgName.empty() && treeArgName !=
"_") {
773 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
778 constraint.getKind() == Constraint::Kind::CK_Attr;
782 verifyBind(infoMap.
bindAttr(treeArgName), treeArgName);
787 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
797 auto numOpArgs = op.getNumArgs();
802 int numDirectives = 0;
803 for (
int i = numTreeArgs - 1; i >= 0; --i) {
805 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
807 else if (dagArg.isEither())
812 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
814 formatv(
"op '{0}' argument number mismatch: "
815 "{1} in pattern vs. {2} in definition",
816 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
817 PrintFatalError(&def, err);
822 if (!treeName.empty()) {
823 LLVM_DEBUG(llvm::dbgs()
824 <<
"found symbol bound to op result: " << treeName <<
'\n');
825 verifyBind(infoMap.
bindOpResult(treeName, op), treeName);
832 for (
int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
836 auto argName = tree.getArgName(i);
837 if (!argName.empty() && argName !=
"_") {
850 auto treeName = tree.getSymbol();
851 if (!treeName.empty()) {
858 for (
int i = 0; i < tree.getNumArgs(); ++i) {
862 auto argName = tree.getArgName(i);
863 if (!argName.empty() && argName !=
"_") {
872 for (
int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
873 if (
auto treeArg = tree.getArgAsNestedDag(i)) {
874 if (treeArg.isEither()) {
875 collectSymbolInEither(tree, treeArg, opArgIdx);
884 }
else if (treeArg.isVariadic()) {
885 collectSymbolInVariadic(tree, treeArg, opArgIdx);
896 auto treeArgName = tree.getArgName(i);
898 if (!treeArgName.empty() && treeArgName !=
"_") {
899 LLVM_DEBUG(llvm::dbgs() <<
"found symbol bound to op argument: "
900 << treeArgName <<
'\n');
901 verifyBind(infoMap.
bindOpArgument(tree, treeArgName, op, opArgIdx),
909 if (!treeName.empty()) {
911 &def, formatv(
"binding symbol '{0}' to non-operation/native code call "
912 "unsupported right now",
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
std::string getConditionTemplate() const
Constraint getAsConstraint() const
bool isNativeCodeCall() const
bool isEnumAttrCase() const
int getNumReturnsOfNativeCode() const
ConstantAttr getAsConstantAttr() const
void print(raw_ostream &os) const
std::string getStringAttr() const
std::string getConditionTemplate() const
StringRef getNativeCodeTemplate() const
bool isUnspecified() const
bool isAttrMatcher() const
EnumAttrCase getAsEnumAttrCase() const
bool isOperandMatcher() const
bool isConstantAttr() const
bool isStringAttr() const
bool isReturnTypeDirective() const
bool isLocationDirective() const
bool isReplaceWithValue() const
DagNode getArgAsNestedDag(unsigned index) const
DagLeaf getArgAsLeaf(unsigned index) const
StringRef getNativeCodeTemplate() const
int getNumReturnsOfNativeCode() const
void print(raw_ostream &os) const
Operator & getDialectOp(RecordOperatorMap *mapper) const
bool isNativeCodeCall() const
bool isNestedDagArg(unsigned index) const
DagNode(const llvm::DagInit *node)
StringRef getSymbol() const
StringRef getArgName(unsigned index) const
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
int getNumResultPatterns() const
std::vector< IdentifierLine > getLocation() const
DagNode getSourcePattern() const
const Operator & getSourceRootOp()
std::vector< AppliedConstraint > getConstraints() const
DagNode getResultPattern(unsigned index) const
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Operator & getDialectOp(DagNode node)
DagNode getSupplementalPattern(unsigned index) const
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
int getNumSupplementalPatterns() const
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
std::string getArgDecl(StringRef name) const
std::string getVarName(StringRef name) const
std::string getVarTypeStr(StringRef name) const
std::string getVarDecl(StringRef name) const
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
int count(StringRef key) const
const_iterator find(StringRef key) const
void assignUniqueAlternativeNames()
bool bindMultipleValues(StringRef symbol, int numValues)
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
bool bindValues(StringRef symbol, int numValues=1)
bool bindAttr(StringRef symbol)
bool bindValue(StringRef symbol)
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
int getStaticValueCount(StringRef symbol) const
bool contains(StringRef symbol) const
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.