27#include "llvm/ADT/IntervalMap.h"
28#include "llvm/ADT/StringMap.h"
29#include "llvm/ADT/StringSet.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/FileSystem.h"
32#include "llvm/Support/LSP/Logging.h"
33#include "llvm/Support/Path.h"
34#include "llvm/Support/VirtualFileSystem.h"
42static llvm::lsp::URIForFile
44 const llvm::lsp::URIForFile &mainFileURI) {
45 int bufferId = mgr.FindBufferContainingLoc(loc.Start);
46 if (bufferId == 0 || bufferId ==
static_cast<int>(mgr.getMainFileID()))
49 llvm::lsp::URIForFile::fromFile(
50 mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
53 llvm::lsp::Logger::error(
"Failed to create URI for include file: {0}",
54 llvm::toString(fileForLoc.takeError()));
61 return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
65static llvm::lsp::Location
67 const llvm::lsp::URIForFile &uri) {
69 llvm::lsp::Range(mgr, range));
73static std::optional<llvm::lsp::Diagnostic>
75 const llvm::lsp::URIForFile &uri) {
76 llvm::lsp::Diagnostic lspDiag;
77 lspDiag.source =
"pdll";
81 lspDiag.category =
"Parse Error";
84 llvm::lsp::Location loc =
86 lspDiag.range = loc.range;
93 switch (
diag.getSeverity()) {
94 case ast::Diagnostic::Severity::DK_Note:
95 llvm_unreachable(
"expected notes to be handled separately");
96 case ast::Diagnostic::Severity::DK_Warning:
97 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
99 case ast::Diagnostic::Severity::DK_Error:
100 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
102 case ast::Diagnostic::Severity::DK_Remark:
103 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
106 lspDiag.message =
diag.getMessage().str();
109 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
111 relatedDiags.emplace_back(
113 note.getMessage().str());
115 if (!relatedDiags.empty())
116 lspDiag.relatedInformation = std::move(relatedDiags);
122static std::optional<std::string>
138struct PDLIndexSymbol {
139 explicit PDLIndexSymbol(
const ast::Decl *definition)
140 : definition(definition) {}
141 explicit PDLIndexSymbol(
const ods::Operation *definition)
142 : definition(definition) {}
145 SMRange getDefLoc()
const {
146 if (
const ast::Decl *decl =
147 llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
148 const ast::Name *declName = decl->getName();
149 return declName ? declName->
getLoc() : decl->getLoc();
151 return cast<const ods::Operation *>(definition)->getLoc();
155 PointerUnion<const ast::Decl *, const ods::Operation *> definition;
157 std::vector<SMRange> references;
164 PDLIndex() : intervalMap(allocator) {}
167 void initialize(
const ast::Module &module,
const ods::Context &odsContext);
172 const PDLIndexSymbol *lookup(SMLoc loc,
173 SMRange *overlappedRange =
nullptr)
const;
179 llvm::IntervalMap<
const char *,
const PDLIndexSymbol *,
180 llvm::IntervalMapImpl::NodeSizer<
181 const char *,
const PDLIndexSymbol *>::LeafSize,
182 llvm::IntervalMapHalfOpenInfo<const char *>>;
185 MapT::Allocator allocator;
196void PDLIndex::initialize(
const ast::Module &module,
198 auto getOrInsertDef = [&](
const auto *def) -> PDLIndexSymbol * {
199 auto it = defToSymbol.try_emplace(def,
nullptr);
201 it.first->second = std::make_unique<PDLIndexSymbol>(def);
202 return &*it.first->second;
204 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
205 bool isDef =
false) {
206 const char *startLoc = refLoc.Start.getPointer();
207 const char *endLoc = refLoc.End.getPointer();
208 if (!intervalMap.overlaps(startLoc, endLoc)) {
209 intervalMap.insert(startLoc, endLoc, sym);
211 sym->references.push_back(refLoc);
214 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
219 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
220 insertDeclRef(symbol, odsOp->
getLoc(),
true);
221 insertDeclRef(symbol, refLoc);
224 module.walk([&](const ast::Node *node) {
226 if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
227 if (std::optional<StringRef> name = decl->getName())
228 insertODSOpRef(*name, decl->getLoc());
229 }
else if (
const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
230 const ast::Name *name = decl->getName();
233 PDLIndexSymbol *declSym = getOrInsertDef(decl);
234 insertDeclRef(declSym, name->
getLoc(),
true);
236 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
238 for (
const auto &it : varDecl->getConstraints())
239 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
241 }
else if (
const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
242 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
247const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
248 SMRange *overlappedRange)
const {
249 auto it = intervalMap.find(loc.getPointer());
250 if (!it.valid() || loc.getPointer() < it.start())
253 if (overlappedRange) {
254 *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
255 SMLoc::getFromPointer(it.stop()));
268 PDLDocument(
const llvm::lsp::URIForFile &uri, StringRef contents,
269 const std::vector<std::string> &extraDirs,
270 std::vector<llvm::lsp::Diagnostic> &diagnostics);
271 PDLDocument(
const PDLDocument &) =
delete;
272 PDLDocument &operator=(
const PDLDocument &) =
delete;
278 void getLocationsOf(
const llvm::lsp::URIForFile &uri,
279 const llvm::lsp::Position &defPos,
280 std::vector<llvm::lsp::Location> &locations);
281 void findReferencesOf(
const llvm::lsp::URIForFile &uri,
282 const llvm::lsp::Position &pos,
283 std::vector<llvm::lsp::Location> &references);
289 void getDocumentLinks(
const llvm::lsp::URIForFile &uri,
290 std::vector<llvm::lsp::DocumentLink> &links);
296 std::optional<llvm::lsp::Hover>
297 findHover(
const llvm::lsp::URIForFile &uri,
298 const llvm::lsp::Position &hoverPos);
299 std::optional<llvm::lsp::Hover> findHover(
const ast::Decl *decl,
300 const SMRange &hoverRange);
301 llvm::lsp::Hover buildHoverForOpName(
const ods::Operation *op,
302 const SMRange &hoverRange);
303 llvm::lsp::Hover buildHoverForVariable(
const ast::VariableDecl *varDecl,
304 const SMRange &hoverRange);
305 llvm::lsp::Hover buildHoverForPattern(
const ast::PatternDecl *decl,
306 const SMRange &hoverRange);
308 buildHoverForCoreConstraint(
const ast::CoreConstraintDecl *decl,
309 const SMRange &hoverRange);
310 template <
typename T>
312 buildHoverForUserConstraintOrRewrite(StringRef typeName,
const T *decl,
313 const SMRange &hoverRange);
319 void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
325 llvm::lsp::CompletionList
326 getCodeCompletion(
const llvm::lsp::URIForFile &uri,
327 const llvm::lsp::Position &completePos);
333 llvm::lsp::SignatureHelp getSignatureHelp(
const llvm::lsp::URIForFile &uri,
334 const llvm::lsp::Position &helpPos);
340 void getInlayHints(
const llvm::lsp::URIForFile &uri,
341 const llvm::lsp::Range &range,
342 std::vector<llvm::lsp::InlayHint> &inlayHints);
343 void getInlayHintsFor(
const ast::VariableDecl *decl,
344 const llvm::lsp::URIForFile &uri,
345 std::vector<llvm::lsp::InlayHint> &inlayHints);
346 void getInlayHintsFor(
const ast::CallExpr *expr,
347 const llvm::lsp::URIForFile &uri,
348 std::vector<llvm::lsp::InlayHint> &inlayHints);
349 void getInlayHintsFor(
const ast::OperationExpr *expr,
350 const llvm::lsp::URIForFile &uri,
351 std::vector<llvm::lsp::InlayHint> &inlayHints);
354 void addParameterHintFor(std::vector<llvm::lsp::InlayHint> &inlayHints,
355 const ast::Expr *expr, StringRef label);
368 std::vector<std::string> includeDirs;
371 llvm::SourceMgr sourceMgr;
374 ods::Context odsContext;
375 ast::Context astContext;
378 FailureOr<ast::Module *> astModule;
384 SmallVector<lsp::SourceMgrInclude> parsedIncludes;
388PDLDocument::PDLDocument(
const llvm::lsp::URIForFile &uri, StringRef contents,
389 const std::vector<std::string> &extraDirs,
390 std::vector<llvm::lsp::Diagnostic> &diagnostics)
391 : astContext(odsContext) {
392 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
394 llvm::lsp::Logger::error(
"Failed to create memory buffer for file",
401 llvm::sys::path::remove_filename(uriDirectory);
402 includeDirs.push_back(uriDirectory.str().str());
403 llvm::append_range(includeDirs, extraDirs);
405 sourceMgr.setIncludeDirs(includeDirs);
406 sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
407 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
411 diagnostics.push_back(std::move(*lspDiag));
423 index.initialize(**astModule, odsContext);
430void PDLDocument::getLocationsOf(
const llvm::lsp::URIForFile &uri,
431 const llvm::lsp::Position &defPos,
432 std::vector<llvm::lsp::Location> &locations) {
433 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
434 const PDLIndexSymbol *symbol = index.lookup(posLoc);
441void PDLDocument::findReferencesOf(
442 const llvm::lsp::URIForFile &uri,
const llvm::lsp::Position &pos,
443 std::vector<llvm::lsp::Location> &references) {
444 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
445 const PDLIndexSymbol *symbol = index.lookup(posLoc);
450 for (SMRange refLoc : symbol->references)
458void PDLDocument::getDocumentLinks(
459 const llvm::lsp::URIForFile &uri,
460 std::vector<llvm::lsp::DocumentLink> &links) {
461 for (
const lsp::SourceMgrInclude &include : parsedIncludes)
462 links.emplace_back(include.range, include.uri);
469std::optional<llvm::lsp::Hover>
470PDLDocument::findHover(
const llvm::lsp::URIForFile &uri,
471 const llvm::lsp::Position &hoverPos) {
472 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
475 for (
const lsp::SourceMgrInclude &include : parsedIncludes)
476 if (include.range.contains(hoverPos))
477 return include.buildHover();
481 const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
487 llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
488 return buildHoverForOpName(op, hoverRange);
489 const auto *decl = cast<const ast::Decl *>(symbol->definition);
490 return findHover(decl, hoverRange);
493std::optional<llvm::lsp::Hover>
494PDLDocument::findHover(
const ast::Decl *decl,
const SMRange &hoverRange) {
496 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
497 return buildHoverForVariable(varDecl, hoverRange);
500 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
501 return buildHoverForPattern(patternDecl, hoverRange);
504 if (
const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
505 return buildHoverForCoreConstraint(cst, hoverRange);
508 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
509 return buildHoverForUserConstraintOrRewrite(
"Constraint", cst, hoverRange);
512 if (
const auto *
rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
513 return buildHoverForUserConstraintOrRewrite(
"Rewrite",
rewrite, hoverRange);
518llvm::lsp::Hover PDLDocument::buildHoverForOpName(
const ods::Operation *op,
519 const SMRange &hoverRange) {
520 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
522 llvm::raw_string_ostream hoverOS(hover.contents.value);
523 hoverOS <<
"**OpName**: `" << op->
getName() <<
"`\n***\n"
531PDLDocument::buildHoverForVariable(
const ast::VariableDecl *varDecl,
532 const SMRange &hoverRange) {
533 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
535 llvm::raw_string_ostream hoverOS(hover.contents.value);
536 hoverOS <<
"**Variable**: `" << varDecl->
getName().
getName() <<
"`\n***\n"
537 <<
"Type: `" << varDecl->
getType() <<
"`\n";
542llvm::lsp::Hover PDLDocument::buildHoverForPattern(
const ast::PatternDecl *decl,
543 const SMRange &hoverRange) {
544 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
546 llvm::raw_string_ostream hoverOS(hover.contents.value);
547 hoverOS <<
"**Pattern**";
548 if (
const ast::Name *name = decl->
getName())
549 hoverOS <<
": `" << name->
getName() <<
"`";
550 hoverOS <<
"\n***\n";
551 if (std::optional<uint16_t> benefit = decl->
getBenefit())
552 hoverOS <<
"Benefit: " << *benefit <<
"\n";
554 hoverOS <<
"HasBoundedRewriteRecursion\n";
555 hoverOS <<
"RootOp: `"
560 hoverOS <<
"\n" << *doc <<
"\n";
566PDLDocument::buildHoverForCoreConstraint(
const ast::CoreConstraintDecl *decl,
567 const SMRange &hoverRange) {
568 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
570 llvm::raw_string_ostream hoverOS(hover.contents.value);
571 hoverOS <<
"**Constraint**: `";
573 .Case([&](
const ast::AttrConstraintDecl *) { hoverOS <<
"Attr"; })
574 .Case([&](
const ast::OpConstraintDecl *opCst) {
576 if (std::optional<StringRef> name = opCst->
getName())
577 hoverOS <<
"<" << *name <<
">";
579 .Case([&](
const ast::TypeConstraintDecl *) { hoverOS <<
"Type"; })
580 .Case([&](
const ast::TypeRangeConstraintDecl *) {
581 hoverOS <<
"TypeRange";
583 .Case([&](
const ast::ValueConstraintDecl *) { hoverOS <<
"Value"; })
584 .Case([&](
const ast::ValueRangeConstraintDecl *) {
585 hoverOS <<
"ValueRange";
593llvm::lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
594 StringRef typeName,
const T *decl,
const SMRange &hoverRange) {
595 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
597 llvm::raw_string_ostream hoverOS(hover.contents.value);
598 hoverOS <<
"**" << typeName <<
"**: `" << decl->getName().getName()
600 ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
601 if (!inputs.empty()) {
602 hoverOS <<
"Parameters:\n";
603 for (
const ast::VariableDecl *input : inputs)
604 hoverOS <<
"* " << input->getName().getName() <<
": `"
605 << input->getType() <<
"`\n";
608 ast::Type resultType = decl->getResultType();
609 if (
auto resultTupleTy = dyn_cast<ast::TupleType>(resultType)) {
610 if (!resultTupleTy.empty()) {
611 hoverOS <<
"Results:\n";
612 for (
auto it : llvm::zip(resultTupleTy.getElementNames(),
613 resultTupleTy.getElementTypes())) {
614 StringRef name = std::get<0>(it);
615 hoverOS <<
"* " << (name.empty() ?
"" : (name +
": ")) <<
"`"
616 << std::get<1>(it) <<
"`\n";
621 hoverOS <<
"Results:\n* `" << resultType <<
"`\n";
627 hoverOS <<
"\n" << *doc <<
"\n";
636void PDLDocument::findDocumentSymbols(
637 std::vector<llvm::lsp::DocumentSymbol> &symbols) {
641 for (
const ast::Decl *decl : (*astModule)->getChildren()) {
645 if (
const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
646 const ast::Name *name = patternDecl->
getName();
648 SMRange nameLoc = name ? name->
getLoc() : patternDecl->getLoc();
649 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
651 symbols.emplace_back(name ? name->
getName() :
"<pattern>",
652 llvm::lsp::SymbolKind::Class,
653 llvm::lsp::Range(sourceMgr, bodyLoc),
654 llvm::lsp::Range(sourceMgr, nameLoc));
655 }
else if (
const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
657 SMRange nameLoc = cDecl->getName().getLoc();
658 SMRange bodyLoc = nameLoc;
660 symbols.emplace_back(cDecl->getName().getName(),
661 llvm::lsp::SymbolKind::Function,
662 llvm::lsp::Range(sourceMgr, bodyLoc),
663 llvm::lsp::Range(sourceMgr, nameLoc));
664 }
else if (
const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
666 SMRange nameLoc = cDecl->getName().getLoc();
667 SMRange bodyLoc = nameLoc;
669 symbols.emplace_back(cDecl->getName().getName(),
670 llvm::lsp::SymbolKind::Function,
671 llvm::lsp::Range(sourceMgr, bodyLoc),
672 llvm::lsp::Range(sourceMgr, nameLoc));
682class LSPCodeCompleteContext :
public CodeCompleteContext {
684 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
685 llvm::lsp::CompletionList &completionList,
686 ods::Context &odsContext,
687 ArrayRef<std::string> includeDirs)
688 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
689 completionList(completionList), odsContext(odsContext),
690 includeDirs(includeDirs) {}
692 void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
final {
693 ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
694 ArrayRef<StringRef> elementNames = tupleType.getElementNames();
695 for (
unsigned i = 0, e = tupleType.size(); i < e; ++i) {
697 llvm::lsp::CompletionItem item;
698 item.label = llvm::formatv(
"{0} (field #{0})", i).str();
699 item.insertText = Twine(i).str();
700 item.filterText = item.sortText = item.insertText;
701 item.kind = llvm::lsp::CompletionItemKind::Field;
702 item.detail = llvm::formatv(
"{0}: {1}", i, elementTypes[i]);
703 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
704 completionList.items.emplace_back(item);
707 if (!elementNames[i].empty()) {
709 llvm::formatv(
"{1} (field #{0})", i, elementNames[i]).str();
710 item.filterText = item.label;
711 item.insertText = elementNames[i].str();
712 completionList.items.emplace_back(item);
717 void codeCompleteOperationMemberAccess(ast::OperationType opType)
final {
718 const ods::Operation *odsOp = opType.getODSOperation();
722 ArrayRef<ods::OperandOrResult> results = odsOp->
getResults();
723 for (
const auto &it : llvm::enumerate(results)) {
724 const ods::OperandOrResult &
result = it.value();
725 const ods::TypeConstraint &constraint =
result.getConstraint();
728 llvm::lsp::CompletionItem item;
729 item.label = llvm::formatv(
"{0} (field #{0})", it.index()).str();
730 item.insertText = Twine(it.index()).str();
731 item.filterText = item.sortText = item.insertText;
732 item.kind = llvm::lsp::CompletionItemKind::Field;
733 switch (
result.getVariableLengthKind()) {
734 case ods::VariableLengthKind::Single:
735 item.detail = llvm::formatv(
"{0}: Value", it.index()).str();
737 case ods::VariableLengthKind::Optional:
738 item.detail = llvm::formatv(
"{0}: Value?", it.index()).str();
740 case ods::VariableLengthKind::Variadic:
741 item.detail = llvm::formatv(
"{0}: ValueRange", it.index()).str();
744 item.documentation = llvm::lsp::MarkupContent{
745 llvm::lsp::MarkupKind::Markdown,
746 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
749 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
750 completionList.items.emplace_back(item);
754 if (!
result.getName().empty()) {
756 llvm::formatv(
"{1} (field #{0})", it.index(),
result.getName())
758 item.filterText = item.label;
759 item.insertText =
result.getName().str();
760 completionList.items.emplace_back(item);
765 void codeCompleteOperationAttributeName(StringRef opName)
final {
771 const ods::AttributeConstraint &constraint = attr.getConstraint();
773 llvm::lsp::CompletionItem item;
774 item.label = attr.
getName().str();
775 item.kind = llvm::lsp::CompletionItemKind::Field;
776 item.detail = attr.isOptional() ?
"optional" :
"";
777 item.documentation = llvm::lsp::MarkupContent{
778 llvm::lsp::MarkupKind::Markdown,
779 llvm::formatv(
"{0}\n\n```c++\n{1}\n```\n", constraint.
getSummary(),
782 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
783 completionList.items.emplace_back(item);
787 void codeCompleteConstraintName(ast::Type currentType,
788 bool allowInlineTypeConstraints,
789 const ast::DeclScope *scope)
final {
790 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
791 StringRef snippetText =
"") {
792 llvm::lsp::CompletionItem item;
793 item.label = constraint.str();
794 item.kind = llvm::lsp::CompletionItemKind::Class;
795 item.detail = (constraint +
" constraint").str();
796 item.documentation = llvm::lsp::MarkupContent{
797 llvm::lsp::MarkupKind::Markdown,
798 (
"A single entity core constraint of type `" + mlirType +
"`").str()};
800 item.insertText = snippetText.str();
801 item.insertTextFormat = snippetText.empty()
802 ? llvm::lsp::InsertTextFormat::PlainText
803 : llvm::lsp::InsertTextFormat::Snippet;
804 completionList.items.emplace_back(item);
811 addCoreConstraint(
"Attr",
"mlir::Attribute");
812 addCoreConstraint(
"Op",
"mlir::Operation *");
813 addCoreConstraint(
"Value",
"mlir::Value");
814 addCoreConstraint(
"ValueRange",
"mlir::ValueRange");
815 addCoreConstraint(
"Type",
"mlir::Type");
816 addCoreConstraint(
"TypeRange",
"mlir::TypeRange");
818 if (allowInlineTypeConstraints) {
820 if (!currentType || isa<ast::AttributeType>(currentType))
821 addCoreConstraint(
"Attr<type>",
"mlir::Attribute",
"Attr<$1>");
823 if (!currentType || isa<ast::ValueType>(currentType))
824 addCoreConstraint(
"Value<type>",
"mlir::Value",
"Value<$1>");
826 if (!currentType || isa<ast::ValueRangeType>(currentType))
827 addCoreConstraint(
"ValueRange<type>",
"mlir::ValueRange",
833 for (
const ast::Decl *decl : scope->getDecls()) {
834 if (
const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
835 llvm::lsp::CompletionItem item;
836 item.label = cst->getName().getName().str();
837 item.kind = llvm::lsp::CompletionItemKind::Interface;
838 item.sortText =
"2_" + item.label;
842 if (cst->getInputs().size() != 1)
846 ast::Type constraintType = cst->getInputs()[0]->getType();
847 if (currentType && !currentType.refineWith(constraintType))
852 llvm::raw_string_ostream strOS(item.detail);
854 llvm::interleaveComma(
855 cst->getInputs(), strOS, [&](
const ast::VariableDecl *var) {
856 strOS << var->getName().getName() <<
": " << var->getType();
858 strOS <<
") -> " << cst->getResultType();
862 if (std::optional<std::string> doc =
864 item.documentation = llvm::lsp::MarkupContent{
865 llvm::lsp::MarkupKind::Markdown, std::move(*doc)};
868 completionList.items.emplace_back(item);
872 scope = scope->getParentScope();
876 void codeCompleteDialectName() final {
878 for (
const ods::Dialect &dialect : odsContext.
getDialects()) {
879 llvm::lsp::CompletionItem item;
880 item.label = dialect.getName().str();
881 item.kind = llvm::lsp::CompletionItemKind::Class;
882 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
883 completionList.items.emplace_back(item);
887 void codeCompleteOperationName(StringRef dialectName)
final {
888 const ods::Dialect *dialect = odsContext.
lookupDialect(dialectName);
893 const ods::Operation &op = *it.second;
895 llvm::lsp::CompletionItem item;
896 item.label = op.
getName().drop_front(dialectName.size() + 1).str();
897 item.kind = llvm::lsp::CompletionItemKind::Field;
898 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
899 completionList.items.emplace_back(item);
903 void codeCompletePatternMetadata() final {
904 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
905 StringRef snippetText =
"") {
906 llvm::lsp::CompletionItem item;
907 item.label = constraint.str();
908 item.kind = llvm::lsp::CompletionItemKind::Class;
909 item.detail =
"pattern metadata";
911 llvm::lsp::MarkupContent{llvm::lsp::MarkupKind::Markdown, desc.str()};
912 item.insertText = snippetText.str();
913 item.insertTextFormat = snippetText.empty()
914 ? llvm::lsp::InsertTextFormat::PlainText
915 : llvm::lsp::InsertTextFormat::Snippet;
916 completionList.items.emplace_back(item);
919 addSimpleConstraint(
"benefit",
"The `benefit` of matching the pattern.",
921 addSimpleConstraint(
"recursion",
922 "The pattern properly handles recursive application.");
925 void codeCompleteIncludeFilename(StringRef curPath)
final {
928 SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath));
929 llvm::sys::path::native(nativeRelDir);
935 auto addIncludeCompletion = [&](StringRef path,
bool isDirectory) {
936 llvm::lsp::CompletionItem item;
937 item.label = path.str();
938 item.kind = isDirectory ? llvm::lsp::CompletionItemKind::Folder
939 : llvm::lsp::CompletionItemKind::File;
940 if (seenResults.insert(item.label).second)
941 completionList.items.emplace_back(item);
946 for (StringRef includeDir : includeDirs) {
947 llvm::SmallString<128> dir = includeDir;
948 if (!nativeRelDir.empty())
949 llvm::sys::path::append(dir, nativeRelDir);
951 std::error_code errorCode;
952 for (
auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
953 e = llvm::sys::fs::directory_iterator();
954 !errorCode && it != e; it.increment(errorCode)) {
955 StringRef filename = llvm::sys::path::filename(it->path());
960 llvm::sys::fs::file_type fileType = it->type();
961 if (fileType == llvm::sys::fs::file_type::symlink_file) {
962 if (
auto fileStatus = it->status())
963 fileType = fileStatus->type();
967 case llvm::sys::fs::file_type::directory_file:
968 addIncludeCompletion(filename,
true);
970 case llvm::sys::fs::file_type::regular_file: {
972 if (filename.ends_with(
".pdll") || filename.ends_with(
".td"))
973 addIncludeCompletion(filename,
false);
984 llvm::sort(completionList.items, [](
const llvm::lsp::CompletionItem &
lhs,
985 const llvm::lsp::CompletionItem &
rhs) {
986 return lhs.label < rhs.label;
991 llvm::SourceMgr &sourceMgr;
992 llvm::lsp::CompletionList &completionList;
993 ods::Context &odsContext;
994 ArrayRef<std::string> includeDirs;
998llvm::lsp::CompletionList
999PDLDocument::getCodeCompletion(
const llvm::lsp::URIForFile &uri,
1000 const llvm::lsp::Position &completePos) {
1001 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
1002 if (!posLoc.isValid())
1003 return llvm::lsp::CompletionList();
1007 ods::Context tmpODSContext;
1008 llvm::lsp::CompletionList completionList;
1009 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
1011 sourceMgr.getIncludeDirs());
1013 ast::Context tmpContext(tmpODSContext);
1015 &lspCompleteContext);
1017 return completionList;
1025class LSPSignatureHelpContext :
public CodeCompleteContext {
1027 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1028 llvm::lsp::SignatureHelp &signatureHelp,
1029 ods::Context &odsContext)
1030 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1031 signatureHelp(signatureHelp), odsContext(odsContext) {}
1033 void codeCompleteCallSignature(
const ast::CallableDecl *callable,
1034 unsigned currentNumArgs)
final {
1035 signatureHelp.activeParameter = currentNumArgs;
1037 llvm::lsp::SignatureInformation signatureInfo;
1039 llvm::raw_string_ostream strOS(signatureInfo.label);
1040 strOS << callable->getName()->getName() <<
"(";
1041 auto formatParamFn = [&](
const ast::VariableDecl *var) {
1042 unsigned paramStart = strOS.str().size();
1043 strOS << var->getName().getName() <<
": " << var->getType();
1044 unsigned paramEnd = strOS.str().size();
1045 signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1046 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1047 std::make_pair(paramStart, paramEnd), std::string()});
1049 llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1050 strOS <<
") -> " << callable->getResultType();
1054 if (std::optional<std::string> doc =
1056 signatureInfo.documentation = std::move(*doc);
1058 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1062 codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
1063 unsigned currentNumOperands)
final {
1064 const ods::Operation *odsOp =
1066 codeCompleteOperationOperandOrResultSignature(
1068 odsOp ? odsOp->
getOperands() : ArrayRef<ods::OperandOrResult>(),
1069 currentNumOperands,
"operand",
"Value");
1072 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
1073 unsigned currentNumResults)
final {
1074 const ods::Operation *odsOp =
1076 codeCompleteOperationOperandOrResultSignature(
1078 odsOp ? odsOp->
getResults() : ArrayRef<ods::OperandOrResult>(),
1079 currentNumResults,
"result",
"Type");
1082 void codeCompleteOperationOperandOrResultSignature(
1083 std::optional<StringRef> opName,
const ods::Operation *odsOp,
1084 ArrayRef<ods::OperandOrResult> values,
unsigned currentValue,
1085 StringRef label, StringRef dataType) {
1086 signatureHelp.activeParameter = currentValue;
1092 if (odsOp && currentValue < values.size()) {
1093 llvm::lsp::SignatureInformation signatureInfo;
1097 llvm::raw_string_ostream strOS(signatureInfo.label);
1099 auto formatFn = [&](
const ods::OperandOrResult &value) {
1100 unsigned paramStart = strOS.str().size();
1102 strOS << value.getName() <<
": ";
1104 StringRef constraintDoc = value.getConstraint().getSummary();
1105 std::string paramDoc;
1106 switch (value.getVariableLengthKind()) {
1107 case ods::VariableLengthKind::Single:
1109 paramDoc = constraintDoc.str();
1111 case ods::VariableLengthKind::Optional:
1112 strOS << dataType <<
"?";
1113 paramDoc = (
"optional: " + constraintDoc).str();
1115 case ods::VariableLengthKind::Variadic:
1116 strOS << dataType <<
"Range";
1117 paramDoc = (
"variadic: " + constraintDoc).str();
1121 unsigned paramEnd = strOS.str().size();
1122 signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1123 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1124 std::make_pair(paramStart, paramEnd), paramDoc});
1126 llvm::interleaveComma(values, strOS, formatFn);
1129 signatureInfo.documentation =
1130 llvm::formatv(
"`op<{0}>` ODS {1} specification", *opName, label)
1132 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1136 if (currentValue == 0 && (!odsOp || !values.empty())) {
1137 llvm::lsp::SignatureInformation signatureInfo;
1138 signatureInfo.label =
1139 llvm::formatv(
"(<{0}s>: {1}Range)", label, dataType).str();
1140 signatureInfo.documentation =
1141 (
"Generic operation " + label +
" specification").str();
1142 signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1143 StringRef(signatureInfo.label).drop_front().drop_back().str(),
1144 std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1145 (
"All of the " + label +
"s of the operation.").str()});
1146 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1151 llvm::SourceMgr &sourceMgr;
1152 llvm::lsp::SignatureHelp &signatureHelp;
1153 ods::Context &odsContext;
1157llvm::lsp::SignatureHelp
1158PDLDocument::getSignatureHelp(
const llvm::lsp::URIForFile &uri,
1159 const llvm::lsp::Position &helpPos) {
1160 SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
1161 if (!posLoc.isValid())
1162 return llvm::lsp::SignatureHelp();
1166 ods::Context tmpODSContext;
1167 llvm::lsp::SignatureHelp signatureHelp;
1168 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1171 ast::Context tmpContext(tmpODSContext);
1175 return signatureHelp;
1188 if (
auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1190 if (declName && declName->
getName() == name)
1197void PDLDocument::getInlayHints(
const llvm::lsp::URIForFile &uri,
1198 const llvm::lsp::Range &range,
1199 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1202 SMRange rangeLoc = range.getAsSMRange(sourceMgr);
1203 if (!rangeLoc.isValid())
1205 (*astModule)->walk([&](
const ast::Node *node) {
1206 SMRange loc = node->
getLoc();
1214 llvm::TypeSwitch<const ast::Node *>(node)
1215 .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1216 [&](
const auto *node) {
1217 this->getInlayHintsFor(node, uri, inlayHints);
1222void PDLDocument::getInlayHintsFor(
1223 const ast::VariableDecl *decl,
const llvm::lsp::URIForFile &uri,
1224 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1231 if (
const ast::Expr *expr = decl->
getInitExpr()) {
1235 if (isa<ast::OperationExpr>(expr))
1239 llvm::lsp::InlayHint hint(llvm::lsp::InlayHintKind::Type,
1240 llvm::lsp::Position(sourceMgr, decl->
getLoc().End));
1242 llvm::raw_string_ostream labelOS(hint.label);
1243 labelOS <<
": " << decl->
getType();
1246 inlayHints.emplace_back(std::move(hint));
1249void PDLDocument::getInlayHintsFor(
1250 const ast::CallExpr *expr,
const llvm::lsp::URIForFile &uri,
1251 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1253 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
1254 const auto *callable =
1255 callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1261 for (
const auto &it : llvm::zip(expr->
getArguments(), callable->getInputs()))
1262 addParameterHintFor(inlayHints, std::get<0>(it),
1263 std::get<1>(it)->getName().getName());
1266void PDLDocument::getInlayHintsFor(
1267 const ast::OperationExpr *expr,
const llvm::lsp::URIForFile &uri,
1268 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1270 ast::OperationType opType = dyn_cast<ast::OperationType>(expr->
getType());
1273 auto addOpHint = [&](
const ast::Expr *valueExpr, StringRef label) {
1278 addParameterHintFor(inlayHints, valueExpr, label);
1284 auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1285 ArrayRef<ods::OperandOrResult> odsValues,
1286 StringRef allValuesName) {
1292 if (values.size() != odsValues.size()) {
1294 if (values.size() == 1)
1295 return addOpHint(values.front(), allValuesName);
1299 for (
const auto &it : llvm::zip(values, odsValues))
1300 addOpHint(std::get<0>(it), std::get<1>(it).getName());
1306 : ArrayRef<ods::OperandOrResult>(),
1310 : ArrayRef<ods::OperandOrResult>(),
1314void PDLDocument::addParameterHintFor(
1315 std::vector<llvm::lsp::InlayHint> &inlayHints,
const ast::Expr *expr,
1320 llvm::lsp::InlayHint hint(
1321 llvm::lsp::InlayHintKind::Parameter,
1322 llvm::lsp::Position(sourceMgr, expr->
getLoc().Start));
1323 hint.label = (label +
":").str();
1324 hint.paddingRight =
true;
1325 inlayHints.emplace_back(std::move(hint));
1332void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1336 if (kind == lsp::PDLLViewOutputKind::AST) {
1337 (*astModule)->print(os);
1344 MLIRContext mlirContext;
1345 SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1346 OwningOpRef<ModuleOp> pdlModule =
1350 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1351 pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1356 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1357 "unexpected PDLLViewOutputKind");
1367struct PDLTextFileChunk {
1368 PDLTextFileChunk(uint64_t lineOffset,
const llvm::lsp::URIForFile &uri,
1370 const std::vector<std::string> &extraDirs,
1371 std::vector<llvm::lsp::Diagnostic> &diagnostics)
1372 : lineOffset(lineOffset),
1373 document(uri, contents, extraDirs, diagnostics) {}
1377 void adjustLocForChunkOffset(llvm::lsp::Range &range) {
1378 adjustLocForChunkOffset(range.start);
1379 adjustLocForChunkOffset(range.end);
1383 void adjustLocForChunkOffset(llvm::lsp::Position &pos) {
1384 pos.line += lineOffset;
1388 uint64_t lineOffset;
1390 PDLDocument document;
1402 PDLTextFile(
const llvm::lsp::URIForFile &uri, StringRef fileContents,
1403 int64_t version,
const std::vector<std::string> &extraDirs,
1404 std::vector<llvm::lsp::Diagnostic> &diagnostics);
1407 int64_t getVersion()
const {
return version; }
1412 update(
const llvm::lsp::URIForFile &uri, int64_t newVersion,
1413 ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
1414 std::vector<llvm::lsp::Diagnostic> &diagnostics);
1420 void getLocationsOf(
const llvm::lsp::URIForFile &uri,
1421 llvm::lsp::Position defPos,
1422 std::vector<llvm::lsp::Location> &locations);
1423 void findReferencesOf(
const llvm::lsp::URIForFile &uri,
1424 llvm::lsp::Position pos,
1425 std::vector<llvm::lsp::Location> &references);
1426 void getDocumentLinks(
const llvm::lsp::URIForFile &uri,
1427 std::vector<llvm::lsp::DocumentLink> &links);
1428 std::optional<llvm::lsp::Hover> findHover(
const llvm::lsp::URIForFile &uri,
1429 llvm::lsp::Position hoverPos);
1430 void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
1431 llvm::lsp::CompletionList getCodeCompletion(
const llvm::lsp::URIForFile &uri,
1432 llvm::lsp::Position completePos);
1433 llvm::lsp::SignatureHelp getSignatureHelp(
const llvm::lsp::URIForFile &uri,
1434 llvm::lsp::Position helpPos);
1435 void getInlayHints(
const llvm::lsp::URIForFile &uri, llvm::lsp::Range range,
1436 std::vector<llvm::lsp::InlayHint> &inlayHints);
1440 using ChunkIterator = llvm::pointee_iterator<
1441 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1444 void initialize(
const llvm::lsp::URIForFile &uri, int64_t newVersion,
1445 std::vector<llvm::lsp::Diagnostic> &diagnostics);
1450 ChunkIterator getChunkItFor(llvm::lsp::Position &pos);
1451 PDLTextFileChunk &getChunkFor(llvm::lsp::Position &pos) {
1452 return *getChunkItFor(pos);
1456 std::string contents;
1459 int64_t version = 0;
1462 int64_t totalNumLines = 0;
1466 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1469 std::vector<std::string> extraIncludeDirs;
1473PDLTextFile::PDLTextFile(
const llvm::lsp::URIForFile &uri,
1474 StringRef fileContents, int64_t version,
1475 const std::vector<std::string> &extraDirs,
1476 std::vector<llvm::lsp::Diagnostic> &diagnostics)
1477 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1482PDLTextFile::update(
const llvm::lsp::URIForFile &uri,
int64_t newVersion,
1484 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1485 if (
failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes,
1487 llvm::lsp::Logger::error(
"Failed to update contents of {0}", uri.file());
1496void PDLTextFile::getLocationsOf(
const llvm::lsp::URIForFile &uri,
1497 llvm::lsp::Position defPos,
1498 std::vector<llvm::lsp::Location> &locations) {
1499 PDLTextFileChunk &chunk = getChunkFor(defPos);
1500 chunk.document.getLocationsOf(uri, defPos, locations);
1503 if (chunk.lineOffset == 0)
1505 for (llvm::lsp::Location &loc : locations)
1507 chunk.adjustLocForChunkOffset(loc.range);
1510void PDLTextFile::findReferencesOf(
1511 const llvm::lsp::URIForFile &uri, llvm::lsp::Position pos,
1512 std::vector<llvm::lsp::Location> &references) {
1513 PDLTextFileChunk &chunk = getChunkFor(pos);
1514 chunk.document.findReferencesOf(uri, pos, references);
1517 if (chunk.lineOffset == 0)
1519 for (llvm::lsp::Location &loc : references)
1521 chunk.adjustLocForChunkOffset(loc.range);
1524void PDLTextFile::getDocumentLinks(
1525 const llvm::lsp::URIForFile &uri,
1526 std::vector<llvm::lsp::DocumentLink> &links) {
1527 chunks.front()->document.getDocumentLinks(uri, links);
1528 for (
const auto &it : llvm::drop_begin(chunks)) {
1529 size_t currentNumLinks = links.size();
1530 it->document.getDocumentLinks(uri, links);
1534 for (
auto &link : llvm::drop_begin(links, currentNumLinks))
1535 it->adjustLocForChunkOffset(link.range);
1539std::optional<llvm::lsp::Hover>
1540PDLTextFile::findHover(
const llvm::lsp::URIForFile &uri,
1541 llvm::lsp::Position hoverPos) {
1542 PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1543 std::optional<llvm::lsp::Hover> hoverInfo =
1544 chunk.document.findHover(uri, hoverPos);
1547 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1548 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1552void PDLTextFile::findDocumentSymbols(
1553 std::vector<llvm::lsp::DocumentSymbol> &symbols) {
1554 if (chunks.size() == 1)
1555 return chunks.front()->document.findDocumentSymbols(symbols);
1559 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1560 PDLTextFileChunk &chunk = *chunks[i];
1561 llvm::lsp::Position startPos(chunk.lineOffset);
1562 llvm::lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1563 : chunks[i + 1]->lineOffset);
1564 llvm::lsp::DocumentSymbol symbol(
1565 "<file-split-" + Twine(i) +
">", llvm::lsp::SymbolKind::Namespace,
1566 llvm::lsp::Range(startPos, endPos),
1567 llvm::lsp::Range(startPos));
1568 chunk.document.findDocumentSymbols(symbol.children);
1572 SmallVector<llvm::lsp::DocumentSymbol *> symbolsToFix;
1573 for (llvm::lsp::DocumentSymbol &childSymbol : symbol.children)
1574 symbolsToFix.push_back(&childSymbol);
1576 while (!symbolsToFix.empty()) {
1577 llvm::lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1578 chunk.adjustLocForChunkOffset(symbol->range);
1579 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1581 for (llvm::lsp::DocumentSymbol &childSymbol : symbol->children)
1582 symbolsToFix.push_back(&childSymbol);
1587 symbols.emplace_back(std::move(symbol));
1591llvm::lsp::CompletionList
1592PDLTextFile::getCodeCompletion(
const llvm::lsp::URIForFile &uri,
1593 llvm::lsp::Position completePos) {
1594 PDLTextFileChunk &chunk = getChunkFor(completePos);
1595 llvm::lsp::CompletionList completionList =
1596 chunk.document.getCodeCompletion(uri, completePos);
1599 for (llvm::lsp::CompletionItem &item : completionList.items) {
1601 chunk.adjustLocForChunkOffset(item.textEdit->range);
1602 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1603 chunk.adjustLocForChunkOffset(edit.range);
1605 return completionList;
1608llvm::lsp::SignatureHelp
1609PDLTextFile::getSignatureHelp(
const llvm::lsp::URIForFile &uri,
1610 llvm::lsp::Position helpPos) {
1611 return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1614void PDLTextFile::getInlayHints(
const llvm::lsp::URIForFile &uri,
1615 llvm::lsp::Range range,
1616 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1617 auto startIt = getChunkItFor(range.start);
1618 auto endIt = getChunkItFor(range.end);
1621 auto getHintsForChunk = [&](ChunkIterator chunkIt, llvm::lsp::Range range) {
1622 size_t currentNumHints = inlayHints.size();
1623 chunkIt->document.getInlayHints(uri, range, inlayHints);
1627 if (&*chunkIt != &*chunks.front()) {
1628 for (
auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1629 chunkIt->adjustLocForChunkOffset(hint.position);
1633 auto getNumLines = [](ChunkIterator chunkIt) {
1634 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1638 if (startIt == endIt)
1639 return getHintsForChunk(startIt, range);
1643 getHintsForChunk(startIt,
1644 llvm::lsp::Range(range.start, getNumLines(startIt)));
1647 for (++startIt; startIt != endIt; ++startIt)
1648 getHintsForChunk(startIt, llvm::lsp::Range(0, getNumLines(startIt)));
1652 getHintsForChunk(startIt, llvm::lsp::Range(0, range.end));
1655lsp::PDLLViewOutputResult
1657 lsp::PDLLViewOutputResult
result;
1659 llvm::raw_string_ostream outputOS(
result.output);
1661 llvm::make_pointee_range(chunks),
1662 [&](PDLTextFileChunk &chunk) {
1663 chunk.document.getPDLLViewOutput(outputOS, kind);
1665 [&] { outputOS <<
"\n"
1671void PDLTextFile::initialize(
const llvm::lsp::URIForFile &uri,
1673 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1674 version = newVersion;
1678 SmallVector<StringRef, 8> subContents;
1680 chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1681 0, uri, subContents.front(), extraIncludeDirs,
1684 uint64_t lineOffset = subContents.front().count(
'\n');
1685 for (StringRef docContents : llvm::drop_begin(subContents)) {
1686 unsigned currentNumDiags = diagnostics.size();
1687 auto chunk = std::make_unique<PDLTextFileChunk>(
1688 lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1689 lineOffset += docContents.count(
'\n');
1693 for (llvm::lsp::Diagnostic &
diag :
1694 llvm::drop_begin(diagnostics, currentNumDiags)) {
1695 chunk->adjustLocForChunkOffset(
diag.range);
1697 if (!
diag.relatedInformation)
1699 for (
auto &it : *
diag.relatedInformation)
1700 if (it.location.uri == uri)
1701 chunk->adjustLocForChunkOffset(it.location.range);
1703 chunks.emplace_back(std::move(chunk));
1705 totalNumLines = lineOffset;
1708PDLTextFile::ChunkIterator
1709PDLTextFile::getChunkItFor(llvm::lsp::Position &pos) {
1710 if (chunks.size() == 1)
1711 return chunks.begin();
1715 auto it = llvm::upper_bound(
1716 chunks, pos, [](
const llvm::lsp::Position &pos,
const auto &chunk) {
1717 return static_cast<uint64_t
>(pos.line) < chunk->lineOffset;
1719 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1720 pos.line -= chunkIt->lineOffset;
1740 llvm::StringMap<std::unique_ptr<PDLTextFile>>
files;
1752 const URIForFile &uri, StringRef contents,
int64_t version,
1753 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1755 std::vector<std::string> additionalIncludeDirs =
impl->options.extraDirs;
1756 const auto &fileInfo =
impl->compilationDatabase.getFileInfo(uri.file());
1757 llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1759 impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1760 uri, contents, version, additionalIncludeDirs, diagnostics);
1765 int64_t version, std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1767 auto it =
impl->files.find(uri.file());
1768 if (it ==
impl->files.end())
1773 if (failed(it->second->update(uri, version, changes, diagnostics)))
1774 impl->files.erase(it);
1778 auto it =
impl->files.find(uri.file());
1779 if (it ==
impl->files.end())
1780 return std::nullopt;
1782 int64_t version = it->second->getVersion();
1783 impl->files.erase(it);
1788 const URIForFile &uri,
const Position &defPos,
1789 std::vector<llvm::lsp::Location> &locations) {
1790 auto fileIt =
impl->files.find(uri.file());
1791 if (fileIt !=
impl->files.end())
1792 fileIt->second->getLocationsOf(uri, defPos, locations);
1796 const URIForFile &uri,
const Position &pos,
1797 std::vector<llvm::lsp::Location> &references) {
1798 auto fileIt =
impl->files.find(uri.file());
1799 if (fileIt !=
impl->files.end())
1800 fileIt->second->findReferencesOf(uri, pos, references);
1804 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1805 auto fileIt =
impl->files.find(uri.file());
1806 if (fileIt !=
impl->files.end())
1807 return fileIt->second->getDocumentLinks(uri, documentLinks);
1810std::optional<llvm::lsp::Hover>
1812 auto fileIt =
impl->files.find(uri.file());
1813 if (fileIt !=
impl->files.end())
1814 return fileIt->second->findHover(uri, hoverPos);
1815 return std::nullopt;
1819 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1820 auto fileIt =
impl->files.find(uri.file());
1821 if (fileIt !=
impl->files.end())
1822 fileIt->second->findDocumentSymbols(symbols);
1827 const Position &completePos) {
1828 auto fileIt =
impl->files.find(uri.file());
1829 if (fileIt !=
impl->files.end())
1830 return fileIt->second->getCodeCompletion(uri, completePos);
1831 return CompletionList();
1834llvm::lsp::SignatureHelp
1836 const Position &helpPos) {
1837 auto fileIt =
impl->files.find(uri.file());
1838 if (fileIt !=
impl->files.end())
1839 return fileIt->second->getSignatureHelp(uri, helpPos);
1840 return SignatureHelp();
1844 std::vector<InlayHint> &inlayHints) {
1845 auto fileIt =
impl->files.find(uri.file());
1846 if (fileIt ==
impl->files.end())
1848 fileIt->second->getInlayHints(uri, range, inlayHints);
1851 llvm::sort(inlayHints);
1852 inlayHints.erase(llvm::unique(inlayHints), inlayHints.end());
1855std::optional<lsp::PDLLViewOutputResult>
1858 auto fileIt =
impl->files.find(uri.file());
1859 if (fileIt !=
impl->files.end())
1860 return fileIt->second->getPDLLViewOutput(kind);
1861 return std::nullopt;
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
static std::string diag(const llvm::Value &value)
static std::optional< std::string > getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl)
Get or extract the documentation for the given decl.
static llvm::lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, const llvm::lsp::URIForFile &uri)
Returns a language server location from the given source range.
static std::optional< llvm::lsp::Diagnostic > getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, const llvm::lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static bool shouldAddHintFor(const ast::Expr *expr, StringRef name)
Returns true if the given name should be added as a hint for expr.
static llvm::lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, const llvm::lsp::URIForFile &mainFileURI)
Returns a language server uri for the given source location.
static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc)
Returns true if the given location is in the main file of the source manager.
static llvm::ManagedStatic< PassManagerOptions > options
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This class contains a collection of compilation information for files provided to the language server...
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
void getDocumentLinks(const URIForFile &uri, std::vector< DocumentLink > &documentLinks)
Return the document links referenced by the given file.
void getInlayHints(const URIForFile &uri, const Range &range, std::vector< InlayHint > &inlayHints)
Get the inlay hints for the range within the given file.
void addDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add the document, with the provided version, at the given URI.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
SignatureHelp getSignatureHelp(const URIForFile &uri, const Position &helpPos)
Get the signature help for the position within the given file.
void updateDocument(const URIForFile &uri, ArrayRef< TextDocumentContentChangeEvent > changes, int64_t version, std::vector< Diagnostic > &diagnostics)
Update the document, with the provided version, at the given URI.
PDLLServer(const Options &options)
std::optional< PDLLViewOutputResult > getPDLLViewOutput(const URIForFile &uri, PDLLViewOutputKind kind)
Get the output of the given PDLL file, or std::nullopt if there is no valid output.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
This class represents the base Decl node.
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
This class provides a simple implementation of a PDLL diagnostic.
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This class represents a top-level AST module.
SMRange getLoc() const
Return the location of this node.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
const Name & getName() const
Return the name of the decl.
Type getType() const
Return the type of the decl.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
StringRef getSummary() const
Return the summary of this constraint.
StringRef getName() const
Return the unique name of this constraint.
This class contains all of the registered ODS operation classes.
auto getDialects() const
Return a range of all of the registered dialects.
const Dialect * lookupDialect(StringRef name) const
Lookup a dialect registered with the given name, or null if no dialect with that name was inserted.
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
const llvm::StringMap< std::unique_ptr< Operation > > & getOperations() const
Return a map of all of the operations registered to this dialect.
StringRef getDescription() const
Returns the description of the operation.
StringRef getSummary() const
Returns the summary of the operation.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
StringRef getName() const
Returns the name of the operation.
SMRange getLoc() const
Return the source location of this operation.
ArrayRef< Attribute > getAttributes() const
Returns the attributes of this operation.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
PDLLViewOutputKind
The type of output to view from PDLL.
void gatherIncludeFiles(llvm::SourceMgr &sourceMgr, SmallVectorImpl< SourceMgrInclude > &includes)
Given a source manager, gather all of the processed include files.
bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
std::optional< std::string > extractSourceDocComment(llvm::SourceMgr &sourceMgr, SMLoc loc)
Extract a documentation comment for the given location within the source manager.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
const char *const kDefaultSplitMarker
llvm::StringSet< AllocatorTy > StringSet
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Impl(const Options &options)
lsp::CompilationDatabase compilationDatabase
The compilation database containing additional information for files passed to the server.
llvm::StringMap< std::unique_ptr< PDLTextFile > > files
The files held by the server, mapped by their URI file name.
const Options & options
PDLL LSP options.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.
SMRange getLoc() const
Get the location of this name.