25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
48 : ctx(ctx), lexer(sourceMgr, ctx.
getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
54 codeCompleteContext(codeCompleteContext) {}
57 FailureOr<ast::Module *> parseModule();
65 enum class ParserContext {
82 enum class OpResultTypeContext {
101 return (curDeclScope = newScope);
103 void pushDeclScope(
ast::DeclScope *scope) { curDeclScope = scope; }
106 void popDeclScope() { curDeclScope = curDeclScope->
getParentScope(); }
115 LogicalResult convertExpressionTo(
122 LogicalResult convertTupleExpressionTo(
133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
139 StringRef processDoc(StringRef doc) {
140 return enableDocumentation ? doc : StringRef();
145 std::string processAndFormatDoc(
const Twine &doc) {
146 if (!enableDocumentation)
150 llvm::raw_string_ostream docOS(docStr);
152 StringRef(docStr).rtrim(
" \t"));
162 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
166 void processTdIncludeRecords(
const llvm::RecordKeeper &tdRecords,
171 template <
typename Constra
intT>
173 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
175 StringRef nativeType, StringRef docString);
176 template <
typename Constra
intT>
180 StringRef nativeType);
186 struct ParsedPatternMetadata {
187 std::optional<uint16_t> benefit;
188 bool hasBoundedRecursion =
false;
191 FailureOr<ast::Decl *> parseTopLevelDecl();
192 FailureOr<ast::NamedAttributeDecl *>
193 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
197 FailureOr<ast::VariableDecl *> parseArgumentDecl();
201 FailureOr<ast::VariableDecl *> parseResultDecl(
unsigned resultNum);
205 FailureOr<ast::UserConstraintDecl *>
206 parseUserConstraintDecl(
bool isInline =
false);
210 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
214 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
221 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(
bool isInline =
false);
225 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
229 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
237 template <
typename T,
typename ParseUserPDLLDeclFnT>
238 FailureOr<T *> parseUserConstraintOrRewriteDecl(
239 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240 StringRef anonymousNamePrefix,
bool isInline);
244 template <
typename T>
245 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
252 LogicalResult parseUserConstraintOrRewriteSignature(
259 LogicalResult validateUserConstraintOrRewriteReturn(
265 FailureOr<ast::CompoundStmt *>
267 bool expectTerminalSemicolon =
true);
268 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269 FailureOr<ast::Decl *> parsePatternDecl();
270 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
274 LogicalResult checkDefineNamedDecl(
const ast::Name &name);
278 FailureOr<ast::VariableDecl *>
279 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
282 FailureOr<ast::VariableDecl *>
283 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
287 LogicalResult parseVariableDeclConstraintList(
291 FailureOr<ast::Expr *> parseTypeConstraintExpr();
299 FailureOr<ast::ConstraintRef>
300 parseConstraint(std::optional<SMRange> &typeConstraint,
302 bool allowInlineTypeConstraints);
307 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
312 FailureOr<ast::Expr *> parseExpr();
315 FailureOr<ast::Expr *> parseAttributeExpr();
316 FailureOr<ast::Expr *> parseCallExpr(
ast::Expr *parentExpr,
317 bool isNegated =
false);
318 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319 FailureOr<ast::Expr *> parseIdentifierExpr();
320 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322 FailureOr<ast::Expr *> parseMemberAccessExpr(
ast::Expr *parentExpr);
323 FailureOr<ast::Expr *> parseNegatedExpr();
324 FailureOr<ast::OpNameDecl *> parseOperationName(
bool allowEmptyName =
false);
325 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(
bool allowEmptyName);
326 FailureOr<ast::Expr *>
327 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328 OpResultTypeContext::Explicit);
329 FailureOr<ast::Expr *> parseTupleExpr();
330 FailureOr<ast::Expr *> parseTypeExpr();
331 FailureOr<ast::Expr *> parseUnderscoreExpr();
336 FailureOr<ast::Stmt *> parseStmt(
bool expectTerminalSemicolon =
true);
337 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338 FailureOr<ast::EraseStmt *> parseEraseStmt();
339 FailureOr<ast::LetStmt *> parseLetStmt();
340 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341 FailureOr<ast::ReturnStmt *> parseReturnStmt();
342 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
357 FailureOr<ast::PatternDecl *>
358 createPatternDecl(SMRange loc,
const ast::Name *name,
359 const ParsedPatternMetadata &metadata,
368 template <
typename T>
369 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
376 FailureOr<ast::VariableDecl *>
377 createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
382 FailureOr<ast::VariableDecl *>
383 createArgOrResultVariableDecl(StringRef name, SMRange loc,
400 LogicalResult validateTypeConstraintExpr(
const ast::Expr *typeExpr);
401 LogicalResult validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr);
406 FailureOr<ast::CallExpr *>
407 createCallExpr(SMRange loc,
ast::Expr *parentExpr,
409 bool isNegated =
false);
410 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc,
ast::Decl *decl);
411 FailureOr<ast::DeclRefExpr *>
412 createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
414 FailureOr<ast::MemberAccessExpr *>
415 createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name, SMRange loc);
419 FailureOr<ast::Type> validateMemberAccess(
ast::Expr *parentExpr,
420 StringRef name, SMRange loc);
421 FailureOr<ast::OperationExpr *>
423 OpResultTypeContext resultTypeContext,
428 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
431 LogicalResult validateOperationResults(SMRange loc,
432 std::optional<StringRef> name,
435 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
437 LogicalResult validateOperationOperandsOrResults(
438 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
442 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
449 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc,
ast::Expr *rootOp);
450 FailureOr<ast::ReplaceStmt *>
451 createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
453 FailureOr<ast::RewriteStmt *>
454 createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
465 LogicalResult codeCompleteMemberAccess(
ast::Expr *parentExpr);
466 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467 LogicalResult codeCompleteConstraintName(
ast::Type inferredType,
468 bool allowInlineTypeConstraints);
469 LogicalResult codeCompleteDialectName();
470 LogicalResult codeCompleteOperationName(StringRef dialectName);
471 LogicalResult codeCompletePatternMetadata();
472 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
474 void codeCompleteCallSignature(
ast::Node *parent,
unsigned currentNumArgs);
475 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476 unsigned currentNumOperands);
477 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478 unsigned currentNumResults);
487 if (curToken.isNot(
kind))
494 void consumeToken() {
496 "shouldn't advance past EOF or errors");
497 curToken = lexer.lexToken();
504 assert(curToken.is(
kind) &&
"consumed an unexpected token");
509 void resetToken(SMRange tokLoc) {
510 lexer.resetPointer(tokLoc.Start.getPointer());
511 curToken = lexer.lexToken();
517 if (curToken.getKind() !=
kind)
518 return emitError(curToken.getLoc(), msg);
522 LogicalResult
emitError(SMRange loc,
const Twine &msg) {
523 lexer.emitError(loc, msg);
526 LogicalResult
emitError(
const Twine &msg) {
527 return emitError(curToken.getLoc(), msg);
529 LogicalResult emitErrorAndNote(SMRange loc,
const Twine &msg, SMRange noteLoc,
531 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
550 bool enableDocumentation;
554 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
557 ParserContext parserContext = ParserContext::Global;
565 unsigned anonymousDeclNameCounter = 0;
572 FailureOr<ast::Module *> Parser::parseModule() {
573 SMLoc moduleLoc = curToken.getStartLoc();
578 if (failed(parseModuleBody(decls)))
579 return popDeclScope(), failure();
588 if (failed(parseDirective(decls)))
593 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
596 decls.push_back(*decl);
606 LogicalResult Parser::convertExpressionTo(
610 if (exprType == type)
615 expr->
getLoc(), llvm::formatv(
"unable to convert expression of type "
616 "`{0}` to the expected type of "
624 if (
auto exprOpType = dyn_cast<ast::OperationType>(exprType))
625 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
631 if ((exprType == valueTy || exprType == valueRangeTy) &&
632 (type == valueTy || type == valueRangeTy))
634 if ((exprType == typeTy || exprType == typeRangeTy) &&
635 (type == typeTy || type == typeRangeTy))
639 if (
auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
640 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
643 return emitConvertError();
646 LogicalResult Parser::convertOpExpressionTo(
651 if (
auto opType = dyn_cast<ast::OperationType>(type)) {
652 if (opType.getName())
653 return emitErrorFn();
658 if (type == valueRangeTy) {
665 if (type == valueTy) {
669 if (odsOp->getResults().empty()) {
670 return emitErrorFn()->attachNote(
671 llvm::formatv(
"see the definition of `{0}`, which was defined "
677 unsigned numSingleResults = llvm::count_if(
679 return result.getVariableLengthKind() ==
680 ods::VariableLengthKind::Single;
682 if (numSingleResults > 1) {
683 return emitErrorFn()->attachNote(
684 llvm::formatv(
"see the definition of `{0}`, which was defined "
685 "with at least {1} results",
686 odsOp->getName(), numSingleResults),
695 return emitErrorFn();
698 LogicalResult Parser::convertTupleExpressionTo(
703 if (
auto tupleType = dyn_cast<ast::TupleType>(type)) {
704 if (tupleType.size() != exprType.
size())
705 return emitErrorFn();
710 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
712 ctx, expr->
getLoc(), expr, llvm::to_string(i),
716 diag.attachNote(llvm::formatv(
"when converting element #{0} of `{1}`",
721 if (failed(convertExpressionTo(newExprs.back(),
722 tupleType.getElementTypes()[i], diagFn)))
726 tupleType.getElementNames());
734 if (parserContext != ParserContext::Rewrite) {
735 return emitErrorFn()->attachNote(
"Tuple to Range conversion is currently "
736 "only allowed within a rewrite context");
741 if (!llvm::is_contained(allowedElementTypes, elementType))
742 return emitErrorFn();
747 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
749 ctx, expr->
getLoc(), expr, llvm::to_string(i),
755 if (type == valueRangeTy)
756 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757 if (type == typeRangeTy)
758 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
760 return emitErrorFn();
768 StringRef directive = curToken.getSpelling();
769 if (directive ==
"#include")
770 return parseInclude(decls);
772 return emitError(
"unknown directive `" + directive +
"`");
776 SMRange loc = curToken.getLoc();
781 return codeCompleteIncludeFilename(curToken.getStringValue());
784 if (!curToken.isString())
786 "expected string file name after `include` directive");
787 SMRange fileLoc = curToken.getLoc();
788 std::string filenameStr = curToken.getStringValue();
789 StringRef filename = filenameStr;
794 if (filename.ends_with(
".pdll")) {
795 if (failed(lexer.pushInclude(filename, fileLoc)))
797 "unable to open include file `" + filename +
"`");
802 curToken = lexer.lexToken();
803 LogicalResult result = parseModuleBody(decls);
804 curToken = lexer.lexToken();
809 if (filename.ends_with(
".td"))
810 return parseTdInclude(filename, fileLoc, decls);
813 "expected include filename to end with `.pdll` or `.td`");
816 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
821 std::string includedFile;
822 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
823 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
825 return emitError(fileLoc,
"unable to open include file `" + filename +
"`");
828 llvm::SourceMgr tdSrcMgr;
829 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
830 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
834 struct DiagHandlerContext {
838 } handlerContext{*
this, filename, fileLoc};
841 tdSrcMgr.setDiagHandler(
842 [](
const llvm::SMDiagnostic &
diag,
void *rawHandlerContext) {
843 auto *ctx =
reinterpret_cast<DiagHandlerContext *
>(rawHandlerContext);
844 (void)ctx->parser.emitError(
846 llvm::formatv(
"error while processing include file `{0}`: {1}",
847 ctx->filename,
diag.getMessage()));
852 llvm::RecordKeeper tdRecords;
853 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
857 processTdIncludeRecords(tdRecords, decls);
862 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
866 void Parser::processTdIncludeRecords(
const llvm::RecordKeeper &tdRecords,
869 auto getLengthKind = [](
const auto &value) {
870 if (value.isOptional())
881 cst.constraint.getUniqueDefName(),
882 processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
884 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
885 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
890 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Op")) {
894 bool supportsResultTypeInferrence =
895 op.getTrait(
"::mlir::InferTypeOpInterface::Trait");
898 op.getOperationName(), processDoc(op.getSummary()),
899 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
900 supportsResultTypeInferrence, op.getLoc().front());
907 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
909 attr.attr.getUniqueDefName(),
910 processDoc(attr.attr.getSummary()),
911 attr.attr.getStorageType()));
914 odsOp->appendOperand(operand.name, getLengthKind(operand),
915 addTypeConstraint(operand));
918 odsOp->appendResult(result.name, getLengthKind(result),
919 addTypeConstraint(result));
923 auto shouldBeSkipped = [
this](
const llvm::Record *def) {
924 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
925 def->isSubClassOf(
"DeclareInterfaceMethods");
929 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Attr")) {
930 if (shouldBeSkipped(def))
934 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
935 constraint, convertLocToRange(def->getLoc().front()), attrTy,
936 constraint.getStorageType()));
939 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Type")) {
940 if (shouldBeSkipped(def))
944 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
945 constraint, convertLocToRange(def->getLoc().front()), typeTy,
946 constraint.getCppType()));
950 for (
const llvm::Record *def :
951 tdRecords.getAllDerivedDefinitions(
"OpInterface")) {
952 if (shouldBeSkipped(def))
955 SMRange loc = convertLocToRange(def->getLoc().front());
957 std::string cppClassName =
958 llvm::formatv(
"{0}::{1}", def->getValueAsString(
"cppNamespace"),
959 def->getValueAsString(
"cppInterfaceName"))
961 std::string codeBlock =
962 llvm::formatv(
"return ::mlir::success(llvm::isa<{0}>(self));",
967 processAndFormatDoc(def->getValueAsString(
"description"));
968 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
969 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
973 template <
typename Constra
intT>
974 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
975 StringRef name, StringRef codeBlock, SMRange loc,
ast::Type type,
976 StringRef nativeType, StringRef docString) {
982 argScope->
add(paramVar);
990 constraintDecl->setDocComment(ctx, docString);
991 curDeclScope->add(constraintDecl);
992 return constraintDecl;
995 template <
typename Constra
intT>
999 StringRef nativeType) {
1010 std::string docString;
1011 if (enableDocumentation) {
1013 docString = processAndFormatDoc(
1018 return createODSNativePDLLConstraintDecl<ConstraintT>(
1027 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1028 FailureOr<ast::Decl *> decl;
1029 switch (curToken.getKind()) {
1031 decl = parseUserConstraintDecl();
1034 decl = parsePatternDecl();
1037 decl = parseUserRewriteDecl();
1040 return emitError(
"expected top-level declaration, such as a `Pattern`");
1047 if (failed(checkDefineNamedDecl(*name)))
1049 curDeclScope->add(*decl);
1054 FailureOr<ast::NamedAttributeDecl *>
1055 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1058 return codeCompleteAttributeName(parentOpName);
1060 std::string attrNameStr;
1061 if (curToken.isString())
1062 attrNameStr = curToken.getStringValue();
1064 attrNameStr = curToken.getSpelling().str();
1066 return emitError(
"expected identifier or string attribute name");
1073 FailureOr<ast::Expr *> attrExpr = parseExpr();
1074 if (failed(attrExpr))
1076 attrValue = *attrExpr;
1086 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1088 bool expectTerminalSemicolon) {
1092 SMLoc bodyStartLoc = curToken.getStartLoc();
1094 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1095 bool failedToParse =
1096 failed(singleStatement) || failed(processStatementFn(*singleStatement));
1101 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1105 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1108 return emitError(
"expected identifier argument name");
1111 StringRef name = curToken.getSpelling();
1112 SMRange nameLoc = curToken.getLoc();
1116 parseToken(
Token::colon,
"expected `:` before argument constraint")))
1119 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1123 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1126 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(
unsigned resultNum) {
1134 StringRef name = curToken.getSpelling();
1135 SMRange nameLoc = curToken.getLoc();
1139 "expected `:` before result constraint")))
1142 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1146 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1152 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1156 return createArgOrResultVariableDecl(
"", cst->referenceLoc, *cst);
1159 FailureOr<ast::UserConstraintDecl *>
1160 Parser::parseUserConstraintDecl(
bool isInline) {
1163 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164 [&](
auto &&...args) {
1165 return this->parseUserPDLLConstraintDecl(args...);
1167 ParserContext::Constraint,
"constraint", isInline);
1170 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1171 FailureOr<ast::UserConstraintDecl *> decl =
1172 parseUserConstraintDecl(
true);
1173 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1176 curDeclScope->add(*decl);
1180 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1186 pushDeclScope(argumentScope);
1192 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1193 [&](
ast::Stmt *&stmt) -> LogicalResult {
1194 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1197 "expected `Constraint` lambda body to contain a "
1198 "single expression");
1204 if (failed(bodyResult))
1208 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1209 if (failed(bodyResult))
1214 auto bodyIt = body->
begin(), bodyE = body->
end();
1215 for (; bodyIt != bodyE; ++bodyIt)
1216 if (isa<ast::ReturnStmt>(*bodyIt))
1218 if (failed(validateUserConstraintOrRewriteReturn(
1219 "Constraint", body, bodyIt, bodyE, results, resultType)))
1224 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225 name, arguments, results, resultType, body);
1228 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(
bool isInline) {
1231 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232 [&](
auto &&...args) {
return this->parseUserPDLLRewriteDecl(args...); },
1233 ParserContext::Rewrite,
"rewrite", isInline);
1236 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1237 FailureOr<ast::UserRewriteDecl *> decl =
1238 parseUserRewriteDecl(
true);
1239 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1242 curDeclScope->add(*decl);
1246 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1252 curDeclScope = argumentScope;
1255 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1256 [&](
ast::Stmt *&statement) -> LogicalResult {
1257 if (isa<ast::OpRewriteStmt>(statement))
1260 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1261 if (!statementExpr) {
1264 "expected `Rewrite` lambda body to contain a single expression "
1265 "or an operation rewrite statement; such as `erase`, "
1266 "`replace`, or `rewrite`");
1273 if (failed(bodyResult))
1277 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1278 if (failed(bodyResult))
1285 auto bodyIt = body->
begin(), bodyE = body->
end();
1286 for (; bodyIt != bodyE; ++bodyIt)
1287 if (isa<ast::ReturnStmt>(*bodyIt))
1289 if (failed(validateUserConstraintOrRewriteReturn(
"Rewrite", body, bodyIt,
1290 bodyE, results, resultType)))
1292 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293 name, arguments, results, resultType, body);
1296 template <
typename T,
typename ParseUserPDLLDeclFnT>
1297 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1298 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299 StringRef anonymousNamePrefix,
bool isInline) {
1300 SMRange loc = curToken.getLoc();
1302 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1310 return emitError(
"expected identifier name");
1314 std::string anonName =
1315 llvm::formatv(
"<anonymous_{0}_{1}>", anonymousNamePrefix,
1316 anonymousDeclNameCounter++)
1329 if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1330 argumentScope, resultType)))
1336 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1340 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341 results, resultType);
1344 template <
typename T>
1345 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1350 std::string codeStrStorage;
1351 std::optional<StringRef> optCodeStr;
1352 if (curToken.isString()) {
1353 codeStrStorage = curToken.getStringValue();
1354 optCodeStr = codeStrStorage;
1356 }
else if (isInline) {
1358 "external declarations must be declared in global scope");
1363 "expected `;` after native declaration")))
1365 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1368 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1373 if (failed(parseToken(
Token::l_paren,
"expected `(` to start argument list")))
1376 argumentScope = pushDeclScope();
1379 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1380 if (failed(argument))
1382 arguments.emplace_back(*argument);
1386 if (failed(parseToken(
Token::r_paren,
"expected `)` to end argument list")))
1392 auto parseResultFn = [&]() -> LogicalResult {
1393 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1396 results.emplace_back(*result);
1403 if (failed(parseResultFn()))
1406 if (failed(parseToken(
Token::r_paren,
"expected `)` to end result list")))
1410 }
else if (failed(parseResultFn())) {
1417 resultType = createUserConstraintRewriteResultType(results);
1420 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1422 results.front()->getLoc(),
1423 "cannot create a single-element tuple with an element label");
1428 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1434 if (bodyIt != bodyE) {
1436 if (std::next(bodyIt) != bodyE) {
1438 (*std::next(bodyIt))->getLoc(),
1439 llvm::formatv(
"`return` terminated the `{0}` body, but found "
1440 "trailing statements afterwards",
1446 }
else if (!results.empty()) {
1448 {body->getLoc().End, body->getLoc().End},
1449 llvm::formatv(
"missing return in a `{0}` expected to return `{1}`",
1450 declType, resultType));
1455 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1456 return parseLambdaBody([&](
ast::Stmt *&statement) -> LogicalResult {
1457 if (isa<ast::OpRewriteStmt>(statement))
1461 "expected Pattern lambda body to contain a single operation "
1462 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1466 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1467 SMRange loc = curToken.getLoc();
1469 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1479 ParsedPatternMetadata metadata;
1480 if (consumeIf(
Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1488 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1489 if (failed(bodyResult))
1494 return emitError(
"expected `{` or `=>` to start pattern body");
1495 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1496 if (failed(bodyResult))
1501 auto bodyIt = body->
begin(), bodyE = body->
end();
1502 for (; bodyIt != bodyE; ++bodyIt) {
1503 if (isa<ast::ReturnStmt>(*bodyIt)) {
1505 "`return` statements are only permitted within a "
1506 "`Constraint` or `Rewrite` body");
1509 if (isa<ast::OpRewriteStmt>(*bodyIt))
1512 if (bodyIt == bodyE) {
1514 "expected Pattern body to terminate with an operation "
1515 "rewrite statement, such as `erase`");
1517 if (std::next(bodyIt) != bodyE) {
1518 return emitError((*std::next(bodyIt))->getLoc(),
1519 "Pattern body was terminated by an operation "
1520 "rewrite statement, but found trailing statements");
1524 return createPatternDecl(loc, name, metadata, body);
1528 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1529 std::optional<SMRange> benefitLoc;
1530 std::optional<SMRange> hasBoundedRecursionLoc;
1535 return codeCompletePatternMetadata();
1538 return emitError(
"expected pattern metadata identifier");
1539 StringRef metadataStr = curToken.getSpelling();
1540 SMRange metadataLoc = curToken.getLoc();
1544 if (metadataStr ==
"benefit") {
1546 return emitErrorAndNote(metadataLoc,
1547 "pattern benefit has already been specified",
1548 *benefitLoc,
"see previous definition here");
1551 "expected `(` before pattern benefit")))
1554 uint16_t benefitValue = 0;
1556 return emitError(
"expected integral pattern benefit");
1557 if (curToken.getSpelling().getAsInteger(10, benefitValue))
1559 "expected pattern benefit to fit within a 16-bit integer");
1562 metadata.benefit = benefitValue;
1563 benefitLoc = metadataLoc;
1566 parseToken(
Token::r_paren,
"expected `)` after pattern benefit")))
1572 if (metadataStr ==
"recursion") {
1573 if (hasBoundedRecursionLoc) {
1574 return emitErrorAndNote(
1576 "pattern recursion metadata has already been specified",
1577 *hasBoundedRecursionLoc,
"see previous definition here");
1579 metadata.hasBoundedRecursion =
true;
1580 hasBoundedRecursionLoc = metadataLoc;
1584 return emitError(metadataLoc,
"unknown pattern metadata");
1590 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1593 FailureOr<ast::Expr *> typeExpr = parseExpr();
1594 if (failed(typeExpr) ||
1596 "expected `>` after variable type constraint")))
1601 LogicalResult Parser::checkDefineNamedDecl(
const ast::Name &name) {
1602 assert(curDeclScope &&
"defining decl outside of a decl scope");
1604 return emitErrorAndNote(
1605 name.
getLoc(),
"`" + name.
getName() +
"` has already been defined",
1606 lastDecl->getName()->getLoc(),
"see previous definition here");
1611 FailureOr<ast::VariableDecl *>
1612 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1615 assert(curDeclScope &&
"defining variable outside of decl scope");
1620 if (name.empty() || name ==
"_") {
1624 if (failed(checkDefineNamedDecl(nameDecl)))
1629 curDeclScope->add(varDecl);
1633 FailureOr<ast::VariableDecl *>
1634 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1636 return defineVariableDecl(name, nameLoc, type,
nullptr,
1640 LogicalResult Parser::parseVariableDeclConstraintList(
1642 std::optional<SMRange> typeConstraint;
1643 auto parseSingleConstraint = [&] {
1644 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1645 typeConstraint, constraints,
true);
1646 if (failed(constraint))
1648 constraints.push_back(*constraint);
1654 return parseSingleConstraint();
1657 if (failed(parseSingleConstraint()))
1660 return parseToken(
Token::r_square,
"expected `]` after constraint list");
1663 FailureOr<ast::ConstraintRef>
1664 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1666 bool allowInlineTypeConstraints) {
1667 auto parseTypeConstraint = [&](
ast::Expr *&typeExpr) -> LogicalResult {
1668 if (!allowInlineTypeConstraints) {
1671 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1672 "permitted on arguments or results");
1675 return emitErrorAndNote(
1677 "the type of this variable has already been constrained",
1678 *typeConstraint,
"see previous constraint location here");
1679 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1680 if (failed(constraintExpr))
1682 typeExpr = *constraintExpr;
1683 typeConstraint = typeExpr->getLoc();
1687 SMRange loc = curToken.getLoc();
1688 switch (curToken.getKind()) {
1694 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1704 FailureOr<ast::OpNameDecl *> opName =
1705 parseWrappedOperationName(
true);
1724 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1735 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1744 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1750 StringRef constraintName = curToken.getSpelling();
1756 return emitError(loc,
"unknown reference to constraint `" +
1757 constraintName +
"`");
1761 if (
auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1764 return emitErrorAndNote(
1765 loc,
"invalid reference to non-constraint", cstDecl->
getLoc(),
1766 "see the definition of `" + constraintName +
"` here");
1772 if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1775 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1780 return emitError(loc,
"expected identifier constraint");
1783 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1784 std::optional<SMRange> typeConstraint;
1785 return parseConstraint(typeConstraint, std::nullopt,
1793 FailureOr<ast::Expr *> Parser::parseExpr() {
1795 return parseUnderscoreExpr();
1798 FailureOr<ast::Expr *> lhsExpr;
1799 switch (curToken.getKind()) {
1801 lhsExpr = parseAttributeExpr();
1804 lhsExpr = parseInlineConstraintLambdaExpr();
1807 lhsExpr = parseNegatedExpr();
1810 lhsExpr = parseIdentifierExpr();
1813 lhsExpr = parseOperationExpr();
1816 lhsExpr = parseInlineRewriteLambdaExpr();
1819 lhsExpr = parseTypeExpr();
1822 lhsExpr = parseTupleExpr();
1825 return emitError(
"expected expression");
1827 if (failed(lhsExpr))
1832 switch (curToken.getKind()) {
1834 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1837 lhsExpr = parseCallExpr(*lhsExpr);
1842 if (failed(lhsExpr))
1847 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1848 SMRange loc = curToken.getLoc();
1855 return parseIdentifierExpr();
1858 if (!curToken.isString())
1859 return emitError(
"expected string literal containing MLIR attribute");
1860 std::string attrExpr = curToken.getStringValue();
1863 loc.End = curToken.getEndLoc();
1865 parseToken(
Token::greater,
"expected `>` after attribute literal")))
1870 FailureOr<ast::Expr *> Parser::parseCallExpr(
ast::Expr *parentExpr,
1880 codeCompleteCallSignature(parentExpr, arguments.size());
1884 FailureOr<ast::Expr *> argument = parseExpr();
1885 if (failed(argument))
1887 arguments.push_back(*argument);
1891 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1892 if (failed(parseToken(
Token::r_paren,
"expected `)` after argument list")))
1895 return createCallExpr(loc, parentExpr, arguments, isNegated);
1898 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1899 ast::Decl *decl = curDeclScope->lookup(name);
1901 return emitError(loc,
"undefined reference to `" + name +
"`");
1903 return createDeclRefExpr(loc, decl);
1906 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1907 StringRef name = curToken.getSpelling();
1908 SMRange nameLoc = curToken.getLoc();
1915 if (failed(parseVariableDeclConstraintList(constraints)))
1918 if (failed(validateVariableConstraints(constraints, type)))
1920 return createInlineVariableExpr(type, name, nameLoc, constraints);
1923 return parseDeclRefExpr(name, nameLoc);
1926 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1927 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1935 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1936 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1944 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(
ast::Expr *parentExpr) {
1945 SMRange dotLoc = curToken.getLoc();
1950 return codeCompleteMemberAccess(parentExpr);
1953 Token memberNameTok = curToken;
1956 return emitError(dotLoc,
"expected identifier or numeric member name");
1957 StringRef memberName = memberNameTok.
getSpelling();
1958 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1961 return createMemberAccessExpr(parentExpr, memberName, loc);
1964 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1968 return emitError(
"expected native constraint");
1969 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1970 if (failed(identifierExpr))
1973 return emitError(
"expected `(` after function name");
1974 return parseCallExpr(*identifierExpr,
true);
1977 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(
bool allowEmptyName) {
1978 SMRange loc = curToken.getLoc();
1982 return codeCompleteDialectName();
1988 return emitError(
"expected dialect namespace");
1990 StringRef name = curToken.getSpelling();
1994 if (failed(parseToken(
Token::dot,
"expected `.` after dialect namespace")))
1999 return codeCompleteOperationName(name);
2002 return emitError(
"expected operation name after dialect namespace");
2004 name = StringRef(name.data(), name.size() + 1);
2006 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2007 loc.End = curToken.getEndLoc();
2010 curToken.isKeyword());
2014 FailureOr<ast::OpNameDecl *>
2015 Parser::parseWrappedOperationName(
bool allowEmptyName) {
2019 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2020 if (failed(opNameDecl))
2023 if (failed(parseToken(
Token::greater,
"expected `>` after operation name")))
2028 FailureOr<ast::Expr *>
2029 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2030 SMRange loc = curToken.getLoc();
2037 return parseIdentifierExpr();
2043 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2044 FailureOr<ast::OpNameDecl *> opNameDecl =
2045 parseWrappedOperationName(allowEmptyName);
2046 if (failed(opNameDecl))
2048 std::optional<StringRef> opName = (*opNameDecl)->getName();
2053 FailureOr<ast::VariableDecl *> rangeVar =
2055 assert(succeeded(rangeVar) &&
"expected range variable to be valid");
2066 if (parserContext != ParserContext::Rewrite) {
2067 operands.push_back(createImplicitRangeVar(
2075 codeCompleteOperationOperandsSignature(opName, operands.size());
2079 FailureOr<ast::Expr *> operand = parseExpr();
2080 if (failed(operand))
2082 operands.push_back(*operand);
2086 "expected `)` after operation operand list")))
2094 FailureOr<ast::NamedAttributeDecl *> decl =
2095 parseNamedAttributeDecl(opName);
2098 attributes.emplace_back(*decl);
2102 "expected `}` after operation attribute list")))
2108 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2113 "expected `(` before operation result type list")))
2121 resultTypeContext = OpResultTypeContext::Explicit;
2128 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2132 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2133 if (failed(resultTypeExpr))
2135 resultTypes.push_back(*resultTypeExpr);
2139 "expected `)` after operation result type list")))
2142 }
else if (parserContext != ParserContext::Rewrite) {
2147 resultTypes.push_back(createImplicitRangeVar(
2149 }
else if (resultTypeContext == OpResultTypeContext::Explicit) {
2152 resultTypeContext = OpResultTypeContext::Interface;
2155 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2156 attributes, resultTypes);
2159 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2160 SMRange loc = curToken.getLoc();
2169 StringRef elementName;
2171 Token elementNameTok = curToken;
2179 auto elementNameIt =
2180 usedNames.try_emplace(elementName, elementNameTok.
getLoc());
2181 if (!elementNameIt.second) {
2182 return emitErrorAndNote(
2184 llvm::formatv(
"duplicate tuple element label `{0}`",
2186 elementNameIt.first->getSecond(),
2187 "see previous label use here");
2192 resetToken(elementNameTok.
getLoc());
2195 elementNames.push_back(elementName);
2198 FailureOr<ast::Expr *> element = parseExpr();
2199 if (failed(element))
2201 elements.push_back(*element);
2204 loc.End = curToken.getEndLoc();
2206 parseToken(
Token::r_paren,
"expected `)` after tuple element list")))
2208 return createTupleExpr(loc, elements, elementNames);
2211 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2212 SMRange loc = curToken.getLoc();
2219 return parseIdentifierExpr();
2222 if (!curToken.isString())
2223 return emitError(
"expected string literal containing MLIR type");
2224 std::string attrExpr = curToken.getStringValue();
2227 loc.End = curToken.getEndLoc();
2228 if (failed(parseToken(
Token::greater,
"expected `>` after type literal")))
2233 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2234 StringRef name = curToken.getSpelling();
2235 SMRange nameLoc = curToken.getLoc();
2239 if (failed(parseToken(
Token::colon,
"expected `:` after `_` variable")))
2244 if (failed(parseVariableDeclConstraintList(constraints)))
2248 if (failed(validateVariableConstraints(constraints, type)))
2250 return createInlineVariableExpr(type, name, nameLoc, constraints);
2257 FailureOr<ast::Stmt *> Parser::parseStmt(
bool expectTerminalSemicolon) {
2258 FailureOr<ast::Stmt *> stmt;
2259 switch (curToken.getKind()) {
2261 stmt = parseEraseStmt();
2264 stmt = parseLetStmt();
2267 stmt = parseReplaceStmt();
2270 stmt = parseReturnStmt();
2273 stmt = parseRewriteStmt();
2280 (expectTerminalSemicolon &&
2286 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2287 SMLoc startLoc = curToken.getStartLoc();
2294 FailureOr<ast::Stmt *> statement = parseStmt();
2295 if (failed(statement))
2296 return popDeclScope(), failure();
2297 statements.push_back(*statement);
2302 SMRange location(startLoc, curToken.getEndLoc());
2308 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2309 if (parserContext == ParserContext::Constraint)
2310 return emitError(
"`erase` cannot be used within a Constraint");
2311 SMRange loc = curToken.getLoc();
2315 FailureOr<ast::Expr *> rootOp = parseExpr();
2319 return createEraseStmt(loc, *rootOp);
2322 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2323 SMRange loc = curToken.getLoc();
2327 SMRange varLoc = curToken.getLoc();
2332 "`_` may only be used to define \"inline\" variables");
2335 "expected identifier after `let` to name a new variable");
2337 StringRef varName = curToken.getSpelling();
2343 failed(parseVariableDeclConstraintList(constraints)))
2349 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2350 if (failed(initOrFailure))
2352 initializer = *initOrFailure;
2357 LogicalResult result =
2361 if (cst->getTypeExpr()) {
2363 constraint.referenceLoc,
2364 "type constraints are not permitted on variables with "
2369 .Default(success());
2375 FailureOr<ast::VariableDecl *> varDecl =
2376 createVariableDecl(varName, varLoc, initializer, constraints);
2377 if (failed(varDecl))
2382 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2383 if (parserContext == ParserContext::Constraint)
2384 return emitError(
"`replace` cannot be used within a Constraint");
2385 SMRange loc = curToken.getLoc();
2389 FailureOr<ast::Expr *> rootOp = parseExpr();
2394 parseToken(
Token::kw_with,
"expected `with` after root operation")))
2398 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2405 loc,
"expected at least one replacement value, consider using "
2406 "`erase` if no replacement values are desired");
2410 FailureOr<ast::Expr *> replExpr = parseExpr();
2411 if (failed(replExpr))
2413 replValues.emplace_back(*replExpr);
2417 "expected `)` after replacement values")))
2422 FailureOr<ast::Expr *> replExpr;
2424 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2426 replExpr = parseExpr();
2427 if (failed(replExpr))
2429 replValues.emplace_back(*replExpr);
2432 return createReplaceStmt(loc, *rootOp, replValues);
2435 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2436 SMRange loc = curToken.getLoc();
2440 FailureOr<ast::Expr *> resultExpr = parseExpr();
2441 if (failed(resultExpr))
2447 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2448 if (parserContext == ParserContext::Constraint)
2449 return emitError(
"`rewrite` cannot be used within a Constraint");
2450 SMRange loc = curToken.getLoc();
2454 FailureOr<ast::Expr *> rootOp = parseExpr();
2458 if (failed(parseToken(
Token::kw_with,
"expected `with` before rewrite body")))
2462 return emitError(
"expected `{` to start rewrite body");
2465 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2467 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2468 if (failed(rewriteBody))
2472 for (
const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2473 if (isa<ast::ReturnStmt>(stmt)) {
2475 "`return` statements are only permitted within a "
2476 "`Constraint` or `Rewrite` body");
2480 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2493 if (
auto *init = dyn_cast<ast::DeclRefExpr>(node))
2494 node = init->getDecl();
2495 return dyn_cast<ast::CallableDecl>(node);
2498 FailureOr<ast::PatternDecl *>
2499 Parser::createPatternDecl(SMRange loc,
const ast::Name *name,
2500 const ParsedPatternMetadata &metadata,
2503 metadata.hasBoundedRecursion, body);
2506 ast::Type Parser::createUserConstraintRewriteResultType(
2509 if (results.size() == 1)
2510 return results[0]->getType();
2514 auto resultTypes = llvm::map_range(
2515 results, [&](
const auto *result) {
return result->getType(); });
2516 auto resultNames = llvm::map_range(
2517 results, [&](
const auto *result) {
return result->getName().getName(); });
2519 llvm::to_vector(resultNames));
2522 template <
typename T>
2523 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2528 if (
auto *retStmt = dyn_cast<ast::ReturnStmt>(body->
getChildren().back())) {
2529 ast::Expr *resultExpr = retStmt->getResultExpr();
2534 if (results.empty())
2535 resultType = resultExpr->
getType();
2536 else if (failed(convertExpressionTo(resultExpr, resultType)))
2539 retStmt->setResultExpr(resultExpr);
2542 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2545 FailureOr<ast::VariableDecl *>
2546 Parser::createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
2551 if (failed(validateVariableConstraints(constraints, type)))
2558 type = initializer->
getType();
2561 else if (failed(convertExpressionTo(initializer, type)))
2567 return emitErrorAndNote(
2568 loc,
"unable to infer type for variable `" + name +
"`", loc,
2569 "the type of a variable must be inferable from the constraint "
2570 "list or the initializer");
2574 if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2576 loc, llvm::formatv(
"unable to define variable of `{0}` type", type));
2580 FailureOr<ast::VariableDecl *> varDecl =
2581 defineVariableDecl(name, loc, type, initializer, constraints);
2582 if (failed(varDecl))
2588 FailureOr<ast::VariableDecl *>
2589 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2592 if (failed(validateVariableConstraint(constraint, argType)))
2594 return defineVariableDecl(name, loc, argType, constraint);
2601 if (failed(validateVariableConstraint(ref, inferredType)))
2609 if (
const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.
constraint)) {
2610 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2611 if (failed(validateTypeConstraintExpr(typeExpr)))
2615 }
else if (
const auto *cst =
2616 dyn_cast<ast::OpConstraintDecl>(ref.
constraint)) {
2619 }
else if (isa<ast::TypeConstraintDecl>(ref.
constraint)) {
2620 constraintType = typeTy;
2621 }
else if (isa<ast::TypeRangeConstraintDecl>(ref.
constraint)) {
2622 constraintType = typeRangeTy;
2623 }
else if (
const auto *cst =
2624 dyn_cast<ast::ValueConstraintDecl>(ref.
constraint)) {
2625 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2626 if (failed(validateTypeConstraintExpr(typeExpr)))
2629 constraintType = valueTy;
2630 }
else if (
const auto *cst =
2631 dyn_cast<ast::ValueRangeConstraintDecl>(ref.
constraint)) {
2632 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2633 if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2636 constraintType = valueRangeTy;
2637 }
else if (
const auto *cst =
2638 dyn_cast<ast::UserConstraintDecl>(ref.
constraint)) {
2640 if (inputs.size() != 1) {
2642 "`Constraint`s applied via a variable constraint "
2643 "list must take a single input, but got " +
2644 Twine(inputs.size()),
2646 "see definition of constraint here");
2648 constraintType = inputs.front()->getType();
2650 llvm_unreachable(
"unknown constraint type");
2655 if (!inferredType) {
2656 inferredType = constraintType;
2658 inferredType = mergedTy;
2661 llvm::formatv(
"constraint type `{0}` is incompatible "
2662 "with the previously inferred type `{1}`",
2663 constraintType, inferredType));
2668 LogicalResult Parser::validateTypeConstraintExpr(
const ast::Expr *typeExpr) {
2670 if (typeExprType != typeTy) {
2672 "expected expression of `Type` in type constraint");
2678 Parser::validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr) {
2680 if (typeExprType != typeRangeTy) {
2682 "expected expression of `TypeRange` in type constraint");
2691 FailureOr<ast::CallExpr *>
2692 Parser::createCallExpr(SMRange loc,
ast::Expr *parentExpr,
2697 if (!callableDecl) {
2699 llvm::formatv(
"expected a reference to a callable "
2700 "`Constraint` or `Rewrite`, but got: `{0}`",
2703 if (parserContext == ParserContext::Rewrite) {
2704 if (isa<ast::UserConstraintDecl>(callableDecl))
2706 loc,
"unable to invoke `Constraint` within a rewrite section");
2708 return emitError(loc,
"unable to negate a Rewrite");
2710 if (isa<ast::UserRewriteDecl>(callableDecl))
2712 "unable to invoke `Rewrite` within a match section");
2713 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2714 return emitError(loc,
"unable to negate non native constraints");
2720 if (callArgs.size() != arguments.size()) {
2721 return emitErrorAndNote(
2723 llvm::formatv(
"invalid number of arguments for {0} call; expected "
2728 llvm::formatv(
"see the definition of {0} here",
2734 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here",
2738 for (
auto it : llvm::zip(callArgs, arguments)) {
2739 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->
getType(),
2748 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2752 if (isa<ast::ConstraintDecl>(decl))
2754 else if (isa<ast::UserRewriteDecl>(decl))
2756 else if (
auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2757 declType = varDecl->getType();
2759 return emitError(loc,
"invalid reference to `" +
2765 FailureOr<ast::DeclRefExpr *>
2766 Parser::createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
2768 FailureOr<ast::VariableDecl *> decl =
2769 defineVariableDecl(name, loc, type, constraints);
2775 FailureOr<ast::MemberAccessExpr *>
2776 Parser::createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name,
2779 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2780 if (failed(memberType))
2786 FailureOr<ast::Type> Parser::validateMemberAccess(
ast::Expr *parentExpr,
2787 StringRef name, SMRange loc) {
2791 return valueRangeTy;
2795 auto results = odsOp->getResults();
2799 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2800 index < results.size()) {
2801 return results[index].isVariadic() ? valueRangeTy : valueTy;
2805 const auto *it = llvm::find_if(results, [&](
const auto &result) {
2806 return result.getName() == name;
2808 if (it != results.end())
2809 return it->isVariadic() ? valueRangeTy : valueTy;
2810 }
else if (llvm::isDigit(name[0])) {
2815 }
else if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2818 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2819 index < tupleType.size()) {
2820 return tupleType.getElementTypes()[index];
2824 auto elementNames = tupleType.getElementNames();
2825 const auto *it = llvm::find(elementNames, name);
2826 if (it != elementNames.end())
2827 return tupleType.getElementTypes()[it - elementNames.begin()];
2831 llvm::formatv(
"invalid member access `{0}` on expression of type `{1}`",
2835 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2837 OpResultTypeContext resultTypeContext,
2841 std::optional<StringRef> opNameRef = name->
getName();
2845 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2851 ast::Type attrType = attr->getValue()->getType();
2852 if (!isa<ast::AttributeType>(attrType)) {
2854 attr->getValue()->getLoc(),
2855 llvm::formatv(
"expected `Attr` expression, but got `{0}`", attrType));
2860 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2861 "unexpected inferrence when results were explicitly specified");
2865 if (resultTypeContext == OpResultTypeContext::Explicit) {
2866 if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2870 }
else if (resultTypeContext == OpResultTypeContext::Interface) {
2872 "expected valid operation name when inferring operation results");
2873 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2881 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2884 return validateOperationOperandsOrResults(
2885 "operand", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2886 operands, odsOp ? odsOp->
getOperands() : std::nullopt, valueTy,
2891 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2894 return validateOperationOperandsOrResults(
2895 "result", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2896 results, odsOp ? odsOp->
getResults() : std::nullopt, typeTy, typeRangeTy);
2899 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2910 "operation result types are marked to be inferred, but "
2911 "`{0}` is unknown. Ensure that `{0}` supports zero "
2912 "results or implements `InferTypeOpInterface`. Include "
2913 "the ODS definition of this operation to remove this warning.",
2922 bool requiresInferrence =
2924 return !result.isVariableLength();
2929 llvm::formatv(
"operation result types are marked to be inferred, but "
2930 "`{0}` does not provide an implementation of "
2931 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2932 "`InferTypeOpInterface` at runtime, or add support to "
2933 "the ODS definition to remove this warning.",
2935 diag->attachNote(llvm::formatv(
"see the definition of `{0}` here", opName),
2941 LogicalResult Parser::validateOperationOperandsOrResults(
2942 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2947 if (values.size() == 1) {
2948 if (failed(convertExpressionTo(values[0], rangeTy)))
2956 auto emitSizeMismatchError = [&] {
2957 return emitErrorAndNote(
2959 llvm::formatv(
"invalid number of {0} groups for `{1}`; expected "
2961 groupName, *name, odsValues.size(), values.size()),
2962 *odsOpLoc, llvm::formatv(
"see the definition of `{0}` here", *name));
2966 if (values.empty()) {
2968 if (odsValues.empty())
2973 unsigned numVariadic = 0;
2974 for (
const auto &odsValue : odsValues) {
2975 if (!odsValue.isVariableLength())
2976 return emitSizeMismatchError();
2982 if (parserContext != ParserContext::Rewrite)
2989 if (numVariadic == 1)
2994 for (
unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2996 ctx, loc, std::nullopt, rangeTy));
3003 if (odsValues.size() != values.size())
3004 return emitSizeMismatchError();
3007 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here", *name),
3010 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
3011 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3012 if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3021 ast::Type valueExprType = valueExpr->getType();
3024 if (valueExprType == rangeTy || valueExprType == singleTy)
3030 if (singleTy == valueTy) {
3031 if (isa<ast::OperationType>(valueExprType)) {
3032 valueExpr = convertOpToValue(valueExpr);
3038 if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3042 valueExpr->getLoc(),
3044 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3045 singleTy, rangeTy, valueExprType));
3050 FailureOr<ast::TupleExpr *>
3053 for (
const ast::Expr *element : elements) {
3055 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3058 llvm::formatv(
"unable to build a tuple with `{0}` element", eleTy));
3068 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3072 if (!isa<ast::OperationType>(rootType))
3078 FailureOr<ast::ReplaceStmt *>
3079 Parser::createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
3083 if (!isa<ast::OperationType>(rootType)) {
3086 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3091 bool shouldConvertOpToValues = replValues.size() > 1;
3092 for (
ast::Expr *&replExpr : replValues) {
3093 ast::Type replType = replExpr->getType();
3096 if (isa<ast::OperationType>(replType)) {
3097 if (shouldConvertOpToValues)
3098 replExpr = convertOpToValue(replExpr);
3102 if (replType != valueTy && replType != valueRangeTy) {
3104 llvm::formatv(
"expected `Op`, `Value` or `ValueRange` "
3105 "expression, but got `{0}`",
3113 FailureOr<ast::RewriteStmt *>
3114 Parser::createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
3118 if (!isa<ast::OperationType>(rootType)) {
3121 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3131 LogicalResult Parser::codeCompleteMemberAccess(
ast::Expr *parentExpr) {
3135 else if (
ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3141 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3148 Parser::codeCompleteConstraintName(
ast::Type inferredType,
3149 bool allowInlineTypeConstraints) {
3151 inferredType, allowInlineTypeConstraints, curDeclScope);
3155 LogicalResult Parser::codeCompleteDialectName() {
3160 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3165 LogicalResult Parser::codeCompletePatternMetadata() {
3170 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3175 void Parser::codeCompleteCallSignature(
ast::Node *parent,
3176 unsigned currentNumArgs) {
3184 void Parser::codeCompleteOperationOperandsSignature(
3185 std::optional<StringRef> opName,
unsigned currentNumOperands) {
3187 opName, currentNumOperands);
3190 void Parser::codeCompleteOperationResultsSignature(
3191 std::optional<StringRef> opName,
unsigned currentNumResults) {
3200 FailureOr<ast::Module *>
3202 bool enableDocumentation,
3204 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3205 return parser.parseModule();
union mlir::linalg::@1196::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
This class breaks up the current file into a token stream.
This represents a token in the MLIR syntax.
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
StringRef getSpelling() const
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
@ code_complete_string
Token signifying a code completion location within a string.
@ code_complete
Token signifying a code completion location.
@ less
Paired punctuation.
@ kw_Attr
General keywords.
static StringRef getMemberName()
Return the member name used for the "all-results" access.
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
static AttributeType get(Context &context)
Return an instance of the Attribute type.
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
This decl represents a shared interface for all callable decls.
Type getResultType() const
Return the result type of this decl.
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
StringRef getCallableType() const
Return the callable type of this decl.
This statement represents a compound statement, which contains a collection of other statements.
ArrayRef< Stmt * >::iterator begin() const
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
ArrayRef< Stmt * >::iterator end() const
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
This class represents the base of all AST Constraint decls.
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
This class represents the main context of the PDLL AST.
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
ods::Context & getODSContext()
Return the ODS context used by the AST.
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
This class represents a scope for named AST decls.
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
void add(Decl *decl)
Add a new decl to the scope.
This class represents the base Decl node.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
This class provides a simple implementation of a PDLL diagnostic.
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a diagnostic that is inflight and set to be reported.
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
This Decl represents a NamedAttribute, and contains a string name and attribute value.
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
This Decl represents an OperationName.
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
static OpNameDecl * create(Context &ctx, const Name &name)
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
This class represents a PDLL type that corresponds to a range of elements with a given element type.
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
This class represents a base AST Statement node.
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
This class represents a PDLL tuple type, i.e.
size_t size() const
Return the number of elements within this tuple.
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
static TypeType get(Context &context)
Return an instance of the Type type.
Type refineWith(Type other) const
Try to refine this type with the one provided.
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
The class represents a Value constraint, and constrains a variable to be a Value.
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
static ValueType get(Context &context)
Return an instance of the Value type.
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
This class contains all of the registered ODS operation classes.
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
StringRef getDescription() const
std::string getConditionTemplate() const
Format context containing substitutions for special placeholders.
FmtContext & withSelf(Twine subst)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
const ConstraintDecl * constraint
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.
static const Name & create(Context &ctx, StringRef name, SMRange location)