MLIR 22.0.0git
PDLLServer.cpp
Go to the documentation of this file.
1//===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PDLLServer.h"
10
11#include "Protocol.h"
12#include "mlir/IR/BuiltinOps.h"
27#include "llvm/ADT/IntervalMap.h"
28#include "llvm/ADT/StringMap.h"
29#include "llvm/ADT/StringSet.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/FileSystem.h"
32#include "llvm/Support/LSP/Logging.h"
33#include "llvm/Support/Path.h"
34#include "llvm/Support/VirtualFileSystem.h"
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::pdll;
39
40/// Returns a language server uri for the given source location. `mainFileURI`
41/// corresponds to the uri for the main file of the source manager.
42static llvm::lsp::URIForFile
43getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
44 const llvm::lsp::URIForFile &mainFileURI) {
45 int bufferId = mgr.FindBufferContainingLoc(loc.Start);
46 if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
47 return mainFileURI;
49 llvm::lsp::URIForFile::fromFile(
50 mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
51 if (fileForLoc)
52 return *fileForLoc;
53 llvm::lsp::Logger::error("Failed to create URI for include file: {0}",
54 llvm::toString(fileForLoc.takeError()));
55 return mainFileURI;
56}
57
58/// Returns true if the given location is in the main file of the source
59/// manager.
60static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
61 return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
62}
63
64/// Returns a language server location from the given source range.
65static llvm::lsp::Location
66getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
67 const llvm::lsp::URIForFile &uri) {
68 return llvm::lsp::Location(getURIFromLoc(mgr, range, uri),
69 llvm::lsp::Range(mgr, range));
70}
71
72/// Convert the given MLIR diagnostic to the LSP form.
73static std::optional<llvm::lsp::Diagnostic>
74getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
75 const llvm::lsp::URIForFile &uri) {
76 llvm::lsp::Diagnostic lspDiag;
77 lspDiag.source = "pdll";
78
79 // FIXME: Right now all of the diagnostics are treated as parser issues, but
80 // some are parser and some are verifier.
81 lspDiag.category = "Parse Error";
82
83 // Try to grab a file location for this diagnostic.
84 llvm::lsp::Location loc =
85 getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
86 lspDiag.range = loc.range;
87
88 // Skip diagnostics that weren't emitted within the main file.
89 if (loc.uri != uri)
90 return std::nullopt;
91
92 // Convert the severity for the diagnostic.
93 switch (diag.getSeverity()) {
94 case ast::Diagnostic::Severity::DK_Note:
95 llvm_unreachable("expected notes to be handled separately");
96 case ast::Diagnostic::Severity::DK_Warning:
97 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
98 break;
99 case ast::Diagnostic::Severity::DK_Error:
100 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
101 break;
102 case ast::Diagnostic::Severity::DK_Remark:
103 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
104 break;
105 }
106 lspDiag.message = diag.getMessage().str();
107
108 // Attach any notes to the main diagnostic as related information.
109 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
110 for (const ast::Diagnostic &note : diag.getNotes()) {
111 relatedDiags.emplace_back(
112 getLocationFromLoc(sourceMgr, note.getLocation(), uri),
113 note.getMessage().str());
114 }
115 if (!relatedDiags.empty())
116 lspDiag.relatedInformation = std::move(relatedDiags);
117
118 return lspDiag;
119}
120
121/// Get or extract the documentation for the given decl.
122static std::optional<std::string>
123getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl) {
124 // If the decl already had documentation set, use it.
125 if (std::optional<StringRef> doc = decl->getDocComment())
126 return doc->str();
127
128 // If the decl doesn't yet have documentation, try to extract it from the
129 // source file.
130 return lsp::extractSourceDocComment(sourceMgr, decl->getLoc().Start);
131}
132
133//===----------------------------------------------------------------------===//
134// PDLIndex
135//===----------------------------------------------------------------------===//
136
137namespace {
138struct PDLIndexSymbol {
139 explicit PDLIndexSymbol(const ast::Decl *definition)
140 : definition(definition) {}
141 explicit PDLIndexSymbol(const ods::Operation *definition)
142 : definition(definition) {}
143
144 /// Return the location of the definition of this symbol.
145 SMRange getDefLoc() const {
146 if (const ast::Decl *decl =
147 llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
148 const ast::Name *declName = decl->getName();
149 return declName ? declName->getLoc() : decl->getLoc();
150 }
151 return cast<const ods::Operation *>(definition)->getLoc();
152 }
153
154 /// The main definition of the symbol.
155 PointerUnion<const ast::Decl *, const ods::Operation *> definition;
156 /// The set of references to the symbol.
157 std::vector<SMRange> references;
158};
159
160/// This class provides an index for definitions/uses within a PDL document.
161/// It provides efficient lookup of a definition given an input source range.
162class PDLIndex {
163public:
164 PDLIndex() : intervalMap(allocator) {}
165
166 /// Initialize the index with the given ast::Module.
167 void initialize(const ast::Module &module, const ods::Context &odsContext);
168
169 /// Lookup a symbol for the given location. Returns nullptr if no symbol could
170 /// be found. If provided, `overlappedRange` is set to the range that the
171 /// provided `loc` overlapped with.
172 const PDLIndexSymbol *lookup(SMLoc loc,
173 SMRange *overlappedRange = nullptr) const;
174
175private:
176 /// The type of interval map used to store source references. SMRange is
177 /// half-open, so we also need to use a half-open interval map.
178 using MapT =
179 llvm::IntervalMap<const char *, const PDLIndexSymbol *,
180 llvm::IntervalMapImpl::NodeSizer<
181 const char *, const PDLIndexSymbol *>::LeafSize,
182 llvm::IntervalMapHalfOpenInfo<const char *>>;
183
184 /// An allocator for the interval map.
185 MapT::Allocator allocator;
186
187 /// An interval map containing a corresponding definition mapped to a source
188 /// interval.
189 MapT intervalMap;
190
191 /// A mapping between definitions and their corresponding symbol.
193};
194} // namespace
195
196void PDLIndex::initialize(const ast::Module &module,
197 const ods::Context &odsContext) {
198 auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
199 auto it = defToSymbol.try_emplace(def, nullptr);
200 if (it.second)
201 it.first->second = std::make_unique<PDLIndexSymbol>(def);
202 return &*it.first->second;
203 };
204 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
205 bool isDef = false) {
206 const char *startLoc = refLoc.Start.getPointer();
207 const char *endLoc = refLoc.End.getPointer();
208 if (!intervalMap.overlaps(startLoc, endLoc)) {
209 intervalMap.insert(startLoc, endLoc, sym);
210 if (!isDef)
211 sym->references.push_back(refLoc);
212 }
213 };
214 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
215 const ods::Operation *odsOp = odsContext.lookupOperation(opName);
216 if (!odsOp)
217 return;
218
219 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
220 insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
221 insertDeclRef(symbol, refLoc);
222 };
223
224 module.walk([&](const ast::Node *node) {
225 // Handle references to PDL decls.
226 if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
227 if (std::optional<StringRef> name = decl->getName())
228 insertODSOpRef(*name, decl->getLoc());
229 } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
230 const ast::Name *name = decl->getName();
231 if (!name)
232 return;
233 PDLIndexSymbol *declSym = getOrInsertDef(decl);
234 insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
235
236 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
237 // Record references to any constraints.
238 for (const auto &it : varDecl->getConstraints())
239 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
240 }
241 } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
242 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
243 }
244 });
245}
246
247const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
248 SMRange *overlappedRange) const {
249 auto it = intervalMap.find(loc.getPointer());
250 if (!it.valid() || loc.getPointer() < it.start())
251 return nullptr;
252
253 if (overlappedRange) {
254 *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
255 SMLoc::getFromPointer(it.stop()));
256 }
257 return it.value();
258}
259
260//===----------------------------------------------------------------------===//
261// PDLDocument
262//===----------------------------------------------------------------------===//
263
264namespace {
265/// This class represents all of the information pertaining to a specific PDL
266/// document.
267struct PDLDocument {
268 PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
269 const std::vector<std::string> &extraDirs,
270 std::vector<llvm::lsp::Diagnostic> &diagnostics);
271 PDLDocument(const PDLDocument &) = delete;
272 PDLDocument &operator=(const PDLDocument &) = delete;
273
274 //===--------------------------------------------------------------------===//
275 // Definitions and References
276 //===--------------------------------------------------------------------===//
277
278 void getLocationsOf(const llvm::lsp::URIForFile &uri,
279 const llvm::lsp::Position &defPos,
280 std::vector<llvm::lsp::Location> &locations);
281 void findReferencesOf(const llvm::lsp::URIForFile &uri,
282 const llvm::lsp::Position &pos,
283 std::vector<llvm::lsp::Location> &references);
284
285 //===--------------------------------------------------------------------===//
286 // Document Links
287 //===--------------------------------------------------------------------===//
288
289 void getDocumentLinks(const llvm::lsp::URIForFile &uri,
290 std::vector<llvm::lsp::DocumentLink> &links);
291
292 //===--------------------------------------------------------------------===//
293 // Hover
294 //===--------------------------------------------------------------------===//
295
296 std::optional<llvm::lsp::Hover>
297 findHover(const llvm::lsp::URIForFile &uri,
298 const llvm::lsp::Position &hoverPos);
299 std::optional<llvm::lsp::Hover> findHover(const ast::Decl *decl,
300 const SMRange &hoverRange);
301 llvm::lsp::Hover buildHoverForOpName(const ods::Operation *op,
302 const SMRange &hoverRange);
303 llvm::lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
304 const SMRange &hoverRange);
305 llvm::lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
306 const SMRange &hoverRange);
307 llvm::lsp::Hover
308 buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
309 const SMRange &hoverRange);
310 template <typename T>
311 llvm::lsp::Hover
312 buildHoverForUserConstraintOrRewrite(StringRef typeName, const T *decl,
313 const SMRange &hoverRange);
314
315 //===--------------------------------------------------------------------===//
316 // Document Symbols
317 //===--------------------------------------------------------------------===//
318
319 void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
320
321 //===--------------------------------------------------------------------===//
322 // Code Completion
323 //===--------------------------------------------------------------------===//
324
325 llvm::lsp::CompletionList
326 getCodeCompletion(const llvm::lsp::URIForFile &uri,
327 const llvm::lsp::Position &completePos);
328
329 //===--------------------------------------------------------------------===//
330 // Signature Help
331 //===--------------------------------------------------------------------===//
332
333 llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri,
334 const llvm::lsp::Position &helpPos);
335
336 //===--------------------------------------------------------------------===//
337 // Inlay Hints
338 //===--------------------------------------------------------------------===//
339
340 void getInlayHints(const llvm::lsp::URIForFile &uri,
341 const llvm::lsp::Range &range,
342 std::vector<llvm::lsp::InlayHint> &inlayHints);
343 void getInlayHintsFor(const ast::VariableDecl *decl,
344 const llvm::lsp::URIForFile &uri,
345 std::vector<llvm::lsp::InlayHint> &inlayHints);
346 void getInlayHintsFor(const ast::CallExpr *expr,
347 const llvm::lsp::URIForFile &uri,
348 std::vector<llvm::lsp::InlayHint> &inlayHints);
349 void getInlayHintsFor(const ast::OperationExpr *expr,
350 const llvm::lsp::URIForFile &uri,
351 std::vector<llvm::lsp::InlayHint> &inlayHints);
352
353 /// Add a parameter hint for the given expression using `label`.
354 void addParameterHintFor(std::vector<llvm::lsp::InlayHint> &inlayHints,
355 const ast::Expr *expr, StringRef label);
356
357 //===--------------------------------------------------------------------===//
358 // PDLL ViewOutput
359 //===--------------------------------------------------------------------===//
360
361 void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
362
363 //===--------------------------------------------------------------------===//
364 // Fields
365 //===--------------------------------------------------------------------===//
366
367 /// The include directories for this file.
368 std::vector<std::string> includeDirs;
369
370 /// The source manager containing the contents of the input file.
371 llvm::SourceMgr sourceMgr;
372
373 /// The ODS and AST contexts.
374 ods::Context odsContext;
375 ast::Context astContext;
376
377 /// The parsed AST module, or failure if the file wasn't valid.
378 FailureOr<ast::Module *> astModule;
379
380 /// The index of the parsed module.
381 PDLIndex index;
382
383 /// The set of includes of the parsed module.
384 SmallVector<lsp::SourceMgrInclude> parsedIncludes;
385};
386} // namespace
387
388PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
389 const std::vector<std::string> &extraDirs,
390 std::vector<llvm::lsp::Diagnostic> &diagnostics)
391 : astContext(odsContext) {
392 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
393 if (!memBuffer) {
394 llvm::lsp::Logger::error("Failed to create memory buffer for file",
395 uri.file());
396 return;
397 }
398
399 // Build the set of include directories for this file.
400 llvm::SmallString<32> uriDirectory(uri.file());
401 llvm::sys::path::remove_filename(uriDirectory);
402 includeDirs.push_back(uriDirectory.str().str());
403 llvm::append_range(includeDirs, extraDirs);
404
405 sourceMgr.setIncludeDirs(includeDirs);
406 sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
407 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
408
409 astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
410 if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
411 diagnostics.push_back(std::move(*lspDiag));
412 });
413 astModule = parsePDLLAST(astContext, sourceMgr, /*enableDocumentation=*/true);
414
415 // Initialize the set of parsed includes.
416 lsp::gatherIncludeFiles(sourceMgr, parsedIncludes);
417
418 // If we failed to parse the module, there is nothing left to initialize.
419 if (failed(astModule))
420 return;
421
422 // Prepare the AST index with the parsed module.
423 index.initialize(**astModule, odsContext);
424}
425
426//===----------------------------------------------------------------------===//
427// PDLDocument: Definitions and References
428//===----------------------------------------------------------------------===//
429
430void PDLDocument::getLocationsOf(const llvm::lsp::URIForFile &uri,
431 const llvm::lsp::Position &defPos,
432 std::vector<llvm::lsp::Location> &locations) {
433 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
434 const PDLIndexSymbol *symbol = index.lookup(posLoc);
435 if (!symbol)
436 return;
437
438 locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
439}
440
441void PDLDocument::findReferencesOf(
442 const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos,
443 std::vector<llvm::lsp::Location> &references) {
444 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
445 const PDLIndexSymbol *symbol = index.lookup(posLoc);
446 if (!symbol)
447 return;
448
449 references.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
450 for (SMRange refLoc : symbol->references)
451 references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri));
452}
453
454//===--------------------------------------------------------------------===//
455// PDLDocument: Document Links
456//===--------------------------------------------------------------------===//
457
458void PDLDocument::getDocumentLinks(
459 const llvm::lsp::URIForFile &uri,
460 std::vector<llvm::lsp::DocumentLink> &links) {
461 for (const lsp::SourceMgrInclude &include : parsedIncludes)
462 links.emplace_back(include.range, include.uri);
463}
464
465//===----------------------------------------------------------------------===//
466// PDLDocument: Hover
467//===----------------------------------------------------------------------===//
468
469std::optional<llvm::lsp::Hover>
470PDLDocument::findHover(const llvm::lsp::URIForFile &uri,
471 const llvm::lsp::Position &hoverPos) {
472 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
473
474 // Check for a reference to an include.
475 for (const lsp::SourceMgrInclude &include : parsedIncludes)
476 if (include.range.contains(hoverPos))
477 return include.buildHover();
478
479 // Find the symbol at the given location.
480 SMRange hoverRange;
481 const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
482 if (!symbol)
483 return std::nullopt;
484
485 // Add hover for operation names.
486 if (const auto *op =
487 llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
488 return buildHoverForOpName(op, hoverRange);
489 const auto *decl = cast<const ast::Decl *>(symbol->definition);
490 return findHover(decl, hoverRange);
491}
492
493std::optional<llvm::lsp::Hover>
494PDLDocument::findHover(const ast::Decl *decl, const SMRange &hoverRange) {
495 // Add hover for variables.
496 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
497 return buildHoverForVariable(varDecl, hoverRange);
498
499 // Add hover for patterns.
500 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
501 return buildHoverForPattern(patternDecl, hoverRange);
502
503 // Add hover for core constraints.
504 if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
505 return buildHoverForCoreConstraint(cst, hoverRange);
506
507 // Add hover for user constraints.
508 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
509 return buildHoverForUserConstraintOrRewrite("Constraint", cst, hoverRange);
510
511 // Add hover for user rewrites.
512 if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
513 return buildHoverForUserConstraintOrRewrite("Rewrite", rewrite, hoverRange);
514
515 return std::nullopt;
516}
517
518llvm::lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
519 const SMRange &hoverRange) {
520 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
521 {
522 llvm::raw_string_ostream hoverOS(hover.contents.value);
523 hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
524 << op->getSummary() << "\n***\n"
525 << op->getDescription();
526 }
527 return hover;
528}
529
530llvm::lsp::Hover
531PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
532 const SMRange &hoverRange) {
533 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
534 {
535 llvm::raw_string_ostream hoverOS(hover.contents.value);
536 hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
537 << "Type: `" << varDecl->getType() << "`\n";
538 }
539 return hover;
540}
541
542llvm::lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
543 const SMRange &hoverRange) {
544 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
545 {
546 llvm::raw_string_ostream hoverOS(hover.contents.value);
547 hoverOS << "**Pattern**";
548 if (const ast::Name *name = decl->getName())
549 hoverOS << ": `" << name->getName() << "`";
550 hoverOS << "\n***\n";
551 if (std::optional<uint16_t> benefit = decl->getBenefit())
552 hoverOS << "Benefit: " << *benefit << "\n";
553 if (decl->hasBoundedRewriteRecursion())
554 hoverOS << "HasBoundedRewriteRecursion\n";
555 hoverOS << "RootOp: `"
556 << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
557
558 // Format the documentation for the decl.
559 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
560 hoverOS << "\n" << *doc << "\n";
561 }
562 return hover;
563}
564
565llvm::lsp::Hover
566PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
567 const SMRange &hoverRange) {
568 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
569 {
570 llvm::raw_string_ostream hoverOS(hover.contents.value);
571 hoverOS << "**Constraint**: `";
573 .Case([&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
574 .Case([&](const ast::OpConstraintDecl *opCst) {
575 hoverOS << "Op";
576 if (std::optional<StringRef> name = opCst->getName())
577 hoverOS << "<" << *name << ">";
578 })
579 .Case([&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
580 .Case([&](const ast::TypeRangeConstraintDecl *) {
581 hoverOS << "TypeRange";
582 })
583 .Case([&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
584 .Case([&](const ast::ValueRangeConstraintDecl *) {
585 hoverOS << "ValueRange";
586 });
587 hoverOS << "`\n";
588 }
589 return hover;
590}
591
592template <typename T>
593llvm::lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
594 StringRef typeName, const T *decl, const SMRange &hoverRange) {
595 llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
596 {
597 llvm::raw_string_ostream hoverOS(hover.contents.value);
598 hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
599 << "`\n***\n";
600 ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
601 if (!inputs.empty()) {
602 hoverOS << "Parameters:\n";
603 for (const ast::VariableDecl *input : inputs)
604 hoverOS << "* " << input->getName().getName() << ": `"
605 << input->getType() << "`\n";
606 hoverOS << "***\n";
607 }
608 ast::Type resultType = decl->getResultType();
609 if (auto resultTupleTy = dyn_cast<ast::TupleType>(resultType)) {
610 if (!resultTupleTy.empty()) {
611 hoverOS << "Results:\n";
612 for (auto it : llvm::zip(resultTupleTy.getElementNames(),
613 resultTupleTy.getElementTypes())) {
614 StringRef name = std::get<0>(it);
615 hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
616 << std::get<1>(it) << "`\n";
617 }
618 hoverOS << "***\n";
619 }
620 } else {
621 hoverOS << "Results:\n* `" << resultType << "`\n";
622 hoverOS << "***\n";
623 }
624
625 // Format the documentation for the decl.
626 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
627 hoverOS << "\n" << *doc << "\n";
628 }
629 return hover;
630}
631
632//===----------------------------------------------------------------------===//
633// PDLDocument: Document Symbols
634//===----------------------------------------------------------------------===//
635
636void PDLDocument::findDocumentSymbols(
637 std::vector<llvm::lsp::DocumentSymbol> &symbols) {
638 if (failed(astModule))
639 return;
640
641 for (const ast::Decl *decl : (*astModule)->getChildren()) {
642 if (!isMainFileLoc(sourceMgr, decl->getLoc()))
643 continue;
644
645 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
646 const ast::Name *name = patternDecl->getName();
647
648 SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
649 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
650
651 symbols.emplace_back(name ? name->getName() : "<pattern>",
652 llvm::lsp::SymbolKind::Class,
653 llvm::lsp::Range(sourceMgr, bodyLoc),
654 llvm::lsp::Range(sourceMgr, nameLoc));
655 } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
656 // TODO: Add source information for the code block body.
657 SMRange nameLoc = cDecl->getName().getLoc();
658 SMRange bodyLoc = nameLoc;
659
660 symbols.emplace_back(cDecl->getName().getName(),
661 llvm::lsp::SymbolKind::Function,
662 llvm::lsp::Range(sourceMgr, bodyLoc),
663 llvm::lsp::Range(sourceMgr, nameLoc));
664 } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
665 // TODO: Add source information for the code block body.
666 SMRange nameLoc = cDecl->getName().getLoc();
667 SMRange bodyLoc = nameLoc;
668
669 symbols.emplace_back(cDecl->getName().getName(),
670 llvm::lsp::SymbolKind::Function,
671 llvm::lsp::Range(sourceMgr, bodyLoc),
672 llvm::lsp::Range(sourceMgr, nameLoc));
673 }
674 }
675}
676
677//===----------------------------------------------------------------------===//
678// PDLDocument: Code Completion
679//===----------------------------------------------------------------------===//
680
681namespace {
682class LSPCodeCompleteContext : public CodeCompleteContext {
683public:
684 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
685 llvm::lsp::CompletionList &completionList,
686 ods::Context &odsContext,
687 ArrayRef<std::string> includeDirs)
688 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
689 completionList(completionList), odsContext(odsContext),
690 includeDirs(includeDirs) {}
691
692 void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
693 ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
694 ArrayRef<StringRef> elementNames = tupleType.getElementNames();
695 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
696 // Push back a completion item that uses the result index.
697 llvm::lsp::CompletionItem item;
698 item.label = llvm::formatv("{0} (field #{0})", i).str();
699 item.insertText = Twine(i).str();
700 item.filterText = item.sortText = item.insertText;
701 item.kind = llvm::lsp::CompletionItemKind::Field;
702 item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
703 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
704 completionList.items.emplace_back(item);
705
706 // If the element has a name, push back a completion item with that name.
707 if (!elementNames[i].empty()) {
708 item.label =
709 llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
710 item.filterText = item.label;
711 item.insertText = elementNames[i].str();
712 completionList.items.emplace_back(item);
713 }
714 }
715 }
716
717 void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
718 const ods::Operation *odsOp = opType.getODSOperation();
719 if (!odsOp)
720 return;
721
722 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
723 for (const auto &it : llvm::enumerate(results)) {
724 const ods::OperandOrResult &result = it.value();
725 const ods::TypeConstraint &constraint = result.getConstraint();
726
727 // Push back a completion item that uses the result index.
728 llvm::lsp::CompletionItem item;
729 item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
730 item.insertText = Twine(it.index()).str();
731 item.filterText = item.sortText = item.insertText;
732 item.kind = llvm::lsp::CompletionItemKind::Field;
733 switch (result.getVariableLengthKind()) {
734 case ods::VariableLengthKind::Single:
735 item.detail = llvm::formatv("{0}: Value", it.index()).str();
736 break;
737 case ods::VariableLengthKind::Optional:
738 item.detail = llvm::formatv("{0}: Value?", it.index()).str();
739 break;
740 case ods::VariableLengthKind::Variadic:
741 item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
742 break;
743 }
744 item.documentation = llvm::lsp::MarkupContent{
745 llvm::lsp::MarkupKind::Markdown,
746 llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
747 constraint.getCppClass())
748 .str()};
749 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
750 completionList.items.emplace_back(item);
751
752 // If the result has a name, push back a completion item with the result
753 // name.
754 if (!result.getName().empty()) {
755 item.label =
756 llvm::formatv("{1} (field #{0})", it.index(), result.getName())
757 .str();
758 item.filterText = item.label;
759 item.insertText = result.getName().str();
760 completionList.items.emplace_back(item);
761 }
762 }
763 }
764
765 void codeCompleteOperationAttributeName(StringRef opName) final {
766 const ods::Operation *odsOp = odsContext.lookupOperation(opName);
767 if (!odsOp)
768 return;
769
770 for (const ods::Attribute &attr : odsOp->getAttributes()) {
771 const ods::AttributeConstraint &constraint = attr.getConstraint();
772
773 llvm::lsp::CompletionItem item;
774 item.label = attr.getName().str();
775 item.kind = llvm::lsp::CompletionItemKind::Field;
776 item.detail = attr.isOptional() ? "optional" : "";
777 item.documentation = llvm::lsp::MarkupContent{
778 llvm::lsp::MarkupKind::Markdown,
779 llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
780 constraint.getCppClass())
781 .str()};
782 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
783 completionList.items.emplace_back(item);
784 }
785 }
786
787 void codeCompleteConstraintName(ast::Type currentType,
788 bool allowInlineTypeConstraints,
789 const ast::DeclScope *scope) final {
790 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
791 StringRef snippetText = "") {
792 llvm::lsp::CompletionItem item;
793 item.label = constraint.str();
794 item.kind = llvm::lsp::CompletionItemKind::Class;
795 item.detail = (constraint + " constraint").str();
796 item.documentation = llvm::lsp::MarkupContent{
797 llvm::lsp::MarkupKind::Markdown,
798 ("A single entity core constraint of type `" + mlirType + "`").str()};
799 item.sortText = "0";
800 item.insertText = snippetText.str();
801 item.insertTextFormat = snippetText.empty()
802 ? llvm::lsp::InsertTextFormat::PlainText
803 : llvm::lsp::InsertTextFormat::Snippet;
804 completionList.items.emplace_back(item);
805 };
806
807 // Insert completions for the core constraints. Some core constraints have
808 // additional characteristics, so we may add then even if a type has been
809 // inferred.
810 if (!currentType) {
811 addCoreConstraint("Attr", "mlir::Attribute");
812 addCoreConstraint("Op", "mlir::Operation *");
813 addCoreConstraint("Value", "mlir::Value");
814 addCoreConstraint("ValueRange", "mlir::ValueRange");
815 addCoreConstraint("Type", "mlir::Type");
816 addCoreConstraint("TypeRange", "mlir::TypeRange");
817 }
818 if (allowInlineTypeConstraints) {
819 /// Attr<Type>.
820 if (!currentType || isa<ast::AttributeType>(currentType))
821 addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
822 /// Value<Type>.
823 if (!currentType || isa<ast::ValueType>(currentType))
824 addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
825 /// ValueRange<TypeRange>.
826 if (!currentType || isa<ast::ValueRangeType>(currentType))
827 addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
828 "ValueRange<$1>");
829 }
830
831 // If a scope was provided, check it for potential constraints.
832 while (scope) {
833 for (const ast::Decl *decl : scope->getDecls()) {
834 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
835 llvm::lsp::CompletionItem item;
836 item.label = cst->getName().getName().str();
837 item.kind = llvm::lsp::CompletionItemKind::Interface;
838 item.sortText = "2_" + item.label;
839
840 // Skip constraints that are not single-arg. We currently only
841 // complete variable constraints.
842 if (cst->getInputs().size() != 1)
843 continue;
844
845 // Ensure the input type matched the given type.
846 ast::Type constraintType = cst->getInputs()[0]->getType();
847 if (currentType && !currentType.refineWith(constraintType))
848 continue;
849
850 // Format the constraint signature.
851 {
852 llvm::raw_string_ostream strOS(item.detail);
853 strOS << "(";
854 llvm::interleaveComma(
855 cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
856 strOS << var->getName().getName() << ": " << var->getType();
857 });
858 strOS << ") -> " << cst->getResultType();
859 }
860
861 // Format the documentation for the constraint.
862 if (std::optional<std::string> doc =
863 getDocumentationFor(sourceMgr, cst)) {
864 item.documentation = llvm::lsp::MarkupContent{
865 llvm::lsp::MarkupKind::Markdown, std::move(*doc)};
866 }
867
868 completionList.items.emplace_back(item);
869 }
870 }
871
872 scope = scope->getParentScope();
873 }
874 }
875
876 void codeCompleteDialectName() final {
877 // Code complete known dialects.
878 for (const ods::Dialect &dialect : odsContext.getDialects()) {
879 llvm::lsp::CompletionItem item;
880 item.label = dialect.getName().str();
881 item.kind = llvm::lsp::CompletionItemKind::Class;
882 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
883 completionList.items.emplace_back(item);
884 }
885 }
886
887 void codeCompleteOperationName(StringRef dialectName) final {
888 const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
889 if (!dialect)
890 return;
891
892 for (const auto &it : dialect->getOperations()) {
893 const ods::Operation &op = *it.second;
894
895 llvm::lsp::CompletionItem item;
896 item.label = op.getName().drop_front(dialectName.size() + 1).str();
897 item.kind = llvm::lsp::CompletionItemKind::Field;
898 item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
899 completionList.items.emplace_back(item);
900 }
901 }
902
903 void codeCompletePatternMetadata() final {
904 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
905 StringRef snippetText = "") {
906 llvm::lsp::CompletionItem item;
907 item.label = constraint.str();
908 item.kind = llvm::lsp::CompletionItemKind::Class;
909 item.detail = "pattern metadata";
910 item.documentation =
911 llvm::lsp::MarkupContent{llvm::lsp::MarkupKind::Markdown, desc.str()};
912 item.insertText = snippetText.str();
913 item.insertTextFormat = snippetText.empty()
914 ? llvm::lsp::InsertTextFormat::PlainText
915 : llvm::lsp::InsertTextFormat::Snippet;
916 completionList.items.emplace_back(item);
917 };
918
919 addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
920 "benefit($1)");
921 addSimpleConstraint("recursion",
922 "The pattern properly handles recursive application.");
923 }
924
925 void codeCompleteIncludeFilename(StringRef curPath) final {
926 // Normalize the path to allow for interacting with the file system
927 // utilities.
928 SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath));
929 llvm::sys::path::native(nativeRelDir);
930
931 // Set of already included completion paths.
932 StringSet<> seenResults;
933
934 // Functor used to add a single include completion item.
935 auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
936 llvm::lsp::CompletionItem item;
937 item.label = path.str();
938 item.kind = isDirectory ? llvm::lsp::CompletionItemKind::Folder
939 : llvm::lsp::CompletionItemKind::File;
940 if (seenResults.insert(item.label).second)
941 completionList.items.emplace_back(item);
942 };
943
944 // Process the include directories for this file, adding any potential
945 // nested include files or directories.
946 for (StringRef includeDir : includeDirs) {
947 llvm::SmallString<128> dir = includeDir;
948 if (!nativeRelDir.empty())
949 llvm::sys::path::append(dir, nativeRelDir);
950
951 std::error_code errorCode;
952 for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
953 e = llvm::sys::fs::directory_iterator();
954 !errorCode && it != e; it.increment(errorCode)) {
955 StringRef filename = llvm::sys::path::filename(it->path());
956
957 // To know whether a symlink should be treated as file or a directory,
958 // we have to stat it. This should be cheap enough as there shouldn't be
959 // many symlinks.
960 llvm::sys::fs::file_type fileType = it->type();
961 if (fileType == llvm::sys::fs::file_type::symlink_file) {
962 if (auto fileStatus = it->status())
963 fileType = fileStatus->type();
964 }
965
966 switch (fileType) {
967 case llvm::sys::fs::file_type::directory_file:
968 addIncludeCompletion(filename, /*isDirectory=*/true);
969 break;
970 case llvm::sys::fs::file_type::regular_file: {
971 // Only consider concrete files that can actually be included by PDLL.
972 if (filename.ends_with(".pdll") || filename.ends_with(".td"))
973 addIncludeCompletion(filename, /*isDirectory=*/false);
974 break;
975 }
976 default:
977 break;
978 }
979 }
980 }
981
982 // Sort the completion results to make sure the output is deterministic in
983 // the face of different iteration schemes for different platforms.
984 llvm::sort(completionList.items, [](const llvm::lsp::CompletionItem &lhs,
985 const llvm::lsp::CompletionItem &rhs) {
986 return lhs.label < rhs.label;
987 });
988 }
989
990private:
991 llvm::SourceMgr &sourceMgr;
992 llvm::lsp::CompletionList &completionList;
993 ods::Context &odsContext;
994 ArrayRef<std::string> includeDirs;
995};
996} // namespace
997
998llvm::lsp::CompletionList
999PDLDocument::getCodeCompletion(const llvm::lsp::URIForFile &uri,
1000 const llvm::lsp::Position &completePos) {
1001 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
1002 if (!posLoc.isValid())
1003 return llvm::lsp::CompletionList();
1004
1005 // To perform code completion, we run another parse of the module with the
1006 // code completion context provided.
1007 ods::Context tmpODSContext;
1008 llvm::lsp::CompletionList completionList;
1009 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
1010 tmpODSContext,
1011 sourceMgr.getIncludeDirs());
1012
1013 ast::Context tmpContext(tmpODSContext);
1014 (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1015 &lspCompleteContext);
1016
1017 return completionList;
1018}
1019
1020//===----------------------------------------------------------------------===//
1021// PDLDocument: Signature Help
1022//===----------------------------------------------------------------------===//
1023
1024namespace {
1025class LSPSignatureHelpContext : public CodeCompleteContext {
1026public:
1027 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1028 llvm::lsp::SignatureHelp &signatureHelp,
1029 ods::Context &odsContext)
1030 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1031 signatureHelp(signatureHelp), odsContext(odsContext) {}
1032
1033 void codeCompleteCallSignature(const ast::CallableDecl *callable,
1034 unsigned currentNumArgs) final {
1035 signatureHelp.activeParameter = currentNumArgs;
1036
1037 llvm::lsp::SignatureInformation signatureInfo;
1038 {
1039 llvm::raw_string_ostream strOS(signatureInfo.label);
1040 strOS << callable->getName()->getName() << "(";
1041 auto formatParamFn = [&](const ast::VariableDecl *var) {
1042 unsigned paramStart = strOS.str().size();
1043 strOS << var->getName().getName() << ": " << var->getType();
1044 unsigned paramEnd = strOS.str().size();
1045 signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1046 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1047 std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()});
1048 };
1049 llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1050 strOS << ") -> " << callable->getResultType();
1051 }
1052
1053 // Format the documentation for the callable.
1054 if (std::optional<std::string> doc =
1055 getDocumentationFor(sourceMgr, callable))
1056 signatureInfo.documentation = std::move(*doc);
1057
1058 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1059 }
1060
1061 void
1062 codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
1063 unsigned currentNumOperands) final {
1064 const ods::Operation *odsOp =
1065 opName ? odsContext.lookupOperation(*opName) : nullptr;
1066 codeCompleteOperationOperandOrResultSignature(
1067 opName, odsOp,
1068 odsOp ? odsOp->getOperands() : ArrayRef<ods::OperandOrResult>(),
1069 currentNumOperands, "operand", "Value");
1070 }
1071
1072 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
1073 unsigned currentNumResults) final {
1074 const ods::Operation *odsOp =
1075 opName ? odsContext.lookupOperation(*opName) : nullptr;
1076 codeCompleteOperationOperandOrResultSignature(
1077 opName, odsOp,
1078 odsOp ? odsOp->getResults() : ArrayRef<ods::OperandOrResult>(),
1079 currentNumResults, "result", "Type");
1080 }
1081
1082 void codeCompleteOperationOperandOrResultSignature(
1083 std::optional<StringRef> opName, const ods::Operation *odsOp,
1084 ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1085 StringRef label, StringRef dataType) {
1086 signatureHelp.activeParameter = currentValue;
1087
1088 // If we have ODS information for the operation, add in the ODS signature
1089 // for the operation. We also verify that the current number of values is
1090 // not more than what is defined in ODS, as this will result in an error
1091 // anyways.
1092 if (odsOp && currentValue < values.size()) {
1093 llvm::lsp::SignatureInformation signatureInfo;
1094
1095 // Build the signature label.
1096 {
1097 llvm::raw_string_ostream strOS(signatureInfo.label);
1098 strOS << "(";
1099 auto formatFn = [&](const ods::OperandOrResult &value) {
1100 unsigned paramStart = strOS.str().size();
1101
1102 strOS << value.getName() << ": ";
1103
1104 StringRef constraintDoc = value.getConstraint().getSummary();
1105 std::string paramDoc;
1106 switch (value.getVariableLengthKind()) {
1107 case ods::VariableLengthKind::Single:
1108 strOS << dataType;
1109 paramDoc = constraintDoc.str();
1110 break;
1111 case ods::VariableLengthKind::Optional:
1112 strOS << dataType << "?";
1113 paramDoc = ("optional: " + constraintDoc).str();
1114 break;
1115 case ods::VariableLengthKind::Variadic:
1116 strOS << dataType << "Range";
1117 paramDoc = ("variadic: " + constraintDoc).str();
1118 break;
1119 }
1120
1121 unsigned paramEnd = strOS.str().size();
1122 signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1123 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1124 std::make_pair(paramStart, paramEnd), paramDoc});
1125 };
1126 llvm::interleaveComma(values, strOS, formatFn);
1127 strOS << ")";
1128 }
1129 signatureInfo.documentation =
1130 llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label)
1131 .str();
1132 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1133 }
1134
1135 // If there aren't any arguments yet, we also add the generic signature.
1136 if (currentValue == 0 && (!odsOp || !values.empty())) {
1137 llvm::lsp::SignatureInformation signatureInfo;
1138 signatureInfo.label =
1139 llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str();
1140 signatureInfo.documentation =
1141 ("Generic operation " + label + " specification").str();
1142 signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1143 StringRef(signatureInfo.label).drop_front().drop_back().str(),
1144 std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1145 ("All of the " + label + "s of the operation.").str()});
1146 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1147 }
1148 }
1149
1150private:
1151 llvm::SourceMgr &sourceMgr;
1152 llvm::lsp::SignatureHelp &signatureHelp;
1153 ods::Context &odsContext;
1154};
1155} // namespace
1156
1157llvm::lsp::SignatureHelp
1158PDLDocument::getSignatureHelp(const llvm::lsp::URIForFile &uri,
1159 const llvm::lsp::Position &helpPos) {
1160 SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
1161 if (!posLoc.isValid())
1162 return llvm::lsp::SignatureHelp();
1163
1164 // To perform code completion, we run another parse of the module with the
1165 // code completion context provided.
1166 ods::Context tmpODSContext;
1167 llvm::lsp::SignatureHelp signatureHelp;
1168 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1169 tmpODSContext);
1170
1171 ast::Context tmpContext(tmpODSContext);
1172 (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1173 &completeContext);
1174
1175 return signatureHelp;
1176}
1177
1178//===----------------------------------------------------------------------===//
1179// PDLDocument: Inlay Hints
1180//===----------------------------------------------------------------------===//
1181
1182/// Returns true if the given name should be added as a hint for `expr`.
1183static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1184 if (name.empty())
1185 return false;
1186
1187 // If the argument is a reference of the same name, don't add it as a hint.
1188 if (auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1189 const ast::Name *declName = ref->getDecl()->getName();
1190 if (declName && declName->getName() == name)
1191 return false;
1192 }
1193
1194 return true;
1195}
1196
1197void PDLDocument::getInlayHints(const llvm::lsp::URIForFile &uri,
1198 const llvm::lsp::Range &range,
1199 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1200 if (failed(astModule))
1201 return;
1202 SMRange rangeLoc = range.getAsSMRange(sourceMgr);
1203 if (!rangeLoc.isValid())
1204 return;
1205 (*astModule)->walk([&](const ast::Node *node) {
1206 SMRange loc = node->getLoc();
1207
1208 // Check that the location of this node is within the input range.
1209 if (!lsp::contains(rangeLoc, loc.Start) &&
1210 !lsp::contains(rangeLoc, loc.End))
1211 return;
1212
1213 // Handle hints for various types of nodes.
1214 llvm::TypeSwitch<const ast::Node *>(node)
1215 .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1216 [&](const auto *node) {
1217 this->getInlayHintsFor(node, uri, inlayHints);
1218 });
1219 });
1220}
1221
1222void PDLDocument::getInlayHintsFor(
1223 const ast::VariableDecl *decl, const llvm::lsp::URIForFile &uri,
1224 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1225 // Check to see if the variable has a constraint list, if it does we don't
1226 // provide initializer hints.
1227 if (!decl->getConstraints().empty())
1228 return;
1229
1230 // Check to see if the variable has an initializer.
1231 if (const ast::Expr *expr = decl->getInitExpr()) {
1232 // Don't add hints for operation expression initialized variables given that
1233 // the type of the variable is easily inferred by the expression operation
1234 // name.
1235 if (isa<ast::OperationExpr>(expr))
1236 return;
1237 }
1238
1239 llvm::lsp::InlayHint hint(llvm::lsp::InlayHintKind::Type,
1240 llvm::lsp::Position(sourceMgr, decl->getLoc().End));
1241 {
1242 llvm::raw_string_ostream labelOS(hint.label);
1243 labelOS << ": " << decl->getType();
1244 }
1245
1246 inlayHints.emplace_back(std::move(hint));
1247}
1248
1249void PDLDocument::getInlayHintsFor(
1250 const ast::CallExpr *expr, const llvm::lsp::URIForFile &uri,
1251 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1252 // Try to extract the callable of this call.
1253 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
1254 const auto *callable =
1255 callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1256 : nullptr;
1257 if (!callable)
1258 return;
1259
1260 // Add hints for the arguments to the call.
1261 for (const auto &it : llvm::zip(expr->getArguments(), callable->getInputs()))
1262 addParameterHintFor(inlayHints, std::get<0>(it),
1263 std::get<1>(it)->getName().getName());
1264}
1265
1266void PDLDocument::getInlayHintsFor(
1267 const ast::OperationExpr *expr, const llvm::lsp::URIForFile &uri,
1268 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1269 // Check for ODS information.
1270 ast::OperationType opType = dyn_cast<ast::OperationType>(expr->getType());
1271 const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1272
1273 auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1274 // If the value expression used the same location as the operation, don't
1275 // add a hint. This expression was materialized during parsing.
1276 if (expr->getLoc().Start == valueExpr->getLoc().Start)
1277 return;
1278 addParameterHintFor(inlayHints, valueExpr, label);
1279 };
1280
1281 // Functor used to process hints for the operands and results of the
1282 // operation. They effectively have the same format, and thus can be processed
1283 // using the same logic.
1284 auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1285 ArrayRef<ods::OperandOrResult> odsValues,
1286 StringRef allValuesName) {
1287 if (values.empty())
1288 return;
1289
1290 // The values should either map to a single range, or be equivalent to the
1291 // ODS values.
1292 if (values.size() != odsValues.size()) {
1293 // Handle the case of a single element that covers the full range.
1294 if (values.size() == 1)
1295 return addOpHint(values.front(), allValuesName);
1296 return;
1297 }
1298
1299 for (const auto &it : llvm::zip(values, odsValues))
1300 addOpHint(std::get<0>(it), std::get<1>(it).getName());
1301 };
1302
1303 // Add hints for the operands and results of the operation.
1304 addOperandOrResultHints(expr->getOperands(),
1305 odsOp ? odsOp->getOperands()
1306 : ArrayRef<ods::OperandOrResult>(),
1307 "operands");
1308 addOperandOrResultHints(expr->getResultTypes(),
1309 odsOp ? odsOp->getResults()
1310 : ArrayRef<ods::OperandOrResult>(),
1311 "results");
1312}
1313
1314void PDLDocument::addParameterHintFor(
1315 std::vector<llvm::lsp::InlayHint> &inlayHints, const ast::Expr *expr,
1316 StringRef label) {
1317 if (!shouldAddHintFor(expr, label))
1318 return;
1319
1320 llvm::lsp::InlayHint hint(
1321 llvm::lsp::InlayHintKind::Parameter,
1322 llvm::lsp::Position(sourceMgr, expr->getLoc().Start));
1323 hint.label = (label + ":").str();
1324 hint.paddingRight = true;
1325 inlayHints.emplace_back(std::move(hint));
1326}
1327
1328//===----------------------------------------------------------------------===//
1329// PDLL ViewOutput
1330//===----------------------------------------------------------------------===//
1331
1332void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1334 if (failed(astModule))
1335 return;
1336 if (kind == lsp::PDLLViewOutputKind::AST) {
1337 (*astModule)->print(os);
1338 return;
1339 }
1340
1341 // Generate the MLIR for the ast module. We also capture diagnostics here to
1342 // show to the user, which may be useful if PDLL isn't capturing constraints
1343 // expected by PDL.
1344 MLIRContext mlirContext;
1345 SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1346 OwningOpRef<ModuleOp> pdlModule =
1347 codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
1348 if (!pdlModule)
1349 return;
1350 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1351 pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1352 return;
1353 }
1354
1355 // Otherwise, generate the output for C++.
1356 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1357 "unexpected PDLLViewOutputKind");
1358 codegenPDLLToCPP(**astModule, *pdlModule, os);
1359}
1360
1361//===----------------------------------------------------------------------===//
1362// PDLTextFileChunk
1363//===----------------------------------------------------------------------===//
1364
1365namespace {
1366/// This class represents a single chunk of an PDL text file.
1367struct PDLTextFileChunk {
1368 PDLTextFileChunk(uint64_t lineOffset, const llvm::lsp::URIForFile &uri,
1369 StringRef contents,
1370 const std::vector<std::string> &extraDirs,
1371 std::vector<llvm::lsp::Diagnostic> &diagnostics)
1372 : lineOffset(lineOffset),
1373 document(uri, contents, extraDirs, diagnostics) {}
1374
1375 /// Adjust the line number of the given range to anchor at the beginning of
1376 /// the file, instead of the beginning of this chunk.
1377 void adjustLocForChunkOffset(llvm::lsp::Range &range) {
1378 adjustLocForChunkOffset(range.start);
1379 adjustLocForChunkOffset(range.end);
1380 }
1381 /// Adjust the line number of the given position to anchor at the beginning of
1382 /// the file, instead of the beginning of this chunk.
1383 void adjustLocForChunkOffset(llvm::lsp::Position &pos) {
1384 pos.line += lineOffset;
1385 }
1386
1387 /// The line offset of this chunk from the beginning of the file.
1388 uint64_t lineOffset;
1389 /// The document referred to by this chunk.
1390 PDLDocument document;
1391};
1392} // namespace
1393
1394//===----------------------------------------------------------------------===//
1395// PDLTextFile
1396//===----------------------------------------------------------------------===//
1397
1398namespace {
1399/// This class represents a text file containing one or more PDL documents.
1400class PDLTextFile {
1401public:
1402 PDLTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents,
1403 int64_t version, const std::vector<std::string> &extraDirs,
1404 std::vector<llvm::lsp::Diagnostic> &diagnostics);
1405
1406 /// Return the current version of this text file.
1407 int64_t getVersion() const { return version; }
1408
1409 /// Update the file to the new version using the provided set of content
1410 /// changes. Returns failure if the update was unsuccessful.
1411 LogicalResult
1412 update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
1413 ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
1414 std::vector<llvm::lsp::Diagnostic> &diagnostics);
1415
1416 //===--------------------------------------------------------------------===//
1417 // LSP Queries
1418 //===--------------------------------------------------------------------===//
1419
1420 void getLocationsOf(const llvm::lsp::URIForFile &uri,
1421 llvm::lsp::Position defPos,
1422 std::vector<llvm::lsp::Location> &locations);
1423 void findReferencesOf(const llvm::lsp::URIForFile &uri,
1424 llvm::lsp::Position pos,
1425 std::vector<llvm::lsp::Location> &references);
1426 void getDocumentLinks(const llvm::lsp::URIForFile &uri,
1427 std::vector<llvm::lsp::DocumentLink> &links);
1428 std::optional<llvm::lsp::Hover> findHover(const llvm::lsp::URIForFile &uri,
1429 llvm::lsp::Position hoverPos);
1430 void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
1431 llvm::lsp::CompletionList getCodeCompletion(const llvm::lsp::URIForFile &uri,
1432 llvm::lsp::Position completePos);
1433 llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri,
1434 llvm::lsp::Position helpPos);
1435 void getInlayHints(const llvm::lsp::URIForFile &uri, llvm::lsp::Range range,
1436 std::vector<llvm::lsp::InlayHint> &inlayHints);
1437 lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
1438
1439private:
1440 using ChunkIterator = llvm::pointee_iterator<
1441 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1442
1443 /// Initialize the text file from the given file contents.
1444 void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion,
1445 std::vector<llvm::lsp::Diagnostic> &diagnostics);
1446
1447 /// Find the PDL document that contains the given position, and update the
1448 /// position to be anchored at the start of the found chunk instead of the
1449 /// beginning of the file.
1450 ChunkIterator getChunkItFor(llvm::lsp::Position &pos);
1451 PDLTextFileChunk &getChunkFor(llvm::lsp::Position &pos) {
1452 return *getChunkItFor(pos);
1453 }
1454
1455 /// The full string contents of the file.
1456 std::string contents;
1457
1458 /// The version of this file.
1459 int64_t version = 0;
1460
1461 /// The number of lines in the file.
1462 int64_t totalNumLines = 0;
1463
1464 /// The chunks of this file. The order of these chunks is the order in which
1465 /// they appear in the text file.
1466 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1467
1468 /// The extra set of include directories for this file.
1469 std::vector<std::string> extraIncludeDirs;
1470};
1471} // namespace
1472
1473PDLTextFile::PDLTextFile(const llvm::lsp::URIForFile &uri,
1474 StringRef fileContents, int64_t version,
1475 const std::vector<std::string> &extraDirs,
1476 std::vector<llvm::lsp::Diagnostic> &diagnostics)
1477 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1478 initialize(uri, version, diagnostics);
1479}
1480
1481LogicalResult
1482PDLTextFile::update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
1484 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1485 if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes,
1486 contents))) {
1487 llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file());
1488 return failure();
1489 }
1490
1491 // If the file contents were properly changed, reinitialize the text file.
1492 initialize(uri, newVersion, diagnostics);
1493 return success();
1494}
1495
1496void PDLTextFile::getLocationsOf(const llvm::lsp::URIForFile &uri,
1497 llvm::lsp::Position defPos,
1498 std::vector<llvm::lsp::Location> &locations) {
1499 PDLTextFileChunk &chunk = getChunkFor(defPos);
1500 chunk.document.getLocationsOf(uri, defPos, locations);
1501
1502 // Adjust any locations within this file for the offset of this chunk.
1503 if (chunk.lineOffset == 0)
1504 return;
1505 for (llvm::lsp::Location &loc : locations)
1506 if (loc.uri == uri)
1507 chunk.adjustLocForChunkOffset(loc.range);
1508}
1509
1510void PDLTextFile::findReferencesOf(
1511 const llvm::lsp::URIForFile &uri, llvm::lsp::Position pos,
1512 std::vector<llvm::lsp::Location> &references) {
1513 PDLTextFileChunk &chunk = getChunkFor(pos);
1514 chunk.document.findReferencesOf(uri, pos, references);
1515
1516 // Adjust any locations within this file for the offset of this chunk.
1517 if (chunk.lineOffset == 0)
1518 return;
1519 for (llvm::lsp::Location &loc : references)
1520 if (loc.uri == uri)
1521 chunk.adjustLocForChunkOffset(loc.range);
1522}
1523
1524void PDLTextFile::getDocumentLinks(
1525 const llvm::lsp::URIForFile &uri,
1526 std::vector<llvm::lsp::DocumentLink> &links) {
1527 chunks.front()->document.getDocumentLinks(uri, links);
1528 for (const auto &it : llvm::drop_begin(chunks)) {
1529 size_t currentNumLinks = links.size();
1530 it->document.getDocumentLinks(uri, links);
1531
1532 // Adjust any links within this file to account for the offset of this
1533 // chunk.
1534 for (auto &link : llvm::drop_begin(links, currentNumLinks))
1535 it->adjustLocForChunkOffset(link.range);
1536 }
1537}
1538
1539std::optional<llvm::lsp::Hover>
1540PDLTextFile::findHover(const llvm::lsp::URIForFile &uri,
1541 llvm::lsp::Position hoverPos) {
1542 PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1543 std::optional<llvm::lsp::Hover> hoverInfo =
1544 chunk.document.findHover(uri, hoverPos);
1545
1546 // Adjust any locations within this file for the offset of this chunk.
1547 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1548 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1549 return hoverInfo;
1550}
1551
1552void PDLTextFile::findDocumentSymbols(
1553 std::vector<llvm::lsp::DocumentSymbol> &symbols) {
1554 if (chunks.size() == 1)
1555 return chunks.front()->document.findDocumentSymbols(symbols);
1556
1557 // If there are multiple chunks in this file, we create top-level symbols for
1558 // each chunk.
1559 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1560 PDLTextFileChunk &chunk = *chunks[i];
1561 llvm::lsp::Position startPos(chunk.lineOffset);
1562 llvm::lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1563 : chunks[i + 1]->lineOffset);
1564 llvm::lsp::DocumentSymbol symbol(
1565 "<file-split-" + Twine(i) + ">", llvm::lsp::SymbolKind::Namespace,
1566 /*range=*/llvm::lsp::Range(startPos, endPos),
1567 /*selectionRange=*/llvm::lsp::Range(startPos));
1568 chunk.document.findDocumentSymbols(symbol.children);
1569
1570 // Fixup the locations of document symbols within this chunk.
1571 if (i != 0) {
1572 SmallVector<llvm::lsp::DocumentSymbol *> symbolsToFix;
1573 for (llvm::lsp::DocumentSymbol &childSymbol : symbol.children)
1574 symbolsToFix.push_back(&childSymbol);
1575
1576 while (!symbolsToFix.empty()) {
1577 llvm::lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1578 chunk.adjustLocForChunkOffset(symbol->range);
1579 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1580
1581 for (llvm::lsp::DocumentSymbol &childSymbol : symbol->children)
1582 symbolsToFix.push_back(&childSymbol);
1583 }
1584 }
1585
1586 // Push the symbol for this chunk.
1587 symbols.emplace_back(std::move(symbol));
1588 }
1589}
1590
1591llvm::lsp::CompletionList
1592PDLTextFile::getCodeCompletion(const llvm::lsp::URIForFile &uri,
1593 llvm::lsp::Position completePos) {
1594 PDLTextFileChunk &chunk = getChunkFor(completePos);
1595 llvm::lsp::CompletionList completionList =
1596 chunk.document.getCodeCompletion(uri, completePos);
1597
1598 // Adjust any completion locations.
1599 for (llvm::lsp::CompletionItem &item : completionList.items) {
1600 if (item.textEdit)
1601 chunk.adjustLocForChunkOffset(item.textEdit->range);
1602 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1603 chunk.adjustLocForChunkOffset(edit.range);
1604 }
1605 return completionList;
1606}
1607
1608llvm::lsp::SignatureHelp
1609PDLTextFile::getSignatureHelp(const llvm::lsp::URIForFile &uri,
1610 llvm::lsp::Position helpPos) {
1611 return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1612}
1613
1614void PDLTextFile::getInlayHints(const llvm::lsp::URIForFile &uri,
1615 llvm::lsp::Range range,
1616 std::vector<llvm::lsp::InlayHint> &inlayHints) {
1617 auto startIt = getChunkItFor(range.start);
1618 auto endIt = getChunkItFor(range.end);
1619
1620 // Functor used to get the chunks for a given file, and fixup any locations
1621 auto getHintsForChunk = [&](ChunkIterator chunkIt, llvm::lsp::Range range) {
1622 size_t currentNumHints = inlayHints.size();
1623 chunkIt->document.getInlayHints(uri, range, inlayHints);
1624
1625 // If this isn't the first chunk, update any positions to account for line
1626 // number differences.
1627 if (&*chunkIt != &*chunks.front()) {
1628 for (auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1629 chunkIt->adjustLocForChunkOffset(hint.position);
1630 }
1631 };
1632 // Returns the number of lines held by a given chunk.
1633 auto getNumLines = [](ChunkIterator chunkIt) {
1634 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1635 };
1636
1637 // Check if the range is fully within a single chunk.
1638 if (startIt == endIt)
1639 return getHintsForChunk(startIt, range);
1640
1641 // Otherwise, the range is split between multiple chunks. The first chunk
1642 // has the correct range start, but covers the total document.
1643 getHintsForChunk(startIt,
1644 llvm::lsp::Range(range.start, getNumLines(startIt)));
1645
1646 // Every chunk in between uses the full document.
1647 for (++startIt; startIt != endIt; ++startIt)
1648 getHintsForChunk(startIt, llvm::lsp::Range(0, getNumLines(startIt)));
1649
1650 // The range for the last chunk starts at the beginning of the document, up
1651 // through the end of the input range.
1652 getHintsForChunk(startIt, llvm::lsp::Range(0, range.end));
1653}
1654
1655lsp::PDLLViewOutputResult
1656PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1657 lsp::PDLLViewOutputResult result;
1658 {
1659 llvm::raw_string_ostream outputOS(result.output);
1660 llvm::interleave(
1661 llvm::make_pointee_range(chunks),
1662 [&](PDLTextFileChunk &chunk) {
1663 chunk.document.getPDLLViewOutput(outputOS, kind);
1664 },
1665 [&] { outputOS << "\n"
1666 << kDefaultSplitMarker << "\n\n"; });
1667 }
1668 return result;
1669}
1670
1671void PDLTextFile::initialize(const llvm::lsp::URIForFile &uri,
1672 int64_t newVersion,
1673 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1674 version = newVersion;
1675 chunks.clear();
1676
1677 // Split the file into separate PDL documents.
1678 SmallVector<StringRef, 8> subContents;
1679 StringRef(contents).split(subContents, kDefaultSplitMarker);
1680 chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1681 /*lineOffset=*/0, uri, subContents.front(), extraIncludeDirs,
1682 diagnostics));
1683
1684 uint64_t lineOffset = subContents.front().count('\n');
1685 for (StringRef docContents : llvm::drop_begin(subContents)) {
1686 unsigned currentNumDiags = diagnostics.size();
1687 auto chunk = std::make_unique<PDLTextFileChunk>(
1688 lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1689 lineOffset += docContents.count('\n');
1690
1691 // Adjust locations used in diagnostics to account for the offset from the
1692 // beginning of the file.
1693 for (llvm::lsp::Diagnostic &diag :
1694 llvm::drop_begin(diagnostics, currentNumDiags)) {
1695 chunk->adjustLocForChunkOffset(diag.range);
1696
1697 if (!diag.relatedInformation)
1698 continue;
1699 for (auto &it : *diag.relatedInformation)
1700 if (it.location.uri == uri)
1701 chunk->adjustLocForChunkOffset(it.location.range);
1702 }
1703 chunks.emplace_back(std::move(chunk));
1704 }
1705 totalNumLines = lineOffset;
1706}
1707
1708PDLTextFile::ChunkIterator
1709PDLTextFile::getChunkItFor(llvm::lsp::Position &pos) {
1710 if (chunks.size() == 1)
1711 return chunks.begin();
1712
1713 // Search for the first chunk with a greater line offset, the previous chunk
1714 // is the one that contains `pos`.
1715 auto it = llvm::upper_bound(
1716 chunks, pos, [](const llvm::lsp::Position &pos, const auto &chunk) {
1717 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1718 });
1719 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1720 pos.line -= chunkIt->lineOffset;
1721 return chunkIt;
1722}
1723
1724//===----------------------------------------------------------------------===//
1725// PDLLServer::Impl
1726//===----------------------------------------------------------------------===//
1727
1729 explicit Impl(const Options &options)
1730 : options(options), compilationDatabase(options.compilationDatabases) {}
1731
1732 /// PDLL LSP options.
1734
1735 /// The compilation database containing additional information for files
1736 /// passed to the server.
1738
1739 /// The files held by the server, mapped by their URI file name.
1740 llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1741};
1742
1743//===----------------------------------------------------------------------===//
1744// PDLLServer
1745//===----------------------------------------------------------------------===//
1746
1748 : impl(std::make_unique<Impl>(options)) {}
1750
1752 const URIForFile &uri, StringRef contents, int64_t version,
1753 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1754 // Build the set of additional include directories.
1755 std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1756 const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file());
1757 llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1758
1759 impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1760 uri, contents, version, additionalIncludeDirs, diagnostics);
1761}
1762
1764 const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1765 int64_t version, std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1766 // Check that we actually have a document for this uri.
1767 auto it = impl->files.find(uri.file());
1768 if (it == impl->files.end())
1769 return;
1770
1771 // Try to update the document. If we fail, erase the file from the server. A
1772 // failed updated generally means we've fallen out of sync somewhere.
1773 if (failed(it->second->update(uri, version, changes, diagnostics)))
1774 impl->files.erase(it);
1775}
1776
1777std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1778 auto it = impl->files.find(uri.file());
1779 if (it == impl->files.end())
1780 return std::nullopt;
1781
1782 int64_t version = it->second->getVersion();
1783 impl->files.erase(it);
1784 return version;
1785}
1786
1788 const URIForFile &uri, const Position &defPos,
1789 std::vector<llvm::lsp::Location> &locations) {
1790 auto fileIt = impl->files.find(uri.file());
1791 if (fileIt != impl->files.end())
1792 fileIt->second->getLocationsOf(uri, defPos, locations);
1793}
1794
1796 const URIForFile &uri, const Position &pos,
1797 std::vector<llvm::lsp::Location> &references) {
1798 auto fileIt = impl->files.find(uri.file());
1799 if (fileIt != impl->files.end())
1800 fileIt->second->findReferencesOf(uri, pos, references);
1801}
1802
1804 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1805 auto fileIt = impl->files.find(uri.file());
1806 if (fileIt != impl->files.end())
1807 return fileIt->second->getDocumentLinks(uri, documentLinks);
1808}
1809
1810std::optional<llvm::lsp::Hover>
1811lsp::PDLLServer::findHover(const URIForFile &uri, const Position &hoverPos) {
1812 auto fileIt = impl->files.find(uri.file());
1813 if (fileIt != impl->files.end())
1814 return fileIt->second->findHover(uri, hoverPos);
1815 return std::nullopt;
1816}
1817
1819 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1820 auto fileIt = impl->files.find(uri.file());
1821 if (fileIt != impl->files.end())
1822 fileIt->second->findDocumentSymbols(symbols);
1823}
1824
1825lsp::CompletionList
1827 const Position &completePos) {
1828 auto fileIt = impl->files.find(uri.file());
1829 if (fileIt != impl->files.end())
1830 return fileIt->second->getCodeCompletion(uri, completePos);
1831 return CompletionList();
1832}
1833
1834llvm::lsp::SignatureHelp
1836 const Position &helpPos) {
1837 auto fileIt = impl->files.find(uri.file());
1838 if (fileIt != impl->files.end())
1839 return fileIt->second->getSignatureHelp(uri, helpPos);
1840 return SignatureHelp();
1841}
1842
1843void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1844 std::vector<InlayHint> &inlayHints) {
1845 auto fileIt = impl->files.find(uri.file());
1846 if (fileIt == impl->files.end())
1847 return;
1848 fileIt->second->getInlayHints(uri, range, inlayHints);
1849
1850 // Drop any duplicated hints that may have cropped up.
1851 llvm::sort(inlayHints);
1852 inlayHints.erase(llvm::unique(inlayHints), inlayHints.end());
1853}
1854
1855std::optional<lsp::PDLLViewOutputResult>
1857 PDLLViewOutputKind kind) {
1858 auto fileIt = impl->files.find(uri.file());
1859 if (fileIt != impl->files.end())
1860 return fileIt->second->getPDLLViewOutput(kind);
1861 return std::nullopt;
1862}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
lhs
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
static std::string diag(const llvm::Value &value)
static std::optional< std::string > getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl)
Get or extract the documentation for the given decl.
static llvm::lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, const llvm::lsp::URIForFile &uri)
Returns a language server location from the given source range.
static std::optional< llvm::lsp::Diagnostic > getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, const llvm::lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static bool shouldAddHintFor(const ast::Expr *expr, StringRef name)
Returns true if the given name should be added as a hint for expr.
static llvm::lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, const llvm::lsp::URIForFile &mainFileURI)
Returns a language server uri for the given source location.
static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc)
Returns true if the given location is in the main file of the source manager.
static llvm::ManagedStatic< PassManagerOptions > options
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This class contains a collection of compilation information for files provided to the language server...
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
void getDocumentLinks(const URIForFile &uri, std::vector< DocumentLink > &documentLinks)
Return the document links referenced by the given file.
void getInlayHints(const URIForFile &uri, const Range &range, std::vector< InlayHint > &inlayHints)
Get the inlay hints for the range within the given file.
void addDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add the document, with the provided version, at the given URI.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
SignatureHelp getSignatureHelp(const URIForFile &uri, const Position &helpPos)
Get the signature help for the position within the given file.
void updateDocument(const URIForFile &uri, ArrayRef< TextDocumentContentChangeEvent > changes, int64_t version, std::vector< Diagnostic > &diagnostics)
Update the document, with the provided version, at the given URI.
PDLLServer(const Options &options)
std::optional< PDLLViewOutputResult > getPDLLViewOutput(const URIForFile &uri, PDLLViewOutputKind kind)
Get the output of the given PDLL file, or std::nullopt if there is no valid output.
Expr * getCallableExpr() const
Return the callable of this call.
Definition Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition Nodes.h:403
This class represents the base Decl node.
Definition Nodes.h:669
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
Definition Nodes.h:682
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition Nodes.h:672
This class provides a simple implementation of a PDLL diagnostic.
Definition Diagnostic.h:29
This class represents a base AST Expression node.
Definition Nodes.h:348
Type getType() const
Return the type of this expression.
Definition Nodes.h:351
This class represents a top-level AST module.
Definition Nodes.h:1297
SMRange getLoc() const
Return the location of this node.
Definition Nodes.h:131
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition Nodes.cpp:404
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition Nodes.h:237
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition Nodes.h:532
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition Nodes.h:540
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition Types.cpp:87
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
Definition Nodes.h:1060
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition Nodes.h:1051
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition Nodes.h:1054
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition Nodes.h:1255
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition Nodes.h:1264
const Name & getName() const
Return the name of the decl.
Definition Nodes.h:1267
Type getType() const
Return the type of the decl.
Definition Nodes.h:1270
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
Definition Constraint.h:66
StringRef getSummary() const
Return the summary of this constraint.
Definition Constraint.h:44
StringRef getName() const
Return the unique name of this constraint.
Definition Constraint.h:35
This class contains all of the registered ODS operation classes.
Definition Context.h:32
auto getDialects() const
Return a range of all of the registered dialects.
Definition Context.h:57
const Dialect * lookupDialect(StringRef name) const
Lookup a dialect registered with the given name, or null if no dialect with that name was inserted.
Definition Context.cpp:57
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition Context.cpp:72
const llvm::StringMap< std::unique_ptr< Operation > > & getOperations() const
Return a map of all of the operations registered to this dialect.
Definition Dialect.h:46
StringRef getDescription() const
Returns the description of the operation.
Definition Operation.h:156
StringRef getSummary() const
Returns the summary of the operation.
Definition Operation.h:153
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition Operation.h:165
StringRef getName() const
Returns the name of the operation.
Definition Operation.h:150
SMRange getLoc() const
Return the source location of this operation.
Definition Operation.h:128
ArrayRef< Attribute > getAttributes() const
Returns the attributes of this operation.
Definition Operation.h:162
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition Operation.h:168
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
Definition Constraint.h:87
PDLLViewOutputKind
The type of output to view from PDLL.
Definition Protocol.h:34
void gatherIncludeFiles(llvm::SourceMgr &sourceMgr, SmallVectorImpl< SourceMgrInclude > &includes)
Given a source manager, gather all of the processed include files.
bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
std::optional< std::string > extractSourceDocComment(llvm::SourceMgr &sourceMgr, SMLoc loc)
Extract a documentation comment for the given location within the source manager.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Definition CPPGen.cpp:246
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition Parser.cpp:3205
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Definition MLIRGen.cpp:624
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const char *const kDefaultSplitMarker
llvm::StringSet< AllocatorTy > StringSet
Definition LLVM.h:133
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
Impl(const Options &options)
lsp::CompilationDatabase compilationDatabase
The compilation database containing additional information for files passed to the server.
llvm::StringMap< std::unique_ptr< PDLTextFile > > files
The files held by the server, mapped by their URI file name.
const Options & options
PDLL LSP options.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
This class provides a convenient API for interacting with source names.
Definition Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition Nodes.h:44