26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
49 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
50 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
51 typeTy(ast::TypeType::
get(ctx)), valueTy(ast::ValueType::
get(ctx)),
52 typeRangeTy(ast::TypeRangeType::
get(ctx)),
53 valueRangeTy(ast::ValueRangeType::
get(ctx)),
54 attrTy(ast::AttributeType::
get(ctx)),
55 codeCompleteContext(codeCompleteContext) {}
66 enum class ParserContext {
83 enum class OpResultTypeContext {
102 return (curDeclScope = newScope);
104 void pushDeclScope(
ast::DeclScope *scope) { curDeclScope = scope; }
107 void popDeclScope() { curDeclScope = curDeclScope->
getParentScope(); }
134 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
135 return opName ? ctx.getODSContext().lookupOperation(*opName) :
nullptr;
140 StringRef processDoc(StringRef doc) {
141 return enableDocumentation ? doc : StringRef();
146 std::string processAndFormatDoc(
const Twine &doc) {
147 if (!enableDocumentation)
151 llvm::raw_string_ostream docOS(docStr);
152 std::string tmpDocStr = doc.str();
154 StringRef(tmpDocStr).rtrim(
" \t"));
164 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
168 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
173 template <
typename Constra
intT>
175 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
177 StringRef nativeType, StringRef docString);
178 template <
typename Constra
intT>
182 StringRef nativeType);
188 struct ParsedPatternMetadata {
189 std::optional<uint16_t> benefit;
190 bool hasBoundedRecursion =
false;
195 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
208 parseUserConstraintDecl(
bool isInline =
false);
239 template <
typename T,
typename ParseUserPDLLDeclFnT>
241 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
242 StringRef anonymousNamePrefix,
bool isInline);
246 template <
typename T>
269 bool expectTerminalSemicolon =
true);
272 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
281 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
285 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
302 parseConstraint(std::optional<SMRange> &typeConstraint,
304 bool allowInlineTypeConstraints);
319 bool isNegated =
false);
329 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
330 OpResultTypeContext::Explicit);
360 createPatternDecl(SMRange loc,
const ast::Name *name,
361 const ParsedPatternMetadata &metadata,
370 template <
typename T>
379 createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
385 createArgOrResultVariableDecl(StringRef name, SMRange loc,
409 createCallExpr(SMRange loc,
ast::Expr *parentExpr,
411 bool isNegated =
false);
414 createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
417 createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name, SMRange loc);
422 StringRef name, SMRange loc);
425 OpResultTypeContext resultTypeContext,
430 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
434 std::optional<StringRef> name,
437 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
440 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
453 createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
456 createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
468 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
470 bool allowInlineTypeConstraints);
472 LogicalResult codeCompleteOperationName(StringRef dialectName);
474 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
476 void codeCompleteCallSignature(
ast::Node *parent,
unsigned currentNumArgs);
477 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
478 unsigned currentNumOperands);
479 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
480 unsigned currentNumResults);
489 if (curToken.isNot(kind))
496 void consumeToken() {
498 "shouldn't advance past EOF or errors");
499 curToken = lexer.lexToken();
506 assert(curToken.is(kind) &&
"consumed an unexpected token");
511 void resetToken(SMRange tokLoc) {
512 lexer.resetPointer(tokLoc.Start.getPointer());
513 curToken = lexer.lexToken();
519 if (curToken.getKind() != kind)
520 return emitError(curToken.getLoc(), msg);
525 lexer.emitError(loc, msg);
529 return emitError(curToken.getLoc(), msg);
531 LogicalResult emitErrorAndNote(SMRange loc,
const Twine &msg, SMRange noteLoc,
533 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
552 bool enableDocumentation;
556 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
559 ParserContext parserContext = ParserContext::Global;
567 unsigned anonymousDeclNameCounter = 0;
575 SMLoc moduleLoc = curToken.getStartLoc();
580 if (
failed(parseModuleBody(decls)))
581 return popDeclScope(),
failure();
590 if (
failed(parseDirective(decls)))
598 decls.push_back(*decl);
612 if (exprType == type)
617 expr->
getLoc(), llvm::formatv(
"unable to convert expression of type "
618 "`{0}` to the expected type of "
627 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
633 if ((exprType == valueTy || exprType == valueRangeTy) &&
634 (type == valueTy || type == valueRangeTy))
636 if ((exprType == typeTy || exprType == typeRangeTy) &&
637 (type == typeTy || type == typeRangeTy))
642 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
645 return emitConvertError();
654 if (opType.getName())
655 return emitErrorFn();
660 if (type == valueRangeTy) {
667 if (type == valueTy) {
671 if (odsOp->getResults().empty()) {
672 return emitErrorFn()->attachNote(
673 llvm::formatv(
"see the definition of `{0}`, which was defined "
679 unsigned numSingleResults = llvm::count_if(
681 return result.getVariableLengthKind() ==
682 ods::VariableLengthKind::Single;
684 if (numSingleResults > 1) {
685 return emitErrorFn()->attachNote(
686 llvm::formatv(
"see the definition of `{0}`, which was defined "
687 "with at least {1} results",
688 odsOp->getName(), numSingleResults),
697 return emitErrorFn();
706 if (tupleType.size() != exprType.
size())
707 return emitErrorFn();
712 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
714 ctx, expr->
getLoc(), expr, llvm::to_string(i),
718 diag.attachNote(llvm::formatv(
"when converting element #{0} of `{1}`",
723 if (
failed(convertExpressionTo(newExprs.back(),
724 tupleType.getElementTypes()[i], diagFn)))
728 tupleType.getElementNames());
736 if (parserContext != ParserContext::Rewrite) {
737 return emitErrorFn()->attachNote(
"Tuple to Range conversion is currently "
738 "only allowed within a rewrite context");
743 if (!llvm::is_contained(allowedElementTypes, elementType))
744 return emitErrorFn();
749 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
751 ctx, expr->
getLoc(), expr, llvm::to_string(i),
757 if (type == valueRangeTy)
758 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
759 if (type == typeRangeTy)
760 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
762 return emitErrorFn();
769 StringRef directive = curToken.getSpelling();
770 if (directive ==
"#include")
771 return parseInclude(decls);
773 return emitError(
"unknown directive `" + directive +
"`");
777 SMRange loc = curToken.getLoc();
782 return codeCompleteIncludeFilename(curToken.getStringValue());
785 if (!curToken.isString())
787 "expected string file name after `include` directive");
788 SMRange fileLoc = curToken.getLoc();
789 std::string filenameStr = curToken.getStringValue();
790 StringRef filename = filenameStr;
795 if (filename.endswith(
".pdll")) {
796 if (
failed(lexer.pushInclude(filename, fileLoc)))
798 "unable to open include file `" + filename +
"`");
803 curToken = lexer.lexToken();
805 curToken = lexer.lexToken();
810 if (filename.endswith(
".td"))
811 return parseTdInclude(filename, fileLoc, decls);
814 "expected include filename to end with `.pdll` or `.td`");
817 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
819 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
822 std::string includedFile;
823 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
824 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
826 return emitError(fileLoc,
"unable to open include file `" + filename +
"`");
829 llvm::SourceMgr tdSrcMgr;
830 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
831 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
835 struct DiagHandlerContext {
839 } handlerContext{*
this, filename, fileLoc};
842 tdSrcMgr.setDiagHandler(
843 [](
const llvm::SMDiagnostic &
diag,
void *rawHandlerContext) {
844 auto *ctx =
reinterpret_cast<DiagHandlerContext *
>(rawHandlerContext);
845 (void)ctx->parser.emitError(
847 llvm::formatv(
"error while processing include file `{0}`: {1}",
848 ctx->filename,
diag.getMessage()));
853 llvm::RecordKeeper tdRecords;
854 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
858 processTdIncludeRecords(tdRecords, decls);
863 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
867 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
870 auto getLengthKind = [](
const auto &value) {
871 if (value.isOptional())
882 cst.constraint.getUniqueDefName(),
883 processDoc(cst.constraint.getSummary()),
884 cst.constraint.getCPPClassName());
886 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
887 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
892 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Op")) {
896 bool supportsResultTypeInferrence =
897 op.getTrait(
"::mlir::InferTypeOpInterface::Trait");
900 op.getOperationName(), processDoc(op.getSummary()),
901 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
902 supportsResultTypeInferrence, op.
getLoc().front());
909 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
911 attr.attr.getUniqueDefName(),
912 processDoc(attr.attr.getSummary()),
913 attr.attr.getStorageType()));
916 odsOp->appendOperand(operand.name, getLengthKind(operand),
917 addTypeConstraint(operand));
920 odsOp->appendResult(result.name, getLengthKind(result),
921 addTypeConstraint(result));
925 auto shouldBeSkipped = [
this](llvm::Record *def) {
926 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
927 def->isSubClassOf(
"DeclareInterfaceMethods");
931 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Attr")) {
932 if (shouldBeSkipped(def))
936 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
937 constraint, convertLocToRange(def->getLoc().front()), attrTy,
938 constraint.getStorageType()));
941 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Type")) {
942 if (shouldBeSkipped(def))
946 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
947 constraint, convertLocToRange(def->getLoc().front()), typeTy,
948 constraint.getCPPClassName()));
952 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"OpInterface")) {
953 if (shouldBeSkipped(def))
956 SMRange loc = convertLocToRange(def->getLoc().front());
958 std::string cppClassName =
959 llvm::formatv(
"{0}::{1}", def->getValueAsString(
"cppNamespace"),
960 def->getValueAsString(
"cppInterfaceName"))
962 std::string codeBlock =
963 llvm::formatv(
"return ::mlir::success(llvm::isa<{0}>(self));",
968 processAndFormatDoc(def->getValueAsString(
"description"));
969 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
970 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
974 template <
typename Constra
intT>
975 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
976 StringRef name, StringRef codeBlock, SMRange loc,
ast::Type type,
977 StringRef nativeType, StringRef docString) {
983 argScope->
add(paramVar);
991 constraintDecl->setDocComment(ctx, docString);
992 curDeclScope->add(constraintDecl);
993 return constraintDecl;
996 template <
typename Constra
intT>
1000 StringRef nativeType) {
1011 std::string docString;
1012 if (enableDocumentation) {
1014 docString = processAndFormatDoc(
1019 return createODSNativePDLLConstraintDecl<ConstraintT>(
1029 switch (curToken.getKind()) {
1031 decl = parseUserConstraintDecl();
1034 decl = parsePatternDecl();
1037 decl = parseUserRewriteDecl();
1040 return emitError(
"expected top-level declaration, such as a `Pattern`");
1047 if (
failed(checkDefineNamedDecl(*name)))
1049 curDeclScope->add(*decl);
1055 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1058 return codeCompleteAttributeName(parentOpName);
1060 std::string attrNameStr;
1061 if (curToken.isString())
1062 attrNameStr = curToken.getStringValue();
1064 attrNameStr = curToken.getSpelling().str();
1066 return emitError(
"expected identifier or string attribute name");
1076 attrValue = *attrExpr;
1088 bool expectTerminalSemicolon) {
1092 SMLoc bodyStartLoc = curToken.getStartLoc();
1095 bool failedToParse =
1096 failed(singleStatement) ||
failed(processStatementFn(*singleStatement));
1101 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1108 return emitError(
"expected identifier argument name");
1111 StringRef name = curToken.getSpelling();
1112 SMRange nameLoc = curToken.getLoc();
1116 parseToken(
Token::colon,
"expected `:` before argument constraint")))
1123 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1134 StringRef name = curToken.getSpelling();
1135 SMRange nameLoc = curToken.getLoc();
1139 "expected `:` before result constraint")))
1146 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1156 return createArgOrResultVariableDecl(
"", cst->referenceLoc, *cst);
1160 Parser::parseUserConstraintDecl(
bool isInline) {
1163 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164 [&](
auto &&...args) {
1165 return this->parseUserPDLLConstraintDecl(args...);
1167 ParserContext::Constraint,
"constraint", isInline);
1172 parseUserConstraintDecl(
true);
1173 if (
failed(decl) ||
failed(checkDefineNamedDecl((*decl)->getName())))
1176 curDeclScope->add(*decl);
1186 pushDeclScope(argumentScope);
1194 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1197 "expected `Constraint` lambda body to contain a "
1198 "single expression");
1214 auto bodyIt = body->
begin(), bodyE = body->
end();
1215 for (; bodyIt != bodyE; ++bodyIt)
1216 if (isa<ast::ReturnStmt>(*bodyIt))
1218 if (
failed(validateUserConstraintOrRewriteReturn(
1219 "Constraint", body, bodyIt, bodyE, results, resultType)))
1224 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225 name, arguments, results, resultType, body);
1231 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232 [&](
auto &&...args) {
return this->parseUserPDLLRewriteDecl(args...); },
1233 ParserContext::Rewrite,
"rewrite", isInline);
1238 parseUserRewriteDecl(
true);
1239 if (
failed(decl) ||
failed(checkDefineNamedDecl((*decl)->getName())))
1242 curDeclScope->add(*decl);
1252 curDeclScope = argumentScope;
1257 if (isa<ast::OpRewriteStmt>(statement))
1260 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1261 if (!statementExpr) {
1264 "expected `Rewrite` lambda body to contain a single expression "
1265 "or an operation rewrite statement; such as `erase`, "
1266 "`replace`, or `rewrite`");
1285 auto bodyIt = body->
begin(), bodyE = body->
end();
1286 for (; bodyIt != bodyE; ++bodyIt)
1287 if (isa<ast::ReturnStmt>(*bodyIt))
1289 if (
failed(validateUserConstraintOrRewriteReturn(
"Rewrite", body, bodyIt,
1290 bodyE, results, resultType)))
1292 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293 name, arguments, results, resultType, body);
1296 template <
typename T,
typename ParseUserPDLLDeclFnT>
1298 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299 StringRef anonymousNamePrefix,
bool isInline) {
1300 SMRange loc = curToken.getLoc();
1302 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1310 return emitError(
"expected identifier name");
1314 std::string anonName =
1315 llvm::formatv(
"<anonymous_{0}_{1}>", anonymousNamePrefix,
1316 anonymousDeclNameCounter++)
1329 if (
failed(parseUserConstraintOrRewriteSignature(arguments, results,
1330 argumentScope, resultType)))
1336 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1340 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341 results, resultType);
1344 template <
typename T>
1350 std::string codeStrStorage;
1351 std::optional<StringRef> optCodeStr;
1352 if (curToken.isString()) {
1353 codeStrStorage = curToken.getStringValue();
1354 optCodeStr = codeStrStorage;
1356 }
else if (isInline) {
1358 "external declarations must be declared in global scope");
1363 "expected `;` after native declaration")))
1367 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1369 "native Constraints currently do not support returning results");
1371 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1374 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1382 argumentScope = pushDeclScope();
1388 arguments.emplace_back(*argument);
1402 results.emplace_back(*result);
1409 if (
failed(parseResultFn()))
1416 }
else if (
failed(parseResultFn())) {
1423 resultType = createUserConstraintRewriteResultType(results);
1426 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1428 results.front()->getLoc(),
1429 "cannot create a single-element tuple with an element label");
1434 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1440 if (bodyIt != bodyE) {
1442 if (std::next(bodyIt) != bodyE) {
1444 (*std::next(bodyIt))->getLoc(),
1445 llvm::formatv(
"`return` terminated the `{0}` body, but found "
1446 "trailing statements afterwards",
1452 }
else if (!results.empty()) {
1454 {body->getLoc().End, body->getLoc().End},
1455 llvm::formatv(
"missing return in a `{0}` expected to return `{1}`",
1456 declType, resultType));
1463 if (isa<ast::OpRewriteStmt>(statement))
1467 "expected Pattern lambda body to contain a single operation "
1468 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1473 SMRange loc = curToken.getLoc();
1475 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1485 ParsedPatternMetadata metadata;
1500 return emitError(
"expected `{` or `=>` to start pattern body");
1507 auto bodyIt = body->
begin(), bodyE = body->
end();
1508 for (; bodyIt != bodyE; ++bodyIt) {
1509 if (isa<ast::ReturnStmt>(*bodyIt)) {
1511 "`return` statements are only permitted within a "
1512 "`Constraint` or `Rewrite` body");
1515 if (isa<ast::OpRewriteStmt>(*bodyIt))
1518 if (bodyIt == bodyE) {
1520 "expected Pattern body to terminate with an operation "
1521 "rewrite statement, such as `erase`");
1523 if (std::next(bodyIt) != bodyE) {
1524 return emitError((*std::next(bodyIt))->getLoc(),
1525 "Pattern body was terminated by an operation "
1526 "rewrite statement, but found trailing statements");
1530 return createPatternDecl(loc, name, metadata, body);
1534 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1535 std::optional<SMRange> benefitLoc;
1536 std::optional<SMRange> hasBoundedRecursionLoc;
1541 return codeCompletePatternMetadata();
1544 return emitError(
"expected pattern metadata identifier");
1545 StringRef metadataStr = curToken.getSpelling();
1546 SMRange metadataLoc = curToken.getLoc();
1550 if (metadataStr ==
"benefit") {
1552 return emitErrorAndNote(metadataLoc,
1553 "pattern benefit has already been specified",
1554 *benefitLoc,
"see previous definition here");
1557 "expected `(` before pattern benefit")))
1560 uint16_t benefitValue = 0;
1562 return emitError(
"expected integral pattern benefit");
1563 if (curToken.getSpelling().getAsInteger(10, benefitValue))
1565 "expected pattern benefit to fit within a 16-bit integer");
1568 metadata.benefit = benefitValue;
1569 benefitLoc = metadataLoc;
1572 parseToken(
Token::r_paren,
"expected `)` after pattern benefit")))
1578 if (metadataStr ==
"recursion") {
1579 if (hasBoundedRecursionLoc) {
1580 return emitErrorAndNote(
1582 "pattern recursion metadata has already been specified",
1583 *hasBoundedRecursionLoc,
"see previous definition here");
1585 metadata.hasBoundedRecursion =
true;
1586 hasBoundedRecursionLoc = metadataLoc;
1590 return emitError(metadataLoc,
"unknown pattern metadata");
1602 "expected `>` after variable type constraint")))
1608 assert(curDeclScope &&
"defining decl outside of a decl scope");
1610 return emitErrorAndNote(
1611 name.
getLoc(),
"`" + name.
getName() +
"` has already been defined",
1612 lastDecl->getName()->getLoc(),
"see previous definition here");
1618 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1621 assert(curDeclScope &&
"defining variable outside of decl scope");
1626 if (name.empty() || name ==
"_") {
1630 if (
failed(checkDefineNamedDecl(nameDecl)))
1635 curDeclScope->add(varDecl);
1640 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1642 return defineVariableDecl(name, nameLoc, type,
nullptr,
1648 std::optional<SMRange> typeConstraint;
1649 auto parseSingleConstraint = [&] {
1651 typeConstraint, constraints,
true);
1654 constraints.push_back(*constraint);
1660 return parseSingleConstraint();
1663 if (
failed(parseSingleConstraint()))
1666 return parseToken(
Token::r_square,
"expected `]` after constraint list");
1670 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1672 bool allowInlineTypeConstraints) {
1674 if (!allowInlineTypeConstraints) {
1677 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1678 "permitted on arguments or results");
1681 return emitErrorAndNote(
1683 "the type of this variable has already been constrained",
1684 *typeConstraint,
"see previous constraint location here");
1686 if (
failed(constraintExpr))
1688 typeExpr = *constraintExpr;
1689 typeConstraint = typeExpr->getLoc();
1693 SMRange loc = curToken.getLoc();
1694 switch (curToken.getKind()) {
1711 parseWrappedOperationName(
true);
1756 StringRef constraintName = curToken.getSpelling();
1762 return emitError(loc,
"unknown reference to constraint `" +
1763 constraintName +
"`");
1767 if (
auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1770 return emitErrorAndNote(
1771 loc,
"invalid reference to non-constraint", cstDecl->
getLoc(),
1772 "see the definition of `" + constraintName +
"` here");
1778 if (
failed(validateVariableConstraints(existingConstraints, inferredType)))
1781 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1786 return emitError(loc,
"expected identifier constraint");
1790 std::optional<SMRange> typeConstraint;
1791 return parseConstraint(typeConstraint, std::nullopt,
1800 return parseUnderscoreExpr();
1804 switch (curToken.getKind()) {
1806 lhsExpr = parseAttributeExpr();
1809 lhsExpr = parseInlineConstraintLambdaExpr();
1812 lhsExpr = parseNegatedExpr();
1815 lhsExpr = parseIdentifierExpr();
1818 lhsExpr = parseOperationExpr();
1821 lhsExpr = parseInlineRewriteLambdaExpr();
1824 lhsExpr = parseTypeExpr();
1827 lhsExpr = parseTupleExpr();
1830 return emitError(
"expected expression");
1837 switch (curToken.getKind()) {
1839 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1842 lhsExpr = parseCallExpr(*lhsExpr);
1853 SMRange loc = curToken.getLoc();
1860 return parseIdentifierExpr();
1863 if (!curToken.isString())
1864 return emitError(
"expected string literal containing MLIR attribute");
1865 std::string attrExpr = curToken.getStringValue();
1868 loc.End = curToken.getEndLoc();
1870 parseToken(
Token::greater,
"expected `>` after attribute literal")))
1885 codeCompleteCallSignature(parentExpr, arguments.size());
1892 arguments.push_back(*argument);
1896 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1900 return createCallExpr(loc, parentExpr, arguments, isNegated);
1904 ast::Decl *decl = curDeclScope->lookup(name);
1906 return emitError(loc,
"undefined reference to `" + name +
"`");
1908 return createDeclRefExpr(loc, decl);
1912 StringRef name = curToken.getSpelling();
1913 SMRange nameLoc = curToken.getLoc();
1920 if (
failed(parseVariableDeclConstraintList(constraints)))
1923 if (
failed(validateVariableConstraints(constraints, type)))
1925 return createInlineVariableExpr(type, name, nameLoc, constraints);
1928 return parseDeclRefExpr(name, nameLoc);
1950 SMRange dotLoc = curToken.getLoc();
1955 return codeCompleteMemberAccess(parentExpr);
1958 Token memberNameTok = curToken;
1961 return emitError(dotLoc,
"expected identifier or numeric member name");
1962 StringRef memberName = memberNameTok.
getSpelling();
1963 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1966 return createMemberAccessExpr(parentExpr, memberName, loc);
1973 return emitError(
"expected native constraint");
1975 if (
failed(identifierExpr))
1977 return parseCallExpr(*identifierExpr,
true);
1981 SMRange loc = curToken.getLoc();
1985 return codeCompleteDialectName();
1991 return emitError(
"expected dialect namespace");
1993 StringRef name = curToken.getSpelling();
1997 if (
failed(parseToken(
Token::dot,
"expected `.` after dialect namespace")))
2002 return codeCompleteOperationName(name);
2005 return emitError(
"expected operation name after dialect namespace");
2007 name = StringRef(name.data(), name.size() + 1);
2009 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2010 loc.End = curToken.getEndLoc();
2013 curToken.isKeyword());
2018 Parser::parseWrappedOperationName(
bool allowEmptyName) {
2032 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2033 SMRange loc = curToken.getLoc();
2040 return parseIdentifierExpr();
2046 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2048 parseWrappedOperationName(allowEmptyName);
2051 std::optional<StringRef> opName = (*opNameDecl)->getName();
2058 assert(
succeeded(rangeVar) &&
"expected range variable to be valid");
2069 if (parserContext != ParserContext::Rewrite) {
2070 operands.push_back(createImplicitRangeVar(
2078 codeCompleteOperationOperandsSignature(opName, operands.size());
2085 operands.push_back(*operand);
2089 "expected `)` after operation operand list")))
2098 parseNamedAttributeDecl(opName);
2101 attributes.emplace_back(*decl);
2105 "expected `}` after operation attribute list")))
2111 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2116 "expected `(` before operation result type list")))
2124 resultTypeContext = OpResultTypeContext::Explicit;
2131 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2136 if (
failed(resultTypeExpr))
2138 resultTypes.push_back(*resultTypeExpr);
2142 "expected `)` after operation result type list")))
2145 }
else if (parserContext != ParserContext::Rewrite) {
2150 resultTypes.push_back(createImplicitRangeVar(
2152 }
else if (resultTypeContext == OpResultTypeContext::Explicit) {
2155 resultTypeContext = OpResultTypeContext::Interface;
2158 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2159 attributes, resultTypes);
2163 SMRange loc = curToken.getLoc();
2172 StringRef elementName;
2174 Token elementNameTok = curToken;
2182 auto elementNameIt =
2183 usedNames.try_emplace(elementName, elementNameTok.
getLoc());
2184 if (!elementNameIt.second) {
2185 return emitErrorAndNote(
2187 llvm::formatv(
"duplicate tuple element label `{0}`",
2189 elementNameIt.first->getSecond(),
2190 "see previous label use here");
2195 resetToken(elementNameTok.
getLoc());
2198 elementNames.push_back(elementName);
2204 elements.push_back(*element);
2207 loc.End = curToken.getEndLoc();
2209 parseToken(
Token::r_paren,
"expected `)` after tuple element list")))
2211 return createTupleExpr(loc, elements, elementNames);
2215 SMRange loc = curToken.getLoc();
2222 return parseIdentifierExpr();
2225 if (!curToken.isString())
2226 return emitError(
"expected string literal containing MLIR type");
2227 std::string attrExpr = curToken.getStringValue();
2230 loc.End = curToken.getEndLoc();
2237 StringRef name = curToken.getSpelling();
2238 SMRange nameLoc = curToken.getLoc();
2247 if (
failed(parseVariableDeclConstraintList(constraints)))
2251 if (
failed(validateVariableConstraints(constraints, type)))
2253 return createInlineVariableExpr(type, name, nameLoc, constraints);
2261 switch (curToken.getKind()) {
2263 stmt = parseEraseStmt();
2266 stmt = parseLetStmt();
2269 stmt = parseReplaceStmt();
2272 stmt = parseReturnStmt();
2275 stmt = parseRewriteStmt();
2282 (expectTerminalSemicolon &&
2289 SMLoc startLoc = curToken.getStartLoc();
2298 return popDeclScope(),
failure();
2299 statements.push_back(*statement);
2304 SMRange location(startLoc, curToken.getEndLoc());
2311 if (parserContext == ParserContext::Constraint)
2312 return emitError(
"`erase` cannot be used within a Constraint");
2313 SMRange loc = curToken.getLoc();
2321 return createEraseStmt(loc, *rootOp);
2325 SMRange loc = curToken.getLoc();
2329 SMRange varLoc = curToken.getLoc();
2334 "`_` may only be used to define \"inline\" variables");
2337 "expected identifier after `let` to name a new variable");
2339 StringRef varName = curToken.getSpelling();
2345 failed(parseVariableDeclConstraintList(constraints)))
2352 if (
failed(initOrFailure))
2354 initializer = *initOrFailure;
2363 if (cst->getTypeExpr()) {
2365 constraint.referenceLoc,
2366 "type constraints are not permitted on variables with "
2378 createVariableDecl(varName, varLoc, initializer, constraints);
2385 if (parserContext == ParserContext::Constraint)
2386 return emitError(
"`replace` cannot be used within a Constraint");
2387 SMRange loc = curToken.getLoc();
2396 parseToken(
Token::kw_with,
"expected `with` after root operation")))
2400 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2407 loc,
"expected at least one replacement value, consider using "
2408 "`erase` if no replacement values are desired");
2415 replValues.emplace_back(*replExpr);
2419 "expected `)` after replacement values")))
2426 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2428 replExpr = parseExpr();
2431 replValues.emplace_back(*replExpr);
2434 return createReplaceStmt(loc, *rootOp, replValues);
2438 SMRange loc = curToken.getLoc();
2450 if (parserContext == ParserContext::Constraint)
2451 return emitError(
"`rewrite` cannot be used within a Constraint");
2452 SMRange loc = curToken.getLoc();
2464 return emitError(
"expected `{` to start rewrite body");
2467 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2474 for (
const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2475 if (isa<ast::ReturnStmt>(stmt)) {
2477 "`return` statements are only permitted within a "
2478 "`Constraint` or `Rewrite` body");
2482 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2494 if (
auto *init = dyn_cast<ast::DeclRefExpr>(node))
2495 node = init->getDecl();
2496 return dyn_cast<ast::CallableDecl>(node);
2500 Parser::createPatternDecl(SMRange loc,
const ast::Name *name,
2501 const ParsedPatternMetadata &metadata,
2504 metadata.hasBoundedRecursion, body);
2507 ast::Type Parser::createUserConstraintRewriteResultType(
2510 if (results.size() == 1)
2511 return results[0]->getType();
2515 auto resultTypes = llvm::map_range(
2516 results, [&](
const auto *result) {
return result->getType(); });
2517 auto resultNames = llvm::map_range(
2518 results, [&](
const auto *result) {
return result->getName().getName(); });
2520 llvm::to_vector(resultNames));
2523 template <
typename T>
2529 if (
auto *retStmt = dyn_cast<ast::ReturnStmt>(body->
getChildren().back())) {
2530 ast::Expr *resultExpr = retStmt->getResultExpr();
2535 if (results.empty())
2536 resultType = resultExpr->
getType();
2537 else if (
failed(convertExpressionTo(resultExpr, resultType)))
2540 retStmt->setResultExpr(resultExpr);
2543 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2547 Parser::createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
2552 if (
failed(validateVariableConstraints(constraints, type)))
2559 type = initializer->
getType();
2562 else if (
failed(convertExpressionTo(initializer, type)))
2568 return emitErrorAndNote(
2569 loc,
"unable to infer type for variable `" + name +
"`", loc,
2570 "the type of a variable must be inferable from the constraint "
2571 "list or the initializer");
2577 loc, llvm::formatv(
"unable to define variable of `{0}` type", type));
2582 defineVariableDecl(name, loc, type, initializer, constraints);
2590 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2593 if (
failed(validateVariableConstraint(constraint, argType)))
2595 return defineVariableDecl(name, loc, argType, constraint);
2602 if (
failed(validateVariableConstraint(ref, inferredType)))
2610 if (
const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.
constraint)) {
2611 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2612 if (
failed(validateTypeConstraintExpr(typeExpr)))
2616 }
else if (
const auto *cst =
2617 dyn_cast<ast::OpConstraintDecl>(ref.
constraint)) {
2620 }
else if (isa<ast::TypeConstraintDecl>(ref.
constraint)) {
2621 constraintType = typeTy;
2622 }
else if (isa<ast::TypeRangeConstraintDecl>(ref.
constraint)) {
2623 constraintType = typeRangeTy;
2624 }
else if (
const auto *cst =
2625 dyn_cast<ast::ValueConstraintDecl>(ref.
constraint)) {
2626 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2627 if (
failed(validateTypeConstraintExpr(typeExpr)))
2630 constraintType = valueTy;
2631 }
else if (
const auto *cst =
2632 dyn_cast<ast::ValueRangeConstraintDecl>(ref.
constraint)) {
2633 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2634 if (
failed(validateTypeRangeConstraintExpr(typeExpr)))
2637 constraintType = valueRangeTy;
2638 }
else if (
const auto *cst =
2639 dyn_cast<ast::UserConstraintDecl>(ref.
constraint)) {
2641 if (inputs.size() != 1) {
2643 "`Constraint`s applied via a variable constraint "
2644 "list must take a single input, but got " +
2645 Twine(inputs.size()),
2647 "see definition of constraint here");
2649 constraintType = inputs.front()->getType();
2651 llvm_unreachable(
"unknown constraint type");
2656 if (!inferredType) {
2657 inferredType = constraintType;
2659 inferredType = mergedTy;
2662 llvm::formatv(
"constraint type `{0}` is incompatible "
2663 "with the previously inferred type `{1}`",
2664 constraintType, inferredType));
2671 if (typeExprType != typeTy) {
2673 "expected expression of `Type` in type constraint");
2679 Parser::validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr) {
2681 if (typeExprType != typeRangeTy) {
2683 "expected expression of `TypeRange` in type constraint");
2692 Parser::createCallExpr(SMRange loc,
ast::Expr *parentExpr,
2697 if (!callableDecl) {
2699 llvm::formatv(
"expected a reference to a callable "
2700 "`Constraint` or `Rewrite`, but got: `{0}`",
2703 if (parserContext == ParserContext::Rewrite) {
2704 if (isa<ast::UserConstraintDecl>(callableDecl))
2706 loc,
"unable to invoke `Constraint` within a rewrite section");
2708 return emitError(loc,
"unable to negate a Rewrite");
2710 if (isa<ast::UserRewriteDecl>(callableDecl))
2712 "unable to invoke `Rewrite` within a match section");
2713 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2714 return emitError(loc,
"unable to negate non native constraints");
2720 if (callArgs.size() != arguments.size()) {
2721 return emitErrorAndNote(
2723 llvm::formatv(
"invalid number of arguments for {0} call; expected "
2728 llvm::formatv(
"see the definition of {0} here",
2734 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here",
2738 for (
auto it : llvm::zip(callArgs, arguments)) {
2739 if (
failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2752 if (isa<ast::ConstraintDecl>(decl))
2754 else if (isa<ast::UserRewriteDecl>(decl))
2756 else if (
auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2757 declType = varDecl->getType();
2759 return emitError(loc,
"invalid reference to `" +
2766 Parser::createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
2769 defineVariableDecl(name, loc, type, constraints);
2776 Parser::createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name,
2787 StringRef name, SMRange loc) {
2791 return valueRangeTy;
2795 auto results = odsOp->getResults();
2799 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2800 index < results.size()) {
2801 return results[index].isVariadic() ? valueRangeTy : valueTy;
2805 const auto *it = llvm::find_if(results, [&](
const auto &result) {
2806 return result.getName() == name;
2808 if (it != results.end())
2809 return it->isVariadic() ? valueRangeTy : valueTy;
2810 }
else if (llvm::isDigit(name[0])) {
2818 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2819 index < tupleType.size()) {
2820 return tupleType.getElementTypes()[index];
2824 auto elementNames = tupleType.getElementNames();
2825 const auto *it = llvm::find(elementNames, name);
2826 if (it != elementNames.end())
2827 return tupleType.getElementTypes()[it - elementNames.begin()];
2831 llvm::formatv(
"invalid member access `{0}` on expression of type `{1}`",
2837 OpResultTypeContext resultTypeContext,
2841 std::optional<StringRef> opNameRef = name->
getName();
2845 if (
failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2851 ast::Type attrType = attr->getValue()->getType();
2854 attr->getValue()->getLoc(),
2855 llvm::formatv(
"expected `Attr` expression, but got `{0}`", attrType));
2860 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2861 "unexpected inferrence when results were explicitly specified");
2865 if (resultTypeContext == OpResultTypeContext::Explicit) {
2866 if (
failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2870 }
else if (resultTypeContext == OpResultTypeContext::Interface) {
2872 "expected valid operation name when inferring operation results");
2873 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2881 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2884 return validateOperationOperandsOrResults(
2885 "operand", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2886 operands, odsOp ? odsOp->
getOperands() : std::nullopt, valueTy,
2891 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2894 return validateOperationOperandsOrResults(
2895 "result", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2896 results, odsOp ? odsOp->
getResults() : std::nullopt, typeTy, typeRangeTy);
2899 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2908 ctx.getDiagEngine().emitWarning(
2910 "operation result types are marked to be inferred, but "
2911 "`{0}` is unknown. Ensure that `{0}` supports zero "
2912 "results or implements `InferTypeOpInterface`. Include "
2913 "the ODS definition of this operation to remove this warning.",
2922 bool requiresInferrence =
2924 return !result.isVariableLength();
2929 llvm::formatv(
"operation result types are marked to be inferred, but "
2930 "`{0}` does not provide an implementation of "
2931 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2932 "`InferTypeOpInterface` at runtime, or add support to "
2933 "the ODS definition to remove this warning.",
2935 diag->attachNote(llvm::formatv(
"see the definition of `{0}` here", opName),
2942 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2947 if (values.size() == 1) {
2948 if (
failed(convertExpressionTo(values[0], rangeTy)))
2956 auto emitSizeMismatchError = [&] {
2957 return emitErrorAndNote(
2959 llvm::formatv(
"invalid number of {0} groups for `{1}`; expected "
2961 groupName, *name, odsValues.size(), values.size()),
2962 *odsOpLoc, llvm::formatv(
"see the definition of `{0}` here", *name));
2966 if (values.empty()) {
2968 if (odsValues.empty())
2973 unsigned numVariadic = 0;
2974 for (
const auto &odsValue : odsValues) {
2975 if (!odsValue.isVariableLength())
2976 return emitSizeMismatchError();
2982 if (parserContext != ParserContext::Rewrite)
2989 if (numVariadic == 1)
2994 for (
unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2996 ctx, loc, std::nullopt, rangeTy));
3003 if (odsValues.size() != values.size())
3004 return emitSizeMismatchError();
3007 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here", *name),
3010 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
3011 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3012 if (
failed(convertExpressionTo(values[i], expectedType, diagFn)))
3021 ast::Type valueExprType = valueExpr->getType();
3024 if (valueExprType == rangeTy || valueExprType == singleTy)
3030 if (singleTy == valueTy) {
3032 valueExpr = convertOpToValue(valueExpr);
3038 if (
succeeded(convertExpressionTo(valueExpr, rangeTy)))
3042 valueExpr->getLoc(),
3044 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3045 singleTy, rangeTy, valueExprType));
3053 for (
const ast::Expr *element : elements) {
3058 llvm::formatv(
"unable to build a tuple with `{0}` element", eleTy));
3078 Parser::createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
3085 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3090 bool shouldConvertOpToValues = replValues.size() > 1;
3091 for (
ast::Expr *&replExpr : replValues) {
3092 ast::Type replType = replExpr->getType();
3096 if (shouldConvertOpToValues)
3097 replExpr = convertOpToValue(replExpr);
3101 if (replType != valueTy && replType != valueRangeTy) {
3103 llvm::formatv(
"expected `Op`, `Value` or `ValueRange` "
3104 "expression, but got `{0}`",
3113 Parser::createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
3120 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3133 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3135 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3140 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3142 codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3147 Parser::codeCompleteConstraintName(
ast::Type inferredType,
3148 bool allowInlineTypeConstraints) {
3149 codeCompleteContext->codeCompleteConstraintName(
3150 inferredType, allowInlineTypeConstraints, curDeclScope);
3155 codeCompleteContext->codeCompleteDialectName();
3159 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3160 codeCompleteContext->codeCompleteOperationName(dialectName);
3165 codeCompleteContext->codeCompletePatternMetadata();
3169 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3170 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3174 void Parser::codeCompleteCallSignature(
ast::Node *parent,
3175 unsigned currentNumArgs) {
3180 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3183 void Parser::codeCompleteOperationOperandsSignature(
3184 std::optional<StringRef> opName,
unsigned currentNumOperands) {
3185 codeCompleteContext->codeCompleteOperationOperandsSignature(
3186 opName, currentNumOperands);
3189 void Parser::codeCompleteOperationResultsSignature(
3190 std::optional<StringRef> opName,
unsigned currentNumResults) {
3191 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3201 bool enableDocumentation,
3203 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3204 return parser.parseModule();
static std::string diag(const llvm::Value &value)
This class provides support for representing a failure result, or a valid value of type T.
This class breaks up the current file into a token stream.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This represents a token in the MLIR syntax.
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
StringRef getSpelling() const
This class provides an abstract interface into the parser for hooking in code completion events.
@ code_complete_string
Token signifying a code completion location within a string.
@ code_complete
Token signifying a code completion location.
@ less
Paired punctuation.
@ kw_Attr
General keywords.
static StringRef getMemberName()
Return the member name used for the "all-results" access.
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
This class represents a PDLL type that corresponds to an mlir::Attribute.
static AttributeType get(Context &context)
Return an instance of the Attribute type.
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
This decl represents a shared interface for all callable decls.
Type getResultType() const
Return the result type of this decl.
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
StringRef getCallableType() const
Return the callable type of this decl.
This statement represents a compound statement, which contains a collection of other statements.
ArrayRef< Stmt * >::iterator begin() const
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
ArrayRef< Stmt * >::iterator end() const
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
This class represents the base of all AST Constraint decls.
This class represents a PDLL type that corresponds to a constraint.
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
This class represents the main context of the PDLL AST.
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
This class represents a scope for named AST decls.
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
void add(Decl *decl)
Add a new decl to the scope.
This class represents the base Decl node.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
This class provides a simple implementation of a PDLL diagnostic.
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a diagnostic that is inflight and set to be reported.
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
This Decl represents a NamedAttribute, and contains a string name and attribute value.
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
This Decl represents an OperationName.
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
static OpNameDecl * create(Context &ctx, const Name &name)
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
This class represents a PDLL type that corresponds to a range of elements with a given element type.
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
This class represents a PDLL type that corresponds to a rewrite reference.
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
This class represents a base AST Statement node.
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
This class represents a PDLL tuple type, i.e.
size_t size() const
Return the number of elements within this tuple.
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Type refineWith(Type other) const
Try to refine this type with the one provided.
bool isa() const
Provide type casting support.
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
The class represents a Value constraint, and constrains a variable to be a Value.
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
This class contains all of the registered ODS operation classes.
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
StringRef getDescription() const
std::string getConditionTemplate() const
Format context containing substitutions for special placeholders.
FmtContext & withSelf(Twine subst)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
const ConstraintDecl * constraint
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.
static const Name & create(Context &ctx, StringRef name, SMRange location)