17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
27 using namespace tblgen;
41 return isa_and_nonnull<llvm::UnsetInit>(def);
46 return isSubClassOf(
"TypeConstraint");
51 return isSubClassOf(
"AttrConstraint");
55 return isSubClassOf(
"NativeCodeCall");
66 "the DAG leaf must be operand or attribute");
67 return Constraint(cast<DefInit>(def)->getDef());
71 assert(
isConstantAttr() &&
"the DAG leaf must be constant attribute");
76 assert(
isEnumCase() &&
"the DAG leaf must be an enum attribute case");
86 return cast<DefInit>(def)->getDef()->getValueAsString(
"expression");
91 return cast<DefInit>(def)->getDef()->getValueAsInt(
"numReturns");
95 assert(
isStringAttr() &&
"the DAG leaf must be string attribute");
96 return def->getAsUnquotedString();
98 bool DagLeaf::isSubClassOf(StringRef superclass)
const {
99 if (
auto *defInit = dyn_cast_or_null<DefInit>(def))
100 return defInit->getDef()->isSubClassOf(superclass);
114 if (
auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
115 return defInit->getDef()->isSubClassOf(
"NativeCodeCall");
127 return cast<DefInit>(node->getOperator())
129 ->getValueAsString(
"expression");
134 return cast<DefInit>(node->getOperator())
136 ->getValueAsInt(
"numReturns");
142 const Record *opDef = cast<DefInit>(node->getOperator())->
getDef();
143 auto [it, inserted] = mapper->try_emplace(opDef);
145 it->second = std::make_unique<Operator>(opDef);
153 for (
int i = 0, e =
getNumArgs(); i != e; ++i) {
155 count += child.getNumOps();
163 return isa<DagInit>(node->getArg(index));
167 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(index)));
172 return DagLeaf(node->getArg(index));
176 return node->getArgNameStr(index);
180 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
181 return dagOpDef->getName() ==
"replaceWithValue";
185 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
186 return dagOpDef->getName() ==
"location";
190 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
191 return dagOpDef->getName() ==
"returnType";
195 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
196 return dagOpDef->getName() ==
"either";
200 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
201 return dagOpDef->getName() ==
"variadic";
215 auto [name, indexStr] = symbol.rsplit(
"__");
217 if (indexStr.consumeInteger(10, idx)) {
227 SymbolInfoMap::SymbolInfo::SymbolInfo(
229 std::optional<DagAndConstant> dagAndConstant)
230 : op(op),
kind(
kind), dagAndConstant(dagAndConstant) {}
232 int SymbolInfoMap::SymbolInfo::getStaticValueCount()
const {
240 case Kind::MultipleValues:
243 llvm_unreachable(
"unknown kind");
247 return alternativeName ? *alternativeName : name.str();
251 LLVM_DEBUG(dbgs() <<
"getVarTypeStr for '" << name <<
"': ");
255 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
256 ->attr.getStorageType()
259 return "::mlir::Attribute";
261 case Kind::Operand: {
264 return "::mlir::Operation::operand_range";
267 return "::mlir::Value";
269 case Kind::MultipleValues: {
270 return "::mlir::ValueRange";
274 return op->getQualCppClassName();
277 llvm_unreachable(
"unknown kind");
281 LLVM_DEBUG(dbgs() <<
"getVarDecl for '" << name <<
"': ");
282 std::string varInit =
kind == Kind::Operand ?
"(op0->getOperands())" :
"";
284 formatv(
"{0} {1}{2};\n", getVarTypeStr(name),
getVarName(name), varInit));
288 LLVM_DEBUG(dbgs() <<
"getArgDecl for '" << name <<
"': ");
290 formatv(
"{0} &{1}", getVarTypeStr(name),
getVarName(name)));
293 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
294 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
295 LLVM_DEBUG(dbgs() <<
"getValueAndRangeUse for '" << name <<
"': ");
299 auto repl = formatv(fmt, name);
300 LLVM_DEBUG(dbgs() << repl <<
" (Attr)\n");
301 return std::string(repl);
303 case Kind::Operand: {
305 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
309 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
310 auto repl = formatv(fmt, name);
311 LLVM_DEBUG(dbgs() << repl <<
" (VariadicOperand)\n");
312 return std::string(repl);
314 auto repl = formatv(fmt, formatv(
"(*{0}.begin())", name));
315 LLVM_DEBUG(dbgs() << repl <<
" (SingleOperand)\n");
316 return std::string(repl);
323 std::string(formatv(
"{0}.getODSResults({1})", name, index));
324 if (!op->getResult(index).isVariadic())
325 v = std::string(formatv(
"(*{0}.begin())", v));
326 auto repl = formatv(fmt, v);
327 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
328 return std::string(repl);
333 if (op->getNumResults() == 0) {
334 LLVM_DEBUG(dbgs() << name <<
" (Op)\n");
335 return formatv(fmt, name);
341 values.reserve(op->getNumResults());
343 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
344 std::string v = std::string(formatv(
"{0}.getODSResults({1})", name, i));
345 if (!op->getResult(i).isVariadic()) {
346 v = std::string(formatv(
"(*{0}.begin())", v));
348 values.push_back(std::string(formatv(fmt, v)));
350 auto repl = llvm::join(values, separator);
351 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
356 assert(op ==
nullptr);
357 auto repl = formatv(fmt, name);
358 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
359 return std::string(repl);
361 case Kind::MultipleValues: {
362 assert(op ==
nullptr);
363 assert(index < getSize());
366 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
367 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
372 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
373 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
374 return std::string(repl);
377 llvm_unreachable(
"unknown kind");
380 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
381 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
382 LLVM_DEBUG(dbgs() <<
"getAllRangeUse for '" << name <<
"': ");
385 case Kind::Operand: {
386 assert(index < 0 &&
"only allowed for symbol bound to result");
387 auto repl = formatv(fmt, name);
388 LLVM_DEBUG(dbgs() << repl <<
" (Operand/Attr)\n");
389 return std::string(repl);
393 auto repl = formatv(fmt, formatv(
"{0}.getODSResults({1})", name, index));
394 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
395 return std::string(repl);
401 values.reserve(op->getNumResults());
403 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
404 values.push_back(std::string(
405 formatv(fmt, formatv(
"{0}.getODSResults({1})", name, i))));
407 auto repl = llvm::join(values, separator);
408 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
412 assert(index < 0 &&
"only allowed for symbol bound to result");
413 assert(op ==
nullptr);
414 auto repl = formatv(fmt, formatv(
"{{{0}}", name));
415 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
416 return std::string(repl);
418 case Kind::MultipleValues: {
419 assert(op ==
nullptr);
420 assert(index < getSize());
423 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
424 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
428 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
429 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
430 return std::string(repl);
433 llvm_unreachable(
"unknown kind");
438 std::optional<int> variadicSubIndex) {
440 if (name != symbol) {
441 auto error = formatv(
442 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
443 PrintFatalError(loc, error);
447 isa<NamedAttribute *>(op.
getArg(argIndex))
448 ? SymbolInfo::getAttr(&op, argIndex)
449 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
451 std::string key = symbol.str();
452 if (symbolInfoMap.count(key)) {
454 if (symInfo.kind != SymbolInfo::Kind::Operand) {
460 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
465 symbolInfoMap.emplace(key, symInfo);
471 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
473 return symbolInfoMap.count(inserted->first) == 1;
484 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
485 return symbolInfoMap.count(inserted->first) == 1;
491 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
492 return symbolInfoMap.count(inserted->first) == 1;
496 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
497 return symbolInfoMap.count(inserted->first) == 1;
501 return find(symbol) != symbolInfoMap.end();
507 return symbolInfoMap.find(name);
513 std::optional<int> variadicSubIndex)
const {
515 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
522 auto range = symbolInfoMap.equal_range(name);
524 for (
auto it = range.first; it != range.second; ++it)
525 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
528 return symbolInfoMap.end();
531 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
535 return symbolInfoMap.equal_range(name);
540 return symbolInfoMap.count(name);
545 if (name != symbol) {
551 return find(name)->second.getStaticValueCount();
556 const char *separator)
const {
560 auto it = symbolInfoMap.find(name.str());
561 if (it == symbolInfoMap.end()) {
562 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
563 PrintFatalError(loc, error);
566 return it->second.getValueAndRangeUse(name, index, fmt, separator);
570 const char *separator)
const {
574 auto it = symbolInfoMap.find(name.str());
575 if (it == symbolInfoMap.end()) {
576 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
577 PrintFatalError(loc, error);
580 return it->second.getAllRangeUse(name, index, fmt, separator);
586 for (
auto symbolInfoIt = symbolInfoMap.begin();
587 symbolInfoIt != symbolInfoMap.end();) {
588 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
589 auto startRange = range.first;
590 auto endRange = range.second;
592 auto operandName = symbolInfoIt->first;
593 int startSearchIndex = 0;
594 for (++startRange; startRange != endRange; ++startRange) {
597 for (
int i = startSearchIndex;; ++i) {
598 std::string alternativeName = operandName + std::to_string(i);
599 if (!usedNames.contains(alternativeName) &&
600 symbolInfoMap.count(alternativeName) == 0) {
601 usedNames.insert(alternativeName);
602 startRange->second.alternativeName = alternativeName;
603 startSearchIndex = i + 1;
610 symbolInfoIt = endRange;
619 : def(*def), recordOpMap(mapper) {}
622 return DagNode(def.getValueAsDag(
"sourcePattern"));
626 auto *results = def.getValueAsListInit(
"resultPatterns");
627 return results->size();
631 auto *results = def.getValueAsListInit(
"resultPatterns");
632 return DagNode(cast<DagInit>(results->getElement(index)));
636 LLVM_DEBUG(dbgs() <<
"start collecting source pattern bound symbols\n");
638 LLVM_DEBUG(dbgs() <<
"done collecting source pattern bound symbols\n");
640 LLVM_DEBUG(dbgs() <<
"start assigning alternative names for symbols\n");
642 LLVM_DEBUG(dbgs() <<
"done assigning alternative names for symbols\n");
646 LLVM_DEBUG(dbgs() <<
"start collecting result pattern bound symbols\n");
651 LLVM_DEBUG(dbgs() <<
"done collecting result pattern bound symbols\n");
663 auto *listInit = def.getValueAsListInit(
"constraints");
664 std::vector<AppliedConstraint> ret;
665 ret.reserve(listInit->size());
667 for (
auto *it : *listInit) {
668 auto *dagInit = dyn_cast<DagInit>(it);
670 PrintFatalError(&def,
"all elements in Pattern multi-entity "
671 "constraints should be DAG nodes");
673 std::vector<std::string> entities;
674 entities.reserve(dagInit->arg_size());
675 for (
auto *argName : dagInit->getArgNames()) {
679 "operands to additional constraints can only be symbol references");
681 entities.emplace_back(argName->getValue());
684 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
685 dagInit->getNameStr(), std::move(entities));
691 auto *results = def.getValueAsListInit(
"supplementalPatterns");
692 return results->size();
696 auto *results = def.getValueAsListInit(
"supplementalPatterns");
697 return DagNode(cast<DagInit>(results->getElement(index)));
704 const DagInit *delta = def.getValueAsDag(
"benefitDelta");
705 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
706 PrintFatalError(&def,
707 "The 'addBenefit' takes and only takes one integer value");
709 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
713 std::vector<std::pair<StringRef, unsigned>> result;
714 result.reserve(def.getLoc().size());
715 for (
auto loc : def.getLoc()) {
716 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
717 assert(buf &&
"invalid source location");
719 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
720 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
725 void Pattern::verifyBind(
bool result, StringRef symbolName) {
727 auto err = formatv(
"symbol '{0}' bound more than once", symbolName);
728 PrintFatalError(&def, err);
738 if (!treeName.empty()) {
740 LLVM_DEBUG(dbgs() <<
"found symbol bound to NativeCodeCall: "
741 << treeName <<
'\n');
746 PrintFatalError(&def,
747 formatv(
"binding symbol '{0}' to NativecodeCall in "
748 "MatchPattern is not supported",
753 for (
int i = 0; i != numTreeArgs; ++i) {
768 if (!treeArgName.empty() && treeArgName !=
"_") {
774 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
779 constraint.getKind() == Constraint::Kind::CK_Attr;
783 verifyBind(infoMap.
bindAttr(treeArgName), treeArgName);
788 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
798 auto numOpArgs = op.getNumArgs();
803 int numDirectives = 0;
804 for (
int i = numTreeArgs - 1; i >= 0; --i) {
806 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
808 else if (dagArg.isEither())
813 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
815 formatv(
"op '{0}' argument number mismatch: "
816 "{1} in pattern vs. {2} in definition",
817 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
818 PrintFatalError(&def, err);
823 if (!treeName.empty()) {
824 LLVM_DEBUG(dbgs() <<
"found symbol bound to op result: " << treeName
826 verifyBind(infoMap.
bindOpResult(treeName, op), treeName);
833 for (
int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
837 auto argName = tree.getArgName(i);
838 if (!argName.empty() && argName !=
"_") {
851 auto treeName = tree.getSymbol();
852 if (!treeName.empty()) {
859 for (
int i = 0; i < tree.getNumArgs(); ++i) {
863 auto argName = tree.getArgName(i);
864 if (!argName.empty() && argName !=
"_") {
873 for (
int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
874 if (
auto treeArg = tree.getArgAsNestedDag(i)) {
875 if (treeArg.isEither()) {
876 collectSymbolInEither(tree, treeArg, opArgIdx);
885 }
else if (treeArg.isVariadic()) {
886 collectSymbolInVariadic(tree, treeArg, opArgIdx);
897 auto treeArgName = tree.getArgName(i);
899 if (!treeArgName.empty() && treeArgName !=
"_") {
900 LLVM_DEBUG(dbgs() <<
"found symbol bound to op argument: "
901 << treeArgName <<
'\n');
902 verifyBind(infoMap.
bindOpArgument(tree, treeArgName, op, opArgIdx),
910 if (!treeName.empty()) {
912 &def, formatv(
"binding symbol '{0}' to non-operation/native code call "
913 "unsupported right now",
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
std::string getConditionTemplate() const
Constraint getAsConstraint() const
bool isNativeCodeCall() const
int getNumReturnsOfNativeCode() const
ConstantAttr getAsConstantAttr() const
void print(raw_ostream &os) const
std::string getStringAttr() const
StringRef getNativeCodeTemplate() const
std::string getConditionTemplate() const
bool isUnspecified() const
EnumCase getAsEnumCase() const
bool isAttrMatcher() const
bool isOperandMatcher() const
bool isConstantAttr() const
bool isStringAttr() const
bool isReturnTypeDirective() const
bool isLocationDirective() const
bool isReplaceWithValue() const
DagNode getArgAsNestedDag(unsigned index) const
DagLeaf getArgAsLeaf(unsigned index) const
int getNumReturnsOfNativeCode() const
StringRef getNativeCodeTemplate() const
void print(raw_ostream &os) const
Operator & getDialectOp(RecordOperatorMap *mapper) const
bool isNativeCodeCall() const
bool isNestedDagArg(unsigned index) const
StringRef getSymbol() const
DagNode(const llvm::DagInit *node)
StringRef getArgName(unsigned index) const
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
int getNumResultPatterns() const
std::vector< IdentifierLine > getLocation() const
DagNode getSourcePattern() const
const Operator & getSourceRootOp()
std::vector< AppliedConstraint > getConstraints() const
DagNode getResultPattern(unsigned index) const
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
DagNode getSupplementalPattern(unsigned index) const
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
int getNumSupplementalPatterns() const
std::string getArgDecl(StringRef name) const
std::string getVarName(StringRef name) const
std::string getVarTypeStr(StringRef name) const
std::string getVarDecl(StringRef name) const
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
int count(StringRef key) const
const_iterator find(StringRef key) const
void assignUniqueAlternativeNames()
bool bindMultipleValues(StringRef symbol, int numValues)
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
bool bindValues(StringRef symbol, int numValues=1)
bool bindAttr(StringRef symbol)
bool bindValue(StringRef symbol)
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
int getStaticValueCount(StringRef symbol) const
bool contains(StringRef symbol) const
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.