25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
48 : ctx(ctx), lexer(sourceMgr, ctx.
getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
54 codeCompleteContext(codeCompleteContext) {}
57 FailureOr<ast::Module *> parseModule();
65 enum class ParserContext {
82 enum class OpResultTypeContext {
101 return (curDeclScope = newScope);
103 void pushDeclScope(
ast::DeclScope *scope) { curDeclScope = scope; }
106 void popDeclScope() { curDeclScope = curDeclScope->
getParentScope(); }
115 LogicalResult convertExpressionTo(
122 LogicalResult convertTupleExpressionTo(
133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
139 StringRef processDoc(StringRef doc) {
140 return enableDocumentation ? doc : StringRef();
145 std::string processAndFormatDoc(
const Twine &doc) {
146 if (!enableDocumentation)
150 llvm::raw_string_ostream docOS(docStr);
151 std::string tmpDocStr = doc.str();
153 StringRef(tmpDocStr).rtrim(
" \t"));
163 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
167 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
172 template <
typename Constra
intT>
174 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
176 StringRef nativeType, StringRef docString);
177 template <
typename Constra
intT>
181 StringRef nativeType);
187 struct ParsedPatternMetadata {
188 std::optional<uint16_t> benefit;
189 bool hasBoundedRecursion =
false;
192 FailureOr<ast::Decl *> parseTopLevelDecl();
193 FailureOr<ast::NamedAttributeDecl *>
194 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
198 FailureOr<ast::VariableDecl *> parseArgumentDecl();
202 FailureOr<ast::VariableDecl *> parseResultDecl(
unsigned resultNum);
206 FailureOr<ast::UserConstraintDecl *>
207 parseUserConstraintDecl(
bool isInline =
false);
211 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
215 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
222 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(
bool isInline =
false);
226 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
230 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
238 template <
typename T,
typename ParseUserPDLLDeclFnT>
239 FailureOr<T *> parseUserConstraintOrRewriteDecl(
240 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
241 StringRef anonymousNamePrefix,
bool isInline);
245 template <
typename T>
246 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
253 LogicalResult parseUserConstraintOrRewriteSignature(
260 LogicalResult validateUserConstraintOrRewriteReturn(
266 FailureOr<ast::CompoundStmt *>
268 bool expectTerminalSemicolon =
true);
269 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
270 FailureOr<ast::Decl *> parsePatternDecl();
271 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
275 LogicalResult checkDefineNamedDecl(
const ast::Name &name);
279 FailureOr<ast::VariableDecl *>
280 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
283 FailureOr<ast::VariableDecl *>
284 defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
288 LogicalResult parseVariableDeclConstraintList(
292 FailureOr<ast::Expr *> parseTypeConstraintExpr();
300 FailureOr<ast::ConstraintRef>
301 parseConstraint(std::optional<SMRange> &typeConstraint,
303 bool allowInlineTypeConstraints);
308 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
313 FailureOr<ast::Expr *> parseExpr();
316 FailureOr<ast::Expr *> parseAttributeExpr();
317 FailureOr<ast::Expr *> parseCallExpr(
ast::Expr *parentExpr,
318 bool isNegated =
false);
319 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
320 FailureOr<ast::Expr *> parseIdentifierExpr();
321 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
322 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
323 FailureOr<ast::Expr *> parseMemberAccessExpr(
ast::Expr *parentExpr);
324 FailureOr<ast::Expr *> parseNegatedExpr();
325 FailureOr<ast::OpNameDecl *> parseOperationName(
bool allowEmptyName =
false);
326 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(
bool allowEmptyName);
327 FailureOr<ast::Expr *>
328 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
329 OpResultTypeContext::Explicit);
330 FailureOr<ast::Expr *> parseTupleExpr();
331 FailureOr<ast::Expr *> parseTypeExpr();
332 FailureOr<ast::Expr *> parseUnderscoreExpr();
337 FailureOr<ast::Stmt *> parseStmt(
bool expectTerminalSemicolon =
true);
338 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
339 FailureOr<ast::EraseStmt *> parseEraseStmt();
340 FailureOr<ast::LetStmt *> parseLetStmt();
341 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
342 FailureOr<ast::ReturnStmt *> parseReturnStmt();
343 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
358 FailureOr<ast::PatternDecl *>
359 createPatternDecl(SMRange loc,
const ast::Name *name,
360 const ParsedPatternMetadata &metadata,
369 template <
typename T>
370 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
377 FailureOr<ast::VariableDecl *>
378 createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
383 FailureOr<ast::VariableDecl *>
384 createArgOrResultVariableDecl(StringRef name, SMRange loc,
401 LogicalResult validateTypeConstraintExpr(
const ast::Expr *typeExpr);
402 LogicalResult validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr);
407 FailureOr<ast::CallExpr *>
408 createCallExpr(SMRange loc,
ast::Expr *parentExpr,
410 bool isNegated =
false);
411 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc,
ast::Decl *decl);
412 FailureOr<ast::DeclRefExpr *>
413 createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
415 FailureOr<ast::MemberAccessExpr *>
416 createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name, SMRange loc);
420 FailureOr<ast::Type> validateMemberAccess(
ast::Expr *parentExpr,
421 StringRef name, SMRange loc);
422 FailureOr<ast::OperationExpr *>
424 OpResultTypeContext resultTypeContext,
429 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
432 LogicalResult validateOperationResults(SMRange loc,
433 std::optional<StringRef> name,
436 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
438 LogicalResult validateOperationOperandsOrResults(
439 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
443 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
450 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc,
ast::Expr *rootOp);
451 FailureOr<ast::ReplaceStmt *>
452 createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
454 FailureOr<ast::RewriteStmt *>
455 createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
466 LogicalResult codeCompleteMemberAccess(
ast::Expr *parentExpr);
467 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
468 LogicalResult codeCompleteConstraintName(
ast::Type inferredType,
469 bool allowInlineTypeConstraints);
470 LogicalResult codeCompleteDialectName();
471 LogicalResult codeCompleteOperationName(StringRef dialectName);
472 LogicalResult codeCompletePatternMetadata();
473 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
475 void codeCompleteCallSignature(
ast::Node *parent,
unsigned currentNumArgs);
476 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
477 unsigned currentNumOperands);
478 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
479 unsigned currentNumResults);
488 if (curToken.isNot(kind))
495 void consumeToken() {
497 "shouldn't advance past EOF or errors");
498 curToken = lexer.lexToken();
505 assert(curToken.is(kind) &&
"consumed an unexpected token");
510 void resetToken(SMRange tokLoc) {
511 lexer.resetPointer(tokLoc.Start.getPointer());
512 curToken = lexer.lexToken();
517 LogicalResult parseToken(
Token::Kind kind,
const Twine &msg) {
518 if (curToken.getKind() != kind)
519 return emitError(curToken.getLoc(), msg);
523 LogicalResult
emitError(SMRange loc,
const Twine &msg) {
524 lexer.emitError(loc, msg);
527 LogicalResult
emitError(
const Twine &msg) {
528 return emitError(curToken.getLoc(), msg);
530 LogicalResult emitErrorAndNote(SMRange loc,
const Twine &msg, SMRange noteLoc,
532 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
551 bool enableDocumentation;
555 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
558 ParserContext parserContext = ParserContext::Global;
566 unsigned anonymousDeclNameCounter = 0;
573 FailureOr<ast::Module *> Parser::parseModule() {
574 SMLoc moduleLoc = curToken.getStartLoc();
579 if (failed(parseModuleBody(decls)))
580 return popDeclScope(), failure();
589 if (failed(parseDirective(decls)))
594 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
597 decls.push_back(*decl);
607 LogicalResult Parser::convertExpressionTo(
611 if (exprType == type)
616 expr->
getLoc(), llvm::formatv(
"unable to convert expression of type "
617 "`{0}` to the expected type of "
625 if (
auto exprOpType = dyn_cast<ast::OperationType>(exprType))
626 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
632 if ((exprType == valueTy || exprType == valueRangeTy) &&
633 (type == valueTy || type == valueRangeTy))
635 if ((exprType == typeTy || exprType == typeRangeTy) &&
636 (type == typeTy || type == typeRangeTy))
640 if (
auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
641 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
644 return emitConvertError();
647 LogicalResult Parser::convertOpExpressionTo(
652 if (
auto opType = dyn_cast<ast::OperationType>(type)) {
653 if (opType.getName())
654 return emitErrorFn();
659 if (type == valueRangeTy) {
666 if (type == valueTy) {
670 if (odsOp->getResults().empty()) {
671 return emitErrorFn()->attachNote(
672 llvm::formatv(
"see the definition of `{0}`, which was defined "
678 unsigned numSingleResults = llvm::count_if(
680 return result.getVariableLengthKind() ==
681 ods::VariableLengthKind::Single;
683 if (numSingleResults > 1) {
684 return emitErrorFn()->attachNote(
685 llvm::formatv(
"see the definition of `{0}`, which was defined "
686 "with at least {1} results",
687 odsOp->getName(), numSingleResults),
696 return emitErrorFn();
699 LogicalResult Parser::convertTupleExpressionTo(
704 if (
auto tupleType = dyn_cast<ast::TupleType>(type)) {
705 if (tupleType.size() != exprType.
size())
706 return emitErrorFn();
711 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
713 ctx, expr->
getLoc(), expr, llvm::to_string(i),
717 diag.attachNote(llvm::formatv(
"when converting element #{0} of `{1}`",
722 if (failed(convertExpressionTo(newExprs.back(),
723 tupleType.getElementTypes()[i], diagFn)))
727 tupleType.getElementNames());
735 if (parserContext != ParserContext::Rewrite) {
736 return emitErrorFn()->attachNote(
"Tuple to Range conversion is currently "
737 "only allowed within a rewrite context");
742 if (!llvm::is_contained(allowedElementTypes, elementType))
743 return emitErrorFn();
748 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
750 ctx, expr->
getLoc(), expr, llvm::to_string(i),
756 if (type == valueRangeTy)
757 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
758 if (type == typeRangeTy)
759 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
761 return emitErrorFn();
768 StringRef directive = curToken.getSpelling();
769 if (directive ==
"#include")
770 return parseInclude(decls);
772 return emitError(
"unknown directive `" + directive +
"`");
776 SMRange loc = curToken.getLoc();
781 return codeCompleteIncludeFilename(curToken.getStringValue());
784 if (!curToken.isString())
786 "expected string file name after `include` directive");
787 SMRange fileLoc = curToken.getLoc();
788 std::string filenameStr = curToken.getStringValue();
789 StringRef filename = filenameStr;
794 if (filename.ends_with(
".pdll")) {
795 if (failed(lexer.pushInclude(filename, fileLoc)))
797 "unable to open include file `" + filename +
"`");
802 curToken = lexer.lexToken();
803 LogicalResult result = parseModuleBody(decls);
804 curToken = lexer.lexToken();
809 if (filename.ends_with(
".td"))
810 return parseTdInclude(filename, fileLoc, decls);
813 "expected include filename to end with `.pdll` or `.td`");
816 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
821 std::string includedFile;
822 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
823 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
825 return emitError(fileLoc,
"unable to open include file `" + filename +
"`");
828 llvm::SourceMgr tdSrcMgr;
829 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
830 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
834 struct DiagHandlerContext {
838 } handlerContext{*
this, filename, fileLoc};
841 tdSrcMgr.setDiagHandler(
842 [](
const llvm::SMDiagnostic &
diag,
void *rawHandlerContext) {
843 auto *ctx =
reinterpret_cast<DiagHandlerContext *
>(rawHandlerContext);
844 (void)ctx->parser.emitError(
846 llvm::formatv(
"error while processing include file `{0}`: {1}",
847 ctx->filename,
diag.getMessage()));
852 llvm::RecordKeeper tdRecords;
853 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
857 processTdIncludeRecords(tdRecords, decls);
862 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
866 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
869 auto getLengthKind = [](
const auto &value) {
870 if (value.isOptional())
881 cst.constraint.getUniqueDefName(),
882 processDoc(cst.constraint.getSummary()),
883 cst.constraint.getCPPClassName());
885 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
886 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
891 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Op")) {
895 bool supportsResultTypeInferrence =
896 op.getTrait(
"::mlir::InferTypeOpInterface::Trait");
899 op.getOperationName(), processDoc(op.getSummary()),
900 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
901 supportsResultTypeInferrence, op.
getLoc().front());
908 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
910 attr.attr.getUniqueDefName(),
911 processDoc(attr.attr.getSummary()),
912 attr.attr.getStorageType()));
915 odsOp->appendOperand(operand.name, getLengthKind(operand),
916 addTypeConstraint(operand));
919 odsOp->appendResult(result.name, getLengthKind(result),
920 addTypeConstraint(result));
924 auto shouldBeSkipped = [
this](llvm::Record *def) {
925 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
926 def->isSubClassOf(
"DeclareInterfaceMethods");
930 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Attr")) {
931 if (shouldBeSkipped(def))
935 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
936 constraint, convertLocToRange(def->getLoc().front()), attrTy,
937 constraint.getStorageType()));
940 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Type")) {
941 if (shouldBeSkipped(def))
945 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
946 constraint, convertLocToRange(def->getLoc().front()), typeTy,
947 constraint.getCPPClassName()));
951 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"OpInterface")) {
952 if (shouldBeSkipped(def))
955 SMRange loc = convertLocToRange(def->getLoc().front());
957 std::string cppClassName =
958 llvm::formatv(
"{0}::{1}", def->getValueAsString(
"cppNamespace"),
959 def->getValueAsString(
"cppInterfaceName"))
961 std::string codeBlock =
962 llvm::formatv(
"return ::mlir::success(llvm::isa<{0}>(self));",
967 processAndFormatDoc(def->getValueAsString(
"description"));
968 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
969 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
973 template <
typename Constra
intT>
974 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
975 StringRef name, StringRef codeBlock, SMRange loc,
ast::Type type,
976 StringRef nativeType, StringRef docString) {
982 argScope->
add(paramVar);
990 constraintDecl->setDocComment(ctx, docString);
991 curDeclScope->add(constraintDecl);
992 return constraintDecl;
995 template <
typename Constra
intT>
999 StringRef nativeType) {
1010 std::string docString;
1011 if (enableDocumentation) {
1013 docString = processAndFormatDoc(
1018 return createODSNativePDLLConstraintDecl<ConstraintT>(
1026 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1027 FailureOr<ast::Decl *> decl;
1028 switch (curToken.getKind()) {
1030 decl = parseUserConstraintDecl();
1033 decl = parsePatternDecl();
1036 decl = parseUserRewriteDecl();
1039 return emitError(
"expected top-level declaration, such as a `Pattern`");
1046 if (failed(checkDefineNamedDecl(*name)))
1048 curDeclScope->add(*decl);
1053 FailureOr<ast::NamedAttributeDecl *>
1054 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1057 return codeCompleteAttributeName(parentOpName);
1059 std::string attrNameStr;
1060 if (curToken.isString())
1061 attrNameStr = curToken.getStringValue();
1063 attrNameStr = curToken.getSpelling().str();
1065 return emitError(
"expected identifier or string attribute name");
1072 FailureOr<ast::Expr *> attrExpr = parseExpr();
1073 if (failed(attrExpr))
1075 attrValue = *attrExpr;
1085 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1087 bool expectTerminalSemicolon) {
1091 SMLoc bodyStartLoc = curToken.getStartLoc();
1093 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1094 bool failedToParse =
1095 failed(singleStatement) || failed(processStatementFn(*singleStatement));
1100 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1104 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1107 return emitError(
"expected identifier argument name");
1110 StringRef name = curToken.getSpelling();
1111 SMRange nameLoc = curToken.getLoc();
1115 parseToken(
Token::colon,
"expected `:` before argument constraint")))
1118 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1122 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1125 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(
unsigned resultNum) {
1133 StringRef name = curToken.getSpelling();
1134 SMRange nameLoc = curToken.getLoc();
1138 "expected `:` before result constraint")))
1141 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1145 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1151 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1155 return createArgOrResultVariableDecl(
"", cst->referenceLoc, *cst);
1158 FailureOr<ast::UserConstraintDecl *>
1159 Parser::parseUserConstraintDecl(
bool isInline) {
1162 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1163 [&](
auto &&...args) {
1164 return this->parseUserPDLLConstraintDecl(args...);
1166 ParserContext::Constraint,
"constraint", isInline);
1169 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1170 FailureOr<ast::UserConstraintDecl *> decl =
1171 parseUserConstraintDecl(
true);
1172 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1175 curDeclScope->add(*decl);
1179 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1185 pushDeclScope(argumentScope);
1191 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1192 [&](
ast::Stmt *&stmt) -> LogicalResult {
1193 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1196 "expected `Constraint` lambda body to contain a "
1197 "single expression");
1203 if (failed(bodyResult))
1207 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1208 if (failed(bodyResult))
1213 auto bodyIt = body->
begin(), bodyE = body->
end();
1214 for (; bodyIt != bodyE; ++bodyIt)
1215 if (isa<ast::ReturnStmt>(*bodyIt))
1217 if (failed(validateUserConstraintOrRewriteReturn(
1218 "Constraint", body, bodyIt, bodyE, results, resultType)))
1223 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1224 name, arguments, results, resultType, body);
1227 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(
bool isInline) {
1230 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1231 [&](
auto &&...args) {
return this->parseUserPDLLRewriteDecl(args...); },
1232 ParserContext::Rewrite,
"rewrite", isInline);
1235 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1236 FailureOr<ast::UserRewriteDecl *> decl =
1237 parseUserRewriteDecl(
true);
1238 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1241 curDeclScope->add(*decl);
1245 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1251 curDeclScope = argumentScope;
1254 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1255 [&](
ast::Stmt *&statement) -> LogicalResult {
1256 if (isa<ast::OpRewriteStmt>(statement))
1259 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1260 if (!statementExpr) {
1263 "expected `Rewrite` lambda body to contain a single expression "
1264 "or an operation rewrite statement; such as `erase`, "
1265 "`replace`, or `rewrite`");
1272 if (failed(bodyResult))
1276 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1277 if (failed(bodyResult))
1284 auto bodyIt = body->
begin(), bodyE = body->
end();
1285 for (; bodyIt != bodyE; ++bodyIt)
1286 if (isa<ast::ReturnStmt>(*bodyIt))
1288 if (failed(validateUserConstraintOrRewriteReturn(
"Rewrite", body, bodyIt,
1289 bodyE, results, resultType)))
1291 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1292 name, arguments, results, resultType, body);
1295 template <
typename T,
typename ParseUserPDLLDeclFnT>
1296 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1297 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1298 StringRef anonymousNamePrefix,
bool isInline) {
1299 SMRange loc = curToken.getLoc();
1301 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1309 return emitError(
"expected identifier name");
1313 std::string anonName =
1314 llvm::formatv(
"<anonymous_{0}_{1}>", anonymousNamePrefix,
1315 anonymousDeclNameCounter++)
1328 if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1329 argumentScope, resultType)))
1335 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1339 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1340 results, resultType);
1343 template <
typename T>
1344 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1349 std::string codeStrStorage;
1350 std::optional<StringRef> optCodeStr;
1351 if (curToken.isString()) {
1352 codeStrStorage = curToken.getStringValue();
1353 optCodeStr = codeStrStorage;
1355 }
else if (isInline) {
1357 "external declarations must be declared in global scope");
1362 "expected `;` after native declaration")))
1364 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1367 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1372 if (failed(parseToken(
Token::l_paren,
"expected `(` to start argument list")))
1375 argumentScope = pushDeclScope();
1378 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1379 if (failed(argument))
1381 arguments.emplace_back(*argument);
1385 if (failed(parseToken(
Token::r_paren,
"expected `)` to end argument list")))
1391 auto parseResultFn = [&]() -> LogicalResult {
1392 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1395 results.emplace_back(*result);
1402 if (failed(parseResultFn()))
1405 if (failed(parseToken(
Token::r_paren,
"expected `)` to end result list")))
1409 }
else if (failed(parseResultFn())) {
1416 resultType = createUserConstraintRewriteResultType(results);
1419 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1421 results.front()->getLoc(),
1422 "cannot create a single-element tuple with an element label");
1427 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1433 if (bodyIt != bodyE) {
1435 if (std::next(bodyIt) != bodyE) {
1437 (*std::next(bodyIt))->getLoc(),
1438 llvm::formatv(
"`return` terminated the `{0}` body, but found "
1439 "trailing statements afterwards",
1445 }
else if (!results.empty()) {
1447 {body->getLoc().End, body->getLoc().End},
1448 llvm::formatv(
"missing return in a `{0}` expected to return `{1}`",
1449 declType, resultType));
1454 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1455 return parseLambdaBody([&](
ast::Stmt *&statement) -> LogicalResult {
1456 if (isa<ast::OpRewriteStmt>(statement))
1460 "expected Pattern lambda body to contain a single operation "
1461 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1465 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1466 SMRange loc = curToken.getLoc();
1468 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1478 ParsedPatternMetadata metadata;
1479 if (consumeIf(
Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1487 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1488 if (failed(bodyResult))
1493 return emitError(
"expected `{` or `=>` to start pattern body");
1494 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1495 if (failed(bodyResult))
1500 auto bodyIt = body->
begin(), bodyE = body->
end();
1501 for (; bodyIt != bodyE; ++bodyIt) {
1502 if (isa<ast::ReturnStmt>(*bodyIt)) {
1504 "`return` statements are only permitted within a "
1505 "`Constraint` or `Rewrite` body");
1508 if (isa<ast::OpRewriteStmt>(*bodyIt))
1511 if (bodyIt == bodyE) {
1513 "expected Pattern body to terminate with an operation "
1514 "rewrite statement, such as `erase`");
1516 if (std::next(bodyIt) != bodyE) {
1517 return emitError((*std::next(bodyIt))->getLoc(),
1518 "Pattern body was terminated by an operation "
1519 "rewrite statement, but found trailing statements");
1523 return createPatternDecl(loc, name, metadata, body);
1527 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1528 std::optional<SMRange> benefitLoc;
1529 std::optional<SMRange> hasBoundedRecursionLoc;
1534 return codeCompletePatternMetadata();
1537 return emitError(
"expected pattern metadata identifier");
1538 StringRef metadataStr = curToken.getSpelling();
1539 SMRange metadataLoc = curToken.getLoc();
1543 if (metadataStr ==
"benefit") {
1545 return emitErrorAndNote(metadataLoc,
1546 "pattern benefit has already been specified",
1547 *benefitLoc,
"see previous definition here");
1550 "expected `(` before pattern benefit")))
1553 uint16_t benefitValue = 0;
1555 return emitError(
"expected integral pattern benefit");
1556 if (curToken.getSpelling().getAsInteger(10, benefitValue))
1558 "expected pattern benefit to fit within a 16-bit integer");
1561 metadata.benefit = benefitValue;
1562 benefitLoc = metadataLoc;
1565 parseToken(
Token::r_paren,
"expected `)` after pattern benefit")))
1571 if (metadataStr ==
"recursion") {
1572 if (hasBoundedRecursionLoc) {
1573 return emitErrorAndNote(
1575 "pattern recursion metadata has already been specified",
1576 *hasBoundedRecursionLoc,
"see previous definition here");
1578 metadata.hasBoundedRecursion =
true;
1579 hasBoundedRecursionLoc = metadataLoc;
1583 return emitError(metadataLoc,
"unknown pattern metadata");
1589 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1592 FailureOr<ast::Expr *> typeExpr = parseExpr();
1593 if (failed(typeExpr) ||
1595 "expected `>` after variable type constraint")))
1600 LogicalResult Parser::checkDefineNamedDecl(
const ast::Name &name) {
1601 assert(curDeclScope &&
"defining decl outside of a decl scope");
1603 return emitErrorAndNote(
1604 name.
getLoc(),
"`" + name.
getName() +
"` has already been defined",
1605 lastDecl->getName()->getLoc(),
"see previous definition here");
1610 FailureOr<ast::VariableDecl *>
1611 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1614 assert(curDeclScope &&
"defining variable outside of decl scope");
1619 if (name.empty() || name ==
"_") {
1623 if (failed(checkDefineNamedDecl(nameDecl)))
1628 curDeclScope->add(varDecl);
1632 FailureOr<ast::VariableDecl *>
1633 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
ast::Type type,
1635 return defineVariableDecl(name, nameLoc, type,
nullptr,
1639 LogicalResult Parser::parseVariableDeclConstraintList(
1641 std::optional<SMRange> typeConstraint;
1642 auto parseSingleConstraint = [&] {
1643 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1644 typeConstraint, constraints,
true);
1645 if (failed(constraint))
1647 constraints.push_back(*constraint);
1653 return parseSingleConstraint();
1656 if (failed(parseSingleConstraint()))
1659 return parseToken(
Token::r_square,
"expected `]` after constraint list");
1662 FailureOr<ast::ConstraintRef>
1663 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1665 bool allowInlineTypeConstraints) {
1666 auto parseTypeConstraint = [&](
ast::Expr *&typeExpr) -> LogicalResult {
1667 if (!allowInlineTypeConstraints) {
1670 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1671 "permitted on arguments or results");
1674 return emitErrorAndNote(
1676 "the type of this variable has already been constrained",
1677 *typeConstraint,
"see previous constraint location here");
1678 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1679 if (failed(constraintExpr))
1681 typeExpr = *constraintExpr;
1682 typeConstraint = typeExpr->getLoc();
1686 SMRange loc = curToken.getLoc();
1687 switch (curToken.getKind()) {
1693 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1703 FailureOr<ast::OpNameDecl *> opName =
1704 parseWrappedOperationName(
true);
1723 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1734 if (curToken.is(
Token::less) && failed(parseTypeConstraint(typeExpr)))
1743 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1749 StringRef constraintName = curToken.getSpelling();
1755 return emitError(loc,
"unknown reference to constraint `" +
1756 constraintName +
"`");
1760 if (
auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1763 return emitErrorAndNote(
1764 loc,
"invalid reference to non-constraint", cstDecl->
getLoc(),
1765 "see the definition of `" + constraintName +
"` here");
1771 if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1774 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1779 return emitError(loc,
"expected identifier constraint");
1782 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1783 std::optional<SMRange> typeConstraint;
1784 return parseConstraint(typeConstraint, std::nullopt,
1791 FailureOr<ast::Expr *> Parser::parseExpr() {
1793 return parseUnderscoreExpr();
1796 FailureOr<ast::Expr *> lhsExpr;
1797 switch (curToken.getKind()) {
1799 lhsExpr = parseAttributeExpr();
1802 lhsExpr = parseInlineConstraintLambdaExpr();
1805 lhsExpr = parseNegatedExpr();
1808 lhsExpr = parseIdentifierExpr();
1811 lhsExpr = parseOperationExpr();
1814 lhsExpr = parseInlineRewriteLambdaExpr();
1817 lhsExpr = parseTypeExpr();
1820 lhsExpr = parseTupleExpr();
1823 return emitError(
"expected expression");
1825 if (failed(lhsExpr))
1830 switch (curToken.getKind()) {
1832 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1835 lhsExpr = parseCallExpr(*lhsExpr);
1840 if (failed(lhsExpr))
1845 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1846 SMRange loc = curToken.getLoc();
1853 return parseIdentifierExpr();
1856 if (!curToken.isString())
1857 return emitError(
"expected string literal containing MLIR attribute");
1858 std::string attrExpr = curToken.getStringValue();
1861 loc.End = curToken.getEndLoc();
1863 parseToken(
Token::greater,
"expected `>` after attribute literal")))
1868 FailureOr<ast::Expr *> Parser::parseCallExpr(
ast::Expr *parentExpr,
1878 codeCompleteCallSignature(parentExpr, arguments.size());
1882 FailureOr<ast::Expr *> argument = parseExpr();
1883 if (failed(argument))
1885 arguments.push_back(*argument);
1889 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1890 if (failed(parseToken(
Token::r_paren,
"expected `)` after argument list")))
1893 return createCallExpr(loc, parentExpr, arguments, isNegated);
1896 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1897 ast::Decl *decl = curDeclScope->lookup(name);
1899 return emitError(loc,
"undefined reference to `" + name +
"`");
1901 return createDeclRefExpr(loc, decl);
1904 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1905 StringRef name = curToken.getSpelling();
1906 SMRange nameLoc = curToken.getLoc();
1913 if (failed(parseVariableDeclConstraintList(constraints)))
1916 if (failed(validateVariableConstraints(constraints, type)))
1918 return createInlineVariableExpr(type, name, nameLoc, constraints);
1921 return parseDeclRefExpr(name, nameLoc);
1924 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1925 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1933 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1934 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1942 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(
ast::Expr *parentExpr) {
1943 SMRange dotLoc = curToken.getLoc();
1948 return codeCompleteMemberAccess(parentExpr);
1951 Token memberNameTok = curToken;
1954 return emitError(dotLoc,
"expected identifier or numeric member name");
1955 StringRef memberName = memberNameTok.
getSpelling();
1956 SMRange loc(parentExpr->
getLoc().Start, curToken.getEndLoc());
1959 return createMemberAccessExpr(parentExpr, memberName, loc);
1962 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1966 return emitError(
"expected native constraint");
1967 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1968 if (failed(identifierExpr))
1971 return emitError(
"expected `(` after function name");
1972 return parseCallExpr(*identifierExpr,
true);
1975 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(
bool allowEmptyName) {
1976 SMRange loc = curToken.getLoc();
1980 return codeCompleteDialectName();
1986 return emitError(
"expected dialect namespace");
1988 StringRef name = curToken.getSpelling();
1992 if (failed(parseToken(
Token::dot,
"expected `.` after dialect namespace")))
1997 return codeCompleteOperationName(name);
2000 return emitError(
"expected operation name after dialect namespace");
2002 name = StringRef(name.data(), name.size() + 1);
2004 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2005 loc.End = curToken.getEndLoc();
2008 curToken.isKeyword());
2012 FailureOr<ast::OpNameDecl *>
2013 Parser::parseWrappedOperationName(
bool allowEmptyName) {
2017 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2018 if (failed(opNameDecl))
2021 if (failed(parseToken(
Token::greater,
"expected `>` after operation name")))
2026 FailureOr<ast::Expr *>
2027 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2028 SMRange loc = curToken.getLoc();
2035 return parseIdentifierExpr();
2041 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2042 FailureOr<ast::OpNameDecl *> opNameDecl =
2043 parseWrappedOperationName(allowEmptyName);
2044 if (failed(opNameDecl))
2046 std::optional<StringRef> opName = (*opNameDecl)->getName();
2051 FailureOr<ast::VariableDecl *> rangeVar =
2053 assert(succeeded(rangeVar) &&
"expected range variable to be valid");
2064 if (parserContext != ParserContext::Rewrite) {
2065 operands.push_back(createImplicitRangeVar(
2073 codeCompleteOperationOperandsSignature(opName, operands.size());
2077 FailureOr<ast::Expr *> operand = parseExpr();
2078 if (failed(operand))
2080 operands.push_back(*operand);
2084 "expected `)` after operation operand list")))
2092 FailureOr<ast::NamedAttributeDecl *> decl =
2093 parseNamedAttributeDecl(opName);
2096 attributes.emplace_back(*decl);
2100 "expected `}` after operation attribute list")))
2106 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2111 "expected `(` before operation result type list")))
2119 resultTypeContext = OpResultTypeContext::Explicit;
2126 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2130 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2131 if (failed(resultTypeExpr))
2133 resultTypes.push_back(*resultTypeExpr);
2137 "expected `)` after operation result type list")))
2140 }
else if (parserContext != ParserContext::Rewrite) {
2145 resultTypes.push_back(createImplicitRangeVar(
2147 }
else if (resultTypeContext == OpResultTypeContext::Explicit) {
2150 resultTypeContext = OpResultTypeContext::Interface;
2153 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2154 attributes, resultTypes);
2157 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2158 SMRange loc = curToken.getLoc();
2167 StringRef elementName;
2169 Token elementNameTok = curToken;
2177 auto elementNameIt =
2178 usedNames.try_emplace(elementName, elementNameTok.
getLoc());
2179 if (!elementNameIt.second) {
2180 return emitErrorAndNote(
2182 llvm::formatv(
"duplicate tuple element label `{0}`",
2184 elementNameIt.first->getSecond(),
2185 "see previous label use here");
2190 resetToken(elementNameTok.
getLoc());
2193 elementNames.push_back(elementName);
2196 FailureOr<ast::Expr *> element = parseExpr();
2197 if (failed(element))
2199 elements.push_back(*element);
2202 loc.End = curToken.getEndLoc();
2204 parseToken(
Token::r_paren,
"expected `)` after tuple element list")))
2206 return createTupleExpr(loc, elements, elementNames);
2209 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2210 SMRange loc = curToken.getLoc();
2217 return parseIdentifierExpr();
2220 if (!curToken.isString())
2221 return emitError(
"expected string literal containing MLIR type");
2222 std::string attrExpr = curToken.getStringValue();
2225 loc.End = curToken.getEndLoc();
2226 if (failed(parseToken(
Token::greater,
"expected `>` after type literal")))
2231 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2232 StringRef name = curToken.getSpelling();
2233 SMRange nameLoc = curToken.getLoc();
2237 if (failed(parseToken(
Token::colon,
"expected `:` after `_` variable")))
2242 if (failed(parseVariableDeclConstraintList(constraints)))
2246 if (failed(validateVariableConstraints(constraints, type)))
2248 return createInlineVariableExpr(type, name, nameLoc, constraints);
2254 FailureOr<ast::Stmt *> Parser::parseStmt(
bool expectTerminalSemicolon) {
2255 FailureOr<ast::Stmt *> stmt;
2256 switch (curToken.getKind()) {
2258 stmt = parseEraseStmt();
2261 stmt = parseLetStmt();
2264 stmt = parseReplaceStmt();
2267 stmt = parseReturnStmt();
2270 stmt = parseRewriteStmt();
2277 (expectTerminalSemicolon &&
2283 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2284 SMLoc startLoc = curToken.getStartLoc();
2291 FailureOr<ast::Stmt *> statement = parseStmt();
2292 if (failed(statement))
2293 return popDeclScope(), failure();
2294 statements.push_back(*statement);
2299 SMRange location(startLoc, curToken.getEndLoc());
2305 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2306 if (parserContext == ParserContext::Constraint)
2307 return emitError(
"`erase` cannot be used within a Constraint");
2308 SMRange loc = curToken.getLoc();
2312 FailureOr<ast::Expr *> rootOp = parseExpr();
2316 return createEraseStmt(loc, *rootOp);
2319 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2320 SMRange loc = curToken.getLoc();
2324 SMRange varLoc = curToken.getLoc();
2329 "`_` may only be used to define \"inline\" variables");
2332 "expected identifier after `let` to name a new variable");
2334 StringRef varName = curToken.getSpelling();
2340 failed(parseVariableDeclConstraintList(constraints)))
2346 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2347 if (failed(initOrFailure))
2349 initializer = *initOrFailure;
2354 LogicalResult result =
2358 if (cst->getTypeExpr()) {
2360 constraint.referenceLoc,
2361 "type constraints are not permitted on variables with "
2366 .Default(success());
2372 FailureOr<ast::VariableDecl *> varDecl =
2373 createVariableDecl(varName, varLoc, initializer, constraints);
2374 if (failed(varDecl))
2379 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2380 if (parserContext == ParserContext::Constraint)
2381 return emitError(
"`replace` cannot be used within a Constraint");
2382 SMRange loc = curToken.getLoc();
2386 FailureOr<ast::Expr *> rootOp = parseExpr();
2391 parseToken(
Token::kw_with,
"expected `with` after root operation")))
2395 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2402 loc,
"expected at least one replacement value, consider using "
2403 "`erase` if no replacement values are desired");
2407 FailureOr<ast::Expr *> replExpr = parseExpr();
2408 if (failed(replExpr))
2410 replValues.emplace_back(*replExpr);
2414 "expected `)` after replacement values")))
2419 FailureOr<ast::Expr *> replExpr;
2421 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2423 replExpr = parseExpr();
2424 if (failed(replExpr))
2426 replValues.emplace_back(*replExpr);
2429 return createReplaceStmt(loc, *rootOp, replValues);
2432 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2433 SMRange loc = curToken.getLoc();
2437 FailureOr<ast::Expr *> resultExpr = parseExpr();
2438 if (failed(resultExpr))
2444 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2445 if (parserContext == ParserContext::Constraint)
2446 return emitError(
"`rewrite` cannot be used within a Constraint");
2447 SMRange loc = curToken.getLoc();
2451 FailureOr<ast::Expr *> rootOp = parseExpr();
2455 if (failed(parseToken(
Token::kw_with,
"expected `with` before rewrite body")))
2459 return emitError(
"expected `{` to start rewrite body");
2462 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2464 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2465 if (failed(rewriteBody))
2469 for (
const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2470 if (isa<ast::ReturnStmt>(stmt)) {
2472 "`return` statements are only permitted within a "
2473 "`Constraint` or `Rewrite` body");
2477 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2489 if (
auto *init = dyn_cast<ast::DeclRefExpr>(node))
2490 node = init->getDecl();
2491 return dyn_cast<ast::CallableDecl>(node);
2494 FailureOr<ast::PatternDecl *>
2495 Parser::createPatternDecl(SMRange loc,
const ast::Name *name,
2496 const ParsedPatternMetadata &metadata,
2499 metadata.hasBoundedRecursion, body);
2502 ast::Type Parser::createUserConstraintRewriteResultType(
2505 if (results.size() == 1)
2506 return results[0]->getType();
2510 auto resultTypes = llvm::map_range(
2511 results, [&](
const auto *result) {
return result->getType(); });
2512 auto resultNames = llvm::map_range(
2513 results, [&](
const auto *result) {
return result->getName().getName(); });
2515 llvm::to_vector(resultNames));
2518 template <
typename T>
2519 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2524 if (
auto *retStmt = dyn_cast<ast::ReturnStmt>(body->
getChildren().back())) {
2525 ast::Expr *resultExpr = retStmt->getResultExpr();
2530 if (results.empty())
2531 resultType = resultExpr->
getType();
2532 else if (failed(convertExpressionTo(resultExpr, resultType)))
2535 retStmt->setResultExpr(resultExpr);
2538 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2541 FailureOr<ast::VariableDecl *>
2542 Parser::createVariableDecl(StringRef name, SMRange loc,
ast::Expr *initializer,
2547 if (failed(validateVariableConstraints(constraints, type)))
2554 type = initializer->
getType();
2557 else if (failed(convertExpressionTo(initializer, type)))
2563 return emitErrorAndNote(
2564 loc,
"unable to infer type for variable `" + name +
"`", loc,
2565 "the type of a variable must be inferable from the constraint "
2566 "list or the initializer");
2570 if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2572 loc, llvm::formatv(
"unable to define variable of `{0}` type", type));
2576 FailureOr<ast::VariableDecl *> varDecl =
2577 defineVariableDecl(name, loc, type, initializer, constraints);
2578 if (failed(varDecl))
2584 FailureOr<ast::VariableDecl *>
2585 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2588 if (failed(validateVariableConstraint(constraint, argType)))
2590 return defineVariableDecl(name, loc, argType, constraint);
2597 if (failed(validateVariableConstraint(ref, inferredType)))
2605 if (
const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.
constraint)) {
2606 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2607 if (failed(validateTypeConstraintExpr(typeExpr)))
2611 }
else if (
const auto *cst =
2612 dyn_cast<ast::OpConstraintDecl>(ref.
constraint)) {
2615 }
else if (isa<ast::TypeConstraintDecl>(ref.
constraint)) {
2616 constraintType = typeTy;
2617 }
else if (isa<ast::TypeRangeConstraintDecl>(ref.
constraint)) {
2618 constraintType = typeRangeTy;
2619 }
else if (
const auto *cst =
2620 dyn_cast<ast::ValueConstraintDecl>(ref.
constraint)) {
2621 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2622 if (failed(validateTypeConstraintExpr(typeExpr)))
2625 constraintType = valueTy;
2626 }
else if (
const auto *cst =
2627 dyn_cast<ast::ValueRangeConstraintDecl>(ref.
constraint)) {
2628 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2629 if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2632 constraintType = valueRangeTy;
2633 }
else if (
const auto *cst =
2634 dyn_cast<ast::UserConstraintDecl>(ref.
constraint)) {
2636 if (inputs.size() != 1) {
2638 "`Constraint`s applied via a variable constraint "
2639 "list must take a single input, but got " +
2640 Twine(inputs.size()),
2642 "see definition of constraint here");
2644 constraintType = inputs.front()->getType();
2646 llvm_unreachable(
"unknown constraint type");
2651 if (!inferredType) {
2652 inferredType = constraintType;
2654 inferredType = mergedTy;
2657 llvm::formatv(
"constraint type `{0}` is incompatible "
2658 "with the previously inferred type `{1}`",
2659 constraintType, inferredType));
2664 LogicalResult Parser::validateTypeConstraintExpr(
const ast::Expr *typeExpr) {
2666 if (typeExprType != typeTy) {
2668 "expected expression of `Type` in type constraint");
2674 Parser::validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr) {
2676 if (typeExprType != typeRangeTy) {
2678 "expected expression of `TypeRange` in type constraint");
2686 FailureOr<ast::CallExpr *>
2687 Parser::createCallExpr(SMRange loc,
ast::Expr *parentExpr,
2692 if (!callableDecl) {
2694 llvm::formatv(
"expected a reference to a callable "
2695 "`Constraint` or `Rewrite`, but got: `{0}`",
2698 if (parserContext == ParserContext::Rewrite) {
2699 if (isa<ast::UserConstraintDecl>(callableDecl))
2701 loc,
"unable to invoke `Constraint` within a rewrite section");
2703 return emitError(loc,
"unable to negate a Rewrite");
2705 if (isa<ast::UserRewriteDecl>(callableDecl))
2707 "unable to invoke `Rewrite` within a match section");
2708 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2709 return emitError(loc,
"unable to negate non native constraints");
2715 if (callArgs.size() != arguments.size()) {
2716 return emitErrorAndNote(
2718 llvm::formatv(
"invalid number of arguments for {0} call; expected "
2723 llvm::formatv(
"see the definition of {0} here",
2729 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here",
2733 for (
auto it : llvm::zip(callArgs, arguments)) {
2734 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->
getType(),
2743 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2747 if (isa<ast::ConstraintDecl>(decl))
2749 else if (isa<ast::UserRewriteDecl>(decl))
2751 else if (
auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2752 declType = varDecl->getType();
2754 return emitError(loc,
"invalid reference to `" +
2760 FailureOr<ast::DeclRefExpr *>
2761 Parser::createInlineVariableExpr(
ast::Type type, StringRef name, SMRange loc,
2763 FailureOr<ast::VariableDecl *> decl =
2764 defineVariableDecl(name, loc, type, constraints);
2770 FailureOr<ast::MemberAccessExpr *>
2771 Parser::createMemberAccessExpr(
ast::Expr *parentExpr, StringRef name,
2774 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2775 if (failed(memberType))
2781 FailureOr<ast::Type> Parser::validateMemberAccess(
ast::Expr *parentExpr,
2782 StringRef name, SMRange loc) {
2786 return valueRangeTy;
2790 auto results = odsOp->getResults();
2794 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2795 index < results.size()) {
2796 return results[index].isVariadic() ? valueRangeTy : valueTy;
2800 const auto *it = llvm::find_if(results, [&](
const auto &result) {
2801 return result.getName() == name;
2803 if (it != results.end())
2804 return it->isVariadic() ? valueRangeTy : valueTy;
2805 }
else if (llvm::isDigit(name[0])) {
2810 }
else if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2813 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2814 index < tupleType.size()) {
2815 return tupleType.getElementTypes()[index];
2819 auto elementNames = tupleType.getElementNames();
2820 const auto *it = llvm::find(elementNames, name);
2821 if (it != elementNames.end())
2822 return tupleType.getElementTypes()[it - elementNames.begin()];
2826 llvm::formatv(
"invalid member access `{0}` on expression of type `{1}`",
2830 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2832 OpResultTypeContext resultTypeContext,
2836 std::optional<StringRef> opNameRef = name->
getName();
2840 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2846 ast::Type attrType = attr->getValue()->getType();
2847 if (!isa<ast::AttributeType>(attrType)) {
2849 attr->getValue()->getLoc(),
2850 llvm::formatv(
"expected `Attr` expression, but got `{0}`", attrType));
2855 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2856 "unexpected inferrence when results were explicitly specified");
2860 if (resultTypeContext == OpResultTypeContext::Explicit) {
2861 if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2865 }
else if (resultTypeContext == OpResultTypeContext::Interface) {
2867 "expected valid operation name when inferring operation results");
2868 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2876 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2879 return validateOperationOperandsOrResults(
2880 "operand", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2881 operands, odsOp ? odsOp->
getOperands() : std::nullopt, valueTy,
2886 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2889 return validateOperationOperandsOrResults(
2890 "result", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2891 results, odsOp ? odsOp->
getResults() : std::nullopt, typeTy, typeRangeTy);
2894 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2905 "operation result types are marked to be inferred, but "
2906 "`{0}` is unknown. Ensure that `{0}` supports zero "
2907 "results or implements `InferTypeOpInterface`. Include "
2908 "the ODS definition of this operation to remove this warning.",
2917 bool requiresInferrence =
2919 return !result.isVariableLength();
2924 llvm::formatv(
"operation result types are marked to be inferred, but "
2925 "`{0}` does not provide an implementation of "
2926 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2927 "`InferTypeOpInterface` at runtime, or add support to "
2928 "the ODS definition to remove this warning.",
2930 diag->attachNote(llvm::formatv(
"see the definition of `{0}` here", opName),
2936 LogicalResult Parser::validateOperationOperandsOrResults(
2937 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2942 if (values.size() == 1) {
2943 if (failed(convertExpressionTo(values[0], rangeTy)))
2951 auto emitSizeMismatchError = [&] {
2952 return emitErrorAndNote(
2954 llvm::formatv(
"invalid number of {0} groups for `{1}`; expected "
2956 groupName, *name, odsValues.size(), values.size()),
2957 *odsOpLoc, llvm::formatv(
"see the definition of `{0}` here", *name));
2961 if (values.empty()) {
2963 if (odsValues.empty())
2968 unsigned numVariadic = 0;
2969 for (
const auto &odsValue : odsValues) {
2970 if (!odsValue.isVariableLength())
2971 return emitSizeMismatchError();
2977 if (parserContext != ParserContext::Rewrite)
2984 if (numVariadic == 1)
2989 for (
unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2991 ctx, loc, std::nullopt, rangeTy));
2998 if (odsValues.size() != values.size())
2999 return emitSizeMismatchError();
3002 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here", *name),
3005 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
3006 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3007 if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3016 ast::Type valueExprType = valueExpr->getType();
3019 if (valueExprType == rangeTy || valueExprType == singleTy)
3025 if (singleTy == valueTy) {
3026 if (isa<ast::OperationType>(valueExprType)) {
3027 valueExpr = convertOpToValue(valueExpr);
3033 if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3037 valueExpr->getLoc(),
3039 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3040 singleTy, rangeTy, valueExprType));
3045 FailureOr<ast::TupleExpr *>
3048 for (
const ast::Expr *element : elements) {
3050 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3053 llvm::formatv(
"unable to build a tuple with `{0}` element", eleTy));
3062 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3066 if (!isa<ast::OperationType>(rootType))
3072 FailureOr<ast::ReplaceStmt *>
3073 Parser::createReplaceStmt(SMRange loc,
ast::Expr *rootOp,
3077 if (!isa<ast::OperationType>(rootType)) {
3080 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3085 bool shouldConvertOpToValues = replValues.size() > 1;
3086 for (
ast::Expr *&replExpr : replValues) {
3087 ast::Type replType = replExpr->getType();
3090 if (isa<ast::OperationType>(replType)) {
3091 if (shouldConvertOpToValues)
3092 replExpr = convertOpToValue(replExpr);
3096 if (replType != valueTy && replType != valueRangeTy) {
3098 llvm::formatv(
"expected `Op`, `Value` or `ValueRange` "
3099 "expression, but got `{0}`",
3107 FailureOr<ast::RewriteStmt *>
3108 Parser::createRewriteStmt(SMRange loc,
ast::Expr *rootOp,
3112 if (!isa<ast::OperationType>(rootType)) {
3115 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3125 LogicalResult Parser::codeCompleteMemberAccess(
ast::Expr *parentExpr) {
3129 else if (
ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3135 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3142 Parser::codeCompleteConstraintName(
ast::Type inferredType,
3143 bool allowInlineTypeConstraints) {
3145 inferredType, allowInlineTypeConstraints, curDeclScope);
3149 LogicalResult Parser::codeCompleteDialectName() {
3154 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3159 LogicalResult Parser::codeCompletePatternMetadata() {
3164 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3169 void Parser::codeCompleteCallSignature(
ast::Node *parent,
3170 unsigned currentNumArgs) {
3178 void Parser::codeCompleteOperationOperandsSignature(
3179 std::optional<StringRef> opName,
unsigned currentNumOperands) {
3181 opName, currentNumOperands);
3184 void Parser::codeCompleteOperationResultsSignature(
3185 std::optional<StringRef> opName,
unsigned currentNumResults) {
3194 FailureOr<ast::Module *>
3196 bool enableDocumentation,
3198 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3199 return parser.parseModule();
static std::string diag(const llvm::Value &value)
This class breaks up the current file into a token stream.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This represents a token in the MLIR syntax.
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
StringRef getSpelling() const
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
@ code_complete_string
Token signifying a code completion location within a string.
@ code_complete
Token signifying a code completion location.
@ less
Paired punctuation.
@ kw_Attr
General keywords.
static StringRef getMemberName()
Return the member name used for the "all-results" access.
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
static AttributeType get(Context &context)
Return an instance of the Attribute type.
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
This decl represents a shared interface for all callable decls.
Type getResultType() const
Return the result type of this decl.
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
StringRef getCallableType() const
Return the callable type of this decl.
This statement represents a compound statement, which contains a collection of other statements.
ArrayRef< Stmt * >::iterator begin() const
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
ArrayRef< Stmt * >::iterator end() const
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
This class represents the base of all AST Constraint decls.
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
This class represents the main context of the PDLL AST.
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
ods::Context & getODSContext()
Return the ODS context used by the AST.
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
This class represents a scope for named AST decls.
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
void add(Decl *decl)
Add a new decl to the scope.
This class represents the base Decl node.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
This class provides a simple implementation of a PDLL diagnostic.
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a diagnostic that is inflight and set to be reported.
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
This Decl represents a NamedAttribute, and contains a string name and attribute value.
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
This Decl represents an OperationName.
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
static OpNameDecl * create(Context &ctx, const Name &name)
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
This class represents a PDLL type that corresponds to a range of elements with a given element type.
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
This class represents a base AST Statement node.
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
This class represents a PDLL tuple type, i.e.
size_t size() const
Return the number of elements within this tuple.
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
static TypeType get(Context &context)
Return an instance of the Type type.
Type refineWith(Type other) const
Try to refine this type with the one provided.
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
The class represents a Value constraint, and constrains a variable to be a Value.
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
static ValueType get(Context &context)
Return an instance of the Value type.
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
This class contains all of the registered ODS operation classes.
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
StringRef getDescription() const
std::string getConditionTemplate() const
Format context containing substitutions for special placeholders.
FmtContext & withSelf(Twine subst)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
const ConstraintDecl * constraint
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.
static const Name & create(Context &ctx, StringRef name, SMRange location)