25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/SaveAndRestore.h"
29#include "llvm/Support/ScopedPrinter.h"
30#include "llvm/Support/VirtualFileSystem.h"
31#include "llvm/TableGen/Error.h"
32#include "llvm/TableGen/Parser.h"
48 : ctx(ctx), lexer(sourceMgr, ctx.
getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
54 codeCompleteContext(codeCompleteContext) {}
57 FailureOr<ast::Module *> parseModule();
65 enum class ParserContext {
82 enum class OpResultTypeContext {
101 return (curDeclScope = newScope);
103 void pushDeclScope(
ast::DeclScope *scope) { curDeclScope = scope; }
106 void popDeclScope() { curDeclScope = curDeclScope->
getParentScope(); }
115 LogicalResult convertExpressionTo(
122 LogicalResult convertTupleExpressionTo(
133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
139 StringRef processDoc(StringRef doc) {
140 return enableDocumentation ? doc : StringRef();
145 std::string processAndFormatDoc(
const Twine &doc) {
146 if (!enableDocumentation)
150 llvm::raw_string_ostream docOS(docStr);
151 std::string tmpDocStr = doc.str();
152 raw_indented_ostream(docOS).printReindented(
153 StringRef(tmpDocStr).rtrim(
" \t"));
161 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
162 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
163 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
164 SmallVectorImpl<ast::Decl *> &decls);
167 void processTdIncludeRecords(
const llvm::RecordKeeper &tdRecords,
168 SmallVectorImpl<ast::Decl *> &decls);
172 template <
typename Constra
intT>
174 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
175 SMRange loc, ast::Type type,
176 StringRef nativeType, StringRef docString);
177 template <
typename Constra
intT>
179 createODSNativePDLLConstraintDecl(
const tblgen::Constraint &constraint,
180 SMRange loc, ast::Type type,
181 StringRef nativeType);
187 struct ParsedPatternMetadata {
188 std::optional<uint16_t> benefit;
189 bool hasBoundedRecursion =
false;
192 FailureOr<ast::Decl *> parseTopLevelDecl();
193 FailureOr<ast::NamedAttributeDecl *>
194 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
198 FailureOr<ast::VariableDecl *> parseArgumentDecl();
202 FailureOr<ast::VariableDecl *> parseResultDecl(
unsigned resultNum);
206 FailureOr<ast::UserConstraintDecl *>
207 parseUserConstraintDecl(
bool isInline =
false);
211 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
215 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
216 const ast::Name &name,
bool isInline,
217 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
218 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
222 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(
bool isInline =
false);
226 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
230 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
231 const ast::Name &name,
bool isInline,
232 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
233 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
238 template <
typename T,
typename ParseUserPDLLDeclFnT>
239 FailureOr<T *> parseUserConstraintOrRewriteDecl(
240 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
241 StringRef anonymousNamePrefix,
bool isInline);
245 template <
typename T>
246 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
247 const ast::Name &name,
bool isInline,
248 ArrayRef<ast::VariableDecl *> arguments,
249 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
253 LogicalResult parseUserConstraintOrRewriteSignature(
254 SmallVectorImpl<ast::VariableDecl *> &arguments,
255 SmallVectorImpl<ast::VariableDecl *> &results,
256 ast::DeclScope *&argumentScope, ast::Type &resultType);
260 LogicalResult validateUserConstraintOrRewriteReturn(
261 StringRef declType, ast::CompoundStmt *body,
262 ArrayRef<ast::Stmt *>::iterator bodyIt,
263 ArrayRef<ast::Stmt *>::iterator bodyE,
264 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
266 FailureOr<ast::CompoundStmt *>
267 parseLambdaBody(
function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
268 bool expectTerminalSemicolon =
true);
269 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
270 FailureOr<ast::Decl *> parsePatternDecl();
271 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
275 LogicalResult checkDefineNamedDecl(
const ast::Name &name);
279 FailureOr<ast::VariableDecl *>
280 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
282 ArrayRef<ast::ConstraintRef> constraints);
283 FailureOr<ast::VariableDecl *>
284 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
285 ArrayRef<ast::ConstraintRef> constraints);
288 LogicalResult parseVariableDeclConstraintList(
289 SmallVectorImpl<ast::ConstraintRef> &constraints);
292 FailureOr<ast::Expr *> parseTypeConstraintExpr();
300 FailureOr<ast::ConstraintRef>
301 parseConstraint(std::optional<SMRange> &typeConstraint,
302 ArrayRef<ast::ConstraintRef> existingConstraints,
303 bool allowInlineTypeConstraints);
308 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
313 FailureOr<ast::Expr *> parseExpr();
316 FailureOr<ast::Expr *> parseAttributeExpr();
317 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
318 bool isNegated =
false);
319 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
320 FailureOr<ast::Expr *> parseIdentifierExpr();
321 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
322 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
323 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
324 FailureOr<ast::Expr *> parseNegatedExpr();
325 FailureOr<ast::OpNameDecl *> parseOperationName(
bool allowEmptyName =
false);
326 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(
bool allowEmptyName);
327 FailureOr<ast::Expr *>
328 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
329 OpResultTypeContext::Explicit);
330 FailureOr<ast::Expr *> parseTupleExpr();
331 FailureOr<ast::Expr *> parseTypeExpr();
332 FailureOr<ast::Expr *> parseUnderscoreExpr();
337 FailureOr<ast::Stmt *> parseStmt(
bool expectTerminalSemicolon =
true);
338 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
339 FailureOr<ast::EraseStmt *> parseEraseStmt();
340 FailureOr<ast::LetStmt *> parseLetStmt();
341 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
342 FailureOr<ast::ReturnStmt *> parseReturnStmt();
343 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
354 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
358 FailureOr<ast::PatternDecl *>
359 createPatternDecl(SMRange loc,
const ast::Name *name,
360 const ParsedPatternMetadata &metadata,
361 ast::CompoundStmt *body);
366 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
369 template <
typename T>
370 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
371 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
372 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
373 ast::CompoundStmt *body);
377 FailureOr<ast::VariableDecl *>
378 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
379 ArrayRef<ast::ConstraintRef> constraints);
383 FailureOr<ast::VariableDecl *>
384 createArgOrResultVariableDecl(StringRef name, SMRange loc,
385 const ast::ConstraintRef &constraint);
393 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
394 ast::Type &inferredType);
399 LogicalResult validateVariableConstraint(
const ast::ConstraintRef &ref,
400 ast::Type &inferredType);
401 LogicalResult validateTypeConstraintExpr(
const ast::Expr *typeExpr);
402 LogicalResult validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr);
407 FailureOr<ast::CallExpr *>
408 createCallExpr(SMRange loc, ast::Expr *parentExpr,
409 MutableArrayRef<ast::Expr *> arguments,
410 bool isNegated =
false);
411 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
412 FailureOr<ast::DeclRefExpr *>
413 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
414 ArrayRef<ast::ConstraintRef> constraints);
415 FailureOr<ast::MemberAccessExpr *>
416 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
420 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
421 StringRef name, SMRange loc);
422 FailureOr<ast::OperationExpr *>
423 createOperationExpr(SMRange loc,
const ast::OpNameDecl *name,
424 OpResultTypeContext resultTypeContext,
425 SmallVectorImpl<ast::Expr *> &operands,
426 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
427 SmallVectorImpl<ast::Expr *> &results);
429 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
430 const ods::Operation *odsOp,
431 SmallVectorImpl<ast::Expr *> &operands);
432 LogicalResult validateOperationResults(SMRange loc,
433 std::optional<StringRef> name,
434 const ods::Operation *odsOp,
435 SmallVectorImpl<ast::Expr *> &results);
436 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
437 const ods::Operation *odsOp);
438 LogicalResult validateOperationOperandsOrResults(
439 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
440 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
441 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
442 ast::RangeType rangeTy);
443 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
444 ArrayRef<ast::Expr *> elements,
445 ArrayRef<StringRef> elementNames);
450 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
451 FailureOr<ast::ReplaceStmt *>
452 createReplaceStmt(SMRange loc, ast::Expr *rootOp,
453 MutableArrayRef<ast::Expr *> replValues);
454 FailureOr<ast::RewriteStmt *>
455 createRewriteStmt(SMRange loc, ast::Expr *rootOp,
456 ast::CompoundStmt *rewriteBody);
466 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
467 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
468 LogicalResult codeCompleteConstraintName(ast::Type inferredType,
469 bool allowInlineTypeConstraints);
470 LogicalResult codeCompleteDialectName();
471 LogicalResult codeCompleteOperationName(StringRef dialectName);
472 LogicalResult codeCompletePatternMetadata();
473 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
475 void codeCompleteCallSignature(ast::Node *parent,
unsigned currentNumArgs);
476 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
477 unsigned currentNumOperands);
478 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
479 unsigned currentNumResults);
488 if (curToken.isNot(kind))
495 void consumeToken() {
497 "shouldn't advance past EOF or errors");
498 curToken = lexer.lexToken();
505 assert(curToken.is(kind) &&
"consumed an unexpected token");
510 void resetToken(SMRange tokLoc) {
511 lexer.resetPointer(tokLoc.Start.getPointer());
512 curToken = lexer.lexToken();
517 LogicalResult parseToken(
Token::Kind kind,
const Twine &msg) {
518 if (curToken.getKind() != kind)
519 return emitError(curToken.getLoc(), msg);
523 LogicalResult
emitError(SMRange loc,
const Twine &msg) {
524 lexer.emitError(loc, msg);
527 LogicalResult
emitError(
const Twine &msg) {
528 return emitError(curToken.getLoc(), msg);
530 LogicalResult emitErrorAndNote(SMRange loc,
const Twine &msg, SMRange noteLoc,
532 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
551 bool enableDocumentation;
554 ast::DeclScope *curDeclScope =
nullptr;
555 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
558 ParserContext parserContext = ParserContext::Global;
561 ast::Type typeTy, valueTy;
562 ast::RangeType typeRangeTy, valueRangeTy;
566 unsigned anonymousDeclNameCounter = 0;
569 CodeCompleteContext *codeCompleteContext;
573FailureOr<ast::Module *> Parser::parseModule() {
574 SMLoc moduleLoc = curToken.getStartLoc();
578 SmallVector<ast::Decl *> decls;
579 if (
failed(parseModuleBody(decls)))
580 return popDeclScope(), failure();
586LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
589 if (
failed(parseDirective(decls)))
594 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
597 decls.push_back(*decl);
602ast::Expr *Parser::convertOpToValue(
const ast::Expr *opExpr) {
607LogicalResult Parser::convertExpressionTo(
608 ast::Expr *&expr, ast::Type type,
610 ast::Type exprType = expr->
getType();
611 if (exprType == type)
614 auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
616 expr->
getLoc(), llvm::formatv(
"unable to convert expression of type "
617 "`{0}` to the expected type of "
625 if (
auto exprOpType = dyn_cast<ast::OperationType>(exprType))
626 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
632 if ((exprType == valueTy || exprType == valueRangeTy) &&
633 (type == valueTy || type == valueRangeTy))
635 if ((exprType == typeTy || exprType == typeRangeTy) &&
636 (type == typeTy || type == typeRangeTy))
640 if (
auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
641 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
644 return emitConvertError();
647LogicalResult Parser::convertOpExpressionTo(
648 ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
652 if (
auto opType = dyn_cast<ast::OperationType>(type)) {
653 if (opType.getName())
654 return emitErrorFn();
659 if (type == valueRangeTy) {
666 if (type == valueTy) {
670 if (odsOp->getResults().empty()) {
671 return emitErrorFn()->attachNote(
672 llvm::formatv(
"see the definition of `{0}`, which was defined "
678 unsigned numSingleResults = llvm::count_if(
679 odsOp->getResults(), [](
const ods::OperandOrResult &
result) {
680 return result.getVariableLengthKind() ==
681 ods::VariableLengthKind::Single;
683 if (numSingleResults > 1) {
684 return emitErrorFn()->attachNote(
685 llvm::formatv(
"see the definition of `{0}`, which was defined "
686 "with at least {1} results",
687 odsOp->getName(), numSingleResults),
696 return emitErrorFn();
699LogicalResult Parser::convertTupleExpressionTo(
700 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
704 if (
auto tupleType = dyn_cast<ast::TupleType>(type)) {
705 if (tupleType.size() != exprType.
size())
706 return emitErrorFn();
710 SmallVector<ast::Expr *> newExprs;
711 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
713 ctx, expr->
getLoc(), expr, llvm::to_string(i),
716 auto diagFn = [&](ast::Diagnostic &
diag) {
717 diag.attachNote(llvm::formatv(
"when converting element #{0} of `{1}`",
722 if (
failed(convertExpressionTo(newExprs.back(),
723 tupleType.getElementTypes()[i], diagFn)))
727 tupleType.getElementNames());
732 auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
733 ast::RangeType resultTy) -> LogicalResult {
735 if (parserContext != ParserContext::Rewrite) {
736 return emitErrorFn()->attachNote(
"Tuple to Range conversion is currently "
737 "only allowed within a rewrite context");
742 if (!llvm::is_contained(allowedElementTypes, elementType))
743 return emitErrorFn();
747 SmallVector<ast::Expr *> newExprs;
748 for (
unsigned i = 0, e = exprType.
size(); i < e; ++i) {
750 ctx, expr->
getLoc(), expr, llvm::to_string(i),
756 if (type == valueRangeTy)
757 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
758 if (type == typeRangeTy)
759 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
761 return emitErrorFn();
768LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
770 if (directive ==
"#include")
771 return parseInclude(decls);
773 return emitError(
"unknown directive `" + directive +
"`");
776LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
777 SMRange loc = curToken.
getLoc();
785 if (!curToken.isString())
787 "expected string file name after `include` directive");
788 SMRange fileLoc = curToken.
getLoc();
790 StringRef filename = filenameStr;
795 if (filename.ends_with(
".pdll")) {
796 if (
failed(lexer.pushInclude(filename, fileLoc)))
798 "unable to open include file `" + filename +
"`");
804 LogicalResult
result = parseModuleBody(decls);
810 if (filename.ends_with(
".td"))
811 return parseTdInclude(filename, fileLoc, decls);
814 "expected include filename to end with `.pdll` or `.td`");
817LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818 SmallVectorImpl<ast::Decl *> &decls) {
822 std::string includedFile;
823 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
824 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
826 return emitError(fileLoc,
"unable to open include file `" + filename +
"`");
829 llvm::SourceMgr tdSrcMgr;
830 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
831 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
832 tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
836 struct DiagHandlerContext {
840 } handlerContext{*
this, filename, fileLoc};
843 tdSrcMgr.setDiagHandler(
844 [](
const llvm::SMDiagnostic &
diag,
void *rawHandlerContext) {
845 auto *ctx =
reinterpret_cast<DiagHandlerContext *
>(rawHandlerContext);
846 (void)ctx->parser.emitError(
848 llvm::formatv(
"error while processing include file `{0}`: {1}",
849 ctx->filename,
diag.getMessage()));
854 llvm::RecordKeeper tdRecords;
855 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
859 processTdIncludeRecords(tdRecords, decls);
864 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
868void Parser::processTdIncludeRecords(
const llvm::RecordKeeper &tdRecords,
869 SmallVectorImpl<ast::Decl *> &decls) {
871 auto getLengthKind = [](
const auto &value) {
872 if (value.isOptional())
873 return ods::VariableLengthKind::Optional;
874 return value.isVariadic() ? ods::VariableLengthKind::Variadic
875 : ods::VariableLengthKind::Single;
880 auto addTypeConstraint = [&](
const tblgen::NamedTypeConstraint &cst)
881 ->
const ods::TypeConstraint & {
883 cst.constraint.getUniqueDefName(),
884 processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
886 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
887 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
892 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Op")) {
893 tblgen::Operator op(def);
896 bool supportsResultTypeInferrence =
897 op.getTrait(
"::mlir::InferTypeOpInterface::Trait");
900 op.getOperationName(), processDoc(op.getSummary()),
901 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
902 supportsResultTypeInferrence, op.getLoc().front());
908 for (
const tblgen::NamedAttribute &attr : op.getAttributes()) {
909 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
911 attr.attr.getUniqueDefName(),
912 processDoc(attr.attr.getSummary()),
913 attr.attr.getStorageType()));
915 for (
const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
916 odsOp->appendOperand(operand.name, getLengthKind(operand),
917 addTypeConstraint(operand));
919 for (
const tblgen::NamedTypeConstraint &
result : op.getResults()) {
921 addTypeConstraint(
result));
925 auto shouldBeSkipped = [
this](
const llvm::Record *def) {
926 return def->isAnonymous() || curDeclScope->
lookup(def->getName()) ||
927 def->isSubClassOf(
"DeclareInterfaceMethods");
931 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Attr")) {
932 if (shouldBeSkipped(def))
935 tblgen::Attribute constraint(def);
936 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
937 constraint, convertLocToRange(def->getLoc().front()), attrTy,
938 constraint.getStorageType()));
941 for (
const llvm::Record *def : tdRecords.getAllDerivedDefinitions(
"Type")) {
942 if (shouldBeSkipped(def))
945 tblgen::TypeConstraint constraint(def);
946 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
947 constraint, convertLocToRange(def->getLoc().front()), typeTy,
948 constraint.getCppType()));
952 for (
const llvm::Record *def :
953 tdRecords.getAllDerivedDefinitions(
"OpInterface")) {
954 if (shouldBeSkipped(def))
957 SMRange loc = convertLocToRange(def->getLoc().front());
959 std::string cppClassName =
960 llvm::formatv(
"{0}::{1}", def->getValueAsString(
"cppNamespace"),
961 def->getValueAsString(
"cppInterfaceName"))
963 std::string codeBlock =
964 llvm::formatv(
"return ::mlir::success(llvm::isa<{0}>(self));",
969 processAndFormatDoc(def->getValueAsString(
"description"));
970 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
971 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
975template <
typename Constra
intT>
976ast::Decl *Parser::createODSNativePDLLConstraintDecl(
977 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
978 StringRef nativeType, StringRef docString) {
980 ast::DeclScope *argScope = pushDeclScope();
983 nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
984 argScope->
add(paramVar);
992 curDeclScope->
add(constraintDecl);
993 return constraintDecl;
996template <
typename Constra
intT>
998Parser::createODSNativePDLLConstraintDecl(
const tblgen::Constraint &constraint,
999 SMRange loc, ast::Type type,
1000 StringRef nativeType) {
1002 tblgen::FmtContext fmtContext;
1011 std::string docString;
1012 if (enableDocumentation) {
1014 docString = processAndFormatDoc(
1019 return createODSNativePDLLConstraintDecl<ConstraintT>(
1028FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1029 FailureOr<ast::Decl *> decl;
1032 decl = parseUserConstraintDecl();
1035 decl = parsePatternDecl();
1038 decl = parseUserRewriteDecl();
1041 return emitError(
"expected top-level declaration, such as a `Pattern`");
1047 if (
const ast::Name *name = (*decl)->getName()) {
1048 if (
failed(checkDefineNamedDecl(*name)))
1050 curDeclScope->
add(*decl);
1055FailureOr<ast::NamedAttributeDecl *>
1056Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1059 return codeCompleteAttributeName(parentOpName);
1061 std::string attrNameStr;
1062 if (curToken.isString())
1067 return emitError(
"expected identifier or string attribute name");
1072 ast::Expr *attrValue =
nullptr;
1074 FailureOr<ast::Expr *> attrExpr = parseExpr();
1077 attrValue = *attrExpr;
1087FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1088 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1089 bool expectTerminalSemicolon) {
1093 SMLoc bodyStartLoc = curToken.getStartLoc();
1095 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1096 bool failedToParse =
1097 failed(singleStatement) ||
failed(processStatementFn(*singleStatement));
1102 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1106FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1109 return emitError(
"expected identifier argument name");
1113 SMRange nameLoc = curToken.
getLoc();
1117 parseToken(
Token::colon,
"expected `:` before argument constraint")))
1120 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1124 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1127FailureOr<ast::VariableDecl *> Parser::parseResultDecl(
unsigned resultNum) {
1136 SMRange nameLoc = curToken.
getLoc();
1140 "expected `:` before result constraint")))
1143 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1147 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1153 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1157 return createArgOrResultVariableDecl(
"", cst->referenceLoc, *cst);
1160FailureOr<ast::UserConstraintDecl *>
1161Parser::parseUserConstraintDecl(
bool isInline) {
1164 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1165 [&](
auto &&...args) {
1166 return this->parseUserPDLLConstraintDecl(args...);
1168 ParserContext::Constraint,
"constraint", isInline);
1171FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1172 FailureOr<ast::UserConstraintDecl *> decl =
1173 parseUserConstraintDecl(
true);
1174 if (
failed(decl) ||
failed(checkDefineNamedDecl((*decl)->getName())))
1177 curDeclScope->
add(*decl);
1181FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1182 const ast::Name &name,
bool isInline,
1183 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1184 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1187 pushDeclScope(argumentScope);
1191 ast::CompoundStmt *body;
1193 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1194 [&](ast::Stmt *&stmt) -> LogicalResult {
1195 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1198 "expected `Constraint` lambda body to contain a "
1199 "single expression");
1209 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1215 auto bodyIt = body->
begin(), bodyE = body->
end();
1216 for (; bodyIt != bodyE; ++bodyIt)
1217 if (isa<ast::ReturnStmt>(*bodyIt))
1219 if (
failed(validateUserConstraintOrRewriteReturn(
1220 "Constraint", body, bodyIt, bodyE, results, resultType)))
1225 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1226 name, arguments, results, resultType, body);
1229FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(
bool isInline) {
1232 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1233 [&](
auto &&...args) {
return this->parseUserPDLLRewriteDecl(args...); },
1234 ParserContext::Rewrite,
"rewrite", isInline);
1237FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1238 FailureOr<ast::UserRewriteDecl *> decl =
1239 parseUserRewriteDecl(
true);
1240 if (
failed(decl) ||
failed(checkDefineNamedDecl((*decl)->getName())))
1243 curDeclScope->
add(*decl);
1247FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1248 const ast::Name &name,
bool isInline,
1249 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1250 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1253 curDeclScope = argumentScope;
1254 ast::CompoundStmt *body;
1256 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1257 [&](ast::Stmt *&statement) -> LogicalResult {
1258 if (isa<ast::OpRewriteStmt>(statement))
1261 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1262 if (!statementExpr) {
1265 "expected `Rewrite` lambda body to contain a single expression "
1266 "or an operation rewrite statement; such as `erase`, "
1267 "`replace`, or `rewrite`");
1278 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1286 auto bodyIt = body->
begin(), bodyE = body->
end();
1287 for (; bodyIt != bodyE; ++bodyIt)
1288 if (isa<ast::ReturnStmt>(*bodyIt))
1290 if (
failed(validateUserConstraintOrRewriteReturn(
"Rewrite", body, bodyIt,
1291 bodyE, results, resultType)))
1293 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1294 name, arguments, results, resultType, body);
1297template <
typename T,
typename ParseUserPDLLDeclFnT>
1298FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1299 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1300 StringRef anonymousNamePrefix,
bool isInline) {
1301 SMRange loc = curToken.
getLoc();
1303 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1306 const ast::Name *name =
nullptr;
1311 return emitError(
"expected identifier name");
1315 std::string anonName =
1316 llvm::formatv(
"<anonymous_{0}_{1}>", anonymousNamePrefix,
1317 anonymousDeclNameCounter++)
1327 SmallVector<ast::VariableDecl *> arguments, results;
1328 ast::DeclScope *argumentScope;
1329 ast::Type resultType;
1330 if (
failed(parseUserConstraintOrRewriteSignature(arguments, results,
1331 argumentScope, resultType)))
1337 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1341 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1342 results, resultType);
1345template <
typename T>
1346FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1347 const ast::Name &name,
bool isInline,
1348 ArrayRef<ast::VariableDecl *> arguments,
1349 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1351 std::string codeStrStorage;
1352 std::optional<StringRef> optCodeStr;
1353 if (curToken.isString()) {
1355 optCodeStr = codeStrStorage;
1357 }
else if (isInline) {
1359 "external declarations must be declared in global scope");
1364 "expected `;` after native declaration")))
1366 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1369LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1370 SmallVectorImpl<ast::VariableDecl *> &arguments,
1371 SmallVectorImpl<ast::VariableDecl *> &results,
1372 ast::DeclScope *&argumentScope, ast::Type &resultType) {
1377 argumentScope = pushDeclScope();
1380 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1383 arguments.emplace_back(*argument);
1393 auto parseResultFn = [&]() -> LogicalResult {
1394 FailureOr<ast::VariableDecl *>
result = parseResultDecl(results.size());
1397 results.emplace_back(*
result);
1404 if (
failed(parseResultFn()))
1411 }
else if (
failed(parseResultFn())) {
1418 resultType = createUserConstraintRewriteResultType(results);
1421 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1423 results.front()->getLoc(),
1424 "cannot create a single-element tuple with an element label");
1429LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1430 StringRef declType, ast::CompoundStmt *body,
1431 ArrayRef<ast::Stmt *>::iterator bodyIt,
1432 ArrayRef<ast::Stmt *>::iterator bodyE,
1433 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1435 if (bodyIt != bodyE) {
1437 if (std::next(bodyIt) != bodyE) {
1439 (*std::next(bodyIt))->getLoc(),
1440 llvm::formatv(
"`return` terminated the `{0}` body, but found "
1441 "trailing statements afterwards",
1447 }
else if (!results.empty()) {
1450 llvm::formatv(
"missing return in a `{0}` expected to return `{1}`",
1451 declType, resultType));
1456FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1457 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1458 if (isa<ast::OpRewriteStmt>(statement))
1462 "expected Pattern lambda body to contain a single operation "
1463 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1467FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1468 SMRange loc = curToken.
getLoc();
1470 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1473 const ast::Name *name =
nullptr;
1480 ParsedPatternMetadata metadata;
1485 ast::CompoundStmt *body;
1489 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1495 return emitError(
"expected `{` or `=>` to start pattern body");
1496 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1502 auto bodyIt = body->
begin(), bodyE = body->
end();
1503 for (; bodyIt != bodyE; ++bodyIt) {
1504 if (isa<ast::ReturnStmt>(*bodyIt)) {
1506 "`return` statements are only permitted within a "
1507 "`Constraint` or `Rewrite` body");
1510 if (isa<ast::OpRewriteStmt>(*bodyIt))
1513 if (bodyIt == bodyE) {
1515 "expected Pattern body to terminate with an operation "
1516 "rewrite statement, such as `erase`");
1518 if (std::next(bodyIt) != bodyE) {
1519 return emitError((*std::next(bodyIt))->getLoc(),
1520 "Pattern body was terminated by an operation "
1521 "rewrite statement, but found trailing statements");
1525 return createPatternDecl(loc, name, metadata, body);
1529Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1530 std::optional<SMRange> benefitLoc;
1531 std::optional<SMRange> hasBoundedRecursionLoc;
1536 return codeCompletePatternMetadata();
1539 return emitError(
"expected pattern metadata identifier");
1541 SMRange metadataLoc = curToken.
getLoc();
1545 if (metadataStr ==
"benefit") {
1547 return emitErrorAndNote(metadataLoc,
1548 "pattern benefit has already been specified",
1549 *benefitLoc,
"see previous definition here");
1552 "expected `(` before pattern benefit")))
1555 uint16_t benefitValue = 0;
1557 return emitError(
"expected integral pattern benefit");
1558 if (curToken.
getSpelling().getAsInteger(10, benefitValue))
1560 "expected pattern benefit to fit within a 16-bit integer");
1563 metadata.benefit = benefitValue;
1564 benefitLoc = metadataLoc;
1567 parseToken(
Token::r_paren,
"expected `)` after pattern benefit")))
1573 if (metadataStr ==
"recursion") {
1574 if (hasBoundedRecursionLoc) {
1575 return emitErrorAndNote(
1577 "pattern recursion metadata has already been specified",
1578 *hasBoundedRecursionLoc,
"see previous definition here");
1580 metadata.hasBoundedRecursion =
true;
1581 hasBoundedRecursionLoc = metadataLoc;
1585 return emitError(metadataLoc,
"unknown pattern metadata");
1591FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1594 FailureOr<ast::Expr *> typeExpr = parseExpr();
1597 "expected `>` after variable type constraint")))
1602LogicalResult Parser::checkDefineNamedDecl(
const ast::Name &name) {
1603 assert(curDeclScope &&
"defining decl outside of a decl scope");
1604 if (ast::Decl *lastDecl = curDeclScope->
lookup(name.
getName())) {
1605 return emitErrorAndNote(
1606 name.
getLoc(),
"`" + name.
getName() +
"` has already been defined",
1607 lastDecl->getName()->getLoc(),
"see previous definition here");
1612FailureOr<ast::VariableDecl *>
1613Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1614 ast::Expr *initExpr,
1615 ArrayRef<ast::ConstraintRef> constraints) {
1616 assert(curDeclScope &&
"defining variable outside of decl scope");
1621 if (name.empty() || name ==
"_") {
1625 if (
failed(checkDefineNamedDecl(nameDecl)))
1630 curDeclScope->
add(varDecl);
1634FailureOr<ast::VariableDecl *>
1635Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1636 ArrayRef<ast::ConstraintRef> constraints) {
1637 return defineVariableDecl(name, nameLoc, type,
nullptr,
1641LogicalResult Parser::parseVariableDeclConstraintList(
1642 SmallVectorImpl<ast::ConstraintRef> &constraints) {
1643 std::optional<SMRange> typeConstraint;
1644 auto parseSingleConstraint = [&] {
1645 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1646 typeConstraint, constraints,
true);
1649 constraints.push_back(*constraint);
1655 return parseSingleConstraint();
1658 if (
failed(parseSingleConstraint()))
1661 return parseToken(
Token::r_square,
"expected `]` after constraint list");
1664FailureOr<ast::ConstraintRef>
1665Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1666 ArrayRef<ast::ConstraintRef> existingConstraints,
1667 bool allowInlineTypeConstraints) {
1668 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1669 if (!allowInlineTypeConstraints) {
1672 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1673 "permitted on arguments or results");
1676 return emitErrorAndNote(
1678 "the type of this variable has already been constrained",
1679 *typeConstraint,
"see previous constraint location here");
1680 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1681 if (
failed(constraintExpr))
1683 typeExpr = *constraintExpr;
1684 typeConstraint = typeExpr->getLoc();
1688 SMRange loc = curToken.
getLoc();
1694 ast::Expr *typeExpr =
nullptr;
1697 return ast::ConstraintRef(
1705 FailureOr<ast::OpNameDecl *> opName =
1706 parseWrappedOperationName(
true);
1724 ast::Expr *typeExpr =
nullptr;
1728 return ast::ConstraintRef(
1735 ast::Expr *typeExpr =
nullptr;
1739 return ast::ConstraintRef(
1745 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1748 return ast::ConstraintRef(*decl, loc);
1751 StringRef constraintName = curToken.
getSpelling();
1755 ast::Decl *cstDecl = curDeclScope->
lookup<ast::Decl>(constraintName);
1757 return emitError(loc,
"unknown reference to constraint `" +
1758 constraintName +
"`");
1762 if (
auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1763 return ast::ConstraintRef(cst, loc);
1765 return emitErrorAndNote(
1766 loc,
"invalid reference to non-constraint", cstDecl->
getLoc(),
1767 "see the definition of `" + constraintName +
"` here");
1772 ast::Type inferredType;
1773 if (
failed(validateVariableConstraints(existingConstraints, inferredType)))
1776 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1781 return emitError(loc,
"expected identifier constraint");
1784FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1785 std::optional<SMRange> typeConstraint;
1786 return parseConstraint(typeConstraint, {},
1794FailureOr<ast::Expr *> Parser::parseExpr() {
1796 return parseUnderscoreExpr();
1799 FailureOr<ast::Expr *> lhsExpr;
1802 lhsExpr = parseAttributeExpr();
1805 lhsExpr = parseInlineConstraintLambdaExpr();
1808 lhsExpr = parseNegatedExpr();
1811 lhsExpr = parseIdentifierExpr();
1814 lhsExpr = parseOperationExpr();
1817 lhsExpr = parseInlineRewriteLambdaExpr();
1820 lhsExpr = parseTypeExpr();
1823 lhsExpr = parseTupleExpr();
1826 return emitError(
"expected expression");
1835 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1838 lhsExpr = parseCallExpr(*lhsExpr);
1848FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1849 SMRange loc = curToken.
getLoc();
1856 return parseIdentifierExpr();
1859 if (!curToken.isString())
1860 return emitError(
"expected string literal containing MLIR attribute");
1866 parseToken(
Token::greater,
"expected `>` after attribute literal")))
1871FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1876 SmallVector<ast::Expr *> arguments;
1881 codeCompleteCallSignature(parentExpr, arguments.size());
1885 FailureOr<ast::Expr *> argument = parseExpr();
1888 arguments.push_back(*argument);
1896 return createCallExpr(loc, parentExpr, arguments, isNegated);
1899FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1900 ast::Decl *decl = curDeclScope->
lookup(name);
1902 return emitError(loc,
"undefined reference to `" + name +
"`");
1904 return createDeclRefExpr(loc, decl);
1907FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1909 SMRange nameLoc = curToken.
getLoc();
1915 SmallVector<ast::ConstraintRef> constraints;
1916 if (
failed(parseVariableDeclConstraintList(constraints)))
1919 if (
failed(validateVariableConstraints(constraints, type)))
1921 return createInlineVariableExpr(type, name, nameLoc, constraints);
1924 return parseDeclRefExpr(name, nameLoc);
1927FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1928 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1936FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1937 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1945FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1946 SMRange dotLoc = curToken.
getLoc();
1951 return codeCompleteMemberAccess(parentExpr);
1954 Token memberNameTok = curToken;
1957 return emitError(dotLoc,
"expected identifier or numeric member name");
1958 StringRef memberName = memberNameTok.
getSpelling();
1962 return createMemberAccessExpr(parentExpr, memberName, loc);
1965FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1969 return emitError(
"expected native constraint");
1970 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1971 if (
failed(identifierExpr))
1974 return emitError(
"expected `(` after function name");
1975 return parseCallExpr(*identifierExpr,
true);
1978FailureOr<ast::OpNameDecl *> Parser::parseOperationName(
bool allowEmptyName) {
1979 SMRange loc = curToken.
getLoc();
1983 return codeCompleteDialectName();
1989 return emitError(
"expected dialect namespace");
1995 if (
failed(parseToken(
Token::dot,
"expected `.` after dialect namespace")))
2000 return codeCompleteOperationName(name);
2003 return emitError(
"expected operation name after dialect namespace");
2005 name = StringRef(name.data(), name.size() + 1);
2007 name = StringRef(name.data(), name.size() + curToken.
getSpelling().size());
2015FailureOr<ast::OpNameDecl *>
2016Parser::parseWrappedOperationName(
bool allowEmptyName) {
2020 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2029FailureOr<ast::Expr *>
2030Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2031 SMRange loc = curToken.
getLoc();
2038 return parseIdentifierExpr();
2044 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2045 FailureOr<ast::OpNameDecl *> opNameDecl =
2046 parseWrappedOperationName(allowEmptyName);
2049 std::optional<StringRef> opName = (*opNameDecl)->getName();
2053 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2054 FailureOr<ast::VariableDecl *> rangeVar =
2055 defineVariableDecl(
"_", loc, type, ast::ConstraintRef(cst, loc));
2056 assert(succeeded(rangeVar) &&
"expected range variable to be valid");
2061 SmallVector<ast::Expr *> operands;
2067 if (parserContext != ParserContext::Rewrite) {
2068 operands.push_back(createImplicitRangeVar(
2076 codeCompleteOperationOperandsSignature(opName, operands.size());
2080 FailureOr<ast::Expr *> operand = parseExpr();
2083 operands.push_back(*operand);
2087 "expected `)` after operation operand list")))
2092 SmallVector<ast::NamedAttributeDecl *> attributes;
2095 FailureOr<ast::NamedAttributeDecl *> decl =
2096 parseNamedAttributeDecl(opName);
2099 attributes.emplace_back(*decl);
2103 "expected `}` after operation attribute list")))
2108 SmallVector<ast::Expr *> resultTypes;
2109 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2114 "expected `(` before operation result type list")))
2122 resultTypeContext = OpResultTypeContext::Explicit;
2129 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2133 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2134 if (
failed(resultTypeExpr))
2136 resultTypes.push_back(*resultTypeExpr);
2140 "expected `)` after operation result type list")))
2143 }
else if (parserContext != ParserContext::Rewrite) {
2148 resultTypes.push_back(createImplicitRangeVar(
2150 }
else if (resultTypeContext == OpResultTypeContext::Explicit) {
2153 resultTypeContext = OpResultTypeContext::Interface;
2156 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2157 attributes, resultTypes);
2160FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2161 SMRange loc = curToken.
getLoc();
2165 SmallVector<StringRef> elementNames;
2166 SmallVector<ast::Expr *> elements;
2170 StringRef elementName;
2172 Token elementNameTok = curToken;
2180 auto elementNameIt =
2181 usedNames.try_emplace(elementName, elementNameTok.
getLoc());
2182 if (!elementNameIt.second) {
2183 return emitErrorAndNote(
2185 llvm::formatv(
"duplicate tuple element label `{0}`",
2187 elementNameIt.first->getSecond(),
2188 "see previous label use here");
2193 resetToken(elementNameTok.
getLoc());
2196 elementNames.push_back(elementName);
2199 FailureOr<ast::Expr *> element = parseExpr();
2202 elements.push_back(*element);
2207 parseToken(
Token::r_paren,
"expected `)` after tuple element list")))
2209 return createTupleExpr(loc, elements, elementNames);
2212FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2213 SMRange loc = curToken.
getLoc();
2220 return parseIdentifierExpr();
2223 if (!curToken.isString())
2224 return emitError(
"expected string literal containing MLIR type");
2234FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2236 SMRange nameLoc = curToken.
getLoc();
2244 SmallVector<ast::ConstraintRef> constraints;
2245 if (
failed(parseVariableDeclConstraintList(constraints)))
2249 if (
failed(validateVariableConstraints(constraints, type)))
2251 return createInlineVariableExpr(type, name, nameLoc, constraints);
2258FailureOr<ast::Stmt *> Parser::parseStmt(
bool expectTerminalSemicolon) {
2259 FailureOr<ast::Stmt *> stmt;
2262 stmt = parseEraseStmt();
2265 stmt = parseLetStmt();
2268 stmt = parseReplaceStmt();
2271 stmt = parseReturnStmt();
2274 stmt = parseRewriteStmt();
2281 (expectTerminalSemicolon &&
2287FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2288 SMLoc startLoc = curToken.getStartLoc();
2293 SmallVector<ast::Stmt *> statements;
2295 FailureOr<ast::Stmt *> statement = parseStmt();
2297 return popDeclScope(), failure();
2298 statements.push_back(*statement);
2303 SMRange location(startLoc, curToken.
getEndLoc());
2309FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2310 if (parserContext == ParserContext::Constraint)
2311 return emitError(
"`erase` cannot be used within a Constraint");
2312 SMRange loc = curToken.
getLoc();
2316 FailureOr<ast::Expr *> rootOp = parseExpr();
2320 return createEraseStmt(loc, *rootOp);
2323FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2324 SMRange loc = curToken.
getLoc();
2328 SMRange varLoc = curToken.
getLoc();
2333 "`_` may only be used to define \"inline\" variables");
2336 "expected identifier after `let` to name a new variable");
2342 SmallVector<ast::ConstraintRef> constraints;
2344 failed(parseVariableDeclConstraintList(constraints)))
2348 ast::Expr *initializer =
nullptr;
2350 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2351 if (
failed(initOrFailure))
2353 initializer = *initOrFailure;
2357 for (ast::ConstraintRef constraint : constraints) {
2360 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2361 ast::ValueRangeConstraintDecl>([&](
const auto *cst) {
2362 if (cst->getTypeExpr()) {
2364 constraint.referenceLoc,
2365 "type constraints are not permitted on variables with "
2376 FailureOr<ast::VariableDecl *> varDecl =
2377 createVariableDecl(varName, varLoc, initializer, constraints);
2383FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2384 if (parserContext == ParserContext::Constraint)
2385 return emitError(
"`replace` cannot be used within a Constraint");
2386 SMRange loc = curToken.
getLoc();
2390 FailureOr<ast::Expr *> rootOp = parseExpr();
2395 parseToken(
Token::kw_with,
"expected `with` after root operation")))
2399 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2402 SmallVector<ast::Expr *> replValues;
2406 loc,
"expected at least one replacement value, consider using "
2407 "`erase` if no replacement values are desired");
2411 FailureOr<ast::Expr *> replExpr = parseExpr();
2414 replValues.emplace_back(*replExpr);
2418 "expected `)` after replacement values")))
2423 FailureOr<ast::Expr *> replExpr;
2425 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2427 replExpr = parseExpr();
2430 replValues.emplace_back(*replExpr);
2433 return createReplaceStmt(loc, *rootOp, replValues);
2436FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2437 SMRange loc = curToken.
getLoc();
2441 FailureOr<ast::Expr *> resultExpr = parseExpr();
2448FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2449 if (parserContext == ParserContext::Constraint)
2450 return emitError(
"`rewrite` cannot be used within a Constraint");
2451 SMRange loc = curToken.
getLoc();
2455 FailureOr<ast::Expr *> rootOp = parseExpr();
2463 return emitError(
"expected `{` to start rewrite body");
2466 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2468 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2473 for (
const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2474 if (isa<ast::ReturnStmt>(stmt)) {
2476 "`return` statements are only permitted within a "
2477 "`Constraint` or `Rewrite` body");
2481 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2492ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2494 if (
auto *init = dyn_cast<ast::DeclRefExpr>(node))
2495 node = init->getDecl();
2496 return dyn_cast<ast::CallableDecl>(node);
2499FailureOr<ast::PatternDecl *>
2500Parser::createPatternDecl(SMRange loc,
const ast::Name *name,
2501 const ParsedPatternMetadata &metadata,
2502 ast::CompoundStmt *body) {
2504 metadata.hasBoundedRecursion, body);
2507ast::Type Parser::createUserConstraintRewriteResultType(
2508 ArrayRef<ast::VariableDecl *> results) {
2510 if (results.size() == 1)
2511 return results[0]->getType();
2515 auto resultTypes = llvm::map_range(
2516 results, [&](
const auto *
result) {
return result->getType(); });
2517 auto resultNames = llvm::map_range(
2518 results, [&](
const auto *
result) {
return result->getName().getName(); });
2520 llvm::to_vector(resultNames));
2523template <
typename T>
2524FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2525 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2526 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2527 ast::CompoundStmt *body) {
2529 if (
auto *retStmt = dyn_cast<ast::ReturnStmt>(body->
getChildren().back())) {
2530 ast::Expr *resultExpr = retStmt->getResultExpr();
2535 if (results.empty())
2536 resultType = resultExpr->
getType();
2537 else if (
failed(convertExpressionTo(resultExpr, resultType)))
2540 retStmt->setResultExpr(resultExpr);
2543 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2546FailureOr<ast::VariableDecl *>
2547Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2548 ArrayRef<ast::ConstraintRef> constraints) {
2552 if (
failed(validateVariableConstraints(constraints, type)))
2559 type = initializer->
getType();
2562 else if (
failed(convertExpressionTo(initializer, type)))
2568 return emitErrorAndNote(
2569 loc,
"unable to infer type for variable `" + name +
"`", loc,
2570 "the type of a variable must be inferable from the constraint "
2571 "list or the initializer");
2575 if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2577 loc, llvm::formatv(
"unable to define variable of `{0}` type", type));
2581 FailureOr<ast::VariableDecl *> varDecl =
2582 defineVariableDecl(name, loc, type, initializer, constraints);
2589FailureOr<ast::VariableDecl *>
2590Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2591 const ast::ConstraintRef &constraint) {
2593 if (
failed(validateVariableConstraint(constraint, argType)))
2595 return defineVariableDecl(name, loc, argType, constraint);
2599Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2600 ast::Type &inferredType) {
2601 for (
const ast::ConstraintRef &ref : constraints)
2602 if (
failed(validateVariableConstraint(ref, inferredType)))
2607LogicalResult Parser::validateVariableConstraint(
const ast::ConstraintRef &ref,
2608 ast::Type &inferredType) {
2609 ast::Type constraintType;
2610 if (
const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.
constraint)) {
2611 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2612 if (
failed(validateTypeConstraintExpr(typeExpr)))
2616 }
else if (
const auto *cst =
2617 dyn_cast<ast::OpConstraintDecl>(ref.
constraint)) {
2620 }
else if (isa<ast::TypeConstraintDecl>(ref.
constraint)) {
2621 constraintType = typeTy;
2622 }
else if (isa<ast::TypeRangeConstraintDecl>(ref.
constraint)) {
2623 constraintType = typeRangeTy;
2624 }
else if (
const auto *cst =
2625 dyn_cast<ast::ValueConstraintDecl>(ref.
constraint)) {
2626 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2627 if (
failed(validateTypeConstraintExpr(typeExpr)))
2630 constraintType = valueTy;
2631 }
else if (
const auto *cst =
2632 dyn_cast<ast::ValueRangeConstraintDecl>(ref.
constraint)) {
2633 if (
const ast::Expr *typeExpr = cst->getTypeExpr()) {
2634 if (
failed(validateTypeRangeConstraintExpr(typeExpr)))
2637 constraintType = valueRangeTy;
2638 }
else if (
const auto *cst =
2639 dyn_cast<ast::UserConstraintDecl>(ref.
constraint)) {
2640 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2641 if (inputs.size() != 1) {
2643 "`Constraint`s applied via a variable constraint "
2644 "list must take a single input, but got " +
2645 Twine(inputs.size()),
2647 "see definition of constraint here");
2649 constraintType = inputs.front()->getType();
2651 llvm_unreachable(
"unknown constraint type");
2656 if (!inferredType) {
2657 inferredType = constraintType;
2658 }
else if (ast::Type mergedTy = inferredType.
refineWith(constraintType)) {
2659 inferredType = mergedTy;
2662 llvm::formatv(
"constraint type `{0}` is incompatible "
2663 "with the previously inferred type `{1}`",
2664 constraintType, inferredType));
2669LogicalResult Parser::validateTypeConstraintExpr(
const ast::Expr *typeExpr) {
2670 ast::Type typeExprType = typeExpr->
getType();
2671 if (typeExprType != typeTy) {
2673 "expected expression of `Type` in type constraint");
2679Parser::validateTypeRangeConstraintExpr(
const ast::Expr *typeExpr) {
2680 ast::Type typeExprType = typeExpr->
getType();
2681 if (typeExprType != typeRangeTy) {
2683 "expected expression of `TypeRange` in type constraint");
2692FailureOr<ast::CallExpr *>
2693Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2694 MutableArrayRef<ast::Expr *> arguments,
bool isNegated) {
2695 ast::Type parentType = parentExpr->
getType();
2697 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2698 if (!callableDecl) {
2700 llvm::formatv(
"expected a reference to a callable "
2701 "`Constraint` or `Rewrite`, but got: `{0}`",
2704 if (parserContext == ParserContext::Rewrite) {
2705 if (isa<ast::UserConstraintDecl>(callableDecl))
2707 loc,
"unable to invoke `Constraint` within a rewrite section");
2709 return emitError(loc,
"unable to negate a Rewrite");
2711 if (isa<ast::UserRewriteDecl>(callableDecl))
2713 "unable to invoke `Rewrite` within a match section");
2714 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2715 return emitError(loc,
"unable to negate non native constraints");
2720 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->
getInputs();
2721 if (callArgs.size() != arguments.size()) {
2722 return emitErrorAndNote(
2724 llvm::formatv(
"invalid number of arguments for {0} call; expected "
2729 llvm::formatv(
"see the definition of {0} here",
2734 auto attachDiagFn = [&](ast::Diagnostic &
diag) {
2735 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here",
2739 for (
auto it : llvm::zip(callArgs, arguments)) {
2740 if (
failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->
getType(),
2749FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2753 if (isa<ast::ConstraintDecl>(decl))
2755 else if (isa<ast::UserRewriteDecl>(decl))
2757 else if (
auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2758 declType = varDecl->getType();
2760 return emitError(loc,
"invalid reference to `" +
2766FailureOr<ast::DeclRefExpr *>
2767Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2768 ArrayRef<ast::ConstraintRef> constraints) {
2769 FailureOr<ast::VariableDecl *> decl =
2770 defineVariableDecl(name, loc, type, constraints);
2776FailureOr<ast::MemberAccessExpr *>
2777Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2780 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2787FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2788 StringRef name, SMRange loc) {
2789 ast::Type parentType = parentExpr->
getType();
2790 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2792 return valueRangeTy;
2795 if (
const ods::Operation *odsOp = opType.getODSOperation()) {
2796 auto results = odsOp->getResults();
2800 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2801 index < results.size()) {
2802 return results[index].isVariadic() ? valueRangeTy : valueTy;
2806 const auto *it = llvm::find_if(results, [&](
const auto &
result) {
2807 return result.getName() == name;
2809 if (it != results.end())
2810 return it->isVariadic() ? valueRangeTy : valueTy;
2811 }
else if (llvm::isDigit(name[0])) {
2816 }
else if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2819 if (llvm::isDigit(name[0]) && !name.getAsInteger(10, index) &&
2820 index < tupleType.size()) {
2821 return tupleType.getElementTypes()[index];
2825 auto elementNames = tupleType.getElementNames();
2826 const auto *it = llvm::find(elementNames, name);
2827 if (it != elementNames.end())
2828 return tupleType.getElementTypes()[it - elementNames.begin()];
2832 llvm::formatv(
"invalid member access `{0}` on expression of type `{1}`",
2836FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2837 SMRange loc,
const ast::OpNameDecl *name,
2838 OpResultTypeContext resultTypeContext,
2839 SmallVectorImpl<ast::Expr *> &operands,
2840 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2841 SmallVectorImpl<ast::Expr *> &results) {
2842 std::optional<StringRef> opNameRef = name->
getName();
2843 const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2846 if (
failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2850 for (ast::NamedAttributeDecl *attr : attributes) {
2852 ast::Type attrType = attr->getValue()->getType();
2853 if (!isa<ast::AttributeType>(attrType)) {
2855 attr->getValue()->getLoc(),
2856 llvm::formatv(
"expected `Attr` expression, but got `{0}`", attrType));
2861 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2862 "unexpected inferrence when results were explicitly specified");
2866 if (resultTypeContext == OpResultTypeContext::Explicit) {
2867 if (
failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2871 }
else if (resultTypeContext == OpResultTypeContext::Interface) {
2873 "expected valid operation name when inferring operation results");
2874 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2882Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2883 const ods::Operation *odsOp,
2884 SmallVectorImpl<ast::Expr *> &operands) {
2885 return validateOperationOperandsOrResults(
2886 "operand", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2888 odsOp ? odsOp->
getOperands() : ArrayRef<pdll::ods::OperandOrResult>(),
2889 valueTy, valueRangeTy);
2893Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2894 const ods::Operation *odsOp,
2895 SmallVectorImpl<ast::Expr *> &results) {
2896 return validateOperationOperandsOrResults(
2897 "result", loc, odsOp ? odsOp->
getLoc() : std::optional<SMRange>(), name,
2899 odsOp ? odsOp->
getResults() : ArrayRef<pdll::ods::OperandOrResult>(),
2900 typeTy, typeRangeTy);
2903void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2904 const ods::Operation *odsOp) {
2914 "operation result types are marked to be inferred, but "
2915 "`{0}` is unknown. Ensure that `{0}` supports zero "
2916 "results or implements `InferTypeOpInterface`. Include "
2917 "the ODS definition of this operation to remove this warning.",
2926 bool requiresInferrence =
2928 return !result.isVariableLength();
2933 llvm::formatv(
"operation result types are marked to be inferred, but "
2934 "`{0}` does not provide an implementation of "
2935 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2936 "`InferTypeOpInterface` at runtime, or add support to "
2937 "the ODS definition to remove this warning.",
2939 diag->
attachNote(llvm::formatv(
"see the definition of `{0}` here", opName),
2945LogicalResult Parser::validateOperationOperandsOrResults(
2946 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2947 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2948 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2949 ast::RangeType rangeTy) {
2951 if (values.size() == 1) {
2952 if (
failed(convertExpressionTo(values[0], rangeTy)))
2960 auto emitSizeMismatchError = [&] {
2961 return emitErrorAndNote(
2963 llvm::formatv(
"invalid number of {0} groups for `{1}`; expected "
2965 groupName, *name, odsValues.size(), values.size()),
2966 *odsOpLoc, llvm::formatv(
"see the definition of `{0}` here", *name));
2970 if (values.empty()) {
2972 if (odsValues.empty())
2977 unsigned numVariadic = 0;
2978 for (
const auto &odsValue : odsValues) {
2979 if (!odsValue.isVariableLength())
2980 return emitSizeMismatchError();
2986 if (parserContext != ParserContext::Rewrite)
2993 if (numVariadic == 1)
2998 for (
unsigned i = 0, e = odsValues.size(); i < e; ++i) {
3007 if (odsValues.size() != values.size())
3008 return emitSizeMismatchError();
3010 auto diagFn = [&](ast::Diagnostic &
diag) {
3011 diag.attachNote(llvm::formatv(
"see the definition of `{0}` here", *name),
3014 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
3015 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3016 if (
failed(convertExpressionTo(values[i], expectedType, diagFn)))
3024 for (ast::Expr *&valueExpr : values) {
3025 ast::Type valueExprType = valueExpr->getType();
3028 if (valueExprType == rangeTy || valueExprType == singleTy)
3034 if (singleTy == valueTy) {
3035 if (isa<ast::OperationType>(valueExprType)) {
3036 valueExpr = convertOpToValue(valueExpr);
3042 if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3046 valueExpr->getLoc(),
3048 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3049 singleTy, rangeTy, valueExprType));
3054FailureOr<ast::TupleExpr *>
3055Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3056 ArrayRef<StringRef> elementNames) {
3057 for (
const ast::Expr *element : elements) {
3058 ast::Type eleTy = element->getType();
3059 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3062 llvm::formatv(
"unable to build a tuple with `{0}` element", eleTy));
3072FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3073 ast::Expr *rootOp) {
3075 ast::Type rootType = rootOp->
getType();
3076 if (!isa<ast::OperationType>(rootType))
3082FailureOr<ast::ReplaceStmt *>
3083Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3084 MutableArrayRef<ast::Expr *> replValues) {
3086 ast::Type rootType = rootOp->
getType();
3087 if (!isa<ast::OperationType>(rootType)) {
3090 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3095 bool shouldConvertOpToValues = replValues.size() > 1;
3096 for (ast::Expr *&replExpr : replValues) {
3097 ast::Type replType = replExpr->getType();
3100 if (isa<ast::OperationType>(replType)) {
3101 if (shouldConvertOpToValues)
3102 replExpr = convertOpToValue(replExpr);
3106 if (replType != valueTy && replType != valueRangeTy) {
3108 llvm::formatv(
"expected `Op`, `Value` or `ValueRange` "
3109 "expression, but got `{0}`",
3117FailureOr<ast::RewriteStmt *>
3118Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3119 ast::CompoundStmt *rewriteBody) {
3121 ast::Type rootType = rootOp->
getType();
3122 if (!isa<ast::OperationType>(rootType)) {
3125 llvm::formatv(
"expected `Op` expression, but got `{0}`", rootType));
3135LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3136 ast::Type parentType = parentExpr->
getType();
3137 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3139 else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3145Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3152Parser::codeCompleteConstraintName(ast::Type inferredType,
3153 bool allowInlineTypeConstraints) {
3155 inferredType, allowInlineTypeConstraints, curDeclScope);
3159LogicalResult Parser::codeCompleteDialectName() {
3164LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3169LogicalResult Parser::codeCompletePatternMetadata() {
3174LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3179void Parser::codeCompleteCallSignature(ast::Node *parent,
3180 unsigned currentNumArgs) {
3181 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3188void Parser::codeCompleteOperationOperandsSignature(
3189 std::optional<StringRef> opName,
unsigned currentNumOperands) {
3191 opName, currentNumOperands);
3194void Parser::codeCompleteOperationResultsSignature(
3195 std::optional<StringRef> opName,
unsigned currentNumResults) {
3204FailureOr<ast::Module *>
3206 bool enableDocumentation,
3208 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3209 return parser.parseModule();
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::string diag(const llvm::Value &value)
const llvm::SourceMgr & getSourceMgr()
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
std::string getStringValue() const
Given a token containing a string literal, return its value, including removing the quote characters ...
bool isAny(Kind k1, Kind k2) const
StringRef getSpelling() const
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
@ code_complete_string
Token signifying a code completion location within a string.
@ code_complete
Token signifying a code completion location.
@ less
Paired punctuation.
@ kw_Attr
General keywords.
static StringRef getMemberName()
Return the member name used for the "all-results" access.
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
static AttributeType get(Context &context)
Return an instance of the Attribute type.
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Type getResultType() const
Return the result type of this decl.
StringRef getCallableType() const
Return the callable type of this decl.
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
ArrayRef< Stmt * >::iterator end() const
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
ArrayRef< Stmt * >::iterator begin() const
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
This class represents the main context of the PDLL AST.
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
ods::Context & getODSContext()
Return the ODS context used by the AST.
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
This class represents a scope for named AST decls.
Decl * lookup(StringRef name)
Lookup a decl with the given name starting from this scope.
void add(Decl *decl)
Add a new decl to the scope.
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
void setDocComment(Context &ctx, StringRef comment)
Set the documentation comment for this decl.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
This class provides a simple implementation of a PDLL diagnostic.
Diagnostic & attachNote(const Twine &msg, std::optional< SMRange > noteLoc=std::nullopt)
Attach a note to this diagnostic.
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a diagnostic that is inflight and set to be reported.
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
SMRange getLoc() const
Return the location of this node.
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
static OpNameDecl * create(Context &ctx, const Name &name)
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
This class represents a PDLL tuple type, i.e.
size_t size() const
Return the number of elements within this tuple.
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
static TypeType get(Context &context)
Return an instance of the Type type.
Type refineWith(Type other) const
Try to refine this type with the one provided.
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
static ValueType get(Context &context)
Return an instance of the Value type.
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
This class provides an ODS representation of a specific operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
SMRange getLoc() const
Return the source location of this operation.
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
StringRef getSummary() const
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
StringRef getDescription() const
std::string getConditionTemplate() const
FmtContext & withSelf(Twine subst)
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
const ConstraintDecl * constraint
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.
static const Name & create(Context &ctx, StringRef name, SMRange location)