17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
27 using namespace tblgen;
41 return isa_and_nonnull<llvm::UnsetInit>(def);
46 return isSubClassOf(
"TypeConstraint");
51 return isSubClassOf(
"AttrConstraint");
55 return isSubClassOf(
"NativeCodeCall");
66 "the DAG leaf must be operand or attribute");
67 return Constraint(cast<DefInit>(def)->getDef());
71 assert(
isConstantAttr() &&
"the DAG leaf must be constant attribute");
76 assert(
isEnumCase() &&
"the DAG leaf must be an enum attribute case");
86 return cast<DefInit>(def)->getDef()->getValueAsString(
"expression");
91 return cast<DefInit>(def)->getDef()->getValueAsInt(
"numReturns");
95 assert(
isStringAttr() &&
"the DAG leaf must be string attribute");
96 return def->getAsUnquotedString();
98 bool DagLeaf::isSubClassOf(StringRef superclass)
const {
99 if (
auto *defInit = dyn_cast_or_null<DefInit>(def))
100 return defInit->getDef()->isSubClassOf(superclass);
114 if (
auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
115 return defInit->getDef()->isSubClassOf(
"NativeCodeCall");
127 return cast<DefInit>(node->getOperator())
129 ->getValueAsString(
"expression");
134 return cast<DefInit>(node->getOperator())
136 ->getValueAsInt(
"numReturns");
142 const Record *opDef = cast<DefInit>(node->getOperator())->
getDef();
143 auto [it, inserted] = mapper->try_emplace(opDef);
145 it->second = std::make_unique<Operator>(opDef);
153 for (
int i = 0, e =
getNumArgs(); i != e; ++i) {
155 count += child.getNumOps();
163 return isa<DagInit>(node->getArg(index));
167 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
172 return DagLeaf(node->getArg(index));
176 return node->getArgNameStr(index);
180 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
181 return dagOpDef->getName() ==
"replaceWithValue";
185 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
186 return dagOpDef->getName() ==
"location";
190 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
191 return dagOpDef->getName() ==
"returnType";
195 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
196 return dagOpDef->getName() ==
"either";
200 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
201 return dagOpDef->getName() ==
"variadic";
215 auto [name, indexStr] = symbol.rsplit(
"__");
217 if (indexStr.consumeInteger(10, idx)) {
227 SymbolInfoMap::SymbolInfo::SymbolInfo(
229 std::optional<DagAndConstant> dagAndConstant)
230 : op(op),
kind(
kind), dagAndConstant(dagAndConstant) {}
232 int SymbolInfoMap::SymbolInfo::getStaticValueCount()
const {
240 case Kind::MultipleValues:
243 llvm_unreachable(
"unknown kind");
247 return alternativeName ? *alternativeName : name.str();
251 LLVM_DEBUG(dbgs() <<
"getVarTypeStr for '" << name <<
"': ");
255 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
256 ->attr.getStorageType()
259 return "::mlir::Attribute";
261 case Kind::Operand: {
264 return "::mlir::Operation::operand_range";
267 return "::mlir::Value";
269 case Kind::MultipleValues: {
270 return "::mlir::ValueRange";
274 return op->getQualCppClassName();
277 llvm_unreachable(
"unknown kind");
281 LLVM_DEBUG(dbgs() <<
"getVarDecl for '" << name <<
"': ");
282 std::string varInit =
kind == Kind::Operand ?
"(op0->getOperands())" :
"";
284 formatv(
"{0} {1}{2};\n", getVarTypeStr(name),
getVarName(name), varInit));
288 LLVM_DEBUG(dbgs() <<
"getArgDecl for '" << name <<
"': ");
290 formatv(
"{0} &{1}", getVarTypeStr(name),
getVarName(name)));
293 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
294 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
295 LLVM_DEBUG(dbgs() <<
"getValueAndRangeUse for '" << name <<
"': ");
299 auto repl = formatv(fmt, name);
300 LLVM_DEBUG(dbgs() << repl <<
" (Attr)\n");
301 return std::string(repl);
303 case Kind::Operand: {
305 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
306 if (operand->isOptional()) {
308 fmt, formatv(
"({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
309 LLVM_DEBUG(dbgs() << repl <<
" (OptionalOperand)\n");
310 return std::string(repl);
315 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
316 auto repl = formatv(fmt, name);
317 LLVM_DEBUG(dbgs() << repl <<
" (VariadicOperand)\n");
318 return std::string(repl);
320 auto repl = formatv(fmt, formatv(
"(*{0}.begin())", name));
321 LLVM_DEBUG(dbgs() << repl <<
" (SingleOperand)\n");
322 return std::string(repl);
329 std::string(formatv(
"{0}.getODSResults({1})", name, index));
330 if (!op->getResult(index).isVariadic())
331 v = std::string(formatv(
"(*{0}.begin())", v));
332 auto repl = formatv(fmt, v);
333 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
334 return std::string(repl);
339 if (op->getNumResults() == 0) {
340 LLVM_DEBUG(dbgs() << name <<
" (Op)\n");
341 return formatv(fmt, name);
347 values.reserve(op->getNumResults());
349 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
350 std::string v = std::string(formatv(
"{0}.getODSResults({1})", name, i));
351 if (!op->getResult(i).isVariadic()) {
352 v = std::string(formatv(
"(*{0}.begin())", v));
354 values.push_back(std::string(formatv(fmt, v)));
356 auto repl = llvm::join(values, separator);
357 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
362 assert(op ==
nullptr);
363 auto repl = formatv(fmt, name);
364 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
365 return std::string(repl);
367 case Kind::MultipleValues: {
368 assert(op ==
nullptr);
369 assert(index < getSize());
372 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
373 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
378 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
379 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
380 return std::string(repl);
383 llvm_unreachable(
"unknown kind");
386 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
387 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
388 LLVM_DEBUG(dbgs() <<
"getAllRangeUse for '" << name <<
"': ");
391 case Kind::Operand: {
392 assert(index < 0 &&
"only allowed for symbol bound to result");
393 auto repl = formatv(fmt, name);
394 LLVM_DEBUG(dbgs() << repl <<
" (Operand/Attr)\n");
395 return std::string(repl);
399 auto repl = formatv(fmt, formatv(
"{0}.getODSResults({1})", name, index));
400 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
401 return std::string(repl);
407 values.reserve(op->getNumResults());
409 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
410 values.push_back(std::string(
411 formatv(fmt, formatv(
"{0}.getODSResults({1})", name, i))));
413 auto repl = llvm::join(values, separator);
414 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
418 assert(index < 0 &&
"only allowed for symbol bound to result");
419 assert(op ==
nullptr);
420 auto repl = formatv(fmt, formatv(
"{{{0}}", name));
421 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
422 return std::string(repl);
424 case Kind::MultipleValues: {
425 assert(op ==
nullptr);
426 assert(index < getSize());
429 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
430 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
434 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
435 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
436 return std::string(repl);
439 llvm_unreachable(
"unknown kind");
444 std::optional<int> variadicSubIndex) {
446 if (name != symbol) {
447 auto error = formatv(
448 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
449 PrintFatalError(loc, error);
453 isa<NamedAttribute *>(op.
getArg(argIndex))
454 ? SymbolInfo::getAttr(&op, argIndex)
455 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
457 std::string key = symbol.str();
458 if (symbolInfoMap.count(key)) {
460 if (symInfo.kind != SymbolInfo::Kind::Operand) {
466 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
471 symbolInfoMap.emplace(key, symInfo);
477 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
479 return symbolInfoMap.count(inserted->first) == 1;
490 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
491 return symbolInfoMap.count(inserted->first) == 1;
497 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
498 return symbolInfoMap.count(inserted->first) == 1;
502 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
503 return symbolInfoMap.count(inserted->first) == 1;
507 return find(symbol) != symbolInfoMap.end();
513 return symbolInfoMap.find(name);
519 std::optional<int> variadicSubIndex)
const {
521 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
528 auto range = symbolInfoMap.equal_range(name);
530 for (
auto it = range.first; it != range.second; ++it)
531 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
534 return symbolInfoMap.end();
537 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
541 return symbolInfoMap.equal_range(name);
546 return symbolInfoMap.count(name);
551 if (name != symbol) {
557 return find(name)->second.getStaticValueCount();
562 const char *separator)
const {
566 auto it = symbolInfoMap.find(name.str());
567 if (it == symbolInfoMap.end()) {
568 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
569 PrintFatalError(loc, error);
572 return it->second.getValueAndRangeUse(name, index, fmt, separator);
576 const char *separator)
const {
580 auto it = symbolInfoMap.find(name.str());
581 if (it == symbolInfoMap.end()) {
582 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
583 PrintFatalError(loc, error);
586 return it->second.getAllRangeUse(name, index, fmt, separator);
592 for (
auto symbolInfoIt = symbolInfoMap.begin();
593 symbolInfoIt != symbolInfoMap.end();) {
594 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
595 auto startRange = range.first;
596 auto endRange = range.second;
598 auto operandName = symbolInfoIt->first;
599 int startSearchIndex = 0;
600 for (++startRange; startRange != endRange; ++startRange) {
603 for (
int i = startSearchIndex;; ++i) {
604 std::string alternativeName = operandName + std::to_string(i);
605 if (!usedNames.contains(alternativeName) &&
606 symbolInfoMap.count(alternativeName) == 0) {
607 usedNames.insert(alternativeName);
608 startRange->second.alternativeName = alternativeName;
609 startSearchIndex = i + 1;
616 symbolInfoIt = endRange;
625 : def(*def), recordOpMap(mapper) {}
628 return DagNode(def.getValueAsDag(
"sourcePattern"));
632 auto *results = def.getValueAsListInit(
"resultPatterns");
633 return results->size();
637 auto *results = def.getValueAsListInit(
"resultPatterns");
638 return DagNode(cast<DagInit>(results->getElement(index)));
642 LLVM_DEBUG(dbgs() <<
"start collecting source pattern bound symbols\n");
644 LLVM_DEBUG(dbgs() <<
"done collecting source pattern bound symbols\n");
646 LLVM_DEBUG(dbgs() <<
"start assigning alternative names for symbols\n");
648 LLVM_DEBUG(dbgs() <<
"done assigning alternative names for symbols\n");
652 LLVM_DEBUG(dbgs() <<
"start collecting result pattern bound symbols\n");
657 LLVM_DEBUG(dbgs() <<
"done collecting result pattern bound symbols\n");
669 auto *listInit = def.getValueAsListInit(
"constraints");
670 std::vector<AppliedConstraint> ret;
671 ret.reserve(listInit->size());
673 for (
auto *it : *listInit) {
674 auto *dagInit = dyn_cast<DagInit>(it);
676 PrintFatalError(&def,
"all elements in Pattern multi-entity "
677 "constraints should be DAG nodes");
679 std::vector<std::string> entities;
680 entities.reserve(dagInit->arg_size());
681 for (
auto *argName : dagInit->getArgNames()) {
685 "operands to additional constraints can only be symbol references");
687 entities.emplace_back(argName->getValue());
690 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
691 dagInit->getNameStr(), std::move(entities));
697 auto *results = def.getValueAsListInit(
"supplementalPatterns");
698 return results->size();
702 auto *results = def.getValueAsListInit(
"supplementalPatterns");
703 return DagNode(cast<DagInit>(results->getElement(index)));
710 const DagInit *delta = def.getValueAsDag(
"benefitDelta");
711 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
712 PrintFatalError(&def,
713 "The 'addBenefit' takes and only takes one integer value");
715 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
719 std::vector<std::pair<StringRef, unsigned>> result;
720 result.reserve(def.getLoc().size());
721 for (
auto loc : def.getLoc()) {
722 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
723 assert(buf &&
"invalid source location");
725 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
726 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
731 void Pattern::verifyBind(
bool result, StringRef symbolName) {
733 auto err = formatv(
"symbol '{0}' bound more than once", symbolName);
734 PrintFatalError(&def, err);
744 if (!treeName.empty()) {
746 LLVM_DEBUG(dbgs() <<
"found symbol bound to NativeCodeCall: "
747 << treeName <<
'\n');
752 PrintFatalError(&def,
753 formatv(
"binding symbol '{0}' to NativecodeCall in "
754 "MatchPattern is not supported",
759 for (
int i = 0; i != numTreeArgs; ++i) {
774 if (!treeArgName.empty() && treeArgName !=
"_") {
780 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
785 constraint.getKind() == Constraint::Kind::CK_Attr;
789 verifyBind(infoMap.
bindAttr(treeArgName), treeArgName);
794 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
804 auto numOpArgs = op.getNumArgs();
809 int numDirectives = 0;
810 for (
int i = numTreeArgs - 1; i >= 0; --i) {
812 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
814 else if (dagArg.isEither())
819 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
821 formatv(
"op '{0}' argument number mismatch: "
822 "{1} in pattern vs. {2} in definition",
823 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
824 PrintFatalError(&def, err);
829 if (!treeName.empty()) {
830 LLVM_DEBUG(dbgs() <<
"found symbol bound to op result: " << treeName
832 verifyBind(infoMap.
bindOpResult(treeName, op), treeName);
839 for (
int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
843 auto argName = tree.getArgName(i);
844 if (!argName.empty() && argName !=
"_") {
857 auto treeName = tree.getSymbol();
858 if (!treeName.empty()) {
865 for (
int i = 0; i < tree.getNumArgs(); ++i) {
869 auto argName = tree.getArgName(i);
870 if (!argName.empty() && argName !=
"_") {
879 for (
int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
880 if (
auto treeArg = tree.getArgAsNestedDag(i)) {
881 if (treeArg.isEither()) {
882 collectSymbolInEither(tree, treeArg, opArgIdx);
891 }
else if (treeArg.isVariadic()) {
892 collectSymbolInVariadic(tree, treeArg, opArgIdx);
903 auto treeArgName = tree.getArgName(i);
905 if (!treeArgName.empty() && treeArgName !=
"_") {
906 LLVM_DEBUG(dbgs() <<
"found symbol bound to op argument: "
907 << treeArgName <<
'\n');
908 verifyBind(infoMap.
bindOpArgument(tree, treeArgName, op, opArgIdx),
916 if (!treeName.empty()) {
918 &def, formatv(
"binding symbol '{0}' to non-operation/native code call "
919 "unsupported right now",
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
std::string getConditionTemplate() const
Constraint getAsConstraint() const
bool isNativeCodeCall() const
int getNumReturnsOfNativeCode() const
ConstantAttr getAsConstantAttr() const
void print(raw_ostream &os) const
std::string getStringAttr() const
StringRef getNativeCodeTemplate() const
std::string getConditionTemplate() const
bool isUnspecified() const
EnumCase getAsEnumCase() const
bool isAttrMatcher() const
bool isOperandMatcher() const
bool isConstantAttr() const
bool isStringAttr() const
bool isReturnTypeDirective() const
bool isLocationDirective() const
bool isReplaceWithValue() const
DagNode getArgAsNestedDag(unsigned index) const
DagLeaf getArgAsLeaf(unsigned index) const
int getNumReturnsOfNativeCode() const
StringRef getNativeCodeTemplate() const
void print(raw_ostream &os) const
Operator & getDialectOp(RecordOperatorMap *mapper) const
bool isNativeCodeCall() const
bool isNestedDagArg(unsigned index) const
StringRef getSymbol() const
DagNode(const llvm::DagInit *node)
StringRef getArgName(unsigned index) const
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
int getNumResultPatterns() const
std::vector< IdentifierLine > getLocation() const
DagNode getSourcePattern() const
const Operator & getSourceRootOp()
std::vector< AppliedConstraint > getConstraints() const
DagNode getResultPattern(unsigned index) const
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
DagNode getSupplementalPattern(unsigned index) const
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
int getNumSupplementalPatterns() const
std::string getArgDecl(StringRef name) const
std::string getVarName(StringRef name) const
std::string getVarTypeStr(StringRef name) const
std::string getVarDecl(StringRef name) const
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
int count(StringRef key) const
const_iterator find(StringRef key) const
void assignUniqueAlternativeNames()
bool bindMultipleValues(StringRef symbol, int numValues)
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
bool bindValues(StringRef symbol, int numValues=1)
bool bindAttr(StringRef symbol)
bool bindValue(StringRef symbol)
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
int getStaticValueCount(StringRef symbol) const
bool contains(StringRef symbol) const
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.