26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
49 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
50 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
51 typeTy(ast::TypeType::
get(ctx)), valueTy(ast::ValueType::
get(ctx)),
52 typeRangeTy(ast::TypeRangeType::
get(ctx)),
53 valueRangeTy(ast::ValueRangeType::
get(ctx)),
54 attrTy(ast::AttributeType::
get(ctx)),
55 codeCompleteContext(codeCompleteContext) {}
66 enum class ParserContext {
83 enum class OpResultTypeContext {
102 return (curDeclScope = newScope);
104 void pushDeclScope(
ast::DeclScope *scope) { curDeclScope = scope; }
107 void popDeclScope() { curDeclScope = curDeclScope->
getParentScope(); }
134 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
135 return opName ? ctx.getODSContext().lookupOperation(*opName) :
nullptr;
140 StringRef processDoc(StringRef doc) {
141 return enableDocumentation ? doc : StringRef();
146 std::string processAndFormatDoc(
const Twine &doc) {
147 if (!enableDocumentation)
151 llvm::raw_string_ostream docOS(docStr);
152 std::string tmpDocStr = doc.str();
154 StringRef(tmpDocStr).rtrim(
" \t"));
164 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
168 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
173 template <
typename Constra
intT>
175 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
177 StringRef nativeType, StringRef docString);
178 template <
typename Constra
intT>
182 StringRef nativeType);
188 struct ParsedPatternMetadata {
189 std::optional<uint16_t> benefit;
190 bool hasBoundedRecursion =
false;
195 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
208 parseUserConstraintDecl(
bool isInline =
false);
239 template <
typename T,
typename ParseUserPDLLDeclFnT>
241 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
242 StringRef anonymousNamePrefix,
bool isInline);
246 template <
typename T>
269 bool expectTerminalSemicolon =
true);
272 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
281 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
285 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
302 parseConstraint(std::optional<SMRange> &typeConstraint,
304 bool allowInlineTypeConstraints);
327 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328 OpResultTypeContext::Explicit);
358 createPatternDecl(SMRange loc,
const ast::Name *name,
359 const ParsedPatternMetadata &metadata,
368 template <
typename T>
377 createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
383 createArgOrResultVariableDecl(StringRef name, SMRange loc,
407 createCallExpr(SMRange loc,
ast::Expr *parentExpr,
411 createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
414 createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name, SMRange loc);
419 StringRef name, SMRange loc);
422 OpResultTypeContext resultTypeContext,
427 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
431 std::optional<StringRef> name,
434 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
437 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
450 createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
453 createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
465 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467 bool allowInlineTypeConstraints);
469 LogicalResult codeCompleteOperationName(StringRef dialectName);
471 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
473 void codeCompleteCallSignature(
ast::Node *parent,
unsigned currentNumArgs);
474 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
475 unsigned currentNumOperands);
476 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
477 unsigned currentNumResults);
486 if (curToken.isNot(kind))
493 void consumeToken() {
495 "shouldn't advance past EOF or errors");
496 curToken = lexer.lexToken();
503 assert(curToken.is(kind) &&
"consumed an unexpected token");
508 void resetToken(SMRange tokLoc) {
509 lexer.resetPointer(tokLoc.Start.getPointer());
510 curToken = lexer.lexToken();
516 if (curToken.getKind() != kind)
517 return emitError(curToken.getLoc(), msg);
522 lexer.emitError(loc, msg);
526 return emitError(curToken.getLoc(), msg);
528 LogicalResult emitErrorAndNote(SMRange loc,
const Twine &msg, SMRange noteLoc,
530 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
549 bool enableDocumentation;
553 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
556 ParserContext parserContext = ParserContext::Global;
564 unsigned anonymousDeclNameCounter = 0;
572 SMLoc moduleLoc = curToken.getStartLoc();
577 if (
failed(parseModuleBody(decls)))
578 return popDeclScope(),
failure();
587 if (
failed(parseDirective(decls)))
595 decls.push_back(*decl);
609 if (exprType == type)
614 expr->
getLoc(), llvm::formatv(
"unable to convert expression of type "
615 "`{0}` to the expected type of "
624 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
630 if ((exprType == valueTy || exprType == valueRangeTy) &&
631 (type == valueTy || type == valueRangeTy))
633 if ((exprType == typeTy || exprType == typeRangeTy) &&
634 (type == typeTy || type == typeRangeTy))
639 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
642 return emitConvertError();
651 if (opType.getName())
652 return emitErrorFn();
657 if (type == valueRangeTy) {
664 if (type == valueTy) {
668 if (odsOp->getResults().empty()) {
669 return emitErrorFn()->attachNote(
670 llvm::formatv(
"see the definition of `{0}`, which was defined "
676 unsigned numSingleResults = llvm::count_if(
678 return result.getVariableLengthKind() ==
679 ods::VariableLengthKind::Single;
681 if (numSingleResults > 1) {
682 return emitErrorFn()->attachNote(
683 llvm::formatv(
"see the definition of `{0}`, which was defined "
684 "with at least {1} results",
685 odsOp->getName(), numSingleResults),
694 return emitErrorFn();
703 if (tupleType.size() != exprType.
size())
704 return emitErrorFn();
709 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
711 ctx, expr->
getLoc(), expr, llvm::to_string(i),
715 diag.attachNote(llvm::formatv(
"when converting element #{0} of `{1}`",
720 if (
failed(convertExpressionTo(newExprs.back(),
721 tupleType.getElementTypes()[i], diagFn)))
725 tupleType.getElementNames());
733 if (parserContext != ParserContext::Rewrite) {
734 return emitErrorFn()->attachNote(
"Tuple to Range conversion is currently "
735 "only allowed within a rewrite context");
740 if (!llvm::is_contained(allowedElementTypes, elementType))
741 return emitErrorFn();
746 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
748 ctx, expr->
getLoc(), expr, llvm::to_string(i),
754 if (type == valueRangeTy)
755 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
756 if (type == typeRangeTy)
757 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
759 return emitErrorFn();
766 StringRef directive = curToken.getSpelling();
767 if (directive ==
"#include")
768 return parseInclude(decls);
770 return emitError(
"unknown directive `" + directive +
"`");
774 SMRange loc = curToken.getLoc();
779 return codeCompleteIncludeFilename(curToken.getStringValue());
782 if (!curToken.isString())
784 "expected string file name after `include` directive");
785 SMRange fileLoc = curToken.getLoc();
786 std::string filenameStr = curToken.getStringValue();
787 StringRef filename = filenameStr;
792 if (filename.endswith(
".pdll")) {
793 if (
failed(lexer.pushInclude(filename, fileLoc)))
795 "unable to open include file `" + filename +
"`");
800 curToken = lexer.lexToken();
802 curToken = lexer.lexToken();
807 if (filename.endswith(
".td"))
808 return parseTdInclude(filename, fileLoc, decls);
811 "expected include filename to end with `.pdll` or `.td`");
814 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
816 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
819 std::string includedFile;
820 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
821 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
823 return emitError(fileLoc,
"unable to open include file `" + filename +
"`");
826 llvm::SourceMgr tdSrcMgr;
827 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
828 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
832 struct DiagHandlerContext {
836 } handlerContext{*
this, filename, fileLoc};
839 tdSrcMgr.setDiagHandler(
840 [](
const llvm::SMDiagnostic &
diag,
void *rawHandlerContext) {
841 auto *ctx =
reinterpret_cast<DiagHandlerContext *
>(rawHandlerContext);
842 (void)ctx->parser.emitError(
844 llvm::formatv(
"error while processing include file `{0}`: {1}",
845 ctx->filename,
diag.getMessage()));
850 llvm::RecordKeeper tdRecords;
851 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
855 processTdIncludeRecords(tdRecords, decls);
860 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
864 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
867 auto getLengthKind = [](
const auto &value) {
868 if (value.isOptional())
879 cst.constraint.getUniqueDefName(),
880 processDoc(cst.constraint.getSummary()),
881 cst.constraint.getCPPClassName());
883 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
884 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
889 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Op")) {
893 bool supportsResultTypeInferrence =
894 op.getTrait(
"::mlir::InferTypeOpInterface::Trait");
897 op.getOperationName(), processDoc(op.getSummary()),
898 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
899 supportsResultTypeInferrence, op.
getLoc().front());
906 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
908 attr.attr.getUniqueDefName(),
909 processDoc(attr.attr.getSummary()),
910 attr.attr.getStorageType()));
913 odsOp->appendOperand(operand.name, getLengthKind(operand),
914 addTypeConstraint(operand));
917 odsOp->appendResult(result.name, getLengthKind(result),
918 addTypeConstraint(result));
922 auto shouldBeSkipped = [
this](llvm::Record *def) {
923 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
924 def->isSubClassOf(
"DeclareInterfaceMethods");
928 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Attr")) {
929 if (shouldBeSkipped(def))
933 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
934 constraint, convertLocToRange(def->getLoc().front()), attrTy,
935 constraint.getStorageType()));
938 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Type")) {
939 if (shouldBeSkipped(def))
943 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
944 constraint, convertLocToRange(def->getLoc().front()), typeTy,
945 constraint.getCPPClassName()));
949 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"OpInterface")) {
950 if (shouldBeSkipped(def))
953 SMRange loc = convertLocToRange(def->getLoc().front());
955 std::string cppClassName =
956 llvm::formatv(
"{0}::{1}", def->getValueAsString(
"cppNamespace"),
957 def->getValueAsString(
"cppInterfaceName"))
959 std::string codeBlock =
960 llvm::formatv(
"return ::mlir::success(llvm::isa<{0}>(self));",
965 processAndFormatDoc(def->getValueAsString(
"description"));
966 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
967 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
971 template <
typename Constra
intT>
972 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
973 StringRef name, StringRef codeBlock, SMRange loc,
ast::Type type,
974 StringRef nativeType, StringRef docString) {
980 argScope->
add(paramVar);
988 constraintDecl->setDocComment(ctx, docString);
989 curDeclScope->add(constraintDecl);
990 return constraintDecl;
993 template <
typename Constra
intT>
997 StringRef nativeType) {
1008 std::string docString;
1009 if (enableDocumentation) {
1011 docString = processAndFormatDoc(
1016 return createODSNativePDLLConstraintDecl<ConstraintT>(
1026 switch (curToken.getKind()) {
1028 decl = parseUserConstraintDecl();
1031 decl = parsePatternDecl();
1034 decl = parseUserRewriteDecl();
1037 return emitError(
"expected top-level declaration, such as a `Pattern`");
1044 if (
failed(checkDefineNamedDecl(*name)))
1046 curDeclScope->add(*decl);
1052 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1055 return codeCompleteAttributeName(parentOpName);
1057 std::string attrNameStr;
1058 if (curToken.isString())
1059 attrNameStr = curToken.getStringValue();
1061 attrNameStr = curToken.getSpelling().str();
1063 return emitError(
"expected identifier or string attribute name");
1073 attrValue = *attrExpr;
1085 bool expectTerminalSemicolon) {
1089 SMLoc bodyStartLoc = curToken.getStartLoc();
1092 bool failedToParse =
1093 failed(singleStatement) ||
failed(processStatementFn(*singleStatement));
1098 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1105 return emitError(
"expected identifier argument name");
1108 StringRef name = curToken.getSpelling();
1109 SMRange nameLoc = curToken.getLoc();
1113 parseToken(
Token::colon,
"expected `:` before argument constraint")))
1120 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1131 StringRef name = curToken.getSpelling();
1132 SMRange nameLoc = curToken.getLoc();
1136 "expected `:` before result constraint")))
1143 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1153 return createArgOrResultVariableDecl(
"", cst->referenceLoc, *cst);
1157 Parser::parseUserConstraintDecl(
bool isInline) {
1160 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1161 [&](
auto &&...args) {
1162 return this->parseUserPDLLConstraintDecl(args...);
1164 ParserContext::Constraint,
"constraint", isInline);
1169 parseUserConstraintDecl(
true);
1170 if (
failed(decl) ||
failed(checkDefineNamedDecl((*decl)->getName())))
1173 curDeclScope->add(*decl);
1183 pushDeclScope(argumentScope);
1191 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1194 "expected `Constraint` lambda body to contain a "
1195 "single expression");
1211 auto bodyIt = body->
begin(), bodyE = body->
end();
1212 for (; bodyIt != bodyE; ++bodyIt)
1213 if (isa<ast::ReturnStmt>(*bodyIt))
1215 if (
failed(validateUserConstraintOrRewriteReturn(
1216 "Constraint", body, bodyIt, bodyE, results, resultType)))
1221 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1222 name, arguments, results, resultType, body);
1228 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1229 [&](
auto &&...args) {
return this->parseUserPDLLRewriteDecl(args...); },
1230 ParserContext::Rewrite,
"rewrite", isInline);
1235 parseUserRewriteDecl(
true);
1236 if (
failed(decl) ||
failed(checkDefineNamedDecl((*decl)->getName())))
1239 curDeclScope->add(*decl);
1249 curDeclScope = argumentScope;
1254 if (isa<ast::OpRewriteStmt>(statement))
1257 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1258 if (!statementExpr) {
1261 "expected `Rewrite` lambda body to contain a single expression "
1262 "or an operation rewrite statement; such as `erase`, "
1263 "`replace`, or `rewrite`");
1282 auto bodyIt = body->
begin(), bodyE = body->
end();
1283 for (; bodyIt != bodyE; ++bodyIt)
1284 if (isa<ast::ReturnStmt>(*bodyIt))
1286 if (
failed(validateUserConstraintOrRewriteReturn(
"Rewrite", body, bodyIt,
1287 bodyE, results, resultType)))
1289 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1290 name, arguments, results, resultType, body);
1293 template <
typename T,
typename ParseUserPDLLDeclFnT>
1295 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1296 StringRef anonymousNamePrefix,
bool isInline) {
1297 SMRange loc = curToken.getLoc();
1299 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1307 return emitError(
"expected identifier name");
1311 std::string anonName =
1312 llvm::formatv(
"<anonymous_{0}_{1}>", anonymousNamePrefix,
1313 anonymousDeclNameCounter++)
1326 if (
failed(parseUserConstraintOrRewriteSignature(arguments, results,
1327 argumentScope, resultType)))
1333 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1337 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1338 results, resultType);
1341 template <
typename T>
1347 std::string codeStrStorage;
1348 std::optional<StringRef> optCodeStr;
1349 if (curToken.isString()) {
1350 codeStrStorage = curToken.getStringValue();
1351 optCodeStr = codeStrStorage;
1353 }
else if (isInline) {
1355 "external declarations must be declared in global scope");
1360 "expected `;` after native declaration")))
1364 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1366 "native Constraints currently do not support returning results");
1368 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1371 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1379 argumentScope = pushDeclScope();
1385 arguments.emplace_back(*argument);
1399 results.emplace_back(*result);
1406 if (
failed(parseResultFn()))
1413 }
else if (
failed(parseResultFn())) {
1420 resultType = createUserConstraintRewriteResultType(results);
1423 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1425 results.front()->getLoc(),
1426 "cannot create a single-element tuple with an element label");
1431 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1437 if (bodyIt != bodyE) {
1439 if (std::next(bodyIt) != bodyE) {
1441 (*std::next(bodyIt))->getLoc(),
1442 llvm::formatv(
"`return` terminated the `{0}` body, but found "
1443 "trailing statements afterwards",
1449 }
else if (!results.empty()) {
1451 {body->getLoc().End, body->getLoc().End},
1452 llvm::formatv(
"missing return in a `{0}` expected to return `{1}`",
1453 declType, resultType));
1460 if (isa<ast::OpRewriteStmt>(statement))
1464 "expected Pattern lambda body to contain a single operation "
1465 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1470 SMRange loc = curToken.getLoc();
1472 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1482 ParsedPatternMetadata metadata;
1497 return emitError(
"expected `{` or `=>` to start pattern body");
1504 auto bodyIt = body->
begin(), bodyE = body->
end();
1505 for (; bodyIt != bodyE; ++bodyIt) {
1506 if (isa<ast::ReturnStmt>(*bodyIt)) {
1508 "`return` statements are only permitted within a "
1509 "`Constraint` or `Rewrite` body");
1512 if (isa<ast::OpRewriteStmt>(*bodyIt))
1515 if (bodyIt == bodyE) {
1517 "expected Pattern body to terminate with an operation "
1518 "rewrite statement, such as `erase`");
1520 if (std::next(bodyIt) != bodyE) {
1521 return emitError((*std::next(bodyIt))->getLoc(),
1522 "Pattern body was terminated by an operation "
1523 "rewrite statement, but found trailing statements");
1527 return createPatternDecl(loc, name, metadata, body);
1531 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1532 std::optional<SMRange> benefitLoc;
1533 std::optional<SMRange> hasBoundedRecursionLoc;
1538 return codeCompletePatternMetadata();
1541 return emitError(
"expected pattern metadata identifier");
1542 StringRef metadataStr = curToken.getSpelling();
1543 SMRange metadataLoc = curToken.getLoc();
1547 if (metadataStr ==
"benefit") {
1549 return emitErrorAndNote(metadataLoc,
1550 "pattern benefit has already been specified",
1551 *benefitLoc,
"see previous definition here");
1554 "expected `(` before pattern benefit")))
1557 uint16_t benefitValue = 0;
1559 return emitError(
"expected integral pattern benefit");
1560 if (curToken.getSpelling().getAsInteger(10, benefitValue))
1562 "expected pattern benefit to fit within a 16-bit integer");
1565 metadata.benefit = benefitValue;
1566 benefitLoc = metadataLoc;
1569 parseToken(
Token::r_paren,
"expected `)` after pattern benefit")))
1575 if (metadataStr ==
"recursion") {
1576 if (hasBoundedRecursionLoc) {
1577 return emitErrorAndNote(
1579 "pattern recursion metadata has already been specified",
1580 *hasBoundedRecursionLoc,
"see previous definition here");
1582 metadata.hasBoundedRecursion =
true;
1583 hasBoundedRecursionLoc = metadataLoc;
1587 return emitError(metadataLoc,
"unknown pattern metadata");
1599 "expected `>` after variable type constraint")))
1605 assert(curDeclScope &&
"defining decl outside of a decl scope");
1607 return emitErrorAndNote(
1608 name.
getLoc(),
"`" + name.
getName() +
"` has already been defined",
1609 lastDecl->getName()->getLoc(),
"see previous definition here");
1615 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1618 assert(curDeclScope &&
"defining variable outside of decl scope");
1623 if (name.empty() || name ==
"_") {
1627 if (
failed(checkDefineNamedDecl(nameDecl)))
1632 curDeclScope->add(varDecl);
1637 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1639 return defineVariableDecl(name, nameLoc, type,
nullptr,
1645 std::optional<SMRange> typeConstraint;
1646 auto parseSingleConstraint = [&] {
1648 typeConstraint, constraints,
true);
1651 constraints.push_back(*constraint);
1657 return parseSingleConstraint();
1660 if (
failed(parseSingleConstraint()))
1663 return parseToken(
Token::r_square,
"expected `]` after constraint list");
1667 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1669 bool allowInlineTypeConstraints) {
1671 if (!allowInlineTypeConstraints) {
1674 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1675 "permitted on arguments or results");
1678 return emitErrorAndNote(
1680 "the type of this variable has already been constrained",
1681 *typeConstraint,
"see previous constraint location here");
1683 if (
failed(constraintExpr))
1685 typeExpr = *constraintExpr;
1686 typeConstraint = typeExpr->getLoc();
1690 SMRange loc = curToken.getLoc();
1691 switch (curToken.getKind()) {
1708 parseWrappedOperationName(
true);
1753 StringRef constraintName = curToken.getSpelling();
1759 return emitError(loc,
"unknown reference to constraint `" +
1760 constraintName +
"`");
1764 if (
auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1767 return emitErrorAndNote(
1768 loc,
"invalid reference to non-constraint", cstDecl->
getLoc(),
1769 "see the definition of `" + constraintName +
"` here");
1775 if (
failed(validateVariableConstraints(existingConstraints, inferredType)))
1778 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1783 return emitError(loc,
"expected identifier constraint");
1787 std::optional<SMRange> typeConstraint;
1788 return parseConstraint(typeConstraint, std::nullopt,
1797 return parseUnderscoreExpr();
1801 switch (curToken.getKind()) {
1803 lhsExpr = parseAttributeExpr();
1806 lhsExpr = parseInlineConstraintLambdaExpr();
1809 lhsExpr = parseIdentifierExpr();
1812 lhsExpr = parseOperationExpr();
1815 lhsExpr = parseInlineRewriteLambdaExpr();
1818 lhsExpr = parseTypeExpr();
1821 lhsExpr = parseTupleExpr();
1824 return emitError(
"expected expression");
1831 switch (curToken.getKind()) {
1833 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1836 lhsExpr = parseCallExpr(*lhsExpr);
1847 SMRange loc = curToken.getLoc();
1854 return parseIdentifierExpr();
1857 if (!curToken.isString())
1858 return emitError(
"expected string literal containing MLIR attribute");
1859 std::string attrExpr = curToken.getStringValue();
1862 loc.End = curToken.getEndLoc();
1864 parseToken(
Token::greater,
"expected `>` after attribute literal")))
1878 codeCompleteCallSignature(parentExpr, arguments.size());
1885 arguments.push_back(*argument);
1889 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1893 return createCallExpr(loc, parentExpr, arguments);
1897 ast::Decl *decl = curDeclScope->lookup(name);
1899 return emitError(loc,
"undefined reference to `" + name +
"`");
1901 return createDeclRefExpr(loc, decl);
1905 StringRef name = curToken.getSpelling();
1906 SMRange nameLoc = curToken.getLoc();
1913 if (
failed(parseVariableDeclConstraintList(constraints)))
1916 if (
failed(validateVariableConstraints(constraints, type)))
1918 return createInlineVariableExpr(type, name, nameLoc, constraints);
1921 return parseDeclRefExpr(name, nameLoc);
1943 SMRange dotLoc = curToken.getLoc();
1948 return codeCompleteMemberAccess(parentExpr);
1951 Token memberNameTok = curToken;
1954 return emitError(dotLoc,
"expected identifier or numeric member name");
1955 StringRef memberName = memberNameTok.
getSpelling();
1956 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1959 return createMemberAccessExpr(parentExpr, memberName, loc);
1963 SMRange loc = curToken.getLoc();
1967 return codeCompleteDialectName();
1973 return emitError(
"expected dialect namespace");
1975 StringRef name = curToken.getSpelling();
1979 if (
failed(parseToken(
Token::dot,
"expected `.` after dialect namespace")))
1984 return codeCompleteOperationName(name);
1987 return emitError(
"expected operation name after dialect namespace");
1989 name = StringRef(name.data(), name.size() + 1);
1991 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1992 loc.End = curToken.getEndLoc();
1995 curToken.isKeyword());
2000 Parser::parseWrappedOperationName(
bool allowEmptyName) {
2014 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2015 SMRange loc = curToken.getLoc();
2022 return parseIdentifierExpr();
2028 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2030 parseWrappedOperationName(allowEmptyName);
2033 std::optional<StringRef> opName = (*opNameDecl)->getName();
2040 assert(
succeeded(rangeVar) &&
"expected range variable to be valid");
2051 if (parserContext != ParserContext::Rewrite) {
2052 operands.push_back(createImplicitRangeVar(
2060 codeCompleteOperationOperandsSignature(opName, operands.size());
2067 operands.push_back(*operand);
2071 "expected `)` after operation operand list")))
2080 parseNamedAttributeDecl(opName);
2083 attributes.emplace_back(*decl);
2087 "expected `}` after operation attribute list")))
2093 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2098 "expected `(` before operation result type list")))
2106 resultTypeContext = OpResultTypeContext::Explicit;
2113 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2118 if (
failed(resultTypeExpr))
2120 resultTypes.push_back(*resultTypeExpr);
2124 "expected `)` after operation result type list")))
2127 }
else if (parserContext != ParserContext::Rewrite) {
2132 resultTypes.push_back(createImplicitRangeVar(
2134 }
else if (resultTypeContext == OpResultTypeContext::Explicit) {
2137 resultTypeContext = OpResultTypeContext::Interface;
2140 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2141 attributes, resultTypes);
2145 SMRange loc = curToken.getLoc();
2154 StringRef elementName;
2156 Token elementNameTok = curToken;
2164 auto elementNameIt =
2165 usedNames.try_emplace(elementName, elementNameTok.
getLoc());
2166 if (!elementNameIt.second) {
2167 return emitErrorAndNote(
2169 llvm::formatv(
"duplicate tuple element label `{0}`",
2171 elementNameIt.first->getSecond(),
2172 "see previous label use here");
2177 resetToken(elementNameTok.
getLoc());
2180 elementNames.push_back(elementName);
2186 elements.push_back(*element);
2189 loc.End = curToken.getEndLoc();
2191 parseToken(
Token::r_paren,
"expected `)` after tuple element list")))
2193 return createTupleExpr(loc, elements, elementNames);
2197 SMRange loc = curToken.getLoc();
2204 return parseIdentifierExpr();
2207 if (!curToken.isString())
2208 return emitError(
"expected string literal containing MLIR type");
2209 std::string attrExpr = curToken.getStringValue();
2212 loc.End = curToken.getEndLoc();
2219 StringRef name = curToken.getSpelling();
2220 SMRange nameLoc = curToken.getLoc();
2229 if (
failed(parseVariableDeclConstraintList(constraints)))
2233 if (
failed(validateVariableConstraints(constraints, type)))
2235 return createInlineVariableExpr(type, name, nameLoc, constraints);
2243 switch (curToken.getKind()) {
2245 stmt = parseEraseStmt();
2248 stmt = parseLetStmt();
2251 stmt = parseReplaceStmt();
2254 stmt = parseReturnStmt();
2257 stmt = parseRewriteStmt();
2264 (expectTerminalSemicolon &&
2271 SMLoc startLoc = curToken.getStartLoc();
2280 return popDeclScope(),
failure();
2281 statements.push_back(*statement);
2286 SMRange location(startLoc, curToken.getEndLoc());
2293 if (parserContext == ParserContext::Constraint)
2294 return emitError(
"`erase` cannot be used within a Constraint");
2295 SMRange loc = curToken.getLoc();
2303 return createEraseStmt(loc, *rootOp);
2307 SMRange loc = curToken.getLoc();
2311 SMRange varLoc = curToken.getLoc();
2316 "`_` may only be used to define \"inline\" variables");
2319 "expected identifier after `let` to name a new variable");
2321 StringRef varName = curToken.getSpelling();
2327 failed(parseVariableDeclConstraintList(constraints)))
2334 if (
failed(initOrFailure))
2336 initializer = *initOrFailure;
2345 if (
auto *typeConstraintExpr = cst->getTypeExpr()) {
2347 constraint.referenceLoc,
2348 "type constraints are not permitted on variables with "
2360 createVariableDecl(varName, varLoc, initializer, constraints);
2367 if (parserContext == ParserContext::Constraint)
2368 return emitError(
"`replace` cannot be used within a Constraint");
2369 SMRange loc = curToken.getLoc();
2378 parseToken(
Token::kw_with,
"expected `with` after root operation")))
2382 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2389 loc,
"expected at least one replacement value, consider using "
2390 "`erase` if no replacement values are desired");
2397 replValues.emplace_back(*replExpr);
2401 "expected `)` after replacement values")))
2408 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2410 replExpr = parseExpr();
2413 replValues.emplace_back(*replExpr);
2416 return createReplaceStmt(loc, *rootOp, replValues);
2420 SMRange loc = curToken.getLoc();
2432 if (parserContext == ParserContext::Constraint)
2433 return emitError(
"`rewrite` cannot be used within a Constraint");
2434 SMRange loc = curToken.getLoc();
2446 return emitError(
"expected `{` to start rewrite body");
2449 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2456 for (
const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2457 if (isa<ast::ReturnStmt>(stmt)) {
2459 "`return` statements are only permitted within a "
2460 "`Constraint` or `Rewrite` body");
2464 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2476 if (
auto *init = dyn_cast<ast::DeclRefExpr>(node))
2477 node = init->getDecl();
2478 return dyn_cast<ast::CallableDecl>(node);
2482 Parser::createPatternDecl(SMRange loc,
const ast::Name *name,
2483 const ParsedPatternMetadata &metadata,
2486 metadata.hasBoundedRecursion, body);
2489 ast::Type Parser::createUserConstraintRewriteResultType(
2492 if (results.size() == 1)
2493 return results[0]->getType();
2497 auto resultTypes = llvm::map_range(
2498 results, [&](
const auto *result) {
return result->getType(); });
2499 auto resultNames = llvm::map_range(
2500 results, [&](
const auto *result) {
return result->getName().getName(); });
2502 llvm::to_vector(resultNames));
2505 template <
typename T>
2511 if (
auto *retStmt = dyn_cast<ast::ReturnStmt>(body->
getChildren().back())) {
2512 ast::Expr *resultExpr = retStmt->getResultExpr();
2517 if (results.empty())
2518 resultType = resultExpr->
getType();
2519 else if (
failed(convertExpressionTo(resultExpr, resultType)))
2522 retStmt->setResultExpr(resultExpr);
2525 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2529 Parser::createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
2534 if (
failed(validateVariableConstraints(constraints, type)))
2541 type = initializer->
getType();
2544 else if (
failed(convertExpressionTo(initializer, type)))
2550 return emitErrorAndNote(
2551 loc,
"unable to infer type for variable `" + name +
"`", loc,
2552 "the type of a variable must be inferable from the constraint "
2553 "list or the initializer");
2559 loc, llvm::formatv(
"unable to define variable of `{0}` type", type));
2564 defineVariableDecl(name, loc, type, initializer, constraints);
2572 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2575 if (
failed(validateVariableConstraint(constraint, argType)))
2577 return defineVariableDecl(name, loc, argType, constraint);
2584 if (
failed(validateVariableConstraint(ref, inferredType)))
2592 if (
const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.
constraint)) {
2593 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2594 if (
failed(validateTypeConstraintExpr(typeExpr)))
2598 }
else if (
const auto *cst =
2599 dyn_cast<ast::OpConstraintDecl>(ref.
constraint)) {
2602 }
else if (isa<ast::TypeConstraintDecl>(ref.
constraint)) {
2603 constraintType = typeTy;
2604 }
else if (isa<ast::TypeRangeConstraintDecl>(ref.
constraint)) {
2605 constraintType = typeRangeTy;
2606 }
else if (
const auto *cst =
2607 dyn_cast<ast::ValueConstraintDecl>(ref.
constraint)) {
2608 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2609 if (
failed(validateTypeConstraintExpr(typeExpr)))
2612 constraintType = valueTy;
2613 }
else if (
const auto *cst =
2614 dyn_cast<ast::ValueRangeConstraintDecl>(ref.
constraint)) {
2615 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2616 if (
failed(validateTypeRangeConstraintExpr(typeExpr)))
2619 constraintType = valueRangeTy;
2620 }
else if (
const auto *cst =
2621 dyn_cast<ast::UserConstraintDecl>(ref.
constraint)) {
2623 if (inputs.size() != 1) {
2625 "`Constraint`s applied via a variable constraint "
2626 "list must take a single input, but got " +
2627 Twine(inputs.size()),
2629 "see definition of constraint here");
2631 constraintType = inputs.front()->getType();
2633 llvm_unreachable(
"unknown constraint type");
2638 if (!inferredType) {
2639 inferredType = constraintType;
2641 inferredType = mergedTy;
2644 llvm::formatv(
"constraint type `{0}` is incompatible "
2645 "with the previously inferred type `{1}`",
2646 constraintType, inferredType));
2653 if (typeExprType != typeTy) {
2655 "expected expression of `Type` in type constraint");
2661 Parser::validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr) {
2663 if (typeExprType != typeRangeTy) {
2665 "expected expression of `TypeRange` in type constraint");
2674 Parser::createCallExpr(SMRange loc,
ast::Expr *parentExpr,
2679 if (!callableDecl) {
2681 llvm::formatv(
"expected a reference to a callable "
2682 "`Constraint` or `Rewrite`, but got: `{0}`",
2685 if (parserContext == ParserContext::Rewrite) {
2686 if (isa<ast::UserConstraintDecl>(callableDecl))
2688 loc,
"unable to invoke `Constraint` within a rewrite section");
2689 }
else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2690 return emitError(loc,
"unable to invoke `Rewrite` within a match section");
2696 if (callArgs.size() != arguments.size()) {
2697 return emitErrorAndNote(
2699 llvm::formatv(
"invalid number of arguments for {0} call; expected "
2704 llvm::formatv(
"see the definition of {0} here",
2710 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here",
2714 for (
auto it : llvm::zip(callArgs, arguments)) {
2715 if (
failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2728 if (isa<ast::ConstraintDecl>(decl))
2730 else if (isa<ast::UserRewriteDecl>(decl))
2732 else if (
auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2733 declType = varDecl->getType();
2735 return emitError(loc,
"invalid reference to `" +
2742 Parser::createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
2745 defineVariableDecl(name, loc, type, constraints);
2752 Parser::createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name,
2763 StringRef name, SMRange loc) {
2767 return valueRangeTy;
2771 auto results = odsOp->getResults();
2775 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2776 index < results.size()) {
2777 return results[index].isVariadic() ? valueRangeTy : valueTy;
2781 const auto *it = llvm::find_if(results, [&](
const auto &result) {
2782 return result.getName() == name;
2784 if (it != results.end())
2785 return it->isVariadic() ? valueRangeTy : valueTy;
2786 }
else if (llvm::isDigit(name[0])) {
2794 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2795 index < tupleType.size()) {
2796 return tupleType.getElementTypes()[index];
2800 auto elementNames = tupleType.getElementNames();
2801 const auto *it = llvm::find(elementNames, name);
2802 if (it != elementNames.end())
2803 return tupleType.getElementTypes()[it - elementNames.begin()];
2807 llvm::formatv(
"invalid member access `{0}` on expression of type `{1}`",
2813 OpResultTypeContext resultTypeContext,
2817 std::optional<StringRef> opNameRef = name->
getName();
2821 if (
failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2827 ast::Type attrType = attr->getValue()->getType();
2830 attr->getValue()->getLoc(),
2831 llvm::formatv(
"expected `Attr` expression, but got `{0}`", attrType));
2836 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2837 "unexpected inferrence when results were explicitly specified");
2841 if (resultTypeContext == OpResultTypeContext::Explicit) {
2842 if (
failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2846 }
else if (resultTypeContext == OpResultTypeContext::Interface) {
2848 "expected valid operation name when inferring operation results");
2849 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2857 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2860 return validateOperationOperandsOrResults(
2861 "operand", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2862 operands, odsOp ? odsOp->
getOperands() : std::nullopt, valueTy,
2867 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2870 return validateOperationOperandsOrResults(
2871 "result", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2872 results, odsOp ? odsOp->
getResults() : std::nullopt, typeTy, typeRangeTy);
2875 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2884 ctx.getDiagEngine().emitWarning(
2886 "operation result types are marked to be inferred, but "
2887 "`{0}` is unknown. Ensure that `{0}` supports zero "
2888 "results or implements `InferTypeOpInterface`. Include "
2889 "the ODS definition of this operation to remove this warning.",
2898 bool requiresInferrence =
2900 return !result.isVariableLength();
2905 llvm::formatv(
"operation result types are marked to be inferred, but "
2906 "`{0}` does not provide an implementation of "
2907 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2908 "`InferTypeOpInterface` at runtime, or add support to "
2909 "the ODS definition to remove this warning.",
2911 diag->attachNote(llvm::formatv(
"see the definition of `{0}` here", opName),
2918 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2923 if (values.size() == 1) {
2924 if (
failed(convertExpressionTo(values[0], rangeTy)))
2932 auto emitSizeMismatchError = [&] {
2933 return emitErrorAndNote(
2935 llvm::formatv(
"invalid number of {0} groups for `{1}`; expected "
2937 groupName, *name, odsValues.size(), values.size()),
2938 *odsOpLoc, llvm::formatv(
"see the definition of `{0}` here", *name));
2942 if (values.empty()) {
2944 if (odsValues.empty())
2949 unsigned numVariadic = 0;
2950 for (
const auto &odsValue : odsValues) {
2951 if (!odsValue.isVariableLength())
2952 return emitSizeMismatchError();
2958 if (parserContext != ParserContext::Rewrite)
2965 if (numVariadic == 1)
2970 for (
unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2972 ctx, loc, std::nullopt, rangeTy));
2979 if (odsValues.size() != values.size())
2980 return emitSizeMismatchError();
2983 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here", *name),
2986 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
2987 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2988 if (
failed(convertExpressionTo(values[i], expectedType, diagFn)))
2997 ast::Type valueExprType = valueExpr->getType();
3000 if (valueExprType == rangeTy || valueExprType == singleTy)
3006 if (singleTy == valueTy) {
3008 valueExpr = convertOpToValue(valueExpr);
3014 if (
succeeded(convertExpressionTo(valueExpr, rangeTy)))
3018 valueExpr->getLoc(),
3020 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3021 singleTy, rangeTy, valueExprType));
3029 for (
const ast::Expr *element : elements) {
3034 llvm::formatv(
"unable to build a tuple with `{0}` element", eleTy));
3054 Parser::createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
3061 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3066 bool shouldConvertOpToValues = replValues.size() > 1;
3067 for (
ast::Expr *&replExpr : replValues) {
3068 ast::Type replType = replExpr->getType();
3072 if (shouldConvertOpToValues)
3073 replExpr = convertOpToValue(replExpr);
3077 if (replType != valueTy && replType != valueRangeTy) {
3079 llvm::formatv(
"expected `Op`, `Value` or `ValueRange` "
3080 "expression, but got `{0}`",
3089 Parser::createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
3096 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3109 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3111 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3116 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3118 codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3123 Parser::codeCompleteConstraintName(
ast::Type inferredType,
3124 bool allowInlineTypeConstraints) {
3125 codeCompleteContext->codeCompleteConstraintName(
3126 inferredType, allowInlineTypeConstraints, curDeclScope);
3131 codeCompleteContext->codeCompleteDialectName();
3135 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3136 codeCompleteContext->codeCompleteOperationName(dialectName);
3141 codeCompleteContext->codeCompletePatternMetadata();
3145 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3146 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3150 void Parser::codeCompleteCallSignature(
ast::Node *parent,
3151 unsigned currentNumArgs) {
3156 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3159 void Parser::codeCompleteOperationOperandsSignature(
3160 std::optional<StringRef> opName,
unsigned currentNumOperands) {
3161 codeCompleteContext->codeCompleteOperationOperandsSignature(
3162 opName, currentNumOperands);
3165 void Parser::codeCompleteOperationResultsSignature(
3166 std::optional<StringRef> opName,
unsigned currentNumResults) {
3167 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3177 bool enableDocumentation,
3179 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3180 return parser.parseModule();
static std::string diag(const llvm::Value &value)
This class provides support for representing a failure result, or a valid value of type T.
This class breaks up the current file into a token stream.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This represents a token in the MLIR syntax.
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
StringRef getSpelling() const
This class provides an abstract interface into the parser for hooking in code completion events.
@ code_complete_string
Token signifying a code completion location within a string.
@ code_complete
Token signifying a code completion location.
@ less
Paired punctuation.
@ kw_Attr
General keywords.
static StringRef getMemberName()
Return the member name used for the "all-results" access.
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
This class represents a PDLL type that corresponds to an mlir::Attribute.
static AttributeType get(Context &context)
Return an instance of the Attribute type.
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType)
This decl represents a shared interface for all callable decls.
Type getResultType() const
Return the result type of this decl.
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
StringRef getCallableType() const
Return the callable type of this decl.
This statement represents a compound statement, which contains a collection of other statements.
ArrayRef< Stmt * >::iterator begin() const
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
ArrayRef< Stmt * >::iterator end() const
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
This class represents the base of all AST Constraint decls.
This class represents a PDLL type that corresponds to a constraint.
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
This class represents the main context of the PDLL AST.
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
This class represents a scope for named AST decls.
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
void add(Decl *decl)
Add a new decl to the scope.
This class represents the base Decl node.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
This class provides a simple implementation of a PDLL diagnostic.
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a diagnostic that is inflight and set to be reported.
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
This Decl represents a NamedAttribute, and contains a string name and attribute value.
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
This Decl represents an OperationName.
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
static OpNameDecl * create(Context &ctx, const Name &name)
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
This class represents a PDLL type that corresponds to a range of elements with a given element type.
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
This class represents a PDLL type that corresponds to a rewrite reference.
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
This class represents a base AST Statement node.
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
This class represents a PDLL tuple type, i.e.
size_t size() const
Return the number of elements within this tuple.
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Type refineWith(Type other) const
Try to refine this type with the one provided.
bool isa() const
Provide type casting support.
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
The class represents a Value constraint, and constrains a variable to be a Value.
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
This class contains all of the registered ODS operation classes.
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
StringRef getDescription() const
std::string getConditionTemplate() const
Format context containing substitutions for special placeholders.
FmtContext & withSelf(Twine subst)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
const ConstraintDecl * constraint
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.
static const Name & create(Context &ctx, StringRef name, SMRange location)