28 #include "llvm/ADT/IntervalMap.h"
29 #include "llvm/ADT/StringMap.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/FileSystem.h"
33 #include "llvm/Support/Path.h"
43 int bufferId = mgr.FindBufferContainingLoc(loc.Start);
44 if (bufferId == 0 || bufferId ==
static_cast<int>(mgr.getMainFileID()))
47 mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
58 return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
68 static std::optional<lsp::Diagnostic>
87 switch (
diag.getSeverity()) {
88 case ast::Diagnostic::Severity::DK_Note:
89 llvm_unreachable(
"expected notes to be handled separately");
90 case ast::Diagnostic::Severity::DK_Warning:
93 case ast::Diagnostic::Severity::DK_Error:
96 case ast::Diagnostic::Severity::DK_Remark:
103 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
105 relatedDiags.emplace_back(
107 note.getMessage().str());
109 if (!relatedDiags.empty())
116 static std::optional<std::string>
132 struct PDLIndexSymbol {
133 explicit PDLIndexSymbol(
const ast::Decl *definition)
134 : definition(definition) {}
136 : definition(definition) {}
139 SMRange getDefLoc()
const {
140 if (
const ast::Decl *decl = llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
142 return declName ? declName->
getLoc() : decl->getLoc();
150 std::vector<SMRange> references;
157 PDLIndex() : intervalMap(allocator) {}
165 const PDLIndexSymbol *lookup(SMLoc loc,
166 SMRange *overlappedRange =
nullptr)
const;
172 llvm::IntervalMap<
const char *,
const PDLIndexSymbol *,
173 llvm::IntervalMapImpl::NodeSizer<
174 const char *,
const PDLIndexSymbol *>::LeafSize,
175 llvm::IntervalMapHalfOpenInfo<const char *>>;
178 MapT::Allocator allocator;
189 void PDLIndex::initialize(
const ast::Module &module,
191 auto getOrInsertDef = [&](
const auto *def) -> PDLIndexSymbol * {
192 auto it = defToSymbol.try_emplace(def,
nullptr);
194 it.first->second = std::make_unique<PDLIndexSymbol>(def);
195 return &*it.first->second;
197 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
198 bool isDef =
false) {
199 const char *startLoc = refLoc.Start.getPointer();
200 const char *endLoc = refLoc.End.getPointer();
201 if (!intervalMap.overlaps(startLoc, endLoc)) {
202 intervalMap.insert(startLoc, endLoc, sym);
204 sym->references.push_back(refLoc);
207 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
212 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
213 insertDeclRef(symbol, odsOp->
getLoc(),
true);
214 insertDeclRef(symbol, refLoc);
219 if (
const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
220 if (std::optional<StringRef> name = decl->getName())
221 insertODSOpRef(*name, decl->getLoc());
222 }
else if (
const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
226 PDLIndexSymbol *declSym = getOrInsertDef(decl);
227 insertDeclRef(declSym, name->
getLoc(),
true);
229 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
231 for (
const auto &it : varDecl->getConstraints())
232 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
234 }
else if (
const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
235 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
240 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
241 SMRange *overlappedRange)
const {
242 auto it = intervalMap.find(loc.getPointer());
243 if (!it.valid() || loc.getPointer() < it.start())
246 if (overlappedRange) {
247 *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
248 SMLoc::getFromPointer(it.stop()));
262 const std::vector<std::string> &extraDirs,
263 std::vector<lsp::Diagnostic> &diagnostics);
264 PDLDocument(
const PDLDocument &) =
delete;
265 PDLDocument &operator=(
const PDLDocument &) =
delete;
272 std::vector<lsp::Location> &locations);
274 std::vector<lsp::Location> &references);
281 std::vector<lsp::DocumentLink> &links);
289 std::optional<lsp::Hover> findHover(
const ast::Decl *decl,
290 const SMRange &hoverRange);
292 const SMRange &hoverRange);
294 const SMRange &hoverRange);
296 const SMRange &hoverRange);
298 const SMRange &hoverRange);
299 template <
typename T>
300 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
302 const SMRange &hoverRange);
308 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
329 std::vector<lsp::InlayHint> &inlayHints);
332 std::vector<lsp::InlayHint> &inlayHints);
334 std::vector<lsp::InlayHint> &inlayHints);
337 std::vector<lsp::InlayHint> &inlayHints);
340 void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
354 std::vector<std::string> includeDirs;
357 llvm::SourceMgr sourceMgr;
374 PDLDocument::PDLDocument(
const lsp::URIForFile &uri, StringRef contents,
375 const std::vector<std::string> &extraDirs,
376 std::vector<lsp::Diagnostic> &diagnostics)
377 : astContext(odsContext) {
378 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.
file());
380 lsp::Logger::error(
"Failed to create memory buffer for file", uri.
file());
386 llvm::sys::path::remove_filename(uriDirectory);
387 includeDirs.push_back(uriDirectory.str().str());
388 includeDirs.insert(includeDirs.end(), extraDirs.begin(), extraDirs.end());
390 sourceMgr.setIncludeDirs(includeDirs);
391 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
395 diagnostics.push_back(std::move(*lspDiag));
407 index.initialize(**astModule, odsContext);
416 std::vector<lsp::Location> &locations) {
418 const PDLIndexSymbol *symbol = index.lookup(posLoc);
427 std::vector<lsp::Location> &references) {
429 const PDLIndexSymbol *symbol = index.lookup(posLoc);
434 for (SMRange refLoc : symbol->references)
443 std::vector<lsp::DocumentLink> &links) {
445 links.emplace_back(include.range, include.uri);
452 std::optional<lsp::Hover>
455 SMLoc posLoc = hoverPos.
getAsSMLoc(sourceMgr);
459 if (include.range.contains(hoverPos))
460 return include.buildHover();
464 const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
469 if (
const auto *op = llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
470 return buildHoverForOpName(op, hoverRange);
471 const auto *decl = symbol->definition.get<
const ast::Decl *>();
472 return findHover(decl, hoverRange);
475 std::optional<lsp::Hover> PDLDocument::findHover(
const ast::Decl *decl,
476 const SMRange &hoverRange) {
478 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
479 return buildHoverForVariable(varDecl, hoverRange);
482 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
483 return buildHoverForPattern(patternDecl, hoverRange);
486 if (
const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
487 return buildHoverForCoreConstraint(cst, hoverRange);
490 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
491 return buildHoverForUserConstraintOrRewrite(
"Constraint", cst, hoverRange);
494 if (
const auto *
rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
495 return buildHoverForUserConstraintOrRewrite(
"Rewrite",
rewrite, hoverRange);
501 const SMRange &hoverRange) {
504 llvm::raw_string_ostream hoverOS(hover.contents.value);
505 hoverOS <<
"**OpName**: `" << op->
getName() <<
"`\n***\n"
506 << op->getSummary() <<
"\n***\n"
507 << op->getDescription();
513 const SMRange &hoverRange) {
516 llvm::raw_string_ostream hoverOS(hover.contents.value);
517 hoverOS <<
"**Variable**: `" << varDecl->
getName().
getName() <<
"`\n***\n"
518 <<
"Type: `" << varDecl->
getType() <<
"`\n";
524 const SMRange &hoverRange) {
527 llvm::raw_string_ostream hoverOS(hover.contents.value);
528 hoverOS <<
"**Pattern**";
529 if (
const ast::Name *name = decl->getName())
530 hoverOS <<
": `" << name->
getName() <<
"`";
531 hoverOS <<
"\n***\n";
532 if (std::optional<uint16_t> benefit = decl->
getBenefit())
533 hoverOS <<
"Benefit: " << *benefit <<
"\n";
535 hoverOS <<
"HasBoundedRewriteRecursion\n";
536 hoverOS <<
"RootOp: `"
541 hoverOS <<
"\n" << *doc <<
"\n";
548 const SMRange &hoverRange) {
551 llvm::raw_string_ostream hoverOS(hover.contents.value);
552 hoverOS <<
"**Constraint**: `";
557 if (std::optional<StringRef> name = opCst->
getName())
558 hoverOS <<
"<" << *name <<
">";
562 hoverOS <<
"TypeRange";
566 hoverOS <<
"ValueRange";
573 template <
typename T>
574 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
575 StringRef typeName,
const T *decl,
const SMRange &hoverRange) {
578 llvm::raw_string_ostream hoverOS(hover.contents.value);
579 hoverOS <<
"**" << typeName <<
"**: `" << decl->getName().getName()
582 if (!inputs.empty()) {
583 hoverOS <<
"Parameters:\n";
585 hoverOS <<
"* " << input->getName().getName() <<
": `"
586 << input->getType() <<
"`\n";
589 ast::Type resultType = decl->getResultType();
591 if (!resultTupleTy.empty()) {
592 hoverOS <<
"Results:\n";
593 for (
auto it : llvm::zip(resultTupleTy.getElementNames(),
594 resultTupleTy.getElementTypes())) {
595 StringRef name = std::get<0>(it);
596 hoverOS <<
"* " << (name.empty() ?
"" : (name +
": ")) <<
"`"
597 << std::get<1>(it) <<
"`\n";
602 hoverOS <<
"Results:\n* `" << resultType <<
"`\n";
608 hoverOS <<
"\n" << *doc <<
"\n";
617 void PDLDocument::findDocumentSymbols(
618 std::vector<lsp::DocumentSymbol> &symbols) {
622 for (
const ast::Decl *decl : (*astModule)->getChildren()) {
626 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
629 SMRange nameLoc = name ? name->
getLoc() : patternDecl->getLoc();
630 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
632 symbols.emplace_back(
633 name ? name->
getName() :
"<pattern>", lsp::SymbolKind::Class,
635 }
else if (
const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
637 SMRange nameLoc = cDecl->getName().getLoc();
638 SMRange bodyLoc = nameLoc;
640 symbols.emplace_back(
641 cDecl->getName().getName(), lsp::SymbolKind::Function,
643 }
else if (
const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
645 SMRange nameLoc = cDecl->getName().getLoc();
646 SMRange bodyLoc = nameLoc;
648 symbols.emplace_back(
649 cDecl->getName().getName(), lsp::SymbolKind::Function,
662 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
667 completionList(completionList), odsContext(odsContext),
668 includeDirs(includeDirs) {}
673 for (
unsigned i = 0, e = tupleType.size(); i < e; ++i) {
676 item.
label = llvm::formatv(
"{0} (field #{0})", i).str();
680 item.
detail = llvm::formatv(
"{0}: {1}", i, elementTypes[i]);
682 completionList.items.emplace_back(item);
685 if (!elementNames[i].empty()) {
687 llvm::formatv(
"{1} (field #{0})", i, elementNames[i]).str();
690 completionList.items.emplace_back(item);
707 item.
label = llvm::formatv(
"{0} (field #{0})", it.index()).str();
713 item.
detail = llvm::formatv(
"{0}: Value", it.index()).str();
716 item.
detail = llvm::formatv(
"{0}: Value?", it.index()).str();
719 item.
detail = llvm::formatv(
"{0}: ValueRange", it.index()).str();
724 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
728 completionList.items.emplace_back(item);
732 if (!result.
getName().empty()) {
734 llvm::formatv(
"{1} (field #{0})", it.index(), result.
getName())
738 completionList.items.emplace_back(item);
752 item.
label = attr.getName().str();
754 item.
detail = attr.isOptional() ?
"optional" :
"";
757 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
761 completionList.items.emplace_back(item);
766 bool allowInlineTypeConstraints,
768 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
769 StringRef snippetText =
"") {
771 item.
label = constraint.str();
773 item.
detail = (constraint +
" constraint").str();
776 (
"A single entity core constraint of type `" + mlirType +
"`").str()};
782 completionList.items.emplace_back(item);
789 addCoreConstraint(
"Attr",
"mlir::Attribute");
790 addCoreConstraint(
"Op",
"mlir::Operation *");
791 addCoreConstraint(
"Value",
"mlir::Value");
792 addCoreConstraint(
"ValueRange",
"mlir::ValueRange");
793 addCoreConstraint(
"Type",
"mlir::Type");
794 addCoreConstraint(
"TypeRange",
"mlir::TypeRange");
796 if (allowInlineTypeConstraints) {
799 addCoreConstraint(
"Attr<type>",
"mlir::Attribute",
"Attr<$1>");
802 addCoreConstraint(
"Value<type>",
"mlir::Value",
"Value<$1>");
805 addCoreConstraint(
"ValueRange<type>",
"mlir::ValueRange",
811 for (
const ast::Decl *decl : scope->getDecls()) {
812 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
814 item.
label = cst->getName().getName().str();
820 if (cst->getInputs().size() != 1)
824 ast::Type constraintType = cst->getInputs()[0]->getType();
825 if (currentType && !currentType.refineWith(constraintType))
830 llvm::raw_string_ostream strOS(item.
detail);
832 llvm::interleaveComma(
834 strOS << var->getName().getName() <<
": " << var->getType();
836 strOS <<
") -> " << cst->getResultType();
840 if (std::optional<std::string> doc =
846 completionList.items.emplace_back(item);
850 scope = scope->getParentScope();
858 item.
label = dialect.getName().str();
861 completionList.items.emplace_back(item);
874 item.
label = op.
getName().drop_front(dialectName.size() + 1).str();
877 completionList.items.emplace_back(item);
882 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
883 StringRef snippetText =
"") {
885 item.
label = constraint.str();
887 item.
detail =
"pattern metadata";
894 completionList.items.emplace_back(item);
897 addSimpleConstraint(
"benefit",
"The `benefit` of matching the pattern.",
899 addSimpleConstraint(
"recursion",
900 "The pattern properly handles recursive application.");
907 llvm::sys::path::native(nativeRelDir);
913 auto addIncludeCompletion = [&](StringRef path,
bool isDirectory) {
915 item.
label = path.str();
918 if (seenResults.insert(item.
label).second)
919 completionList.items.emplace_back(item);
924 for (StringRef includeDir : includeDirs) {
926 if (!nativeRelDir.empty())
927 llvm::sys::path::append(dir, nativeRelDir);
929 std::error_code errorCode;
930 for (
auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
931 e = llvm::sys::fs::directory_iterator();
932 !errorCode && it != e; it.increment(errorCode)) {
933 StringRef filename = llvm::sys::path::filename(it->path());
938 llvm::sys::fs::file_type fileType = it->type();
939 if (fileType == llvm::sys::fs::file_type::symlink_file) {
940 if (
auto fileStatus = it->status())
941 fileType = fileStatus->type();
945 case llvm::sys::fs::file_type::directory_file:
946 addIncludeCompletion(filename,
true);
948 case llvm::sys::fs::file_type::regular_file: {
950 if (filename.ends_with(
".pdll") || filename.ends_with(
".td"))
951 addIncludeCompletion(filename,
false);
964 return lhs.label < rhs.label;
969 llvm::SourceMgr &sourceMgr;
979 SMLoc posLoc = completePos.
getAsSMLoc(sourceMgr);
980 if (!posLoc.isValid())
987 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
989 sourceMgr.getIncludeDirs());
993 &lspCompleteContext);
995 return completionList;
1005 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1009 signatureHelp(signatureHelp), odsContext(odsContext) {}
1012 unsigned currentNumArgs)
final {
1013 signatureHelp.activeParameter = currentNumArgs;
1017 llvm::raw_string_ostream strOS(signatureInfo.
label);
1018 strOS << callable->getName()->getName() <<
"(";
1020 unsigned paramStart = strOS.str().size();
1021 strOS << var->getName().getName() <<
": " << var->getType();
1022 unsigned paramEnd = strOS.str().size();
1024 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1025 std::make_pair(paramStart, paramEnd), std::string()});
1027 llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1028 strOS <<
") -> " << callable->getResultType();
1032 if (std::optional<std::string> doc =
1036 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1041 unsigned currentNumOperands)
final {
1044 codeCompleteOperationOperandOrResultSignature(
1045 opName, odsOp, odsOp ? odsOp->
getOperands() : std::nullopt,
1046 currentNumOperands,
"operand",
"Value");
1050 unsigned currentNumResults)
final {
1053 codeCompleteOperationOperandOrResultSignature(
1054 opName, odsOp, odsOp ? odsOp->
getResults() : std::nullopt,
1055 currentNumResults,
"result",
"Type");
1058 void codeCompleteOperationOperandOrResultSignature(
1061 StringRef label, StringRef dataType) {
1062 signatureHelp.activeParameter = currentValue;
1068 if (odsOp && currentValue < values.size()) {
1073 llvm::raw_string_ostream strOS(signatureInfo.
label);
1076 unsigned paramStart = strOS.str().size();
1078 strOS << value.getName() <<
": ";
1080 StringRef constraintDoc = value.getConstraint().getSummary();
1081 std::string paramDoc;
1082 switch (value.getVariableLengthKind()) {
1085 paramDoc = constraintDoc.str();
1088 strOS << dataType <<
"?";
1089 paramDoc = (
"optional: " + constraintDoc).str();
1092 strOS << dataType <<
"Range";
1093 paramDoc = (
"variadic: " + constraintDoc).str();
1097 unsigned paramEnd = strOS.str().size();
1099 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1100 std::make_pair(paramStart, paramEnd), paramDoc});
1102 llvm::interleaveComma(values, strOS, formatFn);
1106 llvm::formatv(
"`op<{0}>` ODS {1} specification", *opName, label)
1108 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1112 if (currentValue == 0 && (!odsOp || !values.empty())) {
1114 signatureInfo.
label =
1115 llvm::formatv(
"(<{0}s>: {1}Range)", label, dataType).str();
1117 (
"Generic operation " + label +
" specification").str();
1119 StringRef(signatureInfo.
label).drop_front().drop_back().str(),
1120 std::pair<unsigned, unsigned>(1, signatureInfo.
label.size() - 1),
1121 (
"All of the " + label +
"s of the operation.").str()});
1122 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1127 llvm::SourceMgr &sourceMgr;
1135 SMLoc posLoc = helpPos.
getAsSMLoc(sourceMgr);
1136 if (!posLoc.isValid())
1143 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1150 return signatureHelp;
1163 if (
auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1165 if (declName && declName->
getName() == name)
1174 std::vector<lsp::InlayHint> &inlayHints) {
1178 if (!rangeLoc.isValid())
1180 (*astModule)->walk([&](
const ast::Node *node) {
1181 SMRange loc = node->
getLoc();
1191 [&](
const auto *node) {
1192 this->getInlayHintsFor(node, uri, inlayHints);
1199 std::vector<lsp::InlayHint> &inlayHints) {
1210 if (isa<ast::OperationExpr>(expr))
1217 llvm::raw_string_ostream labelOS(hint.label);
1218 labelOS <<
": " << decl->
getType();
1221 inlayHints.emplace_back(std::move(hint));
1224 void PDLDocument::getInlayHintsFor(
const ast::CallExpr *expr,
1226 std::vector<lsp::InlayHint> &inlayHints) {
1228 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
1229 const auto *callable =
1230 callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1236 for (
const auto &it : llvm::zip(expr->
getArguments(), callable->getInputs()))
1237 addParameterHintFor(inlayHints, std::get<0>(it),
1238 std::get<1>(it)->getName().getName());
1243 std::vector<lsp::InlayHint> &inlayHints) {
1248 auto addOpHint = [&](
const ast::Expr *valueExpr, StringRef label) {
1251 if (expr->getLoc().Start == valueExpr->
getLoc().Start)
1253 addParameterHintFor(inlayHints, valueExpr, label);
1261 StringRef allValuesName) {
1267 if (values.size() != odsValues.size()) {
1269 if (values.size() == 1)
1270 return addOpHint(values.front(), allValuesName);
1274 for (
const auto &it : llvm::zip(values, odsValues))
1275 addOpHint(std::get<0>(it), std::get<1>(it).getName());
1280 odsOp ? odsOp->getOperands()
1284 odsOp ? odsOp->getResults()
1289 void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1290 const ast::Expr *expr, StringRef label) {
1296 hint.label = (label +
":").str();
1297 hint.paddingRight =
true;
1298 inlayHints.emplace_back(std::move(hint));
1305 void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1309 if (kind == lsp::PDLLViewOutputKind::AST) {
1310 (*astModule)->print(os);
1323 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1329 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1330 "unexpected PDLLViewOutputKind");
1340 struct PDLTextFileChunk {
1343 const std::vector<std::string> &extraDirs,
1344 std::vector<lsp::Diagnostic> &diagnostics)
1345 : lineOffset(lineOffset),
1346 document(uri, contents, extraDirs, diagnostics) {}
1350 void adjustLocForChunkOffset(
lsp::Range &range) {
1351 adjustLocForChunkOffset(range.
start);
1352 adjustLocForChunkOffset(range.
end);
1359 uint64_t lineOffset;
1361 PDLDocument document;
1374 int64_t version,
const std::vector<std::string> &extraDirs,
1375 std::vector<lsp::Diagnostic> &diagnostics);
1378 int64_t getVersion()
const {
return version; }
1384 std::vector<lsp::Diagnostic> &diagnostics);
1391 std::vector<lsp::Location> &locations);
1393 std::vector<lsp::Location> &references);
1395 std::vector<lsp::DocumentLink> &links);
1398 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1404 std::vector<lsp::InlayHint> &inlayHints);
1408 using ChunkIterator = llvm::pointee_iterator<
1409 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1413 std::vector<lsp::Diagnostic> &diagnostics);
1420 return *getChunkItFor(pos);
1424 std::string contents;
1427 int64_t version = 0;
1430 int64_t totalNumLines = 0;
1434 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1437 std::vector<std::string> extraIncludeDirs;
1441 PDLTextFile::PDLTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1443 const std::vector<std::string> &extraDirs,
1444 std::vector<lsp::Diagnostic> &diagnostics)
1445 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1446 initialize(uri, version, diagnostics);
1452 std::vector<lsp::Diagnostic> &diagnostics) {
1453 if (
failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1454 lsp::Logger::error(
"Failed to update contents of {0}", uri.
file());
1459 initialize(uri, newVersion, diagnostics);
1465 std::vector<lsp::Location> &locations) {
1466 PDLTextFileChunk &chunk = getChunkFor(defPos);
1467 chunk.document.getLocationsOf(uri, defPos, locations);
1470 if (chunk.lineOffset == 0)
1474 chunk.adjustLocForChunkOffset(loc.range);
1479 std::vector<lsp::Location> &references) {
1480 PDLTextFileChunk &chunk = getChunkFor(pos);
1481 chunk.document.findReferencesOf(uri, pos, references);
1484 if (chunk.lineOffset == 0)
1488 chunk.adjustLocForChunkOffset(loc.range);
1492 std::vector<lsp::DocumentLink> &links) {
1493 chunks.front()->document.getDocumentLinks(uri, links);
1494 for (
const auto &it : llvm::drop_begin(chunks)) {
1495 size_t currentNumLinks = links.size();
1496 it->document.getDocumentLinks(uri, links);
1500 for (
auto &link : llvm::drop_begin(links, currentNumLinks))
1501 it->adjustLocForChunkOffset(link.range);
1505 std::optional<lsp::Hover> PDLTextFile::findHover(
const lsp::URIForFile &uri,
1507 PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1508 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1511 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1512 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1516 void PDLTextFile::findDocumentSymbols(
1517 std::vector<lsp::DocumentSymbol> &symbols) {
1518 if (chunks.size() == 1)
1519 return chunks.front()->document.findDocumentSymbols(symbols);
1523 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1524 PDLTextFileChunk &chunk = *chunks[i];
1527 : chunks[i + 1]->lineOffset);
1529 lsp::SymbolKind::Namespace,
1532 chunk.document.findDocumentSymbols(symbol.children);
1538 symbolsToFix.push_back(&childSymbol);
1540 while (!symbolsToFix.empty()) {
1542 chunk.adjustLocForChunkOffset(symbol->
range);
1546 symbolsToFix.push_back(&childSymbol);
1551 symbols.emplace_back(std::move(symbol));
1557 PDLTextFileChunk &chunk = getChunkFor(completePos);
1559 chunk.document.getCodeCompletion(uri, completePos);
1564 chunk.adjustLocForChunkOffset(item.
textEdit->range);
1566 chunk.adjustLocForChunkOffset(edit.
range);
1568 return completionList;
1573 return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1577 std::vector<lsp::InlayHint> &inlayHints) {
1578 auto startIt = getChunkItFor(range.
start);
1579 auto endIt = getChunkItFor(range.
end);
1582 auto getHintsForChunk = [&](ChunkIterator chunkIt,
lsp::Range range) {
1583 size_t currentNumHints = inlayHints.size();
1584 chunkIt->document.getInlayHints(uri, range, inlayHints);
1588 if (&*chunkIt != &*chunks.front()) {
1589 for (
auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1590 chunkIt->adjustLocForChunkOffset(hint.position);
1594 auto getNumLines = [](ChunkIterator chunkIt) {
1595 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1599 if (startIt == endIt)
1600 return getHintsForChunk(startIt, range);
1604 getHintsForChunk(startIt,
lsp::Range(range.
start, getNumLines(startIt)));
1607 for (++startIt; startIt != endIt; ++startIt)
1608 getHintsForChunk(startIt,
lsp::Range(0, getNumLines(startIt)));
1619 llvm::raw_string_ostream outputOS(result.
output);
1621 llvm::make_pointee_range(chunks),
1622 [&](PDLTextFileChunk &chunk) {
1623 chunk.document.getPDLLViewOutput(outputOS, kind);
1625 [&] { outputOS <<
"\n"
1631 void PDLTextFile::initialize(
const lsp::URIForFile &uri, int64_t newVersion,
1632 std::vector<lsp::Diagnostic> &diagnostics) {
1633 version = newVersion;
1639 chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1640 0, uri, subContents.front(), extraIncludeDirs,
1643 uint64_t lineOffset = subContents.front().count(
'\n');
1644 for (StringRef docContents : llvm::drop_begin(subContents)) {
1645 unsigned currentNumDiags = diagnostics.size();
1646 auto chunk = std::make_unique<PDLTextFileChunk>(
1647 lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1648 lineOffset += docContents.count(
'\n');
1653 llvm::drop_begin(diagnostics, currentNumDiags)) {
1654 chunk->adjustLocForChunkOffset(
diag.range);
1656 if (!
diag.relatedInformation)
1658 for (
auto &it : *
diag.relatedInformation)
1659 if (it.location.uri == uri)
1660 chunk->adjustLocForChunkOffset(it.location.range);
1662 chunks.emplace_back(std::move(chunk));
1664 totalNumLines = lineOffset;
1667 PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(
lsp::Position &pos) {
1668 if (chunks.size() == 1)
1669 return chunks.begin();
1673 auto it = llvm::upper_bound(
1674 chunks, pos, [](
const lsp::Position &pos,
const auto &chunk) {
1675 return static_cast<uint64_t
>(pos.
line) < chunk->lineOffset;
1677 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1678 pos.
line -= chunkIt->lineOffset;
1698 llvm::StringMap<std::unique_ptr<PDLTextFile>>
files;
1711 std::vector<Diagnostic> &diagnostics) {
1713 std::vector<std::string> additionalIncludeDirs =
impl->options.extraDirs;
1714 const auto &fileInfo =
impl->compilationDatabase.getFileInfo(uri.
file());
1715 llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1717 impl->files[uri.
file()] = std::make_unique<PDLTextFile>(
1718 uri, contents, version, additionalIncludeDirs, diagnostics);
1723 int64_t version, std::vector<Diagnostic> &diagnostics) {
1725 auto it =
impl->files.find(uri.
file());
1726 if (it ==
impl->files.end())
1731 if (
failed(it->second->update(uri, version, changes, diagnostics)))
1732 impl->files.erase(it);
1736 auto it =
impl->files.find(uri.
file());
1737 if (it ==
impl->files.end())
1738 return std::nullopt;
1740 int64_t version = it->second->getVersion();
1741 impl->files.erase(it);
1747 std::vector<Location> &locations) {
1748 auto fileIt =
impl->files.find(uri.
file());
1749 if (fileIt !=
impl->files.end())
1750 fileIt->second->getLocationsOf(uri, defPos, locations);
1755 std::vector<Location> &references) {
1756 auto fileIt =
impl->files.find(uri.
file());
1757 if (fileIt !=
impl->files.end())
1758 fileIt->second->findReferencesOf(uri, pos, references);
1762 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1763 auto fileIt =
impl->files.find(uri.
file());
1764 if (fileIt !=
impl->files.end())
1765 return fileIt->second->getDocumentLinks(uri, documentLinks);
1770 auto fileIt =
impl->files.find(uri.
file());
1771 if (fileIt !=
impl->files.end())
1772 return fileIt->second->findHover(uri, hoverPos);
1773 return std::nullopt;
1777 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1778 auto fileIt =
impl->files.find(uri.
file());
1779 if (fileIt !=
impl->files.end())
1780 fileIt->second->findDocumentSymbols(symbols);
1786 auto fileIt =
impl->files.find(uri.
file());
1787 if (fileIt !=
impl->files.end())
1788 return fileIt->second->getCodeCompletion(uri, completePos);
1794 auto fileIt =
impl->files.find(uri.
file());
1795 if (fileIt !=
impl->files.end())
1796 return fileIt->second->getSignatureHelp(uri, helpPos);
1801 std::vector<InlayHint> &inlayHints) {
1802 auto fileIt =
impl->files.find(uri.
file());
1803 if (fileIt ==
impl->files.end())
1805 fileIt->second->getInlayHints(uri, range, inlayHints);
1808 llvm::sort(inlayHints);
1809 inlayHints.erase(std::unique(inlayHints.begin(), inlayHints.end()),
1813 std::optional<lsp::PDLLViewOutputResult>
1816 auto fileIt =
impl->files.find(uri.
file());
1817 if (fileIt !=
impl->files.end())
1818 return fileIt->second->getPDLLViewOutput(kind);
1819 return std::nullopt;
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static std::string diag(const llvm::Value &value)
static std::optional< std::string > getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl)
Get or extract the documentation for the given decl.
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, const lsp::URIForFile &uri)
Returns a language server location from the given source range.
static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, const lsp::URIForFile &mainFileURI)
Returns a language server uri for the given source location.
static bool shouldAddHintFor(const ast::Expr *expr, StringRef name)
Returns true if the given name should be added as a hint for expr.
static std::optional< lsp::Diagnostic > getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc)
Returns true if the given location is in the main file of the source manager.
static llvm::ManagedStatic< PassManagerOptions > options
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This class provides support for representing a failure result, or a valid value of type T.
MLIRContext is the top-level object for a collection of MLIR operations.
Set of flags used to control the behavior of the various IR print methods (e.g.
OperationName getName()
The name of an operation is the key identifier for it.
This class is a utility diagnostic handler for use with llvm::SourceMgr.
This class contains a collection of compilation information for files provided to the language server...
static void error(const char *fmt, Ts &&...vals)
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
void getDocumentLinks(const URIForFile &uri, std::vector< DocumentLink > &documentLinks)
Return the document links referenced by the given file.
void getInlayHints(const URIForFile &uri, const Range &range, std::vector< InlayHint > &inlayHints)
Get the inlay hints for the range within the given file.
void addDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add the document, with the provided version, at the given URI.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
SignatureHelp getSignatureHelp(const URIForFile &uri, const Position &helpPos)
Get the signature help for the position within the given file.
void updateDocument(const URIForFile &uri, ArrayRef< TextDocumentContentChangeEvent > changes, int64_t version, std::vector< Diagnostic > &diagnostics)
Update the document, with the provided version, at the given URI.
std::optional< PDLLViewOutputResult > getPDLLViewOutput(const URIForFile &uri, PDLLViewOutputKind kind)
Get the output of the given PDLL file, or std::nullopt if there is no valid output.
URI in "file" scheme for a file.
static llvm::Expected< URIForFile > fromFile(StringRef absoluteFilepath, StringRef scheme="file")
Try to build a URIForFile from the given absolute file path and optional scheme.
StringRef file() const
Returns the absolute path to the file.
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
This class represents a PDLL type that corresponds to an mlir::Attribute.
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
This decl represents a shared interface for all callable decls.
This class represents the main context of the PDLL AST.
This class represents the base of all "core" constraints.
This class represents a scope for named AST decls.
This class represents the base Decl node.
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
This class provides a simple implementation of a PDLL diagnostic.
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a top-level AST module.
This class represents a base AST node.
SMRange getLoc() const
Return the location of this node.
The class represents an Operation constraint, and constrains a variable to be an Operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
This expression represents the structural form of an MLIR Operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
This Decl represents a single Pattern.
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
This class represents a PDLL tuple type, i.e.
The class represents a Type constraint, and constrains a variable to be a Type.
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
The class represents a Value constraint, and constrains a variable to be a Value.
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
This class represents a PDLL type that corresponds to an mlir::ValueRange.
This class represents a PDLL type that corresponds to an mlir::Value.
This Decl represents the definition of a PDLL variable.
const Name & getName() const
Return the name of the decl.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
This class represents a generic ODS Attribute constraint.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
This class provides an ODS representation of a specific operation attribute.
StringRef getSummary() const
Return the summary of this constraint.
This class contains all of the registered ODS operation classes.
auto getDialects() const
Return a range of all of the registered dialects.
const Dialect * lookupDialect(StringRef name) const
Lookup a dialect registered with the given name, or null if no dialect with that name was inserted.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
This class represents an ODS dialect, and contains information on the constructs held within the dial...
const llvm::StringMap< std::unique_ptr< Operation > > & getOperations() const
Return a map of all of the operations registered to this dialect.
This class provides an ODS representation of a specific operation operand or result.
const TypeConstraint & getConstraint() const
Return the constraint of this value.
VariableLengthKind getVariableLengthKind() const
Returns the variable length kind of this value.
StringRef getName() const
Return the name of this value.
This class provides an ODS representation of a specific operation.
ArrayRef< Attribute > getAttributes() const
Returns the attributes of this operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
This class represents a generic ODS Type constraint.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
PDLLViewOutputKind
The type of output to view from PDLL.
void gatherIncludeFiles(llvm::SourceMgr &sourceMgr, SmallVectorImpl< SourceMgrInclude > &includes)
Given a source manager, gather all of the processed include files.
std::optional< std::string > extractSourceDocComment(llvm::SourceMgr &sourceMgr, SMLoc loc)
Extract a documentation comment for the given location within the source manager.
@ PlainText
The primary text to be inserted is treated as a plain string.
@ Snippet
The primary text to be inserted is treated as a snippet.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
const char *const kDefaultSplitMarker
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Impl(const Options &options)
lsp::CompilationDatabase compilationDatabase
The compilation database containing additional information for files passed to the server.
llvm::StringMap< std::unique_ptr< PDLTextFile > > files
The files held by the server, mapped by their URI file name.
const Options & options
PDLL LSP options.
This class represents an efficient way to signal success or failure.
std::string detail
A human-readable string with additional information about this item, like type or symbol information.
std::string filterText
A string that should be used when filtering a set of completion items.
std::optional< TextEdit > textEdit
An edit which is applied to a document when selecting this completion.
std::optional< MarkupContent > documentation
A human-readable string that represents a doc-comment.
std::string insertText
A string that should be inserted to a document when selecting this completion.
std::string label
The label of this completion item.
CompletionItemKind kind
The kind of this completion item.
std::vector< TextEdit > additionalTextEdits
An optional array of additional text edits that are applied when selecting this completion.
InsertTextFormat insertTextFormat
The format of the insert text.
std::string sortText
A string that should be used when comparing this item with other items.
Represents a collection of completion items to be presented in the editor.
std::vector< CompletionItem > items
The completion items.
std::string source
A human-readable string describing the source of this diagnostic, e.g.
DiagnosticSeverity severity
The diagnostic's severity.
Range range
The source range where the message applies.
std::optional< std::vector< DiagnosticRelatedInformation > > relatedInformation
An array of related diagnostic information, e.g.
std::string message
The diagnostic's message.
std::optional< std::string > category
The diagnostic's category.
Represents programming constructs like variables, classes, interfaces etc.
Range range
The range enclosing this symbol not including leading/trailing whitespace but everything else like co...
Range selectionRange
The range that should be selected and revealed when this symbol is being picked, e....
std::vector< DocumentSymbol > children
Children of this symbol, e.g. properties of a class.
URIForFile uri
The text document's URI.
Represents the result of viewing the output of a PDLL file.
std::string output
The string representation of the output.
int line
Line position in a document (zero-based).
SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const
Convert this position into a source location in the main file of the given source manager.
Position end
The range's end position.
Position start
The range's start position.
SMRange getAsSMRange(llvm::SourceMgr &mgr) const
Convert this range into a source range in the main file of the given source manager.
Represents the signature of a callable.
This class represents a single include within a root file.
Range range
The range of the text document to be manipulated.
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.