17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
27 using namespace tblgen;
41 return isa_and_nonnull<llvm::UnsetInit>(def);
46 return isSubClassOf(
"TypeConstraint");
51 return isSubClassOf(
"AttrConstraint");
56 return isSubClassOf(
"PropConstraint");
61 return isSubClassOf(
"Property");
65 return isSubClassOf(
"NativeCodeCall");
78 "the DAG leaf must be operand, attribute, or property");
79 return Constraint(cast<DefInit>(def)->getDef());
83 assert(
isPropMatcher() &&
"the DAG leaf must be a property matcher");
89 return Property(cast<DefInit>(def)->getDef());
93 assert(
isConstantAttr() &&
"the DAG leaf must be constant attribute");
98 assert(
isEnumCase() &&
"the DAG leaf must be an enum attribute case");
103 assert(
isConstantProp() &&
"the DAG leaf must be a constant property value");
113 return cast<DefInit>(def)->getDef()->getValueAsString(
"expression");
118 return cast<DefInit>(def)->getDef()->getValueAsInt(
"numReturns");
122 assert(
isStringAttr() &&
"the DAG leaf must be string attribute");
123 return def->getAsUnquotedString();
125 bool DagLeaf::isSubClassOf(StringRef superclass)
const {
126 if (
auto *defInit = dyn_cast_or_null<DefInit>(def))
127 return defInit->getDef()->isSubClassOf(superclass);
141 if (
auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
142 return defInit->getDef()->isSubClassOf(
"NativeCodeCall");
154 return cast<DefInit>(node->getOperator())
156 ->getValueAsString(
"expression");
161 return cast<DefInit>(node->getOperator())
163 ->getValueAsInt(
"numReturns");
169 const Record *opDef = cast<DefInit>(node->getOperator())->
getDef();
170 auto [it, inserted] = mapper->try_emplace(opDef);
172 it->second = std::make_unique<Operator>(opDef);
180 for (
int i = 0, e =
getNumArgs(); i != e; ++i) {
182 count += child.getNumOps();
190 return isa<DagInit>(node->getArg(index));
194 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
199 return DagLeaf(node->getArg(index));
203 return node->getArgNameStr(index);
207 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
208 return dagOpDef->getName() ==
"replaceWithValue";
212 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
213 return dagOpDef->getName() ==
"location";
217 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
218 return dagOpDef->getName() ==
"returnType";
222 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
223 return dagOpDef->getName() ==
"either";
227 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
228 return dagOpDef->getName() ==
"variadic";
242 auto [name, indexStr] = symbol.rsplit(
"__");
244 if (indexStr.consumeInteger(10, idx)) {
254 SymbolInfoMap::SymbolInfo::SymbolInfo(
256 std::optional<DagAndConstant> dagAndConstant)
257 : op(op),
kind(
kind), dagAndConstant(dagAndConstant) {}
259 int SymbolInfoMap::SymbolInfo::getStaticValueCount()
const {
268 case Kind::MultipleValues:
271 llvm_unreachable(
"unknown kind");
275 return alternativeName ? *alternativeName : name.str();
279 LLVM_DEBUG(dbgs() <<
"getVarTypeStr for '" << name <<
"': ");
283 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
284 ->attr.getStorageType()
287 return "::mlir::Attribute";
291 return cast<NamedProperty *>(op->getArg(getArgIndex()))
292 ->prop.getInterfaceType()
294 assert(dagAndConstant && dagAndConstant->dag &&
295 "generic properties must carry their constraint");
296 return reinterpret_cast<const DagLeaf *
>(dagAndConstant->dag)
297 ->getAsPropConstraint()
301 case Kind::Operand: {
304 return "::mlir::Operation::operand_range";
307 return "::mlir::Value";
309 case Kind::MultipleValues: {
310 return "::mlir::ValueRange";
314 return op->getQualCppClassName();
317 llvm_unreachable(
"unknown kind");
321 LLVM_DEBUG(dbgs() <<
"getVarDecl for '" << name <<
"': ");
322 std::string varInit =
kind == Kind::Operand ?
"(op0->getOperands())" :
"";
324 formatv(
"{0} {1}{2};\n", getVarTypeStr(name),
getVarName(name), varInit));
328 LLVM_DEBUG(dbgs() <<
"getArgDecl for '" << name <<
"': ");
330 formatv(
"{0} &{1}", getVarTypeStr(name),
getVarName(name)));
333 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
334 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
335 LLVM_DEBUG(dbgs() <<
"getValueAndRangeUse for '" << name <<
"': ");
339 auto repl = formatv(fmt, name);
340 LLVM_DEBUG(dbgs() << repl <<
" (Attr)\n");
341 return std::string(repl);
345 auto repl = formatv(fmt, name);
346 LLVM_DEBUG(dbgs() << repl <<
" (Prop)\n");
347 return std::string(repl);
349 case Kind::Operand: {
351 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
352 if (operand->isOptional()) {
354 fmt, formatv(
"({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
355 LLVM_DEBUG(dbgs() << repl <<
" (OptionalOperand)\n");
356 return std::string(repl);
361 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
362 auto repl = formatv(fmt, name);
363 LLVM_DEBUG(dbgs() << repl <<
" (VariadicOperand)\n");
364 return std::string(repl);
366 auto repl = formatv(fmt, formatv(
"(*{0}.begin())", name));
367 LLVM_DEBUG(dbgs() << repl <<
" (SingleOperand)\n");
368 return std::string(repl);
375 std::string(formatv(
"{0}.getODSResults({1})", name, index));
376 if (!op->getResult(index).isVariadic())
377 v = std::string(formatv(
"(*{0}.begin())", v));
378 auto repl = formatv(fmt, v);
379 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
380 return std::string(repl);
385 if (op->getNumResults() == 0) {
386 LLVM_DEBUG(dbgs() << name <<
" (Op)\n");
387 return formatv(fmt, name);
393 values.reserve(op->getNumResults());
395 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
396 std::string v = std::string(formatv(
"{0}.getODSResults({1})", name, i));
397 if (!op->getResult(i).isVariadic()) {
398 v = std::string(formatv(
"(*{0}.begin())", v));
400 values.push_back(std::string(formatv(fmt, v)));
402 auto repl = llvm::join(values, separator);
403 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
408 assert(op ==
nullptr);
409 auto repl = formatv(fmt, name);
410 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
411 return std::string(repl);
413 case Kind::MultipleValues: {
414 assert(op ==
nullptr);
415 assert(index < getSize());
418 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
419 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
424 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
425 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
426 return std::string(repl);
429 llvm_unreachable(
"unknown kind");
432 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
433 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
434 LLVM_DEBUG(dbgs() <<
"getAllRangeUse for '" << name <<
"': ");
438 case Kind::Operand: {
439 assert(index < 0 &&
"only allowed for symbol bound to result");
440 auto repl = formatv(fmt, name);
441 LLVM_DEBUG(dbgs() << repl <<
" (Operand/Attr/Prop)\n");
442 return std::string(repl);
446 auto repl = formatv(fmt, formatv(
"{0}.getODSResults({1})", name, index));
447 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
448 return std::string(repl);
454 values.reserve(op->getNumResults());
456 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
457 values.push_back(std::string(
458 formatv(fmt, formatv(
"{0}.getODSResults({1})", name, i))));
460 auto repl = llvm::join(values, separator);
461 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
465 assert(index < 0 &&
"only allowed for symbol bound to result");
466 assert(op ==
nullptr);
467 auto repl = formatv(fmt, formatv(
"{{{0}}", name));
468 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
469 return std::string(repl);
471 case Kind::MultipleValues: {
472 assert(op ==
nullptr);
473 assert(index < getSize());
476 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
477 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
481 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
482 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
483 return std::string(repl);
486 llvm_unreachable(
"unknown kind");
491 std::optional<int> variadicSubIndex) {
493 if (name != symbol) {
494 auto error = formatv(
495 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
496 PrintFatalError(loc, error);
501 isa<NamedAttribute *>(arg) ? SymbolInfo::getAttr(&op, argIndex)
502 : isa<NamedProperty *>(arg)
503 ? SymbolInfo::getProp(&op, argIndex)
504 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
506 std::string key = symbol.str();
507 if (symbolInfoMap.count(key)) {
509 if (symInfo.kind != SymbolInfo::Kind::Operand) {
515 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
520 symbolInfoMap.emplace(key, symInfo);
526 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
528 return symbolInfoMap.count(inserted->first) == 1;
539 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
540 return symbolInfoMap.count(inserted->first) == 1;
546 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
547 return symbolInfoMap.count(inserted->first) == 1;
551 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
552 return symbolInfoMap.count(inserted->first) == 1;
558 symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint));
559 return symbolInfoMap.count(inserted->first) == 1;
563 return find(symbol) != symbolInfoMap.end();
569 return symbolInfoMap.find(name);
575 std::optional<int> variadicSubIndex)
const {
577 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
584 auto range = symbolInfoMap.equal_range(name);
586 for (
auto it = range.first; it != range.second; ++it)
587 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
590 return symbolInfoMap.end();
593 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
597 return symbolInfoMap.equal_range(name);
602 return symbolInfoMap.count(name);
607 if (name != symbol) {
613 return find(name)->second.getStaticValueCount();
618 const char *separator)
const {
622 auto it = symbolInfoMap.find(name.str());
623 if (it == symbolInfoMap.end()) {
624 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
625 PrintFatalError(loc, error);
628 return it->second.getValueAndRangeUse(name, index, fmt, separator);
632 const char *separator)
const {
636 auto it = symbolInfoMap.find(name.str());
637 if (it == symbolInfoMap.end()) {
638 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
639 PrintFatalError(loc, error);
642 return it->second.getAllRangeUse(name, index, fmt, separator);
648 for (
auto symbolInfoIt = symbolInfoMap.begin();
649 symbolInfoIt != symbolInfoMap.end();) {
650 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
651 auto startRange = range.first;
652 auto endRange = range.second;
654 auto operandName = symbolInfoIt->first;
655 int startSearchIndex = 0;
656 for (++startRange; startRange != endRange; ++startRange) {
659 for (
int i = startSearchIndex;; ++i) {
660 std::string alternativeName = operandName + std::to_string(i);
661 if (!usedNames.contains(alternativeName) &&
662 symbolInfoMap.count(alternativeName) == 0) {
663 usedNames.insert(alternativeName);
664 startRange->second.alternativeName = alternativeName;
665 startSearchIndex = i + 1;
672 symbolInfoIt = endRange;
681 : def(*def), recordOpMap(mapper) {}
684 return DagNode(def.getValueAsDag(
"sourcePattern"));
688 auto *results = def.getValueAsListInit(
"resultPatterns");
689 return results->size();
693 auto *results = def.getValueAsListInit(
"resultPatterns");
694 return DagNode(cast<DagInit>(results->getElement(index)));
698 LLVM_DEBUG(dbgs() <<
"start collecting source pattern bound symbols\n");
700 LLVM_DEBUG(dbgs() <<
"done collecting source pattern bound symbols\n");
702 LLVM_DEBUG(dbgs() <<
"start assigning alternative names for symbols\n");
704 LLVM_DEBUG(dbgs() <<
"done assigning alternative names for symbols\n");
708 LLVM_DEBUG(dbgs() <<
"start collecting result pattern bound symbols\n");
713 LLVM_DEBUG(dbgs() <<
"done collecting result pattern bound symbols\n");
725 auto *listInit = def.getValueAsListInit(
"constraints");
726 std::vector<AppliedConstraint> ret;
727 ret.reserve(listInit->size());
729 for (
auto *it : *listInit) {
730 auto *dagInit = dyn_cast<DagInit>(it);
732 PrintFatalError(&def,
"all elements in Pattern multi-entity "
733 "constraints should be DAG nodes");
735 std::vector<std::string> entities;
736 entities.reserve(dagInit->arg_size());
737 for (
auto *argName : dagInit->getArgNames()) {
741 "operands to additional constraints can only be symbol references");
743 entities.emplace_back(argName->getValue());
746 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
747 dagInit->getNameStr(), std::move(entities));
753 auto *results = def.getValueAsListInit(
"supplementalPatterns");
754 return results->size();
758 auto *results = def.getValueAsListInit(
"supplementalPatterns");
759 return DagNode(cast<DagInit>(results->getElement(index)));
766 const DagInit *delta = def.getValueAsDag(
"benefitDelta");
767 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
768 PrintFatalError(&def,
769 "The 'addBenefit' takes and only takes one integer value");
771 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
775 std::vector<std::pair<StringRef, unsigned>> result;
776 result.reserve(def.getLoc().size());
777 for (
auto loc : def.getLoc()) {
778 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
779 assert(buf &&
"invalid source location");
781 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
782 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
787 void Pattern::verifyBind(
bool result, StringRef symbolName) {
789 auto err = formatv(
"symbol '{0}' bound more than once", symbolName);
790 PrintFatalError(&def, err);
800 if (!treeName.empty()) {
802 LLVM_DEBUG(dbgs() <<
"found symbol bound to NativeCodeCall: "
803 << treeName <<
'\n');
808 PrintFatalError(&def,
809 formatv(
"binding symbol '{0}' to NativecodeCall in "
810 "MatchPattern is not supported",
815 for (
int i = 0; i != numTreeArgs; ++i) {
830 if (!treeArgName.empty() && treeArgName !=
"_") {
837 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
841 if (propConstraint.getInterfaceType().empty()) {
842 PrintFatalError(&def,
843 formatv(
"binding symbol '{0}' in NativeCodeCall to "
844 "a property constraint without specifying "
845 "that constraint's type is unsupported",
848 verifyBind(infoMap.
bindProp(treeArgName, propConstraint),
854 constraint.getKind() == Constraint::Kind::CK_Attr;
858 verifyBind(infoMap.
bindAttr(treeArgName), treeArgName);
863 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
873 auto numOpArgs = op.getNumArgs();
878 int numDirectives = 0;
879 for (
int i = numTreeArgs - 1; i >= 0; --i) {
881 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
883 else if (dagArg.isEither())
888 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
890 formatv(
"op '{0}' argument number mismatch: "
891 "{1} in pattern vs. {2} in definition",
892 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
893 PrintFatalError(&def, err);
898 if (!treeName.empty()) {
899 LLVM_DEBUG(dbgs() <<
"found symbol bound to op result: " << treeName
901 verifyBind(infoMap.
bindOpResult(treeName, op), treeName);
908 for (
int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
912 auto argName = tree.getArgName(i);
913 if (!argName.empty() && argName !=
"_") {
926 auto treeName = tree.getSymbol();
927 if (!treeName.empty()) {
934 for (
int i = 0; i < tree.getNumArgs(); ++i) {
938 auto argName = tree.getArgName(i);
939 if (!argName.empty() && argName !=
"_") {
948 for (
int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
949 if (
auto treeArg = tree.getArgAsNestedDag(i)) {
950 if (treeArg.isEither()) {
951 collectSymbolInEither(tree, treeArg, opArgIdx);
960 }
else if (treeArg.isVariadic()) {
961 collectSymbolInVariadic(tree, treeArg, opArgIdx);
972 auto treeArgName = tree.getArgName(i);
974 if (!treeArgName.empty() && treeArgName !=
"_") {
975 LLVM_DEBUG(dbgs() <<
"found symbol bound to op argument: "
976 << treeArgName <<
'\n');
977 verifyBind(infoMap.
bindOpArgument(tree, treeArgName, op, opArgIdx),
985 if (!treeName.empty()) {
987 &def, formatv(
"binding symbol '{0}' to non-operation/native code call "
988 "unsupported right now",
union mlir::linalg::@1223::ArityGroupAndKind::Kind kind
std::string getConditionTemplate() const
Constraint getAsConstraint() const
bool isNativeCodeCall() const
bool isPropMatcher() const
int getNumReturnsOfNativeCode() const
ConstantAttr getAsConstantAttr() const
void print(raw_ostream &os) const
std::string getStringAttr() const
Property getAsProperty() const
StringRef getNativeCodeTemplate() const
std::string getConditionTemplate() const
bool isConstantProp() const
PropConstraint getAsPropConstraint() const
ConstantProp getAsConstantProp() const
bool isUnspecified() const
EnumCase getAsEnumCase() const
bool isAttrMatcher() const
bool isOperandMatcher() const
bool isPropDefinition() const
bool isConstantAttr() const
bool isStringAttr() const
bool isReturnTypeDirective() const
bool isLocationDirective() const
bool isReplaceWithValue() const
DagNode getArgAsNestedDag(unsigned index) const
DagLeaf getArgAsLeaf(unsigned index) const
int getNumReturnsOfNativeCode() const
StringRef getNativeCodeTemplate() const
void print(raw_ostream &os) const
Operator & getDialectOp(RecordOperatorMap *mapper) const
bool isNativeCodeCall() const
bool isNestedDagArg(unsigned index) const
StringRef getSymbol() const
DagNode(const llvm::DagInit *node)
StringRef getArgName(unsigned index) const
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
int getNumResultPatterns() const
std::vector< IdentifierLine > getLocation() const
DagNode getSourcePattern() const
const Operator & getSourceRootOp()
std::vector< AppliedConstraint > getConstraints() const
DagNode getResultPattern(unsigned index) const
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
DagNode getSupplementalPattern(unsigned index) const
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
int getNumSupplementalPatterns() const
std::string getArgDecl(StringRef name) const
std::string getVarName(StringRef name) const
std::string getVarTypeStr(StringRef name) const
std::string getVarDecl(StringRef name) const
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
int count(StringRef key) const
const_iterator find(StringRef key) const
void assignUniqueAlternativeNames()
bool bindMultipleValues(StringRef symbol, int numValues)
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
bool bindValues(StringRef symbol, int numValues=1)
bool bindAttr(StringRef symbol)
bool bindProp(StringRef symbol, const PropConstraint &constraint)
bool bindValue(StringRef symbol)
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
int getStaticValueCount(StringRef symbol) const
bool contains(StringRef symbol) const
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.