28 #include "llvm/ADT/IntervalMap.h"
29 #include "llvm/ADT/StringMap.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/FileSystem.h"
33 #include "llvm/Support/Path.h"
43 int bufferId = mgr.FindBufferContainingLoc(loc.Start);
44 if (bufferId == 0 || bufferId ==
static_cast<int>(mgr.getMainFileID()))
47 mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
58 return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
68 static std::optional<lsp::Diagnostic>
87 switch (
diag.getSeverity()) {
88 case ast::Diagnostic::Severity::DK_Note:
89 llvm_unreachable(
"expected notes to be handled separately");
90 case ast::Diagnostic::Severity::DK_Warning:
93 case ast::Diagnostic::Severity::DK_Error:
96 case ast::Diagnostic::Severity::DK_Remark:
103 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
105 relatedDiags.emplace_back(
107 note.getMessage().str());
109 if (!relatedDiags.empty())
116 static std::optional<std::string>
132 struct PDLIndexSymbol {
133 explicit PDLIndexSymbol(
const ast::Decl *definition)
134 : definition(definition) {}
136 : definition(definition) {}
139 SMRange getDefLoc()
const {
141 llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
143 return declName ? declName->
getLoc() : decl->getLoc();
151 std::vector<SMRange> references;
158 PDLIndex() : intervalMap(allocator) {}
166 const PDLIndexSymbol *lookup(SMLoc loc,
167 SMRange *overlappedRange =
nullptr)
const;
173 llvm::IntervalMap<
const char *,
const PDLIndexSymbol *,
174 llvm::IntervalMapImpl::NodeSizer<
175 const char *,
const PDLIndexSymbol *>::LeafSize,
176 llvm::IntervalMapHalfOpenInfo<const char *>>;
179 MapT::Allocator allocator;
190 void PDLIndex::initialize(
const ast::Module &module,
192 auto getOrInsertDef = [&](
const auto *def) -> PDLIndexSymbol * {
193 auto it = defToSymbol.try_emplace(def,
nullptr);
195 it.first->second = std::make_unique<PDLIndexSymbol>(def);
196 return &*it.first->second;
198 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
199 bool isDef =
false) {
200 const char *startLoc = refLoc.Start.getPointer();
201 const char *endLoc = refLoc.End.getPointer();
202 if (!intervalMap.overlaps(startLoc, endLoc)) {
203 intervalMap.insert(startLoc, endLoc, sym);
205 sym->references.push_back(refLoc);
208 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
213 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
214 insertDeclRef(symbol, odsOp->
getLoc(),
true);
215 insertDeclRef(symbol, refLoc);
220 if (
const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
221 if (std::optional<StringRef> name = decl->getName())
222 insertODSOpRef(*name, decl->getLoc());
223 }
else if (
const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
227 PDLIndexSymbol *declSym = getOrInsertDef(decl);
228 insertDeclRef(declSym, name->
getLoc(),
true);
230 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
232 for (
const auto &it : varDecl->getConstraints())
233 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
235 }
else if (
const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
236 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
241 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
242 SMRange *overlappedRange)
const {
243 auto it = intervalMap.find(loc.getPointer());
244 if (!it.valid() || loc.getPointer() < it.start())
247 if (overlappedRange) {
248 *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
249 SMLoc::getFromPointer(it.stop()));
263 const std::vector<std::string> &extraDirs,
264 std::vector<lsp::Diagnostic> &diagnostics);
265 PDLDocument(
const PDLDocument &) =
delete;
266 PDLDocument &operator=(
const PDLDocument &) =
delete;
273 std::vector<lsp::Location> &locations);
275 std::vector<lsp::Location> &references);
282 std::vector<lsp::DocumentLink> &links);
290 std::optional<lsp::Hover> findHover(
const ast::Decl *decl,
291 const SMRange &hoverRange);
293 const SMRange &hoverRange);
295 const SMRange &hoverRange);
297 const SMRange &hoverRange);
299 const SMRange &hoverRange);
300 template <
typename T>
301 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
303 const SMRange &hoverRange);
309 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
330 std::vector<lsp::InlayHint> &inlayHints);
333 std::vector<lsp::InlayHint> &inlayHints);
335 std::vector<lsp::InlayHint> &inlayHints);
338 std::vector<lsp::InlayHint> &inlayHints);
341 void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
355 std::vector<std::string> includeDirs;
358 llvm::SourceMgr sourceMgr;
365 FailureOr<ast::Module *> astModule;
375 PDLDocument::PDLDocument(
const lsp::URIForFile &uri, StringRef contents,
376 const std::vector<std::string> &extraDirs,
377 std::vector<lsp::Diagnostic> &diagnostics)
378 : astContext(odsContext) {
379 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.
file());
381 lsp::Logger::error(
"Failed to create memory buffer for file", uri.
file());
387 llvm::sys::path::remove_filename(uriDirectory);
388 includeDirs.push_back(uriDirectory.str().str());
389 includeDirs.insert(includeDirs.end(), extraDirs.begin(), extraDirs.end());
391 sourceMgr.setIncludeDirs(includeDirs);
392 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
396 diagnostics.push_back(std::move(*lspDiag));
404 if (failed(astModule))
408 index.initialize(**astModule, odsContext);
417 std::vector<lsp::Location> &locations) {
419 const PDLIndexSymbol *symbol = index.lookup(posLoc);
428 std::vector<lsp::Location> &references) {
430 const PDLIndexSymbol *symbol = index.lookup(posLoc);
435 for (SMRange refLoc : symbol->references)
444 std::vector<lsp::DocumentLink> &links) {
446 links.emplace_back(include.range, include.uri);
453 std::optional<lsp::Hover>
456 SMLoc posLoc = hoverPos.
getAsSMLoc(sourceMgr);
460 if (include.range.contains(hoverPos))
461 return include.buildHover();
465 const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
471 llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
472 return buildHoverForOpName(op, hoverRange);
473 const auto *decl = symbol->definition.get<
const ast::Decl *>();
474 return findHover(decl, hoverRange);
477 std::optional<lsp::Hover> PDLDocument::findHover(
const ast::Decl *decl,
478 const SMRange &hoverRange) {
480 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
481 return buildHoverForVariable(varDecl, hoverRange);
484 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
485 return buildHoverForPattern(patternDecl, hoverRange);
488 if (
const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
489 return buildHoverForCoreConstraint(cst, hoverRange);
492 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
493 return buildHoverForUserConstraintOrRewrite(
"Constraint", cst, hoverRange);
496 if (
const auto *
rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
497 return buildHoverForUserConstraintOrRewrite(
"Rewrite",
rewrite, hoverRange);
503 const SMRange &hoverRange) {
506 llvm::raw_string_ostream hoverOS(hover.contents.value);
507 hoverOS <<
"**OpName**: `" << op->
getName() <<
"`\n***\n"
508 << op->getSummary() <<
"\n***\n"
509 << op->getDescription();
515 const SMRange &hoverRange) {
518 llvm::raw_string_ostream hoverOS(hover.contents.value);
519 hoverOS <<
"**Variable**: `" << varDecl->
getName().
getName() <<
"`\n***\n"
520 <<
"Type: `" << varDecl->
getType() <<
"`\n";
526 const SMRange &hoverRange) {
529 llvm::raw_string_ostream hoverOS(hover.contents.value);
530 hoverOS <<
"**Pattern**";
531 if (
const ast::Name *name = decl->getName())
532 hoverOS <<
": `" << name->
getName() <<
"`";
533 hoverOS <<
"\n***\n";
534 if (std::optional<uint16_t> benefit = decl->
getBenefit())
535 hoverOS <<
"Benefit: " << *benefit <<
"\n";
537 hoverOS <<
"HasBoundedRewriteRecursion\n";
538 hoverOS <<
"RootOp: `"
543 hoverOS <<
"\n" << *doc <<
"\n";
550 const SMRange &hoverRange) {
553 llvm::raw_string_ostream hoverOS(hover.contents.value);
554 hoverOS <<
"**Constraint**: `";
559 if (std::optional<StringRef> name = opCst->
getName())
560 hoverOS <<
"<" << *name <<
">";
564 hoverOS <<
"TypeRange";
568 hoverOS <<
"ValueRange";
575 template <
typename T>
576 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
577 StringRef typeName,
const T *decl,
const SMRange &hoverRange) {
580 llvm::raw_string_ostream hoverOS(hover.contents.value);
581 hoverOS <<
"**" << typeName <<
"**: `" << decl->getName().getName()
584 if (!inputs.empty()) {
585 hoverOS <<
"Parameters:\n";
587 hoverOS <<
"* " << input->getName().getName() <<
": `"
588 << input->getType() <<
"`\n";
591 ast::Type resultType = decl->getResultType();
592 if (
auto resultTupleTy = dyn_cast<ast::TupleType>(resultType)) {
593 if (!resultTupleTy.empty()) {
594 hoverOS <<
"Results:\n";
595 for (
auto it : llvm::zip(resultTupleTy.getElementNames(),
596 resultTupleTy.getElementTypes())) {
597 StringRef name = std::get<0>(it);
598 hoverOS <<
"* " << (name.empty() ?
"" : (name +
": ")) <<
"`"
599 << std::get<1>(it) <<
"`\n";
604 hoverOS <<
"Results:\n* `" << resultType <<
"`\n";
610 hoverOS <<
"\n" << *doc <<
"\n";
619 void PDLDocument::findDocumentSymbols(
620 std::vector<lsp::DocumentSymbol> &symbols) {
621 if (failed(astModule))
624 for (
const ast::Decl *decl : (*astModule)->getChildren()) {
628 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
631 SMRange nameLoc = name ? name->
getLoc() : patternDecl->getLoc();
632 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
634 symbols.emplace_back(
635 name ? name->
getName() :
"<pattern>", lsp::SymbolKind::Class,
637 }
else if (
const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
639 SMRange nameLoc = cDecl->getName().getLoc();
640 SMRange bodyLoc = nameLoc;
642 symbols.emplace_back(
643 cDecl->getName().getName(), lsp::SymbolKind::Function,
645 }
else if (
const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
647 SMRange nameLoc = cDecl->getName().getLoc();
648 SMRange bodyLoc = nameLoc;
650 symbols.emplace_back(
651 cDecl->getName().getName(), lsp::SymbolKind::Function,
664 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
669 completionList(completionList), odsContext(odsContext),
670 includeDirs(includeDirs) {}
675 for (
unsigned i = 0, e = tupleType.size(); i < e; ++i) {
678 item.
label = llvm::formatv(
"{0} (field #{0})", i).str();
682 item.
detail = llvm::formatv(
"{0}: {1}", i, elementTypes[i]);
684 completionList.items.emplace_back(item);
687 if (!elementNames[i].empty()) {
689 llvm::formatv(
"{1} (field #{0})", i, elementNames[i]).str();
692 completionList.items.emplace_back(item);
709 item.
label = llvm::formatv(
"{0} (field #{0})", it.index()).str();
715 item.
detail = llvm::formatv(
"{0}: Value", it.index()).str();
718 item.
detail = llvm::formatv(
"{0}: Value?", it.index()).str();
721 item.
detail = llvm::formatv(
"{0}: ValueRange", it.index()).str();
726 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
730 completionList.items.emplace_back(item);
734 if (!result.
getName().empty()) {
736 llvm::formatv(
"{1} (field #{0})", it.index(), result.
getName())
740 completionList.items.emplace_back(item);
754 item.
label = attr.getName().str();
756 item.
detail = attr.isOptional() ?
"optional" :
"";
759 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
763 completionList.items.emplace_back(item);
768 bool allowInlineTypeConstraints,
770 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
771 StringRef snippetText =
"") {
773 item.
label = constraint.str();
775 item.
detail = (constraint +
" constraint").str();
778 (
"A single entity core constraint of type `" + mlirType +
"`").str()};
784 completionList.items.emplace_back(item);
791 addCoreConstraint(
"Attr",
"mlir::Attribute");
792 addCoreConstraint(
"Op",
"mlir::Operation *");
793 addCoreConstraint(
"Value",
"mlir::Value");
794 addCoreConstraint(
"ValueRange",
"mlir::ValueRange");
795 addCoreConstraint(
"Type",
"mlir::Type");
796 addCoreConstraint(
"TypeRange",
"mlir::TypeRange");
798 if (allowInlineTypeConstraints) {
800 if (!currentType || isa<ast::AttributeType>(currentType))
801 addCoreConstraint(
"Attr<type>",
"mlir::Attribute",
"Attr<$1>");
803 if (!currentType || isa<ast::ValueType>(currentType))
804 addCoreConstraint(
"Value<type>",
"mlir::Value",
"Value<$1>");
806 if (!currentType || isa<ast::ValueRangeType>(currentType))
807 addCoreConstraint(
"ValueRange<type>",
"mlir::ValueRange",
813 for (
const ast::Decl *decl : scope->getDecls()) {
814 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
816 item.
label = cst->getName().getName().str();
822 if (cst->getInputs().size() != 1)
826 ast::Type constraintType = cst->getInputs()[0]->getType();
827 if (currentType && !currentType.refineWith(constraintType))
832 llvm::raw_string_ostream strOS(item.
detail);
834 llvm::interleaveComma(
836 strOS << var->getName().getName() <<
": " << var->getType();
838 strOS <<
") -> " << cst->getResultType();
842 if (std::optional<std::string> doc =
848 completionList.items.emplace_back(item);
852 scope = scope->getParentScope();
860 item.
label = dialect.getName().str();
863 completionList.items.emplace_back(item);
876 item.
label = op.
getName().drop_front(dialectName.size() + 1).str();
879 completionList.items.emplace_back(item);
884 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
885 StringRef snippetText =
"") {
887 item.
label = constraint.str();
889 item.
detail =
"pattern metadata";
896 completionList.items.emplace_back(item);
899 addSimpleConstraint(
"benefit",
"The `benefit` of matching the pattern.",
901 addSimpleConstraint(
"recursion",
902 "The pattern properly handles recursive application.");
909 llvm::sys::path::native(nativeRelDir);
915 auto addIncludeCompletion = [&](StringRef path,
bool isDirectory) {
917 item.
label = path.str();
920 if (seenResults.insert(item.
label).second)
921 completionList.items.emplace_back(item);
926 for (StringRef includeDir : includeDirs) {
928 if (!nativeRelDir.empty())
929 llvm::sys::path::append(dir, nativeRelDir);
931 std::error_code errorCode;
932 for (
auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
933 e = llvm::sys::fs::directory_iterator();
934 !errorCode && it != e; it.increment(errorCode)) {
935 StringRef filename = llvm::sys::path::filename(it->path());
940 llvm::sys::fs::file_type fileType = it->type();
941 if (fileType == llvm::sys::fs::file_type::symlink_file) {
942 if (
auto fileStatus = it->status())
943 fileType = fileStatus->type();
947 case llvm::sys::fs::file_type::directory_file:
948 addIncludeCompletion(filename,
true);
950 case llvm::sys::fs::file_type::regular_file: {
952 if (filename.ends_with(
".pdll") || filename.ends_with(
".td"))
953 addIncludeCompletion(filename,
false);
966 return lhs.label < rhs.label;
971 llvm::SourceMgr &sourceMgr;
981 SMLoc posLoc = completePos.
getAsSMLoc(sourceMgr);
982 if (!posLoc.isValid())
989 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
991 sourceMgr.getIncludeDirs());
995 &lspCompleteContext);
997 return completionList;
1007 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1011 signatureHelp(signatureHelp), odsContext(odsContext) {}
1014 unsigned currentNumArgs)
final {
1015 signatureHelp.activeParameter = currentNumArgs;
1019 llvm::raw_string_ostream strOS(signatureInfo.
label);
1020 strOS << callable->getName()->getName() <<
"(";
1022 unsigned paramStart = strOS.str().size();
1023 strOS << var->getName().getName() <<
": " << var->getType();
1024 unsigned paramEnd = strOS.str().size();
1026 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1027 std::make_pair(paramStart, paramEnd), std::string()});
1029 llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1030 strOS <<
") -> " << callable->getResultType();
1034 if (std::optional<std::string> doc =
1038 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1043 unsigned currentNumOperands)
final {
1046 codeCompleteOperationOperandOrResultSignature(
1047 opName, odsOp, odsOp ? odsOp->
getOperands() : std::nullopt,
1048 currentNumOperands,
"operand",
"Value");
1052 unsigned currentNumResults)
final {
1055 codeCompleteOperationOperandOrResultSignature(
1056 opName, odsOp, odsOp ? odsOp->
getResults() : std::nullopt,
1057 currentNumResults,
"result",
"Type");
1060 void codeCompleteOperationOperandOrResultSignature(
1063 StringRef label, StringRef dataType) {
1064 signatureHelp.activeParameter = currentValue;
1070 if (odsOp && currentValue < values.size()) {
1075 llvm::raw_string_ostream strOS(signatureInfo.
label);
1078 unsigned paramStart = strOS.str().size();
1080 strOS << value.getName() <<
": ";
1082 StringRef constraintDoc = value.getConstraint().getSummary();
1083 std::string paramDoc;
1084 switch (value.getVariableLengthKind()) {
1087 paramDoc = constraintDoc.str();
1090 strOS << dataType <<
"?";
1091 paramDoc = (
"optional: " + constraintDoc).str();
1094 strOS << dataType <<
"Range";
1095 paramDoc = (
"variadic: " + constraintDoc).str();
1099 unsigned paramEnd = strOS.str().size();
1101 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1102 std::make_pair(paramStart, paramEnd), paramDoc});
1104 llvm::interleaveComma(values, strOS, formatFn);
1108 llvm::formatv(
"`op<{0}>` ODS {1} specification", *opName, label)
1110 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1114 if (currentValue == 0 && (!odsOp || !values.empty())) {
1116 signatureInfo.
label =
1117 llvm::formatv(
"(<{0}s>: {1}Range)", label, dataType).str();
1119 (
"Generic operation " + label +
" specification").str();
1121 StringRef(signatureInfo.
label).drop_front().drop_back().str(),
1122 std::pair<unsigned, unsigned>(1, signatureInfo.
label.size() - 1),
1123 (
"All of the " + label +
"s of the operation.").str()});
1124 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1129 llvm::SourceMgr &sourceMgr;
1137 SMLoc posLoc = helpPos.
getAsSMLoc(sourceMgr);
1138 if (!posLoc.isValid())
1145 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1152 return signatureHelp;
1165 if (
auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1167 if (declName && declName->
getName() == name)
1176 std::vector<lsp::InlayHint> &inlayHints) {
1177 if (failed(astModule))
1180 if (!rangeLoc.isValid())
1182 (*astModule)->walk([&](
const ast::Node *node) {
1183 SMRange loc = node->
getLoc();
1193 [&](
const auto *node) {
1194 this->getInlayHintsFor(node, uri, inlayHints);
1201 std::vector<lsp::InlayHint> &inlayHints) {
1212 if (isa<ast::OperationExpr>(expr))
1219 llvm::raw_string_ostream labelOS(hint.label);
1220 labelOS <<
": " << decl->
getType();
1223 inlayHints.emplace_back(std::move(hint));
1226 void PDLDocument::getInlayHintsFor(
const ast::CallExpr *expr,
1228 std::vector<lsp::InlayHint> &inlayHints) {
1230 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
1231 const auto *callable =
1232 callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1238 for (
const auto &it : llvm::zip(expr->
getArguments(), callable->getInputs()))
1239 addParameterHintFor(inlayHints, std::get<0>(it),
1240 std::get<1>(it)->getName().getName());
1245 std::vector<lsp::InlayHint> &inlayHints) {
1250 auto addOpHint = [&](
const ast::Expr *valueExpr, StringRef label) {
1253 if (expr->getLoc().Start == valueExpr->
getLoc().Start)
1255 addParameterHintFor(inlayHints, valueExpr, label);
1263 StringRef allValuesName) {
1269 if (values.size() != odsValues.size()) {
1271 if (values.size() == 1)
1272 return addOpHint(values.front(), allValuesName);
1276 for (
const auto &it : llvm::zip(values, odsValues))
1277 addOpHint(std::get<0>(it), std::get<1>(it).getName());
1291 void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1292 const ast::Expr *expr, StringRef label) {
1298 hint.label = (label +
":").str();
1299 hint.paddingRight =
true;
1300 inlayHints.emplace_back(std::move(hint));
1307 void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1309 if (failed(astModule))
1311 if (kind == lsp::PDLLViewOutputKind::AST) {
1312 (*astModule)->print(os);
1325 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1331 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1332 "unexpected PDLLViewOutputKind");
1342 struct PDLTextFileChunk {
1345 const std::vector<std::string> &extraDirs,
1346 std::vector<lsp::Diagnostic> &diagnostics)
1347 : lineOffset(lineOffset),
1348 document(uri, contents, extraDirs, diagnostics) {}
1352 void adjustLocForChunkOffset(
lsp::Range &range) {
1353 adjustLocForChunkOffset(range.
start);
1354 adjustLocForChunkOffset(range.
end);
1361 uint64_t lineOffset;
1363 PDLDocument document;
1376 int64_t version,
const std::vector<std::string> &extraDirs,
1377 std::vector<lsp::Diagnostic> &diagnostics);
1380 int64_t getVersion()
const {
return version; }
1386 std::vector<lsp::Diagnostic> &diagnostics);
1393 std::vector<lsp::Location> &locations);
1395 std::vector<lsp::Location> &references);
1397 std::vector<lsp::DocumentLink> &links);
1400 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1406 std::vector<lsp::InlayHint> &inlayHints);
1410 using ChunkIterator = llvm::pointee_iterator<
1411 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1415 std::vector<lsp::Diagnostic> &diagnostics);
1422 return *getChunkItFor(pos);
1426 std::string contents;
1429 int64_t version = 0;
1432 int64_t totalNumLines = 0;
1436 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1439 std::vector<std::string> extraIncludeDirs;
1443 PDLTextFile::PDLTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1445 const std::vector<std::string> &extraDirs,
1446 std::vector<lsp::Diagnostic> &diagnostics)
1447 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1448 initialize(uri, version, diagnostics);
1454 std::vector<lsp::Diagnostic> &diagnostics) {
1455 if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1456 lsp::Logger::error(
"Failed to update contents of {0}", uri.
file());
1461 initialize(uri, newVersion, diagnostics);
1467 std::vector<lsp::Location> &locations) {
1468 PDLTextFileChunk &chunk = getChunkFor(defPos);
1469 chunk.document.getLocationsOf(uri, defPos, locations);
1472 if (chunk.lineOffset == 0)
1476 chunk.adjustLocForChunkOffset(loc.range);
1481 std::vector<lsp::Location> &references) {
1482 PDLTextFileChunk &chunk = getChunkFor(pos);
1483 chunk.document.findReferencesOf(uri, pos, references);
1486 if (chunk.lineOffset == 0)
1490 chunk.adjustLocForChunkOffset(loc.range);
1494 std::vector<lsp::DocumentLink> &links) {
1495 chunks.front()->document.getDocumentLinks(uri, links);
1496 for (
const auto &it : llvm::drop_begin(chunks)) {
1497 size_t currentNumLinks = links.size();
1498 it->document.getDocumentLinks(uri, links);
1502 for (
auto &link : llvm::drop_begin(links, currentNumLinks))
1503 it->adjustLocForChunkOffset(link.range);
1507 std::optional<lsp::Hover> PDLTextFile::findHover(
const lsp::URIForFile &uri,
1509 PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1510 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1513 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1514 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1518 void PDLTextFile::findDocumentSymbols(
1519 std::vector<lsp::DocumentSymbol> &symbols) {
1520 if (chunks.size() == 1)
1521 return chunks.front()->document.findDocumentSymbols(symbols);
1525 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1526 PDLTextFileChunk &chunk = *chunks[i];
1529 : chunks[i + 1]->lineOffset);
1531 lsp::SymbolKind::Namespace,
1534 chunk.document.findDocumentSymbols(symbol.children);
1540 symbolsToFix.push_back(&childSymbol);
1542 while (!symbolsToFix.empty()) {
1544 chunk.adjustLocForChunkOffset(symbol->
range);
1548 symbolsToFix.push_back(&childSymbol);
1553 symbols.emplace_back(std::move(symbol));
1559 PDLTextFileChunk &chunk = getChunkFor(completePos);
1561 chunk.document.getCodeCompletion(uri, completePos);
1566 chunk.adjustLocForChunkOffset(item.
textEdit->range);
1568 chunk.adjustLocForChunkOffset(edit.
range);
1570 return completionList;
1575 return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1579 std::vector<lsp::InlayHint> &inlayHints) {
1580 auto startIt = getChunkItFor(range.
start);
1581 auto endIt = getChunkItFor(range.
end);
1584 auto getHintsForChunk = [&](ChunkIterator chunkIt,
lsp::Range range) {
1585 size_t currentNumHints = inlayHints.size();
1586 chunkIt->document.getInlayHints(uri, range, inlayHints);
1590 if (&*chunkIt != &*chunks.front()) {
1591 for (
auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1592 chunkIt->adjustLocForChunkOffset(hint.position);
1596 auto getNumLines = [](ChunkIterator chunkIt) {
1597 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1601 if (startIt == endIt)
1602 return getHintsForChunk(startIt, range);
1606 getHintsForChunk(startIt,
lsp::Range(range.
start, getNumLines(startIt)));
1609 for (++startIt; startIt != endIt; ++startIt)
1610 getHintsForChunk(startIt,
lsp::Range(0, getNumLines(startIt)));
1621 llvm::raw_string_ostream outputOS(result.
output);
1623 llvm::make_pointee_range(chunks),
1624 [&](PDLTextFileChunk &chunk) {
1625 chunk.document.getPDLLViewOutput(outputOS, kind);
1627 [&] { outputOS <<
"\n"
1633 void PDLTextFile::initialize(
const lsp::URIForFile &uri, int64_t newVersion,
1634 std::vector<lsp::Diagnostic> &diagnostics) {
1635 version = newVersion;
1641 chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1642 0, uri, subContents.front(), extraIncludeDirs,
1645 uint64_t lineOffset = subContents.front().count(
'\n');
1646 for (StringRef docContents : llvm::drop_begin(subContents)) {
1647 unsigned currentNumDiags = diagnostics.size();
1648 auto chunk = std::make_unique<PDLTextFileChunk>(
1649 lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1650 lineOffset += docContents.count(
'\n');
1655 llvm::drop_begin(diagnostics, currentNumDiags)) {
1656 chunk->adjustLocForChunkOffset(
diag.range);
1658 if (!
diag.relatedInformation)
1660 for (
auto &it : *
diag.relatedInformation)
1661 if (it.location.uri == uri)
1662 chunk->adjustLocForChunkOffset(it.location.range);
1664 chunks.emplace_back(std::move(chunk));
1666 totalNumLines = lineOffset;
1669 PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(
lsp::Position &pos) {
1670 if (chunks.size() == 1)
1671 return chunks.begin();
1675 auto it = llvm::upper_bound(
1676 chunks, pos, [](
const lsp::Position &pos,
const auto &chunk) {
1677 return static_cast<uint64_t
>(pos.
line) < chunk->lineOffset;
1679 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1680 pos.
line -= chunkIt->lineOffset;
1700 llvm::StringMap<std::unique_ptr<PDLTextFile>>
files;
1713 std::vector<Diagnostic> &diagnostics) {
1715 std::vector<std::string> additionalIncludeDirs =
impl->options.extraDirs;
1716 const auto &fileInfo =
impl->compilationDatabase.getFileInfo(uri.
file());
1717 llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1719 impl->files[uri.
file()] = std::make_unique<PDLTextFile>(
1720 uri, contents, version, additionalIncludeDirs, diagnostics);
1725 int64_t version, std::vector<Diagnostic> &diagnostics) {
1727 auto it =
impl->files.find(uri.
file());
1728 if (it ==
impl->files.end())
1733 if (failed(it->second->update(uri, version, changes, diagnostics)))
1734 impl->files.erase(it);
1738 auto it =
impl->files.find(uri.
file());
1739 if (it ==
impl->files.end())
1740 return std::nullopt;
1742 int64_t version = it->second->getVersion();
1743 impl->files.erase(it);
1749 std::vector<Location> &locations) {
1750 auto fileIt =
impl->files.find(uri.
file());
1751 if (fileIt !=
impl->files.end())
1752 fileIt->second->getLocationsOf(uri, defPos, locations);
1757 std::vector<Location> &references) {
1758 auto fileIt =
impl->files.find(uri.
file());
1759 if (fileIt !=
impl->files.end())
1760 fileIt->second->findReferencesOf(uri, pos, references);
1764 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1765 auto fileIt =
impl->files.find(uri.
file());
1766 if (fileIt !=
impl->files.end())
1767 return fileIt->second->getDocumentLinks(uri, documentLinks);
1772 auto fileIt =
impl->files.find(uri.
file());
1773 if (fileIt !=
impl->files.end())
1774 return fileIt->second->findHover(uri, hoverPos);
1775 return std::nullopt;
1779 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1780 auto fileIt =
impl->files.find(uri.
file());
1781 if (fileIt !=
impl->files.end())
1782 fileIt->second->findDocumentSymbols(symbols);
1788 auto fileIt =
impl->files.find(uri.
file());
1789 if (fileIt !=
impl->files.end())
1790 return fileIt->second->getCodeCompletion(uri, completePos);
1796 auto fileIt =
impl->files.find(uri.
file());
1797 if (fileIt !=
impl->files.end())
1798 return fileIt->second->getSignatureHelp(uri, helpPos);
1803 std::vector<InlayHint> &inlayHints) {
1804 auto fileIt =
impl->files.find(uri.
file());
1805 if (fileIt ==
impl->files.end())
1807 fileIt->second->getInlayHints(uri, range, inlayHints);
1810 llvm::sort(inlayHints);
1811 inlayHints.erase(llvm::unique(inlayHints), inlayHints.end());
1814 std::optional<lsp::PDLLViewOutputResult>
1817 auto fileIt =
impl->files.find(uri.
file());
1818 if (fileIt !=
impl->files.end())
1819 return fileIt->second->getPDLLViewOutput(kind);
1820 return std::nullopt;
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static std::string diag(const llvm::Value &value)
static std::optional< std::string > getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl)
Get or extract the documentation for the given decl.
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, const lsp::URIForFile &uri)
Returns a language server location from the given source range.
static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, const lsp::URIForFile &mainFileURI)
Returns a language server uri for the given source location.
static bool shouldAddHintFor(const ast::Expr *expr, StringRef name)
Returns true if the given name should be added as a hint for expr.
static std::optional< lsp::Diagnostic > getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc)
Returns true if the given location is in the main file of the source manager.
static llvm::ManagedStatic< PassManagerOptions > options
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
MLIRContext is the top-level object for a collection of MLIR operations.
Set of flags used to control the behavior of the various IR print methods (e.g.
OperationName getName()
The name of an operation is the key identifier for it.
This class is a utility diagnostic handler for use with llvm::SourceMgr.
This class contains a collection of compilation information for files provided to the language server...
static void error(const char *fmt, Ts &&...vals)
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
void getDocumentLinks(const URIForFile &uri, std::vector< DocumentLink > &documentLinks)
Return the document links referenced by the given file.
void getInlayHints(const URIForFile &uri, const Range &range, std::vector< InlayHint > &inlayHints)
Get the inlay hints for the range within the given file.
void addDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add the document, with the provided version, at the given URI.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
SignatureHelp getSignatureHelp(const URIForFile &uri, const Position &helpPos)
Get the signature help for the position within the given file.
void updateDocument(const URIForFile &uri, ArrayRef< TextDocumentContentChangeEvent > changes, int64_t version, std::vector< Diagnostic > &diagnostics)
Update the document, with the provided version, at the given URI.
std::optional< PDLLViewOutputResult > getPDLLViewOutput(const URIForFile &uri, PDLLViewOutputKind kind)
Get the output of the given PDLL file, or std::nullopt if there is no valid output.
URI in "file" scheme for a file.
static llvm::Expected< URIForFile > fromFile(StringRef absoluteFilepath, StringRef scheme="file")
Try to build a URIForFile from the given absolute file path and optional scheme.
StringRef file() const
Returns the absolute path to the file.
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
This decl represents a shared interface for all callable decls.
This class represents the main context of the PDLL AST.
This class represents the base of all "core" constraints.
This class represents a scope for named AST decls.
This class represents the base Decl node.
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
This class provides a simple implementation of a PDLL diagnostic.
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a top-level AST module.
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
The class represents an Operation constraint, and constrains a variable to be an Operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
This expression represents the structural form of an MLIR Operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
This Decl represents a single Pattern.
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
This class represents a PDLL tuple type, i.e.
The class represents a Type constraint, and constrains a variable to be a Type.
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
The class represents a Value constraint, and constrains a variable to be a Value.
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
This Decl represents the definition of a PDLL variable.
const Name & getName() const
Return the name of the decl.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
This class represents a generic ODS Attribute constraint.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
This class provides an ODS representation of a specific operation attribute.
StringRef getSummary() const
Return the summary of this constraint.
This class contains all of the registered ODS operation classes.
auto getDialects() const
Return a range of all of the registered dialects.
const Dialect * lookupDialect(StringRef name) const
Lookup a dialect registered with the given name, or null if no dialect with that name was inserted.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
This class represents an ODS dialect, and contains information on the constructs held within the dial...
const llvm::StringMap< std::unique_ptr< Operation > > & getOperations() const
Return a map of all of the operations registered to this dialect.
This class provides an ODS representation of a specific operation operand or result.
const TypeConstraint & getConstraint() const
Return the constraint of this value.
VariableLengthKind getVariableLengthKind() const
Returns the variable length kind of this value.
StringRef getName() const
Return the name of this value.
This class provides an ODS representation of a specific operation.
ArrayRef< Attribute > getAttributes() const
Returns the attributes of this operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
PDLLViewOutputKind
The type of output to view from PDLL.
void gatherIncludeFiles(llvm::SourceMgr &sourceMgr, SmallVectorImpl< SourceMgrInclude > &includes)
Given a source manager, gather all of the processed include files.
std::optional< std::string > extractSourceDocComment(llvm::SourceMgr &sourceMgr, SMLoc loc)
Extract a documentation comment for the given location within the source manager.
@ PlainText
The primary text to be inserted is treated as a plain string.
@ Snippet
The primary text to be inserted is treated as a snippet.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
const char *const kDefaultSplitMarker
Impl(const Options &options)
lsp::CompilationDatabase compilationDatabase
The compilation database containing additional information for files passed to the server.
llvm::StringMap< std::unique_ptr< PDLTextFile > > files
The files held by the server, mapped by their URI file name.
const Options & options
PDLL LSP options.
std::string detail
A human-readable string with additional information about this item, like type or symbol information.
std::string filterText
A string that should be used when filtering a set of completion items.
std::optional< TextEdit > textEdit
An edit which is applied to a document when selecting this completion.
std::optional< MarkupContent > documentation
A human-readable string that represents a doc-comment.
std::string insertText
A string that should be inserted to a document when selecting this completion.
std::string label
The label of this completion item.
CompletionItemKind kind
The kind of this completion item.
std::vector< TextEdit > additionalTextEdits
An optional array of additional text edits that are applied when selecting this completion.
InsertTextFormat insertTextFormat
The format of the insert text.
std::string sortText
A string that should be used when comparing this item with other items.
Represents a collection of completion items to be presented in the editor.
std::vector< CompletionItem > items
The completion items.
std::string source
A human-readable string describing the source of this diagnostic, e.g.
DiagnosticSeverity severity
The diagnostic's severity.
Range range
The source range where the message applies.
std::optional< std::vector< DiagnosticRelatedInformation > > relatedInformation
An array of related diagnostic information, e.g.
std::string message
The diagnostic's message.
std::optional< std::string > category
The diagnostic's category.
Represents programming constructs like variables, classes, interfaces etc.
Range range
The range enclosing this symbol not including leading/trailing whitespace but everything else like co...
Range selectionRange
The range that should be selected and revealed when this symbol is being picked, e....
std::vector< DocumentSymbol > children
Children of this symbol, e.g. properties of a class.
URIForFile uri
The text document's URI.
Represents the result of viewing the output of a PDLL file.
std::string output
The string representation of the output.
int line
Line position in a document (zero-based).
SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const
Convert this position into a source location in the main file of the given source manager.
Position end
The range's end position.
Position start
The range's start position.
SMRange getAsSMRange(llvm::SourceMgr &mgr) const
Convert this range into a source range in the main file of the given source manager.
Represents the signature of a callable.
This class represents a single include within a root file.
Range range
The range of the text document to be manipulated.
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.