17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/Path.h"
22#include "llvm/TableGen/Error.h"
23#include "llvm/TableGen/Record.h"
25#define DEBUG_TYPE "mlir-tblgen-pattern"
42 return isa_and_nonnull<llvm::UnsetInit>(def);
47 return isSubClassOf(
"TypeConstraint");
52 return isSubClassOf(
"AttrConstraint");
57 return isSubClassOf(
"PropConstraint");
62 return isSubClassOf(
"Property");
66 return isSubClassOf(
"NativeCodeCall");
79 "the DAG leaf must be operand, attribute, or property");
80 return Constraint(cast<DefInit>(def)->getDef());
84 assert(
isPropMatcher() &&
"the DAG leaf must be a property matcher");
90 return Property(cast<DefInit>(def)->getDef());
94 assert(
isConstantAttr() &&
"the DAG leaf must be constant attribute");
99 assert(
isEnumCase() &&
"the DAG leaf must be an enum attribute case");
100 return EnumCase(cast<DefInit>(def));
104 assert(
isConstantProp() &&
"the DAG leaf must be a constant property value");
114 return cast<DefInit>(def)->getDef()->getValueAsString(
"expression");
119 return cast<DefInit>(def)->getDef()->getValueAsInt(
"numReturns");
123 assert(
isStringAttr() &&
"the DAG leaf must be string attribute");
124 return def->getAsUnquotedString();
126bool DagLeaf::isSubClassOf(StringRef superclass)
const {
127 if (
auto *defInit = dyn_cast_or_null<DefInit>(def))
128 return defInit->getDef()->isSubClassOf(superclass);
142 if (
auto *defInit = dyn_cast_or_null<DefInit>(node->getOperator()))
143 return defInit->getDef()->isSubClassOf(
"NativeCodeCall");
155 return cast<DefInit>(node->getOperator())
157 ->getValueAsString(
"expression");
162 return cast<DefInit>(node->getOperator())
164 ->getValueAsInt(
"numReturns");
170 const Record *opDef = cast<DefInit>(node->getOperator())->
getDef();
171 auto [it,
inserted] = mapper->try_emplace(opDef);
173 it->second = std::make_unique<Operator>(opDef);
181 for (
int i = 0, e =
getNumArgs(); i != e; ++i) {
183 count += child.getNumOps();
191 return isa<DagInit>(node->getArg(
index));
195 return DagNode(dyn_cast_or_null<DagInit>(node->getArg(
index)));
204 return node->getArgNameStr(
index);
208 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
209 return dagOpDef->getName() ==
"replaceWithValue";
213 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
214 return dagOpDef->getName() ==
"location";
218 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
219 return dagOpDef->getName() ==
"returnType";
223 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
224 return dagOpDef->getName() ==
"either";
228 auto *dagOpDef = cast<DefInit>(node->getOperator())->getDef();
229 return dagOpDef->getName() ==
"variadic";
243 auto [name, indexStr] = symbol.rsplit(
"__");
245 if (indexStr.consumeInteger(10, idx)) {
255SymbolInfoMap::SymbolInfo::SymbolInfo(
256 const Operator *op, SymbolInfo::Kind kind,
257 std::optional<DagAndConstant> dagAndConstant)
258 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
260int SymbolInfoMap::SymbolInfo::getStaticValueCount()
const {
269 case Kind::MultipleValues:
272 llvm_unreachable(
"unknown kind");
276 return alternativeName ? *alternativeName : name.str();
280 LLVM_DEBUG(dbgs() <<
"getVarTypeStr for '" << name <<
"': ");
284 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
285 ->attr.getStorageType()
288 return "::mlir::Attribute";
292 return cast<NamedProperty *>(op->getArg(getArgIndex()))
293 ->prop.getInterfaceType()
295 assert(dagAndConstant && dagAndConstant->dag &&
296 "generic properties must carry their constraint");
297 return reinterpret_cast<const DagLeaf *
>(dagAndConstant->dag)
298 ->getAsPropConstraint()
302 case Kind::Operand: {
305 return "::mlir::Operation::operand_range";
308 return "::mlir::Value";
310 case Kind::MultipleValues: {
311 return "::mlir::ValueRange";
315 return op->getQualCppClassName();
318 llvm_unreachable(
"unknown kind");
322 LLVM_DEBUG(dbgs() <<
"getVarDecl for '" << name <<
"': ");
323 std::string varInit = kind == Kind::Operand ?
"(op0->getOperands())" :
"";
329 LLVM_DEBUG(dbgs() <<
"getArgDecl for '" << name <<
"': ");
334std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
335 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
336 LLVM_DEBUG(dbgs() <<
"getValueAndRangeUse for '" << name <<
"': ");
340 auto repl = formatv(fmt, name);
341 LLVM_DEBUG(dbgs() << repl <<
" (Attr)\n");
342 return std::string(repl);
346 auto repl = formatv(fmt, name);
347 LLVM_DEBUG(dbgs() << repl <<
" (Prop)\n");
348 return std::string(repl);
350 case Kind::Operand: {
352 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
353 if (operand->isOptional()) {
355 fmt, formatv(
"({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
356 LLVM_DEBUG(dbgs() << repl <<
" (OptionalOperand)\n");
357 return std::string(repl);
362 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
363 auto repl = formatv(fmt, name);
364 LLVM_DEBUG(dbgs() << repl <<
" (VariadicOperand)\n");
365 return std::string(repl);
367 auto repl = formatv(fmt, formatv(
"(*{0}.begin())", name));
368 LLVM_DEBUG(dbgs() << repl <<
" (SingleOperand)\n");
369 return std::string(repl);
376 std::string(formatv(
"{0}.getODSResults({1})", name, index));
377 if (!op->getResult(index).isVariadic())
378 v = std::string(formatv(
"(*{0}.begin())", v));
379 auto repl = formatv(fmt, v);
380 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
381 return std::string(repl);
386 if (op->getNumResults() == 0) {
387 LLVM_DEBUG(dbgs() << name <<
" (Op)\n");
388 return formatv(fmt, name);
393 SmallVector<std::string, 4> values;
394 values.reserve(op->getNumResults());
396 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
397 std::string v = std::string(formatv(
"{0}.getODSResults({1})", name, i));
398 if (!op->getResult(i).isVariadic()) {
399 v = std::string(formatv(
"(*{0}.begin())", v));
401 values.push_back(std::string(formatv(fmt, v)));
403 auto repl = llvm::join(values, separator);
404 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
409 assert(op ==
nullptr);
410 auto repl = formatv(fmt, name);
411 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
412 return std::string(repl);
414 case Kind::MultipleValues: {
415 assert(op ==
nullptr);
416 assert(index < getSize());
419 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
420 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
425 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
426 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
427 return std::string(repl);
430 llvm_unreachable(
"unknown kind");
433std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
434 StringRef name,
int index,
const char *fmt,
const char *separator)
const {
435 LLVM_DEBUG(dbgs() <<
"getAllRangeUse for '" << name <<
"': ");
439 case Kind::Operand: {
440 assert(index < 0 &&
"only allowed for symbol bound to result");
441 auto repl = formatv(fmt, name);
442 LLVM_DEBUG(dbgs() << repl <<
" (Operand/Attr/Prop)\n");
443 return std::string(repl);
447 auto repl = formatv(fmt, formatv(
"{0}.getODSResults({1})", name, index));
448 LLVM_DEBUG(dbgs() << repl <<
" (SingleResult)\n");
449 return std::string(repl);
454 SmallVector<std::string, 4> values;
455 values.reserve(op->getNumResults());
457 for (
int i = 0, e = op->getNumResults(); i < e; ++i) {
458 values.push_back(std::string(
459 formatv(fmt, formatv(
"{0}.getODSResults({1})", name, i))));
461 auto repl = llvm::join(values, separator);
462 LLVM_DEBUG(dbgs() << repl <<
" (VariadicResult)\n");
466 assert(index < 0 &&
"only allowed for symbol bound to result");
467 assert(op ==
nullptr);
468 auto repl = formatv(fmt, formatv(
"{{{0}}", name));
469 LLVM_DEBUG(dbgs() << repl <<
" (Value)\n");
470 return std::string(repl);
472 case Kind::MultipleValues: {
473 assert(op ==
nullptr);
474 assert(index < getSize());
477 formatv(fmt, std::string(formatv(
"{0}[{1}]", name, index)));
478 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
482 formatv(fmt, std::string(formatv(
"{0}.begin(), {0}.end()", name)));
483 LLVM_DEBUG(dbgs() << repl <<
" (MultipleValues)\n");
484 return std::string(repl);
487 llvm_unreachable(
"unknown kind");
492 std::optional<int> variadicSubIndex) {
494 if (name != symbol) {
495 auto error = formatv(
496 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
497 PrintFatalError(loc, error);
502 isa<NamedAttribute *>(arg) ? SymbolInfo::getAttr(&op, argIndex)
503 : isa<NamedProperty *>(arg)
504 ? SymbolInfo::getProp(&op, argIndex)
505 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
507 std::string key = symbol.str();
508 if (symbolInfoMap.count(key)) {
510 if (symInfo.kind != SymbolInfo::Kind::Operand) {
516 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
521 symbolInfoMap.emplace(key, symInfo);
527 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
529 return symbolInfoMap.count(
inserted->first) == 1;
540 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
541 return symbolInfoMap.count(
inserted->first) == 1;
547 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
548 return symbolInfoMap.count(
inserted->first) == 1;
552 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
553 return symbolInfoMap.count(
inserted->first) == 1;
559 symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint));
560 return symbolInfoMap.count(
inserted->first) == 1;
564 return find(symbol) != symbolInfoMap.end();
570 return symbolInfoMap.find(name);
576 std::optional<int> variadicSubIndex)
const {
578 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
585 auto range = symbolInfoMap.equal_range(name);
587 for (
auto it = range.first; it != range.second; ++it)
588 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
591 return symbolInfoMap.end();
594std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
598 return symbolInfoMap.equal_range(name);
603 return symbolInfoMap.count(name);
608 if (name != symbol) {
614 return find(name)->second.getStaticValueCount();
619 const char *separator)
const {
623 auto it = symbolInfoMap.find(name.str());
624 if (it == symbolInfoMap.end()) {
625 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
626 PrintFatalError(loc, error);
629 return it->second.getValueAndRangeUse(name,
index, fmt, separator);
633 const char *separator)
const {
637 auto it = symbolInfoMap.find(name.str());
638 if (it == symbolInfoMap.end()) {
639 auto error = formatv(
"referencing unbound symbol '{0}'", symbol);
640 PrintFatalError(loc, error);
643 return it->second.getAllRangeUse(name,
index, fmt, separator);
649 for (
auto symbolInfoIt = symbolInfoMap.begin();
650 symbolInfoIt != symbolInfoMap.end();) {
651 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
652 auto startRange = range.first;
653 auto endRange = range.second;
655 auto operandName = symbolInfoIt->first;
656 int startSearchIndex = 0;
657 for (++startRange; startRange != endRange; ++startRange) {
660 for (
int i = startSearchIndex;; ++i) {
661 std::string alternativeName = operandName + std::to_string(i);
662 if (!usedNames.contains(alternativeName) &&
663 symbolInfoMap.count(alternativeName) == 0) {
664 usedNames.insert(alternativeName);
665 startRange->second.alternativeName = alternativeName;
666 startSearchIndex = i + 1;
673 symbolInfoIt = endRange;
682 : def(*def), recordOpMap(mapper) {}
685 return DagNode(def.getValueAsDag(
"sourcePattern"));
689 auto *results = def.getValueAsListInit(
"resultPatterns");
690 return results->size();
694 auto *results = def.getValueAsListInit(
"resultPatterns");
695 return DagNode(cast<DagInit>(results->getElement(
index)));
699 LLVM_DEBUG(dbgs() <<
"start collecting source pattern bound symbols\n");
701 LLVM_DEBUG(dbgs() <<
"done collecting source pattern bound symbols\n");
703 LLVM_DEBUG(dbgs() <<
"start assigning alternative names for symbols\n");
705 LLVM_DEBUG(dbgs() <<
"done assigning alternative names for symbols\n");
709 LLVM_DEBUG(dbgs() <<
"start collecting result pattern bound symbols\n");
714 LLVM_DEBUG(dbgs() <<
"done collecting result pattern bound symbols\n");
726 auto *listInit = def.getValueAsListInit(
"constraints");
727 std::vector<AppliedConstraint> ret;
728 ret.reserve(listInit->size());
730 for (
auto *it : *listInit) {
731 auto *dagInit = dyn_cast<DagInit>(it);
733 PrintFatalError(&def,
"all elements in Pattern multi-entity "
734 "constraints should be DAG nodes");
736 std::vector<std::string> entities;
737 entities.reserve(dagInit->arg_size());
738 for (
auto *argName : dagInit->getArgNames()) {
742 "operands to additional constraints can only be symbol references");
744 entities.emplace_back(argName->getValue());
747 ret.emplace_back(cast<DefInit>(dagInit->getOperator())->getDef(),
748 dagInit->getNameStr(), std::move(entities));
754 auto *results = def.getValueAsListInit(
"supplementalPatterns");
755 return results->size();
759 auto *results = def.getValueAsListInit(
"supplementalPatterns");
760 return DagNode(cast<DagInit>(results->getElement(
index)));
767 const DagInit *delta = def.getValueAsDag(
"benefitDelta");
768 if (delta->getNumArgs() != 1 || !isa<IntInit>(delta->getArg(0))) {
769 PrintFatalError(&def,
770 "The 'addBenefit' takes and only takes one integer value");
772 return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
775std::vector<Pattern::IdentifierLine>
777 std::vector<std::pair<StringRef, unsigned>>
result;
778 result.reserve(def.getLoc().size());
779 for (
auto loc : def.getLoc()) {
780 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
781 assert(buf &&
"invalid source location");
783 StringRef bufferName =
784 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier();
791 if (forSourceOutput && llvm::sys::path::is_absolute(bufferName))
792 bufferName = llvm::sys::path::filename(bufferName);
794 result.emplace_back(bufferName,
795 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
800void Pattern::verifyBind(
bool result, StringRef symbolName) {
802 auto err = formatv(
"symbol '{0}' bound more than once", symbolName);
803 PrintFatalError(&def, err);
813 if (!treeName.empty()) {
815 LLVM_DEBUG(dbgs() <<
"found symbol bound to NativeCodeCall: "
816 << treeName <<
'\n');
821 PrintFatalError(&def,
822 formatv(
"binding symbol '{0}' to NativecodeCall in "
823 "MatchPattern is not supported",
828 for (
int i = 0; i != numTreeArgs; ++i) {
843 if (!treeArgName.empty() && treeArgName !=
"_") {
850 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
854 if (propConstraint.getInterfaceType().empty()) {
855 PrintFatalError(&def,
856 formatv(
"binding symbol '{0}' in NativeCodeCall to "
857 "a property constraint without specifying "
858 "that constraint's type is unsupported",
861 verifyBind(infoMap.
bindProp(treeArgName, propConstraint),
871 verifyBind(infoMap.
bindAttr(treeArgName), treeArgName);
876 verifyBind(infoMap.
bindValue(treeArgName), treeArgName);
886 auto numOpArgs = op.getNumArgs();
891 int numDirectives = 0;
892 for (
int i = numTreeArgs - 1; i >= 0; --i) {
894 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
896 else if (dagArg.isEither())
901 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
903 formatv(
"op '{0}' argument number mismatch: "
904 "{1} in pattern vs. {2} in definition",
905 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
906 PrintFatalError(&def, err);
911 if (!treeName.empty()) {
912 LLVM_DEBUG(dbgs() <<
"found symbol bound to op result: " << treeName
914 verifyBind(infoMap.
bindOpResult(treeName, op), treeName);
921 for (
int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
925 auto argName = tree.getArgName(i);
926 if (!argName.empty() && argName !=
"_") {
939 auto treeName = tree.getSymbol();
940 if (!treeName.empty()) {
947 for (
int i = 0; i < tree.getNumArgs(); ++i) {
951 auto argName = tree.getArgName(i);
952 if (!argName.empty() && argName !=
"_") {
961 for (
int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
962 if (
auto treeArg = tree.getArgAsNestedDag(i)) {
963 if (treeArg.isEither()) {
964 collectSymbolInEither(tree, treeArg, opArgIdx);
973 }
else if (treeArg.isVariadic()) {
974 collectSymbolInVariadic(tree, treeArg, opArgIdx);
985 auto treeArgName = tree.getArgName(i);
987 if (!treeArgName.empty() && treeArgName !=
"_") {
988 LLVM_DEBUG(dbgs() <<
"found symbol bound to op argument: "
989 << treeArgName <<
'\n');
990 verifyBind(infoMap.
bindOpArgument(tree, treeArgName, op, opArgIdx),
998 if (!treeName.empty()) {
1000 &def, formatv(
"binding symbol '{0}' to non-operation/native code call "
1001 "unsupported right now",
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
std::string getConditionTemplate() const
Constraint getAsConstraint() const
bool isNativeCodeCall() const
bool isPropMatcher() const
int getNumReturnsOfNativeCode() const
ConstantAttr getAsConstantAttr() const
void print(raw_ostream &os) const
std::string getStringAttr() const
Property getAsProperty() const
StringRef getNativeCodeTemplate() const
std::string getConditionTemplate() const
bool isConstantProp() const
PropConstraint getAsPropConstraint() const
ConstantProp getAsConstantProp() const
bool isUnspecified() const
EnumCase getAsEnumCase() const
bool isAttrMatcher() const
bool isOperandMatcher() const
bool isPropDefinition() const
bool isConstantAttr() const
bool isStringAttr() const
bool isReturnTypeDirective() const
bool isLocationDirective() const
bool isReplaceWithValue() const
DagNode getArgAsNestedDag(unsigned index) const
DagLeaf getArgAsLeaf(unsigned index) const
int getNumReturnsOfNativeCode() const
StringRef getNativeCodeTemplate() const
void print(raw_ostream &os) const
Operator & getDialectOp(RecordOperatorMap *mapper) const
bool isNativeCodeCall() const
bool isNestedDagArg(unsigned index) const
StringRef getSymbol() const
DagNode(const llvm::DagInit *node)
StringRef getArgName(unsigned index) const
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
int getNumResultPatterns() const
DagNode getSourcePattern() const
const Operator & getSourceRootOp()
std::vector< AppliedConstraint > getConstraints() const
DagNode getResultPattern(unsigned index) const
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
DagNode getSupplementalPattern(unsigned index) const
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
int getNumSupplementalPatterns() const
std::vector< IdentifierLine > getLocation(bool forSourceOutput=false) const
std::string getArgDecl(StringRef name) const
std::string getVarName(StringRef name) const
std::string getVarTypeStr(StringRef name) const
std::string getVarDecl(StringRef name) const
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
int count(StringRef key) const
const_iterator find(StringRef key) const
void assignUniqueAlternativeNames()
bool bindMultipleValues(StringRef symbol, int numValues)
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
bool bindValues(StringRef symbol, int numValues=1)
bool bindAttr(StringRef symbol)
bool bindProp(StringRef symbol, const PropConstraint &constraint)
bool bindValue(StringRef symbol)
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
int getStaticValueCount(StringRef symbol) const
bool contains(StringRef symbol) const
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
DenseMap< const llvm::Record *, std::unique_ptr< Operator > > RecordOperatorMap
Include the generated interface declarations.