28 #include "llvm/ADT/IntervalMap.h"
29 #include "llvm/ADT/StringMap.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/FileSystem.h"
33 #include "llvm/Support/Path.h"
43 int bufferId = mgr.FindBufferContainingLoc(loc.Start);
44 if (bufferId == 0 || bufferId ==
static_cast<int>(mgr.getMainFileID()))
47 mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
58 return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
68 static std::optional<lsp::Diagnostic>
87 switch (
diag.getSeverity()) {
88 case ast::Diagnostic::Severity::DK_Note:
89 llvm_unreachable(
"expected notes to be handled separately");
90 case ast::Diagnostic::Severity::DK_Warning:
93 case ast::Diagnostic::Severity::DK_Error:
96 case ast::Diagnostic::Severity::DK_Remark:
103 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
105 relatedDiags.emplace_back(
107 note.getMessage().str());
109 if (!relatedDiags.empty())
116 static std::optional<std::string>
132 struct PDLIndexSymbol {
133 explicit PDLIndexSymbol(
const ast::Decl *definition)
134 : definition(definition) {}
136 : definition(definition) {}
139 SMRange getDefLoc()
const {
141 llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
143 return declName ? declName->
getLoc() : decl->getLoc();
145 return cast<const ods::Operation *>(definition)->getLoc();
151 std::vector<SMRange> references;
158 PDLIndex() : intervalMap(allocator) {}
166 const PDLIndexSymbol *lookup(SMLoc loc,
167 SMRange *overlappedRange =
nullptr)
const;
173 llvm::IntervalMap<
const char *,
const PDLIndexSymbol *,
174 llvm::IntervalMapImpl::NodeSizer<
175 const char *,
const PDLIndexSymbol *>::LeafSize,
176 llvm::IntervalMapHalfOpenInfo<const char *>>;
179 MapT::Allocator allocator;
190 void PDLIndex::initialize(
const ast::Module &module,
192 auto getOrInsertDef = [&](
const auto *def) -> PDLIndexSymbol * {
193 auto it = defToSymbol.try_emplace(def,
nullptr);
195 it.first->second = std::make_unique<PDLIndexSymbol>(def);
196 return &*it.first->second;
198 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
199 bool isDef =
false) {
200 const char *startLoc = refLoc.Start.getPointer();
201 const char *endLoc = refLoc.End.getPointer();
202 if (!intervalMap.overlaps(startLoc, endLoc)) {
203 intervalMap.insert(startLoc, endLoc, sym);
205 sym->references.push_back(refLoc);
208 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
213 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
214 insertDeclRef(symbol, odsOp->
getLoc(),
true);
215 insertDeclRef(symbol, refLoc);
220 if (
const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
221 if (std::optional<StringRef> name = decl->getName())
222 insertODSOpRef(*name, decl->getLoc());
223 }
else if (
const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
227 PDLIndexSymbol *declSym = getOrInsertDef(decl);
228 insertDeclRef(declSym, name->
getLoc(),
true);
230 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
232 for (
const auto &it : varDecl->getConstraints())
233 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
235 }
else if (
const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
236 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
241 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
242 SMRange *overlappedRange)
const {
243 auto it = intervalMap.find(loc.getPointer());
244 if (!it.valid() || loc.getPointer() < it.start())
247 if (overlappedRange) {
248 *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
249 SMLoc::getFromPointer(it.stop()));
263 const std::vector<std::string> &extraDirs,
264 std::vector<lsp::Diagnostic> &diagnostics);
265 PDLDocument(
const PDLDocument &) =
delete;
266 PDLDocument &operator=(
const PDLDocument &) =
delete;
273 std::vector<lsp::Location> &locations);
275 std::vector<lsp::Location> &references);
282 std::vector<lsp::DocumentLink> &links);
290 std::optional<lsp::Hover> findHover(
const ast::Decl *decl,
291 const SMRange &hoverRange);
293 const SMRange &hoverRange);
295 const SMRange &hoverRange);
297 const SMRange &hoverRange);
299 const SMRange &hoverRange);
300 template <
typename T>
301 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
303 const SMRange &hoverRange);
309 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
330 std::vector<lsp::InlayHint> &inlayHints);
333 std::vector<lsp::InlayHint> &inlayHints);
335 std::vector<lsp::InlayHint> &inlayHints);
338 std::vector<lsp::InlayHint> &inlayHints);
341 void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
355 std::vector<std::string> includeDirs;
358 llvm::SourceMgr sourceMgr;
365 FailureOr<ast::Module *> astModule;
375 PDLDocument::PDLDocument(
const lsp::URIForFile &uri, StringRef contents,
376 const std::vector<std::string> &extraDirs,
377 std::vector<lsp::Diagnostic> &diagnostics)
378 : astContext(odsContext) {
379 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.
file());
381 lsp::Logger::error(
"Failed to create memory buffer for file", uri.
file());
387 llvm::sys::path::remove_filename(uriDirectory);
388 includeDirs.push_back(uriDirectory.str().str());
389 llvm::append_range(includeDirs, extraDirs);
391 sourceMgr.setIncludeDirs(includeDirs);
392 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
396 diagnostics.push_back(std::move(*lspDiag));
404 if (failed(astModule))
408 index.initialize(**astModule, odsContext);
417 std::vector<lsp::Location> &locations) {
419 const PDLIndexSymbol *symbol = index.lookup(posLoc);
428 std::vector<lsp::Location> &references) {
430 const PDLIndexSymbol *symbol = index.lookup(posLoc);
435 for (SMRange refLoc : symbol->references)
444 std::vector<lsp::DocumentLink> &links) {
446 links.emplace_back(include.range, include.uri);
453 std::optional<lsp::Hover>
456 SMLoc posLoc = hoverPos.
getAsSMLoc(sourceMgr);
460 if (include.range.contains(hoverPos))
461 return include.buildHover();
465 const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
471 llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
472 return buildHoverForOpName(op, hoverRange);
473 const auto *decl = cast<const ast::Decl *>(symbol->definition);
474 return findHover(decl, hoverRange);
477 std::optional<lsp::Hover> PDLDocument::findHover(
const ast::Decl *decl,
478 const SMRange &hoverRange) {
480 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
481 return buildHoverForVariable(varDecl, hoverRange);
484 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
485 return buildHoverForPattern(patternDecl, hoverRange);
488 if (
const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
489 return buildHoverForCoreConstraint(cst, hoverRange);
492 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
493 return buildHoverForUserConstraintOrRewrite(
"Constraint", cst, hoverRange);
496 if (
const auto *
rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
497 return buildHoverForUserConstraintOrRewrite(
"Rewrite",
rewrite, hoverRange);
503 const SMRange &hoverRange) {
506 llvm::raw_string_ostream hoverOS(hover.contents.value);
507 hoverOS <<
"**OpName**: `" << op->
getName() <<
"`\n***\n"
515 const SMRange &hoverRange) {
518 llvm::raw_string_ostream hoverOS(hover.contents.value);
519 hoverOS <<
"**Variable**: `" << varDecl->
getName().
getName() <<
"`\n***\n"
520 <<
"Type: `" << varDecl->
getType() <<
"`\n";
526 const SMRange &hoverRange) {
529 llvm::raw_string_ostream hoverOS(hover.contents.value);
530 hoverOS <<
"**Pattern**";
531 if (
const ast::Name *name = decl->getName())
532 hoverOS <<
": `" << name->
getName() <<
"`";
533 hoverOS <<
"\n***\n";
534 if (std::optional<uint16_t> benefit = decl->
getBenefit())
535 hoverOS <<
"Benefit: " << *benefit <<
"\n";
537 hoverOS <<
"HasBoundedRewriteRecursion\n";
538 hoverOS <<
"RootOp: `"
543 hoverOS <<
"\n" << *doc <<
"\n";
550 const SMRange &hoverRange) {
553 llvm::raw_string_ostream hoverOS(hover.contents.value);
554 hoverOS <<
"**Constraint**: `";
559 if (std::optional<StringRef> name = opCst->
getName())
560 hoverOS <<
"<" << *name <<
">";
564 hoverOS <<
"TypeRange";
568 hoverOS <<
"ValueRange";
575 template <
typename T>
576 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
577 StringRef typeName,
const T *decl,
const SMRange &hoverRange) {
580 llvm::raw_string_ostream hoverOS(hover.contents.value);
581 hoverOS <<
"**" << typeName <<
"**: `" << decl->getName().getName()
584 if (!inputs.empty()) {
585 hoverOS <<
"Parameters:\n";
587 hoverOS <<
"* " << input->getName().getName() <<
": `"
588 << input->getType() <<
"`\n";
591 ast::Type resultType = decl->getResultType();
592 if (
auto resultTupleTy = dyn_cast<ast::TupleType>(resultType)) {
593 if (!resultTupleTy.empty()) {
594 hoverOS <<
"Results:\n";
595 for (
auto it : llvm::zip(resultTupleTy.getElementNames(),
596 resultTupleTy.getElementTypes())) {
597 StringRef name = std::get<0>(it);
598 hoverOS <<
"* " << (name.empty() ?
"" : (name +
": ")) <<
"`"
599 << std::get<1>(it) <<
"`\n";
604 hoverOS <<
"Results:\n* `" << resultType <<
"`\n";
610 hoverOS <<
"\n" << *doc <<
"\n";
619 void PDLDocument::findDocumentSymbols(
620 std::vector<lsp::DocumentSymbol> &symbols) {
621 if (failed(astModule))
624 for (
const ast::Decl *decl : (*astModule)->getChildren()) {
628 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
631 SMRange nameLoc = name ? name->
getLoc() : patternDecl->getLoc();
632 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
634 symbols.emplace_back(
635 name ? name->
getName() :
"<pattern>", lsp::SymbolKind::Class,
637 }
else if (
const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
639 SMRange nameLoc = cDecl->getName().getLoc();
640 SMRange bodyLoc = nameLoc;
642 symbols.emplace_back(
643 cDecl->getName().getName(), lsp::SymbolKind::Function,
645 }
else if (
const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
647 SMRange nameLoc = cDecl->getName().getLoc();
648 SMRange bodyLoc = nameLoc;
650 symbols.emplace_back(
651 cDecl->getName().getName(), lsp::SymbolKind::Function,
664 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
669 completionList(completionList), odsContext(odsContext),
670 includeDirs(includeDirs) {}
675 for (
unsigned i = 0, e = tupleType.size(); i < e; ++i) {
678 item.
label = llvm::formatv(
"{0} (field #{0})", i).str();
682 item.
detail = llvm::formatv(
"{0}: {1}", i, elementTypes[i]);
684 completionList.items.emplace_back(item);
687 if (!elementNames[i].empty()) {
689 llvm::formatv(
"{1} (field #{0})", i, elementNames[i]).str();
692 completionList.items.emplace_back(item);
709 item.
label = llvm::formatv(
"{0} (field #{0})", it.index()).str();
715 item.
detail = llvm::formatv(
"{0}: Value", it.index()).str();
718 item.
detail = llvm::formatv(
"{0}: Value?", it.index()).str();
721 item.
detail = llvm::formatv(
"{0}: ValueRange", it.index()).str();
726 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
730 completionList.items.emplace_back(item);
734 if (!result.
getName().empty()) {
736 llvm::formatv(
"{1} (field #{0})", it.index(), result.
getName())
740 completionList.items.emplace_back(item);
754 item.
label = attr.getName().str();
756 item.
detail = attr.isOptional() ?
"optional" :
"";
759 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
763 completionList.items.emplace_back(item);
768 bool allowInlineTypeConstraints,
770 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
771 StringRef snippetText =
"") {
773 item.
label = constraint.str();
775 item.
detail = (constraint +
" constraint").str();
778 (
"A single entity core constraint of type `" + mlirType +
"`").str()};
784 completionList.items.emplace_back(item);
791 addCoreConstraint(
"Attr",
"mlir::Attribute");
792 addCoreConstraint(
"Op",
"mlir::Operation *");
793 addCoreConstraint(
"Value",
"mlir::Value");
794 addCoreConstraint(
"ValueRange",
"mlir::ValueRange");
795 addCoreConstraint(
"Type",
"mlir::Type");
796 addCoreConstraint(
"TypeRange",
"mlir::TypeRange");
798 if (allowInlineTypeConstraints) {
800 if (!currentType || isa<ast::AttributeType>(currentType))
801 addCoreConstraint(
"Attr<type>",
"mlir::Attribute",
"Attr<$1>");
803 if (!currentType || isa<ast::ValueType>(currentType))
804 addCoreConstraint(
"Value<type>",
"mlir::Value",
"Value<$1>");
806 if (!currentType || isa<ast::ValueRangeType>(currentType))
807 addCoreConstraint(
"ValueRange<type>",
"mlir::ValueRange",
813 for (
const ast::Decl *decl : scope->getDecls()) {
814 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
816 item.
label = cst->getName().getName().str();
822 if (cst->getInputs().size() != 1)
826 ast::Type constraintType = cst->getInputs()[0]->getType();
827 if (currentType && !currentType.refineWith(constraintType))
832 llvm::raw_string_ostream strOS(item.
detail);
834 llvm::interleaveComma(
836 strOS << var->getName().getName() <<
": " << var->getType();
838 strOS <<
") -> " << cst->getResultType();
842 if (std::optional<std::string> doc =
848 completionList.items.emplace_back(item);
852 scope = scope->getParentScope();
860 item.
label = dialect.getName().str();
863 completionList.items.emplace_back(item);
876 item.
label = op.
getName().drop_front(dialectName.size() + 1).str();
879 completionList.items.emplace_back(item);
884 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
885 StringRef snippetText =
"") {
887 item.
label = constraint.str();
889 item.
detail =
"pattern metadata";
896 completionList.items.emplace_back(item);
899 addSimpleConstraint(
"benefit",
"The `benefit` of matching the pattern.",
901 addSimpleConstraint(
"recursion",
902 "The pattern properly handles recursive application.");
909 llvm::sys::path::native(nativeRelDir);
915 auto addIncludeCompletion = [&](StringRef path,
bool isDirectory) {
917 item.
label = path.str();
920 if (seenResults.insert(item.
label).second)
921 completionList.items.emplace_back(item);
926 for (StringRef includeDir : includeDirs) {
928 if (!nativeRelDir.empty())
929 llvm::sys::path::append(dir, nativeRelDir);
931 std::error_code errorCode;
932 for (
auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
933 e = llvm::sys::fs::directory_iterator();
934 !errorCode && it != e; it.increment(errorCode)) {
935 StringRef filename = llvm::sys::path::filename(it->path());
940 llvm::sys::fs::file_type fileType = it->type();
941 if (fileType == llvm::sys::fs::file_type::symlink_file) {
942 if (
auto fileStatus = it->status())
943 fileType = fileStatus->type();
947 case llvm::sys::fs::file_type::directory_file:
948 addIncludeCompletion(filename,
true);
950 case llvm::sys::fs::file_type::regular_file: {
952 if (filename.ends_with(
".pdll") || filename.ends_with(
".td"))
953 addIncludeCompletion(filename,
false);
966 return lhs.label < rhs.label;
971 llvm::SourceMgr &sourceMgr;
981 SMLoc posLoc = completePos.
getAsSMLoc(sourceMgr);
982 if (!posLoc.isValid())
989 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
991 sourceMgr.getIncludeDirs());
995 &lspCompleteContext);
997 return completionList;
1007 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1011 signatureHelp(signatureHelp), odsContext(odsContext) {}
1014 unsigned currentNumArgs)
final {
1015 signatureHelp.activeParameter = currentNumArgs;
1019 llvm::raw_string_ostream strOS(signatureInfo.
label);
1020 strOS << callable->getName()->getName() <<
"(";
1022 unsigned paramStart = strOS.str().size();
1023 strOS << var->getName().getName() <<
": " << var->getType();
1024 unsigned paramEnd = strOS.str().size();
1026 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1027 std::make_pair(paramStart, paramEnd), std::string()});
1029 llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1030 strOS <<
") -> " << callable->getResultType();
1034 if (std::optional<std::string> doc =
1038 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1043 unsigned currentNumOperands)
final {
1046 codeCompleteOperationOperandOrResultSignature(
1049 currentNumOperands,
"operand",
"Value");
1053 unsigned currentNumResults)
final {
1056 codeCompleteOperationOperandOrResultSignature(
1059 currentNumResults,
"result",
"Type");
1062 void codeCompleteOperationOperandOrResultSignature(
1065 StringRef label, StringRef dataType) {
1066 signatureHelp.activeParameter = currentValue;
1072 if (odsOp && currentValue < values.size()) {
1077 llvm::raw_string_ostream strOS(signatureInfo.
label);
1080 unsigned paramStart = strOS.str().size();
1082 strOS << value.getName() <<
": ";
1084 StringRef constraintDoc = value.getConstraint().getSummary();
1085 std::string paramDoc;
1086 switch (value.getVariableLengthKind()) {
1089 paramDoc = constraintDoc.str();
1092 strOS << dataType <<
"?";
1093 paramDoc = (
"optional: " + constraintDoc).str();
1096 strOS << dataType <<
"Range";
1097 paramDoc = (
"variadic: " + constraintDoc).str();
1101 unsigned paramEnd = strOS.str().size();
1103 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1104 std::make_pair(paramStart, paramEnd), paramDoc});
1106 llvm::interleaveComma(values, strOS, formatFn);
1110 llvm::formatv(
"`op<{0}>` ODS {1} specification", *opName, label)
1112 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1116 if (currentValue == 0 && (!odsOp || !values.empty())) {
1118 signatureInfo.
label =
1119 llvm::formatv(
"(<{0}s>: {1}Range)", label, dataType).str();
1121 (
"Generic operation " + label +
" specification").str();
1123 StringRef(signatureInfo.
label).drop_front().drop_back().str(),
1124 std::pair<unsigned, unsigned>(1, signatureInfo.
label.size() - 1),
1125 (
"All of the " + label +
"s of the operation.").str()});
1126 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1131 llvm::SourceMgr &sourceMgr;
1139 SMLoc posLoc = helpPos.
getAsSMLoc(sourceMgr);
1140 if (!posLoc.isValid())
1147 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1154 return signatureHelp;
1167 if (
auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1169 if (declName && declName->
getName() == name)
1178 std::vector<lsp::InlayHint> &inlayHints) {
1179 if (failed(astModule))
1182 if (!rangeLoc.isValid())
1184 (*astModule)->walk([&](
const ast::Node *node) {
1185 SMRange loc = node->
getLoc();
1195 [&](
const auto *node) {
1196 this->getInlayHintsFor(node, uri, inlayHints);
1203 std::vector<lsp::InlayHint> &inlayHints) {
1214 if (isa<ast::OperationExpr>(expr))
1221 llvm::raw_string_ostream labelOS(hint.label);
1222 labelOS <<
": " << decl->
getType();
1225 inlayHints.emplace_back(std::move(hint));
1228 void PDLDocument::getInlayHintsFor(
const ast::CallExpr *expr,
1230 std::vector<lsp::InlayHint> &inlayHints) {
1232 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
1233 const auto *callable =
1234 callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1240 for (
const auto &it : llvm::zip(expr->
getArguments(), callable->getInputs()))
1241 addParameterHintFor(inlayHints, std::get<0>(it),
1242 std::get<1>(it)->getName().getName());
1247 std::vector<lsp::InlayHint> &inlayHints) {
1252 auto addOpHint = [&](
const ast::Expr *valueExpr, StringRef label) {
1255 if (expr->getLoc().Start == valueExpr->
getLoc().Start)
1257 addParameterHintFor(inlayHints, valueExpr, label);
1265 StringRef allValuesName) {
1271 if (values.size() != odsValues.size()) {
1273 if (values.size() == 1)
1274 return addOpHint(values.front(), allValuesName);
1278 for (
const auto &it : llvm::zip(values, odsValues))
1279 addOpHint(std::get<0>(it), std::get<1>(it).getName());
1293 void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1294 const ast::Expr *expr, StringRef label) {
1300 hint.label = (label +
":").str();
1301 hint.paddingRight =
true;
1302 inlayHints.emplace_back(std::move(hint));
1309 void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1311 if (failed(astModule))
1313 if (
kind == lsp::PDLLViewOutputKind::AST) {
1314 (*astModule)->print(os);
1327 if (
kind == lsp::PDLLViewOutputKind::MLIR) {
1333 assert(
kind == lsp::PDLLViewOutputKind::CPP &&
1334 "unexpected PDLLViewOutputKind");
1344 struct PDLTextFileChunk {
1347 const std::vector<std::string> &extraDirs,
1348 std::vector<lsp::Diagnostic> &diagnostics)
1349 : lineOffset(lineOffset),
1350 document(uri, contents, extraDirs, diagnostics) {}
1354 void adjustLocForChunkOffset(
lsp::Range &range) {
1355 adjustLocForChunkOffset(range.
start);
1356 adjustLocForChunkOffset(range.
end);
1363 uint64_t lineOffset;
1365 PDLDocument document;
1378 int64_t version,
const std::vector<std::string> &extraDirs,
1379 std::vector<lsp::Diagnostic> &diagnostics);
1382 int64_t getVersion()
const {
return version; }
1388 std::vector<lsp::Diagnostic> &diagnostics);
1395 std::vector<lsp::Location> &locations);
1397 std::vector<lsp::Location> &references);
1399 std::vector<lsp::DocumentLink> &links);
1402 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1408 std::vector<lsp::InlayHint> &inlayHints);
1412 using ChunkIterator = llvm::pointee_iterator<
1413 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1417 std::vector<lsp::Diagnostic> &diagnostics);
1424 return *getChunkItFor(pos);
1428 std::string contents;
1431 int64_t version = 0;
1434 int64_t totalNumLines = 0;
1438 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1441 std::vector<std::string> extraIncludeDirs;
1445 PDLTextFile::PDLTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1447 const std::vector<std::string> &extraDirs,
1448 std::vector<lsp::Diagnostic> &diagnostics)
1449 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1450 initialize(uri, version, diagnostics);
1456 std::vector<lsp::Diagnostic> &diagnostics) {
1457 if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1458 lsp::Logger::error(
"Failed to update contents of {0}", uri.
file());
1463 initialize(uri, newVersion, diagnostics);
1469 std::vector<lsp::Location> &locations) {
1470 PDLTextFileChunk &chunk = getChunkFor(defPos);
1471 chunk.document.getLocationsOf(uri, defPos, locations);
1474 if (chunk.lineOffset == 0)
1478 chunk.adjustLocForChunkOffset(loc.range);
1483 std::vector<lsp::Location> &references) {
1484 PDLTextFileChunk &chunk = getChunkFor(pos);
1485 chunk.document.findReferencesOf(uri, pos, references);
1488 if (chunk.lineOffset == 0)
1492 chunk.adjustLocForChunkOffset(loc.range);
1496 std::vector<lsp::DocumentLink> &links) {
1497 chunks.front()->document.getDocumentLinks(uri, links);
1498 for (
const auto &it : llvm::drop_begin(chunks)) {
1499 size_t currentNumLinks = links.size();
1500 it->document.getDocumentLinks(uri, links);
1504 for (
auto &link : llvm::drop_begin(links, currentNumLinks))
1505 it->adjustLocForChunkOffset(link.range);
1509 std::optional<lsp::Hover> PDLTextFile::findHover(
const lsp::URIForFile &uri,
1511 PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1512 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1515 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1516 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1520 void PDLTextFile::findDocumentSymbols(
1521 std::vector<lsp::DocumentSymbol> &symbols) {
1522 if (chunks.size() == 1)
1523 return chunks.front()->document.findDocumentSymbols(symbols);
1527 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1528 PDLTextFileChunk &chunk = *chunks[i];
1531 : chunks[i + 1]->lineOffset);
1533 lsp::SymbolKind::Namespace,
1536 chunk.document.findDocumentSymbols(symbol.children);
1542 symbolsToFix.push_back(&childSymbol);
1544 while (!symbolsToFix.empty()) {
1546 chunk.adjustLocForChunkOffset(symbol->
range);
1550 symbolsToFix.push_back(&childSymbol);
1555 symbols.emplace_back(std::move(symbol));
1561 PDLTextFileChunk &chunk = getChunkFor(completePos);
1563 chunk.document.getCodeCompletion(uri, completePos);
1568 chunk.adjustLocForChunkOffset(item.
textEdit->range);
1570 chunk.adjustLocForChunkOffset(edit.
range);
1572 return completionList;
1577 return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1581 std::vector<lsp::InlayHint> &inlayHints) {
1582 auto startIt = getChunkItFor(range.
start);
1583 auto endIt = getChunkItFor(range.
end);
1586 auto getHintsForChunk = [&](ChunkIterator chunkIt,
lsp::Range range) {
1587 size_t currentNumHints = inlayHints.size();
1588 chunkIt->document.getInlayHints(uri, range, inlayHints);
1592 if (&*chunkIt != &*chunks.front()) {
1593 for (
auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1594 chunkIt->adjustLocForChunkOffset(hint.position);
1598 auto getNumLines = [](ChunkIterator chunkIt) {
1599 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1603 if (startIt == endIt)
1604 return getHintsForChunk(startIt, range);
1608 getHintsForChunk(startIt,
lsp::Range(range.
start, getNumLines(startIt)));
1611 for (++startIt; startIt != endIt; ++startIt)
1612 getHintsForChunk(startIt,
lsp::Range(0, getNumLines(startIt)));
1623 llvm::raw_string_ostream outputOS(result.
output);
1625 llvm::make_pointee_range(chunks),
1626 [&](PDLTextFileChunk &chunk) {
1627 chunk.document.getPDLLViewOutput(outputOS,
kind);
1629 [&] { outputOS <<
"\n"
1635 void PDLTextFile::initialize(
const lsp::URIForFile &uri, int64_t newVersion,
1636 std::vector<lsp::Diagnostic> &diagnostics) {
1637 version = newVersion;
1643 chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1644 0, uri, subContents.front(), extraIncludeDirs,
1647 uint64_t lineOffset = subContents.front().count(
'\n');
1648 for (StringRef docContents : llvm::drop_begin(subContents)) {
1649 unsigned currentNumDiags = diagnostics.size();
1650 auto chunk = std::make_unique<PDLTextFileChunk>(
1651 lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1652 lineOffset += docContents.count(
'\n');
1657 llvm::drop_begin(diagnostics, currentNumDiags)) {
1658 chunk->adjustLocForChunkOffset(
diag.range);
1660 if (!
diag.relatedInformation)
1662 for (
auto &it : *
diag.relatedInformation)
1663 if (it.location.uri == uri)
1664 chunk->adjustLocForChunkOffset(it.location.range);
1666 chunks.emplace_back(std::move(chunk));
1668 totalNumLines = lineOffset;
1671 PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(
lsp::Position &pos) {
1672 if (chunks.size() == 1)
1673 return chunks.begin();
1677 auto it = llvm::upper_bound(
1678 chunks, pos, [](
const lsp::Position &pos,
const auto &chunk) {
1679 return static_cast<uint64_t
>(pos.
line) < chunk->lineOffset;
1681 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1682 pos.
line -= chunkIt->lineOffset;
1702 llvm::StringMap<std::unique_ptr<PDLTextFile>>
files;
1715 std::vector<Diagnostic> &diagnostics) {
1717 std::vector<std::string> additionalIncludeDirs =
impl->options.extraDirs;
1718 const auto &fileInfo =
impl->compilationDatabase.getFileInfo(uri.
file());
1719 llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1721 impl->files[uri.
file()] = std::make_unique<PDLTextFile>(
1722 uri, contents, version, additionalIncludeDirs, diagnostics);
1727 int64_t version, std::vector<Diagnostic> &diagnostics) {
1729 auto it =
impl->files.find(uri.
file());
1730 if (it ==
impl->files.end())
1735 if (failed(it->second->update(uri, version, changes, diagnostics)))
1736 impl->files.erase(it);
1740 auto it =
impl->files.find(uri.
file());
1741 if (it ==
impl->files.end())
1742 return std::nullopt;
1744 int64_t version = it->second->getVersion();
1745 impl->files.erase(it);
1751 std::vector<Location> &locations) {
1752 auto fileIt =
impl->files.find(uri.
file());
1753 if (fileIt !=
impl->files.end())
1754 fileIt->second->getLocationsOf(uri, defPos, locations);
1759 std::vector<Location> &references) {
1760 auto fileIt =
impl->files.find(uri.
file());
1761 if (fileIt !=
impl->files.end())
1762 fileIt->second->findReferencesOf(uri, pos, references);
1766 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1767 auto fileIt =
impl->files.find(uri.
file());
1768 if (fileIt !=
impl->files.end())
1769 return fileIt->second->getDocumentLinks(uri, documentLinks);
1774 auto fileIt =
impl->files.find(uri.
file());
1775 if (fileIt !=
impl->files.end())
1776 return fileIt->second->findHover(uri, hoverPos);
1777 return std::nullopt;
1781 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1782 auto fileIt =
impl->files.find(uri.
file());
1783 if (fileIt !=
impl->files.end())
1784 fileIt->second->findDocumentSymbols(symbols);
1790 auto fileIt =
impl->files.find(uri.
file());
1791 if (fileIt !=
impl->files.end())
1792 return fileIt->second->getCodeCompletion(uri, completePos);
1798 auto fileIt =
impl->files.find(uri.
file());
1799 if (fileIt !=
impl->files.end())
1800 return fileIt->second->getSignatureHelp(uri, helpPos);
1805 std::vector<InlayHint> &inlayHints) {
1806 auto fileIt =
impl->files.find(uri.
file());
1807 if (fileIt ==
impl->files.end())
1809 fileIt->second->getInlayHints(uri, range, inlayHints);
1812 llvm::sort(inlayHints);
1813 inlayHints.erase(llvm::unique(inlayHints), inlayHints.end());
1816 std::optional<lsp::PDLLViewOutputResult>
1819 auto fileIt =
impl->files.find(uri.
file());
1820 if (fileIt !=
impl->files.end())
1821 return fileIt->second->getPDLLViewOutput(
kind);
1822 return std::nullopt;
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
union mlir::linalg::@1223::ArityGroupAndKind::Kind kind
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static std::string diag(const llvm::Value &value)
static std::optional< std::string > getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl)
Get or extract the documentation for the given decl.
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, const lsp::URIForFile &uri)
Returns a language server location from the given source range.
static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, const lsp::URIForFile &mainFileURI)
Returns a language server uri for the given source location.
static bool shouldAddHintFor(const ast::Expr *expr, StringRef name)
Returns true if the given name should be added as a hint for expr.
static std::optional< lsp::Diagnostic > getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc)
Returns true if the given location is in the main file of the source manager.
static llvm::ManagedStatic< PassManagerOptions > options
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
MLIRContext is the top-level object for a collection of MLIR operations.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class is a utility diagnostic handler for use with llvm::SourceMgr.
This class contains a collection of compilation information for files provided to the language server...
static void error(const char *fmt, Ts &&...vals)
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
void getDocumentLinks(const URIForFile &uri, std::vector< DocumentLink > &documentLinks)
Return the document links referenced by the given file.
void getInlayHints(const URIForFile &uri, const Range &range, std::vector< InlayHint > &inlayHints)
Get the inlay hints for the range within the given file.
void addDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add the document, with the provided version, at the given URI.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
SignatureHelp getSignatureHelp(const URIForFile &uri, const Position &helpPos)
Get the signature help for the position within the given file.
void updateDocument(const URIForFile &uri, ArrayRef< TextDocumentContentChangeEvent > changes, int64_t version, std::vector< Diagnostic > &diagnostics)
Update the document, with the provided version, at the given URI.
std::optional< PDLLViewOutputResult > getPDLLViewOutput(const URIForFile &uri, PDLLViewOutputKind kind)
Get the output of the given PDLL file, or std::nullopt if there is no valid output.
URI in "file" scheme for a file.
static llvm::Expected< URIForFile > fromFile(StringRef absoluteFilepath, StringRef scheme="file")
Try to build a URIForFile from the given absolute file path and optional scheme.
StringRef file() const
Returns the absolute path to the file.
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
This decl represents a shared interface for all callable decls.
This class represents the main context of the PDLL AST.
This class represents the base of all "core" constraints.
This class represents a scope for named AST decls.
This class represents the base Decl node.
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
This class provides a simple implementation of a PDLL diagnostic.
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a top-level AST module.
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
The class represents an Operation constraint, and constrains a variable to be an Operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
This expression represents the structural form of an MLIR Operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
This Decl represents a single Pattern.
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
This class represents a PDLL tuple type, i.e.
The class represents a Type constraint, and constrains a variable to be a Type.
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
The class represents a Value constraint, and constrains a variable to be a Value.
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
This Decl represents the definition of a PDLL variable.
const Name & getName() const
Return the name of the decl.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
This class represents a generic ODS Attribute constraint.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
This class provides an ODS representation of a specific operation attribute.
StringRef getSummary() const
Return the summary of this constraint.
This class contains all of the registered ODS operation classes.
auto getDialects() const
Return a range of all of the registered dialects.
const Dialect * lookupDialect(StringRef name) const
Lookup a dialect registered with the given name, or null if no dialect with that name was inserted.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
This class represents an ODS dialect, and contains information on the constructs held within the dial...
const llvm::StringMap< std::unique_ptr< Operation > > & getOperations() const
Return a map of all of the operations registered to this dialect.
This class provides an ODS representation of a specific operation operand or result.
const TypeConstraint & getConstraint() const
Return the constraint of this value.
VariableLengthKind getVariableLengthKind() const
Returns the variable length kind of this value.
StringRef getName() const
Return the name of this value.
This class provides an ODS representation of a specific operation.
StringRef getDescription() const
Returns the description of the operation.
StringRef getSummary() const
Returns the summary of the operation.
ArrayRef< Attribute > getAttributes() const
Returns the attributes of this operation.
StringRef getName() const
Returns the name of the operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
PDLLViewOutputKind
The type of output to view from PDLL.
void gatherIncludeFiles(llvm::SourceMgr &sourceMgr, SmallVectorImpl< SourceMgrInclude > &includes)
Given a source manager, gather all of the processed include files.
std::optional< std::string > extractSourceDocComment(llvm::SourceMgr &sourceMgr, SMLoc loc)
Extract a documentation comment for the given location within the source manager.
@ PlainText
The primary text to be inserted is treated as a plain string.
@ Snippet
The primary text to be inserted is treated as a snippet.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
const char *const kDefaultSplitMarker
Impl(const Options &options)
lsp::CompilationDatabase compilationDatabase
The compilation database containing additional information for files passed to the server.
llvm::StringMap< std::unique_ptr< PDLTextFile > > files
The files held by the server, mapped by their URI file name.
const Options & options
PDLL LSP options.
std::string detail
A human-readable string with additional information about this item, like type or symbol information.
std::string filterText
A string that should be used when filtering a set of completion items.
std::optional< TextEdit > textEdit
An edit which is applied to a document when selecting this completion.
std::optional< MarkupContent > documentation
A human-readable string that represents a doc-comment.
std::string insertText
A string that should be inserted to a document when selecting this completion.
std::string label
The label of this completion item.
CompletionItemKind kind
The kind of this completion item.
std::vector< TextEdit > additionalTextEdits
An optional array of additional text edits that are applied when selecting this completion.
InsertTextFormat insertTextFormat
The format of the insert text.
std::string sortText
A string that should be used when comparing this item with other items.
Represents a collection of completion items to be presented in the editor.
std::vector< CompletionItem > items
The completion items.
std::string source
A human-readable string describing the source of this diagnostic, e.g.
DiagnosticSeverity severity
The diagnostic's severity.
Range range
The source range where the message applies.
std::optional< std::vector< DiagnosticRelatedInformation > > relatedInformation
An array of related diagnostic information, e.g.
std::string message
The diagnostic's message.
std::optional< std::string > category
The diagnostic's category.
Represents programming constructs like variables, classes, interfaces etc.
Range range
The range enclosing this symbol not including leading/trailing whitespace but everything else like co...
Range selectionRange
The range that should be selected and revealed when this symbol is being picked, e....
std::vector< DocumentSymbol > children
Children of this symbol, e.g. properties of a class.
URIForFile uri
The text document's URI.
Represents the result of viewing the output of a PDLL file.
std::string output
The string representation of the output.
int line
Line position in a document (zero-based).
SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const
Convert this position into a source location in the main file of the given source manager.
Position end
The range's end position.
Position start
The range's start position.
SMRange getAsSMRange(llvm::SourceMgr &mgr) const
Convert this range into a source range in the main file of the given source manager.
Represents the signature of a callable.
This class represents a single include within a root file.
Range range
The range of the text document to be manipulated.
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.