25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
48 : ctx(ctx), lexer(sourceMgr, ctx.
getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
54 codeCompleteContext(codeCompleteContext) {}
57 FailureOr<ast::Module *> parseModule();
65 enum class ParserContext {
82 enum class OpResultTypeContext {
101 return (curDeclScope = newScope);
103 void pushDeclScope(
ast::DeclScope *scope) { curDeclScope = scope; }
106 void popDeclScope() { curDeclScope = curDeclScope->
getParentScope(); }
115 LogicalResult convertExpressionTo(
122 LogicalResult convertTupleExpressionTo(
133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
139 StringRef processDoc(StringRef doc) {
140 return enableDocumentation ? doc : StringRef();
145 std::string processAndFormatDoc(
const Twine &doc) {
146 if (!enableDocumentation)
150 llvm::raw_string_ostream docOS(docStr);
152 StringRef(docStr).rtrim(
" \t"));
162 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
166 void processTdIncludeRecords(
const llvm::RecordKeeper &tdRecords,
171 template <
typename Constra
intT>
173 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
175 StringRef nativeType, StringRef docString);
176 template <
typename Constra
intT>
180 StringRef nativeType);
186 struct ParsedPatternMetadata {
187 std::optional<uint16_t> benefit;
188 bool hasBoundedRecursion =
false;
191 FailureOr<ast::Decl *> parseTopLevelDecl();
192 FailureOr<ast::NamedAttributeDecl *>
193 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
197 FailureOr<ast::VariableDecl *> parseArgumentDecl();
201 FailureOr<ast::VariableDecl *> parseResultDecl(
unsigned resultNum);
205 FailureOr<ast::UserConstraintDecl *>
206 parseUserConstraintDecl(
bool isInline =
false);
210 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
214 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
221 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(
bool isInline =
false);
225 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
229 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
237 template <
typename T,
typename ParseUserPDLLDeclFnT>
238 FailureOr<T *> parseUserConstraintOrRewriteDecl(
239 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240 StringRef anonymousNamePrefix,
bool isInline);
244 template <
typename T>
245 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
252 LogicalResult parseUserConstraintOrRewriteSignature(
259 LogicalResult validateUserConstraintOrRewriteReturn(
265 FailureOr<ast::CompoundStmt *>
267 bool expectTerminalSemicolon =
true);
268 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269 FailureOr<ast::Decl *> parsePatternDecl();
270 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
274 LogicalResult checkDefineNamedDecl(
const ast::Name &name);
278 FailureOr<ast::VariableDecl *>
279 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
282 FailureOr<ast::VariableDecl *>
283 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
287 LogicalResult parseVariableDeclConstraintList(
291 FailureOr<ast::Expr *> parseTypeConstraintExpr();
299 FailureOr<ast::ConstraintRef>
300 parseConstraint(std::optional<SMRange> &typeConstraint,
302 bool allowInlineTypeConstraints);
307 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
312 FailureOr<ast::Expr *> parseExpr();
315 FailureOr<ast::Expr *> parseAttributeExpr();
316 FailureOr<ast::Expr *> parseCallExpr(
ast::Expr *parentExpr,
317 bool isNegated =
false);
318 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319 FailureOr<ast::Expr *> parseIdentifierExpr();
320 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322 FailureOr<ast::Expr *> parseMemberAccessExpr(
ast::Expr *parentExpr);
323 FailureOr<ast::Expr *> parseNegatedExpr();
324 FailureOr<ast::OpNameDecl *> parseOperationName(
bool allowEmptyName =
false);
325 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(
bool allowEmptyName);
326 FailureOr<ast::Expr *>
327 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328 OpResultTypeContext::Explicit);
329 FailureOr<ast::Expr *> parseTupleExpr();
330 FailureOr<ast::Expr *> parseTypeExpr();
331 FailureOr<ast::Expr *> parseUnderscoreExpr();
336 FailureOr<ast::Stmt *> parseStmt(
bool expectTerminalSemicolon =
true);
337 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338 FailureOr<ast::EraseStmt *> parseEraseStmt();
339 FailureOr<ast::LetStmt *> parseLetStmt();
340 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341 FailureOr<ast::ReturnStmt *> parseReturnStmt();
342 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
357 FailureOr<ast::PatternDecl *>
358 createPatternDecl(SMRange loc,
const ast::Name *name,
359 const ParsedPatternMetadata &metadata,
368 template <
typename T>
369 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
376 FailureOr<ast::VariableDecl *>
377 createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
382 FailureOr<ast::VariableDecl *>
383 createArgOrResultVariableDecl(StringRef name, SMRange loc,
400 LogicalResult validateTypeConstraintExpr(
const ast::Expr *typeExpr);
401 LogicalResult validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr);
406 FailureOr<ast::CallExpr *>
407 createCallExpr(SMRange loc,
ast::Expr *parentExpr,
409 bool isNegated =
false);
410 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc,
ast::Decl *decl);
411 FailureOr<ast::DeclRefExpr *>
412 createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
414 FailureOr<ast::MemberAccessExpr *>
415 createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name, SMRange loc);
419 FailureOr<ast::Type> validateMemberAccess(
ast::Expr *parentExpr,
420 StringRef name, SMRange loc);
421 FailureOr<ast::OperationExpr *>
423 OpResultTypeContext resultTypeContext,
428 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
431 LogicalResult validateOperationResults(SMRange loc,
432 std::optional<StringRef> name,
435 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
437 LogicalResult validateOperationOperandsOrResults(
438 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
442 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
449 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc,
ast::Expr *rootOp);
450 FailureOr<ast::ReplaceStmt *>
451 createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
453 FailureOr<ast::RewriteStmt *>
454 createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
465 LogicalResult codeCompleteMemberAccess(
ast::Expr *parentExpr);
466 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467 LogicalResult codeCompleteConstraintName(
ast::Type inferredType,
468 bool allowInlineTypeConstraints);
469 LogicalResult codeCompleteDialectName();
470 LogicalResult codeCompleteOperationName(StringRef dialectName);
471 LogicalResult codeCompletePatternMetadata();
472 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
474 void codeCompleteCallSignature(
ast::Node *parent,
unsigned currentNumArgs);
475 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476 unsigned currentNumOperands);
477 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478 unsigned currentNumResults);
487 if (curToken.isNot(kind))
494 void consumeToken() {
496 "shouldn't advance past EOF or errors");
497 curToken = lexer.lexToken();
504 assert(curToken.is(kind) &&
"consumed an unexpected token");
509 void resetToken(SMRange tokLoc) {
510 lexer.resetPointer(tokLoc.Start.getPointer());
511 curToken = lexer.lexToken();
516 LogicalResult parseToken(
Token::Kind kind,
const Twine &msg) {
517 if (curToken.getKind() != kind)
518 return emitError(curToken.getLoc(), msg);
522 LogicalResult
emitError(SMRange loc,
const Twine &msg) {
523 lexer.emitError(loc, msg);
526 LogicalResult
emitError(
const Twine &msg) {
527 return emitError(curToken.getLoc(), msg);
529 LogicalResult emitErrorAndNote(SMRange loc,
const Twine &msg, SMRange noteLoc,
531 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
550 bool enableDocumentation;
554 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
557 ParserContext parserContext = ParserContext::Global;
565 unsigned anonymousDeclNameCounter = 0;
572 FailureOr<ast::Module *> Parser::parseModule() {
573 SMLoc moduleLoc = curToken.getStartLoc();
578 if (failed(parseModuleBody(decls)))
579 return popDeclScope(), failure();
588 if (failed(parseDirective(decls)))
593 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
596 decls.push_back(*decl);
606 LogicalResult Parser::convertExpressionTo(
610 if (exprType == type)
615 expr->
getLoc(), llvm::formatv(
"unable to convert expression of type "
616 "`{0}` to the expected type of "
624 if (
auto exprOpType = dyn_cast<ast::OperationType>(exprType))
625 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
631 if ((exprType == valueTy || exprType == valueRangeTy) &&
632 (type == valueTy || type == valueRangeTy))
634 if ((exprType == typeTy || exprType == typeRangeTy) &&
635 (type == typeTy || type == typeRangeTy))
639 if (
auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
640 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
643 return emitConvertError();
646 LogicalResult Parser::convertOpExpressionTo(
651 if (
auto opType = dyn_cast<ast::OperationType>(type)) {
652 if (opType.getName())
653 return emitErrorFn();
658 if (type == valueRangeTy) {
665 if (type == valueTy) {
669 if (odsOp->getResults().empty()) {
670 return emitErrorFn()->attachNote(
671 llvm::formatv(
"see the definition of `{0}`, which was defined "
677 unsigned numSingleResults = llvm::count_if(
679 return result.getVariableLengthKind() ==
680 ods::VariableLengthKind::Single;
682 if (numSingleResults > 1) {
683 return emitErrorFn()->attachNote(
684 llvm::formatv(
"see the definition of `{0}`, which was defined "
685 "with at least {1} results",
686 odsOp->getName(), numSingleResults),
695 return emitErrorFn();
698 LogicalResult Parser::convertTupleExpressionTo(
703 if (
auto tupleType = dyn_cast<ast::TupleType>(type)) {
704 if (tupleType.size() != exprType.
size())
705 return emitErrorFn();
710 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
712 ctx, expr->
getLoc(), expr, llvm::to_string(i),
716 diag.attachNote(llvm::formatv(
"when converting element #{0} of `{1}`",
721 if (failed(convertExpressionTo(newExprs.back(),
722 tupleType.getElementTypes()[i], diagFn)))
726 tupleType.getElementNames());
734 if (parserContext != ParserContext::Rewrite) {
735 return emitErrorFn()->attachNote(
"Tuple to Range conversion is currently "
736 "only allowed within a rewrite context");
741 if (!llvm::is_contained(allowedElementTypes, elementType))
742 return emitErrorFn();
747 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
749 ctx, expr->
getLoc(), expr, llvm::to_string(i),
755 if (type == valueRangeTy)
756 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757 if (type == typeRangeTy)
758 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
760 return emitErrorFn();
767 StringRef directive = curToken.getSpelling();
768 if (directive ==
"#include")
769 return parseInclude(decls);
771 return emitError(
"unknown directive `" + directive +
"`");
775 SMRange loc = curToken.getLoc();
780 return codeCompleteIncludeFilename(curToken.getStringValue());
783 if (!curToken.isString())
785 "expected string file name after `include` directive");
786 SMRange fileLoc = curToken.getLoc();
787 std::string filenameStr = curToken.getStringValue();
788 StringRef filename = filenameStr;
793 if (filename.ends_with(
".pdll")) {
794 if (failed(lexer.pushInclude(filename, fileLoc)))
796 "unable to open include file `" + filename +
"`");
801 curToken = lexer.lexToken();
802 LogicalResult result = parseModuleBody(decls);
803 curToken = lexer.lexToken();
808 if (filename.ends_with(
".td"))
809 return parseTdInclude(filename, fileLoc, decls);
812 "expected include filename to end with `.pdll` or `.td`");
815 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
817 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
820 std::string includedFile;
821 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
822 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
824 return emitError(fileLoc,
"unable to open include file `" + filename +
"`");
827 llvm::SourceMgr tdSrcMgr;
828 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
829 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
833 struct DiagHandlerContext {
837 } handlerContext{*
this, filename, fileLoc};
840 tdSrcMgr.setDiagHandler(
841 [](
const llvm::SMDiagnostic &
diag,
void *rawHandlerContext) {
842 auto *ctx =
reinterpret_cast<DiagHandlerContext *
>(rawHandlerContext);
843 (void)ctx->parser.emitError(
845 llvm::formatv(
"error while processing include file `{0}`: {1}",
846 ctx->filename,
diag.getMessage()));
851 llvm::RecordKeeper tdRecords;
852 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
856 processTdIncludeRecords(tdRecords, decls);
861 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
865 void Parser::processTdIncludeRecords(
const llvm::RecordKeeper &tdRecords,
868 auto getLengthKind = [](
const auto &value) {
869 if (value.isOptional())
880 cst.constraint.getUniqueDefName(),
881 processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
883 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
884 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
889 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Op")) {
893 bool supportsResultTypeInferrence =
894 op.getTrait(
"::mlir::InferTypeOpInterface::Trait");
897 op.getOperationName(), processDoc(op.getSummary()),
898 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
899 supportsResultTypeInferrence, op.getLoc().front());
906 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
908 attr.attr.getUniqueDefName(),
909 processDoc(attr.attr.getSummary()),
910 attr.attr.getStorageType()));
913 odsOp->appendOperand(operand.name, getLengthKind(operand),
914 addTypeConstraint(operand));
917 odsOp->appendResult(result.name, getLengthKind(result),
918 addTypeConstraint(result));
922 auto shouldBeSkipped = [
this](
const llvm::Record *def) {
923 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
924 def->isSubClassOf(
"DeclareInterfaceMethods");
928 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Attr")) {
929 if (shouldBeSkipped(def))
933 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
934 constraint, convertLocToRange(def->getLoc().front()), attrTy,
935 constraint.getStorageType()));
938 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Type")) {
939 if (shouldBeSkipped(def))
943 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
944 constraint, convertLocToRange(def->getLoc().front()), typeTy,
945 constraint.getCppType()));
949 for (
const llvm::Record *def :
950 tdRecords.getAllDerivedDefinitions(
"OpInterface")) {
951 if (shouldBeSkipped(def))
954 SMRange loc = convertLocToRange(def->getLoc().front());
956 std::string cppClassName =
957 llvm::formatv(
"{0}::{1}", def->getValueAsString(
"cppNamespace"),
958 def->getValueAsString(
"cppInterfaceName"))
960 std::string codeBlock =
961 llvm::formatv(
"return ::mlir::success(llvm::isa<{0}>(self));",
966 processAndFormatDoc(def->getValueAsString(
"description"));
967 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
968 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
972 template <
typename Constra
intT>
973 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
974 StringRef name, StringRef codeBlock, SMRange loc,
ast::Type type,
975 StringRef nativeType, StringRef docString) {
981 argScope->
add(paramVar);
989 constraintDecl->setDocComment(ctx, docString);
990 curDeclScope->add(constraintDecl);
991 return constraintDecl;
994 template <
typename Constra
intT>
998 StringRef nativeType) {
1009 std::string docString;
1010 if (enableDocumentation) {
1012 docString = processAndFormatDoc(
1017 return createODSNativePDLLConstraintDecl<ConstraintT>(
1025 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1026 FailureOr<ast::Decl *> decl;
1027 switch (curToken.getKind()) {
1029 decl = parseUserConstraintDecl();
1032 decl = parsePatternDecl();
1035 decl = parseUserRewriteDecl();
1038 return emitError(
"expected top-level declaration, such as a `Pattern`");
1045 if (failed(checkDefineNamedDecl(*name)))
1047 curDeclScope->add(*decl);
1052 FailureOr<ast::NamedAttributeDecl *>
1053 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1056 return codeCompleteAttributeName(parentOpName);
1058 std::string attrNameStr;
1059 if (curToken.isString())
1060 attrNameStr = curToken.getStringValue();
1062 attrNameStr = curToken.getSpelling().str();
1064 return emitError(
"expected identifier or string attribute name");
1071 FailureOr<ast::Expr *> attrExpr = parseExpr();
1072 if (failed(attrExpr))
1074 attrValue = *attrExpr;
1084 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1086 bool expectTerminalSemicolon) {
1090 SMLoc bodyStartLoc = curToken.getStartLoc();
1092 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1093 bool failedToParse =
1094 failed(singleStatement) || failed(processStatementFn(*singleStatement));
1099 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1103 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1106 return emitError(
"expected identifier argument name");
1109 StringRef name = curToken.getSpelling();
1110 SMRange nameLoc = curToken.getLoc();
1114 parseToken(
Token::colon,
"expected `:` before argument constraint")))
1117 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1121 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1124 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(
unsigned resultNum) {
1132 StringRef name = curToken.getSpelling();
1133 SMRange nameLoc = curToken.getLoc();
1137 "expected `:` before result constraint")))
1140 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1144 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1150 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1154 return createArgOrResultVariableDecl(
"", cst->referenceLoc, *cst);
1157 FailureOr<ast::UserConstraintDecl *>
1158 Parser::parseUserConstraintDecl(
bool isInline) {
1161 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1162 [&](
auto &&...args) {
1163 return this->parseUserPDLLConstraintDecl(args...);
1165 ParserContext::Constraint,
"constraint", isInline);
1168 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1169 FailureOr<ast::UserConstraintDecl *> decl =
1170 parseUserConstraintDecl(
true);
1171 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1174 curDeclScope->add(*decl);
1178 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1184 pushDeclScope(argumentScope);
1190 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1191 [&](
ast::Stmt *&stmt) -> LogicalResult {
1192 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1195 "expected `Constraint` lambda body to contain a "
1196 "single expression");
1202 if (failed(bodyResult))
1206 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1207 if (failed(bodyResult))
1212 auto bodyIt = body->
begin(), bodyE = body->
end();
1213 for (; bodyIt != bodyE; ++bodyIt)
1214 if (isa<ast::ReturnStmt>(*bodyIt))
1216 if (failed(validateUserConstraintOrRewriteReturn(
1217 "Constraint", body, bodyIt, bodyE, results, resultType)))
1222 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1223 name, arguments, results, resultType, body);
1226 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(
bool isInline) {
1229 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1230 [&](
auto &&...args) {
return this->parseUserPDLLRewriteDecl(args...); },
1231 ParserContext::Rewrite,
"rewrite", isInline);
1234 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1235 FailureOr<ast::UserRewriteDecl *> decl =
1236 parseUserRewriteDecl(
true);
1237 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1240 curDeclScope->add(*decl);
1244 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1250 curDeclScope = argumentScope;
1253 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1254 [&](
ast::Stmt *&statement) -> LogicalResult {
1255 if (isa<ast::OpRewriteStmt>(statement))
1258 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1259 if (!statementExpr) {
1262 "expected `Rewrite` lambda body to contain a single expression "
1263 "or an operation rewrite statement; such as `erase`, "
1264 "`replace`, or `rewrite`");
1271 if (failed(bodyResult))
1275 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1276 if (failed(bodyResult))
1283 auto bodyIt = body->
begin(), bodyE = body->
end();
1284 for (; bodyIt != bodyE; ++bodyIt)
1285 if (isa<ast::ReturnStmt>(*bodyIt))
1287 if (failed(validateUserConstraintOrRewriteReturn(
"Rewrite", body, bodyIt,
1288 bodyE, results, resultType)))
1290 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1291 name, arguments, results, resultType, body);
1294 template <
typename T,
typename ParseUserPDLLDeclFnT>
1295 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1296 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1297 StringRef anonymousNamePrefix,
bool isInline) {
1298 SMRange loc = curToken.getLoc();
1300 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1308 return emitError(
"expected identifier name");
1312 std::string anonName =
1313 llvm::formatv(
"<anonymous_{0}_{1}>", anonymousNamePrefix,
1314 anonymousDeclNameCounter++)
1327 if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1328 argumentScope, resultType)))
1334 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1338 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1339 results, resultType);
1342 template <
typename T>
1343 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1348 std::string codeStrStorage;
1349 std::optional<StringRef> optCodeStr;
1350 if (curToken.isString()) {
1351 codeStrStorage = curToken.getStringValue();
1352 optCodeStr = codeStrStorage;
1354 }
else if (isInline) {
1356 "external declarations must be declared in global scope");
1361 "expected `;` after native declaration")))
1363 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1366 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1371 if (failed(parseToken(
Token::l_paren,
"expected `(` to start argument list")))
1374 argumentScope = pushDeclScope();
1377 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1378 if (failed(argument))
1380 arguments.emplace_back(*argument);
1384 if (failed(parseToken(
Token::r_paren,
"expected `)` to end argument list")))
1390 auto parseResultFn = [&]() -> LogicalResult {
1391 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1394 results.emplace_back(*result);
1401 if (failed(parseResultFn()))
1404 if (failed(parseToken(
Token::r_paren,
"expected `)` to end result list")))
1408 }
else if (failed(parseResultFn())) {
1415 resultType = createUserConstraintRewriteResultType(results);
1418 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1420 results.front()->getLoc(),
1421 "cannot create a single-element tuple with an element label");
1426 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1432 if (bodyIt != bodyE) {
1434 if (std::next(bodyIt) != bodyE) {
1436 (*std::next(bodyIt))->getLoc(),
1437 llvm::formatv(
"`return` terminated the `{0}` body, but found "
1438 "trailing statements afterwards",
1444 }
else if (!results.empty()) {
1446 {body->getLoc().End, body->getLoc().End},
1447 llvm::formatv(
"missing return in a `{0}` expected to return `{1}`",
1448 declType, resultType));
1453 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1454 return parseLambdaBody([&](
ast::Stmt *&statement) -> LogicalResult {
1455 if (isa<ast::OpRewriteStmt>(statement))
1459 "expected Pattern lambda body to contain a single operation "
1460 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1464 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1465 SMRange loc = curToken.getLoc();
1467 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1477 ParsedPatternMetadata metadata;
1478 if (consumeIf(
Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1486 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1487 if (failed(bodyResult))
1492 return emitError(
"expected `{` or `=>` to start pattern body");
1493 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1494 if (failed(bodyResult))
1499 auto bodyIt = body->
begin(), bodyE = body->
end();
1500 for (; bodyIt != bodyE; ++bodyIt) {
1501 if (isa<ast::ReturnStmt>(*bodyIt)) {
1503 "`return` statements are only permitted within a "
1504 "`Constraint` or `Rewrite` body");
1507 if (isa<ast::OpRewriteStmt>(*bodyIt))
1510 if (bodyIt == bodyE) {
1512 "expected Pattern body to terminate with an operation "
1513 "rewrite statement, such as `erase`");
1515 if (std::next(bodyIt) != bodyE) {
1516 return emitError((*std::next(bodyIt))->getLoc(),
1517 "Pattern body was terminated by an operation "
1518 "rewrite statement, but found trailing statements");
1522 return createPatternDecl(loc, name, metadata, body);
1526 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1527 std::optional<SMRange> benefitLoc;
1528 std::optional<SMRange> hasBoundedRecursionLoc;
1533 return codeCompletePatternMetadata();
1536 return emitError(
"expected pattern metadata identifier");
1537 StringRef metadataStr = curToken.getSpelling();
1538 SMRange metadataLoc = curToken.getLoc();
1542 if (metadataStr ==
"benefit") {
1544 return emitErrorAndNote(metadataLoc,
1545 "pattern benefit has already been specified",
1546 *benefitLoc,
"see previous definition here");
1549 "expected `(` before pattern benefit")))
1552 uint16_t benefitValue = 0;
1554 return emitError(
"expected integral pattern benefit");
1555 if (curToken.getSpelling().getAsInteger(10, benefitValue))
1557 "expected pattern benefit to fit within a 16-bit integer");
1560 metadata.benefit = benefitValue;
1561 benefitLoc = metadataLoc;
1564 parseToken(
Token::r_paren,
"expected `)` after pattern benefit")))
1570 if (metadataStr ==
"recursion") {
1571 if (hasBoundedRecursionLoc) {
1572 return emitErrorAndNote(
1574 "pattern recursion metadata has already been specified",
1575 *hasBoundedRecursionLoc,
"see previous definition here");
1577 metadata.hasBoundedRecursion =
true;
1578 hasBoundedRecursionLoc = metadataLoc;
1582 return emitError(metadataLoc,
"unknown pattern metadata");
1588 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1591 FailureOr<ast::Expr *> typeExpr = parseExpr();
1592 if (failed(typeExpr) ||
1594 "expected `>` after variable type constraint")))
1599 LogicalResult Parser::checkDefineNamedDecl(
const ast::Name &name) {
1600 assert(curDeclScope &&
"defining decl outside of a decl scope");
1602 return emitErrorAndNote(
1603 name.
getLoc(),
"`" + name.
getName() +
"` has already been defined",
1604 lastDecl->getName()->getLoc(),
"see previous definition here");
1609 FailureOr<ast::VariableDecl *>
1610 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1613 assert(curDeclScope &&
"defining variable outside of decl scope");
1618 if (name.empty() || name ==
"_") {
1622 if (failed(checkDefineNamedDecl(nameDecl)))
1627 curDeclScope->add(varDecl);
1631 FailureOr<ast::VariableDecl *>
1632 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1634 return defineVariableDecl(name, nameLoc, type,
nullptr,
1638 LogicalResult Parser::parseVariableDeclConstraintList(
1640 std::optional<SMRange> typeConstraint;
1641 auto parseSingleConstraint = [&] {
1642 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1643 typeConstraint, constraints,
true);
1644 if (failed(constraint))
1646 constraints.push_back(*constraint);
1652 return parseSingleConstraint();
1655 if (failed(parseSingleConstraint()))
1658 return parseToken(
Token::r_square,
"expected `]` after constraint list");
1661 FailureOr<ast::ConstraintRef>
1662 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1664 bool allowInlineTypeConstraints) {
1665 auto parseTypeConstraint = [&](
ast::Expr *&typeExpr) -> LogicalResult {
1666 if (!allowInlineTypeConstraints) {
1669 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1670 "permitted on arguments or results");
1673 return emitErrorAndNote(
1675 "the type of this variable has already been constrained",
1676 *typeConstraint,
"see previous constraint location here");
1677 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1678 if (failed(constraintExpr))
1680 typeExpr = *constraintExpr;
1681 typeConstraint = typeExpr->getLoc();
1685 SMRange loc = curToken.getLoc();
1686 switch (curToken.getKind()) {
1692 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1702 FailureOr<ast::OpNameDecl *> opName =
1703 parseWrappedOperationName(
true);
1722 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1733 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1742 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1748 StringRef constraintName = curToken.getSpelling();
1754 return emitError(loc,
"unknown reference to constraint `" +
1755 constraintName +
"`");
1759 if (
auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1762 return emitErrorAndNote(
1763 loc,
"invalid reference to non-constraint", cstDecl->
getLoc(),
1764 "see the definition of `" + constraintName +
"` here");
1770 if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1773 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1778 return emitError(loc,
"expected identifier constraint");
1781 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1782 std::optional<SMRange> typeConstraint;
1783 return parseConstraint(typeConstraint, std::nullopt,
1790 FailureOr<ast::Expr *> Parser::parseExpr() {
1792 return parseUnderscoreExpr();
1795 FailureOr<ast::Expr *> lhsExpr;
1796 switch (curToken.getKind()) {
1798 lhsExpr = parseAttributeExpr();
1801 lhsExpr = parseInlineConstraintLambdaExpr();
1804 lhsExpr = parseNegatedExpr();
1807 lhsExpr = parseIdentifierExpr();
1810 lhsExpr = parseOperationExpr();
1813 lhsExpr = parseInlineRewriteLambdaExpr();
1816 lhsExpr = parseTypeExpr();
1819 lhsExpr = parseTupleExpr();
1822 return emitError(
"expected expression");
1824 if (failed(lhsExpr))
1829 switch (curToken.getKind()) {
1831 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1834 lhsExpr = parseCallExpr(*lhsExpr);
1839 if (failed(lhsExpr))
1844 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1845 SMRange loc = curToken.getLoc();
1852 return parseIdentifierExpr();
1855 if (!curToken.isString())
1856 return emitError(
"expected string literal containing MLIR attribute");
1857 std::string attrExpr = curToken.getStringValue();
1860 loc.End = curToken.getEndLoc();
1862 parseToken(
Token::greater,
"expected `>` after attribute literal")))
1867 FailureOr<ast::Expr *> Parser::parseCallExpr(
ast::Expr *parentExpr,
1877 codeCompleteCallSignature(parentExpr, arguments.size());
1881 FailureOr<ast::Expr *> argument = parseExpr();
1882 if (failed(argument))
1884 arguments.push_back(*argument);
1888 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1889 if (failed(parseToken(
Token::r_paren,
"expected `)` after argument list")))
1892 return createCallExpr(loc, parentExpr, arguments, isNegated);
1895 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1896 ast::Decl *decl = curDeclScope->lookup(name);
1898 return emitError(loc,
"undefined reference to `" + name +
"`");
1900 return createDeclRefExpr(loc, decl);
1903 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1904 StringRef name = curToken.getSpelling();
1905 SMRange nameLoc = curToken.getLoc();
1912 if (failed(parseVariableDeclConstraintList(constraints)))
1915 if (failed(validateVariableConstraints(constraints, type)))
1917 return createInlineVariableExpr(type, name, nameLoc, constraints);
1920 return parseDeclRefExpr(name, nameLoc);
1923 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1924 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1932 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1933 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1941 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(
ast::Expr *parentExpr) {
1942 SMRange dotLoc = curToken.getLoc();
1947 return codeCompleteMemberAccess(parentExpr);
1950 Token memberNameTok = curToken;
1953 return emitError(dotLoc,
"expected identifier or numeric member name");
1954 StringRef memberName = memberNameTok.
getSpelling();
1955 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1958 return createMemberAccessExpr(parentExpr, memberName, loc);
1961 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1965 return emitError(
"expected native constraint");
1966 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1967 if (failed(identifierExpr))
1970 return emitError(
"expected `(` after function name");
1971 return parseCallExpr(*identifierExpr,
true);
1974 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(
bool allowEmptyName) {
1975 SMRange loc = curToken.getLoc();
1979 return codeCompleteDialectName();
1985 return emitError(
"expected dialect namespace");
1987 StringRef name = curToken.getSpelling();
1991 if (failed(parseToken(
Token::dot,
"expected `.` after dialect namespace")))
1996 return codeCompleteOperationName(name);
1999 return emitError(
"expected operation name after dialect namespace");
2001 name = StringRef(name.data(), name.size() + 1);
2003 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2004 loc.End = curToken.getEndLoc();
2007 curToken.isKeyword());
2011 FailureOr<ast::OpNameDecl *>
2012 Parser::parseWrappedOperationName(
bool allowEmptyName) {
2016 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2017 if (failed(opNameDecl))
2020 if (failed(parseToken(
Token::greater,
"expected `>` after operation name")))
2025 FailureOr<ast::Expr *>
2026 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2027 SMRange loc = curToken.getLoc();
2034 return parseIdentifierExpr();
2040 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2041 FailureOr<ast::OpNameDecl *> opNameDecl =
2042 parseWrappedOperationName(allowEmptyName);
2043 if (failed(opNameDecl))
2045 std::optional<StringRef> opName = (*opNameDecl)->getName();
2050 FailureOr<ast::VariableDecl *> rangeVar =
2052 assert(succeeded(rangeVar) &&
"expected range variable to be valid");
2063 if (parserContext != ParserContext::Rewrite) {
2064 operands.push_back(createImplicitRangeVar(
2072 codeCompleteOperationOperandsSignature(opName, operands.size());
2076 FailureOr<ast::Expr *> operand = parseExpr();
2077 if (failed(operand))
2079 operands.push_back(*operand);
2083 "expected `)` after operation operand list")))
2091 FailureOr<ast::NamedAttributeDecl *> decl =
2092 parseNamedAttributeDecl(opName);
2095 attributes.emplace_back(*decl);
2099 "expected `}` after operation attribute list")))
2105 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2110 "expected `(` before operation result type list")))
2118 resultTypeContext = OpResultTypeContext::Explicit;
2125 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2129 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2130 if (failed(resultTypeExpr))
2132 resultTypes.push_back(*resultTypeExpr);
2136 "expected `)` after operation result type list")))
2139 }
else if (parserContext != ParserContext::Rewrite) {
2144 resultTypes.push_back(createImplicitRangeVar(
2146 }
else if (resultTypeContext == OpResultTypeContext::Explicit) {
2149 resultTypeContext = OpResultTypeContext::Interface;
2152 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2153 attributes, resultTypes);
2156 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2157 SMRange loc = curToken.getLoc();
2166 StringRef elementName;
2168 Token elementNameTok = curToken;
2176 auto elementNameIt =
2177 usedNames.try_emplace(elementName, elementNameTok.
getLoc());
2178 if (!elementNameIt.second) {
2179 return emitErrorAndNote(
2181 llvm::formatv(
"duplicate tuple element label `{0}`",
2183 elementNameIt.first->getSecond(),
2184 "see previous label use here");
2189 resetToken(elementNameTok.
getLoc());
2192 elementNames.push_back(elementName);
2195 FailureOr<ast::Expr *> element = parseExpr();
2196 if (failed(element))
2198 elements.push_back(*element);
2201 loc.End = curToken.getEndLoc();
2203 parseToken(
Token::r_paren,
"expected `)` after tuple element list")))
2205 return createTupleExpr(loc, elements, elementNames);
2208 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2209 SMRange loc = curToken.getLoc();
2216 return parseIdentifierExpr();
2219 if (!curToken.isString())
2220 return emitError(
"expected string literal containing MLIR type");
2221 std::string attrExpr = curToken.getStringValue();
2224 loc.End = curToken.getEndLoc();
2225 if (failed(parseToken(
Token::greater,
"expected `>` after type literal")))
2230 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2231 StringRef name = curToken.getSpelling();
2232 SMRange nameLoc = curToken.getLoc();
2236 if (failed(parseToken(
Token::colon,
"expected `:` after `_` variable")))
2241 if (failed(parseVariableDeclConstraintList(constraints)))
2245 if (failed(validateVariableConstraints(constraints, type)))
2247 return createInlineVariableExpr(type, name, nameLoc, constraints);
2253 FailureOr<ast::Stmt *> Parser::parseStmt(
bool expectTerminalSemicolon) {
2254 FailureOr<ast::Stmt *> stmt;
2255 switch (curToken.getKind()) {
2257 stmt = parseEraseStmt();
2260 stmt = parseLetStmt();
2263 stmt = parseReplaceStmt();
2266 stmt = parseReturnStmt();
2269 stmt = parseRewriteStmt();
2276 (expectTerminalSemicolon &&
2282 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2283 SMLoc startLoc = curToken.getStartLoc();
2290 FailureOr<ast::Stmt *> statement = parseStmt();
2291 if (failed(statement))
2292 return popDeclScope(), failure();
2293 statements.push_back(*statement);
2298 SMRange location(startLoc, curToken.getEndLoc());
2304 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2305 if (parserContext == ParserContext::Constraint)
2306 return emitError(
"`erase` cannot be used within a Constraint");
2307 SMRange loc = curToken.getLoc();
2311 FailureOr<ast::Expr *> rootOp = parseExpr();
2315 return createEraseStmt(loc, *rootOp);
2318 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2319 SMRange loc = curToken.getLoc();
2323 SMRange varLoc = curToken.getLoc();
2328 "`_` may only be used to define \"inline\" variables");
2331 "expected identifier after `let` to name a new variable");
2333 StringRef varName = curToken.getSpelling();
2339 failed(parseVariableDeclConstraintList(constraints)))
2345 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2346 if (failed(initOrFailure))
2348 initializer = *initOrFailure;
2353 LogicalResult result =
2357 if (cst->getTypeExpr()) {
2359 constraint.referenceLoc,
2360 "type constraints are not permitted on variables with "
2365 .Default(success());
2371 FailureOr<ast::VariableDecl *> varDecl =
2372 createVariableDecl(varName, varLoc, initializer, constraints);
2373 if (failed(varDecl))
2378 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2379 if (parserContext == ParserContext::Constraint)
2380 return emitError(
"`replace` cannot be used within a Constraint");
2381 SMRange loc = curToken.getLoc();
2385 FailureOr<ast::Expr *> rootOp = parseExpr();
2390 parseToken(
Token::kw_with,
"expected `with` after root operation")))
2394 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2401 loc,
"expected at least one replacement value, consider using "
2402 "`erase` if no replacement values are desired");
2406 FailureOr<ast::Expr *> replExpr = parseExpr();
2407 if (failed(replExpr))
2409 replValues.emplace_back(*replExpr);
2413 "expected `)` after replacement values")))
2418 FailureOr<ast::Expr *> replExpr;
2420 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2422 replExpr = parseExpr();
2423 if (failed(replExpr))
2425 replValues.emplace_back(*replExpr);
2428 return createReplaceStmt(loc, *rootOp, replValues);
2431 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2432 SMRange loc = curToken.getLoc();
2436 FailureOr<ast::Expr *> resultExpr = parseExpr();
2437 if (failed(resultExpr))
2443 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2444 if (parserContext == ParserContext::Constraint)
2445 return emitError(
"`rewrite` cannot be used within a Constraint");
2446 SMRange loc = curToken.getLoc();
2450 FailureOr<ast::Expr *> rootOp = parseExpr();
2454 if (failed(parseToken(
Token::kw_with,
"expected `with` before rewrite body")))
2458 return emitError(
"expected `{` to start rewrite body");
2461 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2463 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2464 if (failed(rewriteBody))
2468 for (
const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2469 if (isa<ast::ReturnStmt>(stmt)) {
2471 "`return` statements are only permitted within a "
2472 "`Constraint` or `Rewrite` body");
2476 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2488 if (
auto *init = dyn_cast<ast::DeclRefExpr>(node))
2489 node = init->getDecl();
2490 return dyn_cast<ast::CallableDecl>(node);
2493 FailureOr<ast::PatternDecl *>
2494 Parser::createPatternDecl(SMRange loc,
const ast::Name *name,
2495 const ParsedPatternMetadata &metadata,
2498 metadata.hasBoundedRecursion, body);
2501 ast::Type Parser::createUserConstraintRewriteResultType(
2504 if (results.size() == 1)
2505 return results[0]->getType();
2509 auto resultTypes = llvm::map_range(
2510 results, [&](
const auto *result) {
return result->getType(); });
2511 auto resultNames = llvm::map_range(
2512 results, [&](
const auto *result) {
return result->getName().getName(); });
2514 llvm::to_vector(resultNames));
2517 template <
typename T>
2518 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2523 if (
auto *retStmt = dyn_cast<ast::ReturnStmt>(body->
getChildren().back())) {
2524 ast::Expr *resultExpr = retStmt->getResultExpr();
2529 if (results.empty())
2530 resultType = resultExpr->
getType();
2531 else if (failed(convertExpressionTo(resultExpr, resultType)))
2534 retStmt->setResultExpr(resultExpr);
2537 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2540 FailureOr<ast::VariableDecl *>
2541 Parser::createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
2546 if (failed(validateVariableConstraints(constraints, type)))
2553 type = initializer->
getType();
2556 else if (failed(convertExpressionTo(initializer, type)))
2562 return emitErrorAndNote(
2563 loc,
"unable to infer type for variable `" + name +
"`", loc,
2564 "the type of a variable must be inferable from the constraint "
2565 "list or the initializer");
2569 if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2571 loc, llvm::formatv(
"unable to define variable of `{0}` type", type));
2575 FailureOr<ast::VariableDecl *> varDecl =
2576 defineVariableDecl(name, loc, type, initializer, constraints);
2577 if (failed(varDecl))
2583 FailureOr<ast::VariableDecl *>
2584 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2587 if (failed(validateVariableConstraint(constraint, argType)))
2589 return defineVariableDecl(name, loc, argType, constraint);
2596 if (failed(validateVariableConstraint(ref, inferredType)))
2604 if (
const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.
constraint)) {
2605 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2606 if (failed(validateTypeConstraintExpr(typeExpr)))
2610 }
else if (
const auto *cst =
2611 dyn_cast<ast::OpConstraintDecl>(ref.
constraint)) {
2614 }
else if (isa<ast::TypeConstraintDecl>(ref.
constraint)) {
2615 constraintType = typeTy;
2616 }
else if (isa<ast::TypeRangeConstraintDecl>(ref.
constraint)) {
2617 constraintType = typeRangeTy;
2618 }
else if (
const auto *cst =
2619 dyn_cast<ast::ValueConstraintDecl>(ref.
constraint)) {
2620 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2621 if (failed(validateTypeConstraintExpr(typeExpr)))
2624 constraintType = valueTy;
2625 }
else if (
const auto *cst =
2626 dyn_cast<ast::ValueRangeConstraintDecl>(ref.
constraint)) {
2627 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2628 if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2631 constraintType = valueRangeTy;
2632 }
else if (
const auto *cst =
2633 dyn_cast<ast::UserConstraintDecl>(ref.
constraint)) {
2635 if (inputs.size() != 1) {
2637 "`Constraint`s applied via a variable constraint "
2638 "list must take a single input, but got " +
2639 Twine(inputs.size()),
2641 "see definition of constraint here");
2643 constraintType = inputs.front()->getType();
2645 llvm_unreachable(
"unknown constraint type");
2650 if (!inferredType) {
2651 inferredType = constraintType;
2653 inferredType = mergedTy;
2656 llvm::formatv(
"constraint type `{0}` is incompatible "
2657 "with the previously inferred type `{1}`",
2658 constraintType, inferredType));
2663 LogicalResult Parser::validateTypeConstraintExpr(
const ast::Expr *typeExpr) {
2665 if (typeExprType != typeTy) {
2667 "expected expression of `Type` in type constraint");
2673 Parser::validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr) {
2675 if (typeExprType != typeRangeTy) {
2677 "expected expression of `TypeRange` in type constraint");
2685 FailureOr<ast::CallExpr *>
2686 Parser::createCallExpr(SMRange loc,
ast::Expr *parentExpr,
2691 if (!callableDecl) {
2693 llvm::formatv(
"expected a reference to a callable "
2694 "`Constraint` or `Rewrite`, but got: `{0}`",
2697 if (parserContext == ParserContext::Rewrite) {
2698 if (isa<ast::UserConstraintDecl>(callableDecl))
2700 loc,
"unable to invoke `Constraint` within a rewrite section");
2702 return emitError(loc,
"unable to negate a Rewrite");
2704 if (isa<ast::UserRewriteDecl>(callableDecl))
2706 "unable to invoke `Rewrite` within a match section");
2707 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2708 return emitError(loc,
"unable to negate non native constraints");
2714 if (callArgs.size() != arguments.size()) {
2715 return emitErrorAndNote(
2717 llvm::formatv(
"invalid number of arguments for {0} call; expected "
2722 llvm::formatv(
"see the definition of {0} here",
2728 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here",
2732 for (
auto it : llvm::zip(callArgs, arguments)) {
2733 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->
getType(),
2742 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2746 if (isa<ast::ConstraintDecl>(decl))
2748 else if (isa<ast::UserRewriteDecl>(decl))
2750 else if (
auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2751 declType = varDecl->getType();
2753 return emitError(loc,
"invalid reference to `" +
2759 FailureOr<ast::DeclRefExpr *>
2760 Parser::createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
2762 FailureOr<ast::VariableDecl *> decl =
2763 defineVariableDecl(name, loc, type, constraints);
2769 FailureOr<ast::MemberAccessExpr *>
2770 Parser::createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name,
2773 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2774 if (failed(memberType))
2780 FailureOr<ast::Type> Parser::validateMemberAccess(
ast::Expr *parentExpr,
2781 StringRef name, SMRange loc) {
2785 return valueRangeTy;
2789 auto results = odsOp->getResults();
2793 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2794 index < results.size()) {
2795 return results[index].isVariadic() ? valueRangeTy : valueTy;
2799 const auto *it = llvm::find_if(results, [&](
const auto &result) {
2800 return result.getName() == name;
2802 if (it != results.end())
2803 return it->isVariadic() ? valueRangeTy : valueTy;
2804 }
else if (llvm::isDigit(name[0])) {
2809 }
else if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2812 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2813 index < tupleType.size()) {
2814 return tupleType.getElementTypes()[index];
2818 auto elementNames = tupleType.getElementNames();
2819 const auto *it = llvm::find(elementNames, name);
2820 if (it != elementNames.end())
2821 return tupleType.getElementTypes()[it - elementNames.begin()];
2825 llvm::formatv(
"invalid member access `{0}` on expression of type `{1}`",
2829 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2831 OpResultTypeContext resultTypeContext,
2835 std::optional<StringRef> opNameRef = name->
getName();
2839 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2845 ast::Type attrType = attr->getValue()->getType();
2846 if (!isa<ast::AttributeType>(attrType)) {
2848 attr->getValue()->getLoc(),
2849 llvm::formatv(
"expected `Attr` expression, but got `{0}`", attrType));
2854 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2855 "unexpected inferrence when results were explicitly specified");
2859 if (resultTypeContext == OpResultTypeContext::Explicit) {
2860 if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2864 }
else if (resultTypeContext == OpResultTypeContext::Interface) {
2866 "expected valid operation name when inferring operation results");
2867 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2875 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2878 return validateOperationOperandsOrResults(
2879 "operand", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2880 operands, odsOp ? odsOp->
getOperands() : std::nullopt, valueTy,
2885 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2888 return validateOperationOperandsOrResults(
2889 "result", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2890 results, odsOp ? odsOp->
getResults() : std::nullopt, typeTy, typeRangeTy);
2893 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2904 "operation result types are marked to be inferred, but "
2905 "`{0}` is unknown. Ensure that `{0}` supports zero "
2906 "results or implements `InferTypeOpInterface`. Include "
2907 "the ODS definition of this operation to remove this warning.",
2916 bool requiresInferrence =
2918 return !result.isVariableLength();
2923 llvm::formatv(
"operation result types are marked to be inferred, but "
2924 "`{0}` does not provide an implementation of "
2925 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2926 "`InferTypeOpInterface` at runtime, or add support to "
2927 "the ODS definition to remove this warning.",
2929 diag->attachNote(llvm::formatv(
"see the definition of `{0}` here", opName),
2935 LogicalResult Parser::validateOperationOperandsOrResults(
2936 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2941 if (values.size() == 1) {
2942 if (failed(convertExpressionTo(values[0], rangeTy)))
2950 auto emitSizeMismatchError = [&] {
2951 return emitErrorAndNote(
2953 llvm::formatv(
"invalid number of {0} groups for `{1}`; expected "
2955 groupName, *name, odsValues.size(), values.size()),
2956 *odsOpLoc, llvm::formatv(
"see the definition of `{0}` here", *name));
2960 if (values.empty()) {
2962 if (odsValues.empty())
2967 unsigned numVariadic = 0;
2968 for (
const auto &odsValue : odsValues) {
2969 if (!odsValue.isVariableLength())
2970 return emitSizeMismatchError();
2976 if (parserContext != ParserContext::Rewrite)
2983 if (numVariadic == 1)
2988 for (
unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2990 ctx, loc, std::nullopt, rangeTy));
2997 if (odsValues.size() != values.size())
2998 return emitSizeMismatchError();
3001 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here", *name),
3004 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
3005 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3006 if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3015 ast::Type valueExprType = valueExpr->getType();
3018 if (valueExprType == rangeTy || valueExprType == singleTy)
3024 if (singleTy == valueTy) {
3025 if (isa<ast::OperationType>(valueExprType)) {
3026 valueExpr = convertOpToValue(valueExpr);
3032 if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3036 valueExpr->getLoc(),
3038 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3039 singleTy, rangeTy, valueExprType));
3044 FailureOr<ast::TupleExpr *>
3047 for (
const ast::Expr *element : elements) {
3049 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3052 llvm::formatv(
"unable to build a tuple with `{0}` element", eleTy));
3061 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3065 if (!isa<ast::OperationType>(rootType))
3071 FailureOr<ast::ReplaceStmt *>
3072 Parser::createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
3076 if (!isa<ast::OperationType>(rootType)) {
3079 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3084 bool shouldConvertOpToValues = replValues.size() > 1;
3085 for (
ast::Expr *&replExpr : replValues) {
3086 ast::Type replType = replExpr->getType();
3089 if (isa<ast::OperationType>(replType)) {
3090 if (shouldConvertOpToValues)
3091 replExpr = convertOpToValue(replExpr);
3095 if (replType != valueTy && replType != valueRangeTy) {
3097 llvm::formatv(
"expected `Op`, `Value` or `ValueRange` "
3098 "expression, but got `{0}`",
3106 FailureOr<ast::RewriteStmt *>
3107 Parser::createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
3111 if (!isa<ast::OperationType>(rootType)) {
3114 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3124 LogicalResult Parser::codeCompleteMemberAccess(
ast::Expr *parentExpr) {
3128 else if (
ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3134 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3141 Parser::codeCompleteConstraintName(
ast::Type inferredType,
3142 bool allowInlineTypeConstraints) {
3144 inferredType, allowInlineTypeConstraints, curDeclScope);
3148 LogicalResult Parser::codeCompleteDialectName() {
3153 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3158 LogicalResult Parser::codeCompletePatternMetadata() {
3163 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3168 void Parser::codeCompleteCallSignature(
ast::Node *parent,
3169 unsigned currentNumArgs) {
3177 void Parser::codeCompleteOperationOperandsSignature(
3178 std::optional<StringRef> opName,
unsigned currentNumOperands) {
3180 opName, currentNumOperands);
3183 void Parser::codeCompleteOperationResultsSignature(
3184 std::optional<StringRef> opName,
unsigned currentNumResults) {
3193 FailureOr<ast::Module *>
3195 bool enableDocumentation,
3197 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3198 return parser.parseModule();
static std::string diag(const llvm::Value &value)
This class breaks up the current file into a token stream.
This represents a token in the MLIR syntax.
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
StringRef getSpelling() const
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
@ code_complete_string
Token signifying a code completion location within a string.
@ code_complete
Token signifying a code completion location.
@ less
Paired punctuation.
@ kw_Attr
General keywords.
static StringRef getMemberName()
Return the member name used for the "all-results" access.
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
static AttributeType get(Context &context)
Return an instance of the Attribute type.
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
This decl represents a shared interface for all callable decls.
Type getResultType() const
Return the result type of this decl.
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
StringRef getCallableType() const
Return the callable type of this decl.
This statement represents a compound statement, which contains a collection of other statements.
ArrayRef< Stmt * >::iterator begin() const
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
ArrayRef< Stmt * >::iterator end() const
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
This class represents the base of all AST Constraint decls.
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
This class represents the main context of the PDLL AST.
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
ods::Context & getODSContext()
Return the ODS context used by the AST.
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
This class represents a scope for named AST decls.
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
void add(Decl *decl)
Add a new decl to the scope.
This class represents the base Decl node.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
This class provides a simple implementation of a PDLL diagnostic.
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a diagnostic that is inflight and set to be reported.
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
This Decl represents a NamedAttribute, and contains a string name and attribute value.
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
This Decl represents an OperationName.
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
static OpNameDecl * create(Context &ctx, const Name &name)
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
This class represents a PDLL type that corresponds to a range of elements with a given element type.
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
This class represents a base AST Statement node.
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
This class represents a PDLL tuple type, i.e.
size_t size() const
Return the number of elements within this tuple.
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
static TypeType get(Context &context)
Return an instance of the Type type.
Type refineWith(Type other) const
Try to refine this type with the one provided.
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
The class represents a Value constraint, and constrains a variable to be a Value.
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
static ValueType get(Context &context)
Return an instance of the Value type.
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
This class contains all of the registered ODS operation classes.
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
StringRef getDescription() const
std::string getConditionTemplate() const
Format context containing substitutions for special placeholders.
FmtContext & withSelf(Twine subst)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
const ConstraintDecl * constraint
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.
static const Name & create(Context &ctx, StringRef name, SMRange location)