MLIR  22.0.0git
PDLLServer.cpp
Go to the documentation of this file.
1 //===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PDLLServer.h"
10 
11 #include "Protocol.h"
12 #include "mlir/IR/BuiltinOps.h"
27 #include "llvm/ADT/IntervalMap.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/FileSystem.h"
32 #include "llvm/Support/LSP/Logging.h"
33 #include "llvm/Support/Path.h"
34 #include <optional>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 /// Returns a language server uri for the given source location. `mainFileURI`
40 /// corresponds to the uri for the main file of the source manager.
41 static llvm::lsp::URIForFile
42 getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
43  const llvm::lsp::URIForFile &mainFileURI) {
44  int bufferId = mgr.FindBufferContainingLoc(loc.Start);
45  if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
46  return mainFileURI;
48  llvm::lsp::URIForFile::fromFile(
49  mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
50  if (fileForLoc)
51  return *fileForLoc;
52  llvm::lsp::Logger::error("Failed to create URI for include file: {0}",
53  llvm::toString(fileForLoc.takeError()));
54  return mainFileURI;
55 }
56 
57 /// Returns true if the given location is in the main file of the source
58 /// manager.
59 static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
60  return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
61 }
62 
63 /// Returns a language server location from the given source range.
64 static llvm::lsp::Location
65 getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
66  const llvm::lsp::URIForFile &uri) {
67  return llvm::lsp::Location(getURIFromLoc(mgr, range, uri),
68  llvm::lsp::Range(mgr, range));
69 }
70 
71 /// Convert the given MLIR diagnostic to the LSP form.
72 static std::optional<llvm::lsp::Diagnostic>
73 getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
74  const llvm::lsp::URIForFile &uri) {
75  llvm::lsp::Diagnostic lspDiag;
76  lspDiag.source = "pdll";
77 
78  // FIXME: Right now all of the diagnostics are treated as parser issues, but
79  // some are parser and some are verifier.
80  lspDiag.category = "Parse Error";
81 
82  // Try to grab a file location for this diagnostic.
83  llvm::lsp::Location loc =
84  getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
85  lspDiag.range = loc.range;
86 
87  // Skip diagnostics that weren't emitted within the main file.
88  if (loc.uri != uri)
89  return std::nullopt;
90 
91  // Convert the severity for the diagnostic.
92  switch (diag.getSeverity()) {
93  case ast::Diagnostic::Severity::DK_Note:
94  llvm_unreachable("expected notes to be handled separately");
95  case ast::Diagnostic::Severity::DK_Warning:
96  lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
97  break;
98  case ast::Diagnostic::Severity::DK_Error:
99  lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
100  break;
101  case ast::Diagnostic::Severity::DK_Remark:
102  lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
103  break;
104  }
105  lspDiag.message = diag.getMessage().str();
106 
107  // Attach any notes to the main diagnostic as related information.
108  std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
109  for (const ast::Diagnostic &note : diag.getNotes()) {
110  relatedDiags.emplace_back(
111  getLocationFromLoc(sourceMgr, note.getLocation(), uri),
112  note.getMessage().str());
113  }
114  if (!relatedDiags.empty())
115  lspDiag.relatedInformation = std::move(relatedDiags);
116 
117  return lspDiag;
118 }
119 
120 /// Get or extract the documentation for the given decl.
121 static std::optional<std::string>
122 getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl) {
123  // If the decl already had documentation set, use it.
124  if (std::optional<StringRef> doc = decl->getDocComment())
125  return doc->str();
126 
127  // If the decl doesn't yet have documentation, try to extract it from the
128  // source file.
129  return lsp::extractSourceDocComment(sourceMgr, decl->getLoc().Start);
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // PDLIndex
134 //===----------------------------------------------------------------------===//
135 
136 namespace {
137 struct PDLIndexSymbol {
138  explicit PDLIndexSymbol(const ast::Decl *definition)
139  : definition(definition) {}
140  explicit PDLIndexSymbol(const ods::Operation *definition)
141  : definition(definition) {}
142 
143  /// Return the location of the definition of this symbol.
144  SMRange getDefLoc() const {
145  if (const ast::Decl *decl =
146  llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
147  const ast::Name *declName = decl->getName();
148  return declName ? declName->getLoc() : decl->getLoc();
149  }
150  return cast<const ods::Operation *>(definition)->getLoc();
151  }
152 
153  /// The main definition of the symbol.
155  /// The set of references to the symbol.
156  std::vector<SMRange> references;
157 };
158 
159 /// This class provides an index for definitions/uses within a PDL document.
160 /// It provides efficient lookup of a definition given an input source range.
161 class PDLIndex {
162 public:
163  PDLIndex() : intervalMap(allocator) {}
164 
165  /// Initialize the index with the given ast::Module.
166  void initialize(const ast::Module &module, const ods::Context &odsContext);
167 
168  /// Lookup a symbol for the given location. Returns nullptr if no symbol could
169  /// be found. If provided, `overlappedRange` is set to the range that the
170  /// provided `loc` overlapped with.
171  const PDLIndexSymbol *lookup(SMLoc loc,
172  SMRange *overlappedRange = nullptr) const;
173 
174 private:
175  /// The type of interval map used to store source references. SMRange is
176  /// half-open, so we also need to use a half-open interval map.
177  using MapT =
178  llvm::IntervalMap<const char *, const PDLIndexSymbol *,
179  llvm::IntervalMapImpl::NodeSizer<
180  const char *, const PDLIndexSymbol *>::LeafSize,
181  llvm::IntervalMapHalfOpenInfo<const char *>>;
182 
183  /// An allocator for the interval map.
184  MapT::Allocator allocator;
185 
186  /// An interval map containing a corresponding definition mapped to a source
187  /// interval.
188  MapT intervalMap;
189 
190  /// A mapping between definitions and their corresponding symbol.
192 };
193 } // namespace
194 
195 void PDLIndex::initialize(const ast::Module &module,
196  const ods::Context &odsContext) {
197  auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
198  auto it = defToSymbol.try_emplace(def, nullptr);
199  if (it.second)
200  it.first->second = std::make_unique<PDLIndexSymbol>(def);
201  return &*it.first->second;
202  };
203  auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
204  bool isDef = false) {
205  const char *startLoc = refLoc.Start.getPointer();
206  const char *endLoc = refLoc.End.getPointer();
207  if (!intervalMap.overlaps(startLoc, endLoc)) {
208  intervalMap.insert(startLoc, endLoc, sym);
209  if (!isDef)
210  sym->references.push_back(refLoc);
211  }
212  };
213  auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
214  const ods::Operation *odsOp = odsContext.lookupOperation(opName);
215  if (!odsOp)
216  return;
217 
218  PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
219  insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
220  insertDeclRef(symbol, refLoc);
221  };
222 
223  module.walk([&](const ast::Node *node) {
224  // Handle references to PDL decls.
225  if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
226  if (std::optional<StringRef> name = decl->getName())
227  insertODSOpRef(*name, decl->getLoc());
228  } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
229  const ast::Name *name = decl->getName();
230  if (!name)
231  return;
232  PDLIndexSymbol *declSym = getOrInsertDef(decl);
233  insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
234 
235  if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
236  // Record references to any constraints.
237  for (const auto &it : varDecl->getConstraints())
238  insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
239  }
240  } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
241  insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
242  }
243  });
244 }
245 
246 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
247  SMRange *overlappedRange) const {
248  auto it = intervalMap.find(loc.getPointer());
249  if (!it.valid() || loc.getPointer() < it.start())
250  return nullptr;
251 
252  if (overlappedRange) {
253  *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
254  SMLoc::getFromPointer(it.stop()));
255  }
256  return it.value();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // PDLDocument
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 /// This class represents all of the information pertaining to a specific PDL
265 /// document.
266 struct PDLDocument {
267  PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
268  const std::vector<std::string> &extraDirs,
269  std::vector<llvm::lsp::Diagnostic> &diagnostics);
270  PDLDocument(const PDLDocument &) = delete;
271  PDLDocument &operator=(const PDLDocument &) = delete;
272 
273  //===--------------------------------------------------------------------===//
274  // Definitions and References
275  //===--------------------------------------------------------------------===//
276 
277  void getLocationsOf(const llvm::lsp::URIForFile &uri,
278  const llvm::lsp::Position &defPos,
279  std::vector<llvm::lsp::Location> &locations);
280  void findReferencesOf(const llvm::lsp::URIForFile &uri,
281  const llvm::lsp::Position &pos,
282  std::vector<llvm::lsp::Location> &references);
283 
284  //===--------------------------------------------------------------------===//
285  // Document Links
286  //===--------------------------------------------------------------------===//
287 
288  void getDocumentLinks(const llvm::lsp::URIForFile &uri,
289  std::vector<llvm::lsp::DocumentLink> &links);
290 
291  //===--------------------------------------------------------------------===//
292  // Hover
293  //===--------------------------------------------------------------------===//
294 
295  std::optional<llvm::lsp::Hover>
296  findHover(const llvm::lsp::URIForFile &uri,
297  const llvm::lsp::Position &hoverPos);
298  std::optional<llvm::lsp::Hover> findHover(const ast::Decl *decl,
299  const SMRange &hoverRange);
300  llvm::lsp::Hover buildHoverForOpName(const ods::Operation *op,
301  const SMRange &hoverRange);
302  llvm::lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
303  const SMRange &hoverRange);
304  llvm::lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
305  const SMRange &hoverRange);
306  llvm::lsp::Hover
307  buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
308  const SMRange &hoverRange);
309  template <typename T>
310  llvm::lsp::Hover
311  buildHoverForUserConstraintOrRewrite(StringRef typeName, const T *decl,
312  const SMRange &hoverRange);
313 
314  //===--------------------------------------------------------------------===//
315  // Document Symbols
316  //===--------------------------------------------------------------------===//
317 
318  void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
319 
320  //===--------------------------------------------------------------------===//
321  // Code Completion
322  //===--------------------------------------------------------------------===//
323 
324  llvm::lsp::CompletionList
325  getCodeCompletion(const llvm::lsp::URIForFile &uri,
326  const llvm::lsp::Position &completePos);
327 
328  //===--------------------------------------------------------------------===//
329  // Signature Help
330  //===--------------------------------------------------------------------===//
331 
332  llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri,
333  const llvm::lsp::Position &helpPos);
334 
335  //===--------------------------------------------------------------------===//
336  // Inlay Hints
337  //===--------------------------------------------------------------------===//
338 
339  void getInlayHints(const llvm::lsp::URIForFile &uri,
340  const llvm::lsp::Range &range,
341  std::vector<llvm::lsp::InlayHint> &inlayHints);
342  void getInlayHintsFor(const ast::VariableDecl *decl,
343  const llvm::lsp::URIForFile &uri,
344  std::vector<llvm::lsp::InlayHint> &inlayHints);
345  void getInlayHintsFor(const ast::CallExpr *expr,
346  const llvm::lsp::URIForFile &uri,
347  std::vector<llvm::lsp::InlayHint> &inlayHints);
348  void getInlayHintsFor(const ast::OperationExpr *expr,
349  const llvm::lsp::URIForFile &uri,
350  std::vector<llvm::lsp::InlayHint> &inlayHints);
351 
352  /// Add a parameter hint for the given expression using `label`.
353  void addParameterHintFor(std::vector<llvm::lsp::InlayHint> &inlayHints,
354  const ast::Expr *expr, StringRef label);
355 
356  //===--------------------------------------------------------------------===//
357  // PDLL ViewOutput
358  //===--------------------------------------------------------------------===//
359 
360  void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
361 
362  //===--------------------------------------------------------------------===//
363  // Fields
364  //===--------------------------------------------------------------------===//
365 
366  /// The include directories for this file.
367  std::vector<std::string> includeDirs;
368 
369  /// The source manager containing the contents of the input file.
370  llvm::SourceMgr sourceMgr;
371 
372  /// The ODS and AST contexts.
373  ods::Context odsContext;
374  ast::Context astContext;
375 
376  /// The parsed AST module, or failure if the file wasn't valid.
377  FailureOr<ast::Module *> astModule;
378 
379  /// The index of the parsed module.
380  PDLIndex index;
381 
382  /// The set of includes of the parsed module.
383  SmallVector<lsp::SourceMgrInclude> parsedIncludes;
384 };
385 } // namespace
386 
387 PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
388  const std::vector<std::string> &extraDirs,
389  std::vector<llvm::lsp::Diagnostic> &diagnostics)
390  : astContext(odsContext) {
391  auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
392  if (!memBuffer) {
393  llvm::lsp::Logger::error("Failed to create memory buffer for file",
394  uri.file());
395  return;
396  }
397 
398  // Build the set of include directories for this file.
399  llvm::SmallString<32> uriDirectory(uri.file());
400  llvm::sys::path::remove_filename(uriDirectory);
401  includeDirs.push_back(uriDirectory.str().str());
402  llvm::append_range(includeDirs, extraDirs);
403 
404  sourceMgr.setIncludeDirs(includeDirs);
405  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
406 
407  astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
408  if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
409  diagnostics.push_back(std::move(*lspDiag));
410  });
411  astModule = parsePDLLAST(astContext, sourceMgr, /*enableDocumentation=*/true);
412 
413  // Initialize the set of parsed includes.
414  lsp::gatherIncludeFiles(sourceMgr, parsedIncludes);
415 
416  // If we failed to parse the module, there is nothing left to initialize.
417  if (failed(astModule))
418  return;
419 
420  // Prepare the AST index with the parsed module.
421  index.initialize(**astModule, odsContext);
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // PDLDocument: Definitions and References
426 //===----------------------------------------------------------------------===//
427 
428 void PDLDocument::getLocationsOf(const llvm::lsp::URIForFile &uri,
429  const llvm::lsp::Position &defPos,
430  std::vector<llvm::lsp::Location> &locations) {
431  SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
432  const PDLIndexSymbol *symbol = index.lookup(posLoc);
433  if (!symbol)
434  return;
435 
436  locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
437 }
438 
439 void PDLDocument::findReferencesOf(
440  const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos,
441  std::vector<llvm::lsp::Location> &references) {
442  SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
443  const PDLIndexSymbol *symbol = index.lookup(posLoc);
444  if (!symbol)
445  return;
446 
447  references.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
448  for (SMRange refLoc : symbol->references)
449  references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri));
450 }
451 
452 //===--------------------------------------------------------------------===//
453 // PDLDocument: Document Links
454 //===--------------------------------------------------------------------===//
455 
456 void PDLDocument::getDocumentLinks(
457  const llvm::lsp::URIForFile &uri,
458  std::vector<llvm::lsp::DocumentLink> &links) {
459  for (const lsp::SourceMgrInclude &include : parsedIncludes)
460  links.emplace_back(include.range, include.uri);
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // PDLDocument: Hover
465 //===----------------------------------------------------------------------===//
466 
467 std::optional<llvm::lsp::Hover>
468 PDLDocument::findHover(const llvm::lsp::URIForFile &uri,
469  const llvm::lsp::Position &hoverPos) {
470  SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
471 
472  // Check for a reference to an include.
473  for (const lsp::SourceMgrInclude &include : parsedIncludes)
474  if (include.range.contains(hoverPos))
475  return include.buildHover();
476 
477  // Find the symbol at the given location.
478  SMRange hoverRange;
479  const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
480  if (!symbol)
481  return std::nullopt;
482 
483  // Add hover for operation names.
484  if (const auto *op =
485  llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
486  return buildHoverForOpName(op, hoverRange);
487  const auto *decl = cast<const ast::Decl *>(symbol->definition);
488  return findHover(decl, hoverRange);
489 }
490 
491 std::optional<llvm::lsp::Hover>
492 PDLDocument::findHover(const ast::Decl *decl, const SMRange &hoverRange) {
493  // Add hover for variables.
494  if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
495  return buildHoverForVariable(varDecl, hoverRange);
496 
497  // Add hover for patterns.
498  if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
499  return buildHoverForPattern(patternDecl, hoverRange);
500 
501  // Add hover for core constraints.
502  if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
503  return buildHoverForCoreConstraint(cst, hoverRange);
504 
505  // Add hover for user constraints.
506  if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
507  return buildHoverForUserConstraintOrRewrite("Constraint", cst, hoverRange);
508 
509  // Add hover for user rewrites.
510  if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
511  return buildHoverForUserConstraintOrRewrite("Rewrite", rewrite, hoverRange);
512 
513  return std::nullopt;
514 }
515 
516 llvm::lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
517  const SMRange &hoverRange) {
518  llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
519  {
520  llvm::raw_string_ostream hoverOS(hover.contents.value);
521  hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
522  << op->getSummary() << "\n***\n"
523  << op->getDescription();
524  }
525  return hover;
526 }
527 
528 llvm::lsp::Hover
529 PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
530  const SMRange &hoverRange) {
531  llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
532  {
533  llvm::raw_string_ostream hoverOS(hover.contents.value);
534  hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
535  << "Type: `" << varDecl->getType() << "`\n";
536  }
537  return hover;
538 }
539 
540 llvm::lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
541  const SMRange &hoverRange) {
542  llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
543  {
544  llvm::raw_string_ostream hoverOS(hover.contents.value);
545  hoverOS << "**Pattern**";
546  if (const ast::Name *name = decl->getName())
547  hoverOS << ": `" << name->getName() << "`";
548  hoverOS << "\n***\n";
549  if (std::optional<uint16_t> benefit = decl->getBenefit())
550  hoverOS << "Benefit: " << *benefit << "\n";
551  if (decl->hasBoundedRewriteRecursion())
552  hoverOS << "HasBoundedRewriteRecursion\n";
553  hoverOS << "RootOp: `"
554  << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
555 
556  // Format the documentation for the decl.
557  if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
558  hoverOS << "\n" << *doc << "\n";
559  }
560  return hover;
561 }
562 
563 llvm::lsp::Hover
564 PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
565  const SMRange &hoverRange) {
566  llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
567  {
568  llvm::raw_string_ostream hoverOS(hover.contents.value);
569  hoverOS << "**Constraint**: `";
571  .Case([&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
572  .Case([&](const ast::OpConstraintDecl *opCst) {
573  hoverOS << "Op";
574  if (std::optional<StringRef> name = opCst->getName())
575  hoverOS << "<" << *name << ">";
576  })
577  .Case([&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
578  .Case([&](const ast::TypeRangeConstraintDecl *) {
579  hoverOS << "TypeRange";
580  })
581  .Case([&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
582  .Case([&](const ast::ValueRangeConstraintDecl *) {
583  hoverOS << "ValueRange";
584  });
585  hoverOS << "`\n";
586  }
587  return hover;
588 }
589 
590 template <typename T>
591 llvm::lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
592  StringRef typeName, const T *decl, const SMRange &hoverRange) {
593  llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
594  {
595  llvm::raw_string_ostream hoverOS(hover.contents.value);
596  hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
597  << "`\n***\n";
598  ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
599  if (!inputs.empty()) {
600  hoverOS << "Parameters:\n";
601  for (const ast::VariableDecl *input : inputs)
602  hoverOS << "* " << input->getName().getName() << ": `"
603  << input->getType() << "`\n";
604  hoverOS << "***\n";
605  }
606  ast::Type resultType = decl->getResultType();
607  if (auto resultTupleTy = dyn_cast<ast::TupleType>(resultType)) {
608  if (!resultTupleTy.empty()) {
609  hoverOS << "Results:\n";
610  for (auto it : llvm::zip(resultTupleTy.getElementNames(),
611  resultTupleTy.getElementTypes())) {
612  StringRef name = std::get<0>(it);
613  hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
614  << std::get<1>(it) << "`\n";
615  }
616  hoverOS << "***\n";
617  }
618  } else {
619  hoverOS << "Results:\n* `" << resultType << "`\n";
620  hoverOS << "***\n";
621  }
622 
623  // Format the documentation for the decl.
624  if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
625  hoverOS << "\n" << *doc << "\n";
626  }
627  return hover;
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // PDLDocument: Document Symbols
632 //===----------------------------------------------------------------------===//
633 
634 void PDLDocument::findDocumentSymbols(
635  std::vector<llvm::lsp::DocumentSymbol> &symbols) {
636  if (failed(astModule))
637  return;
638 
639  for (const ast::Decl *decl : (*astModule)->getChildren()) {
640  if (!isMainFileLoc(sourceMgr, decl->getLoc()))
641  continue;
642 
643  if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
644  const ast::Name *name = patternDecl->getName();
645 
646  SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
647  SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
648 
649  symbols.emplace_back(name ? name->getName() : "<pattern>",
650  llvm::lsp::SymbolKind::Class,
651  llvm::lsp::Range(sourceMgr, bodyLoc),
652  llvm::lsp::Range(sourceMgr, nameLoc));
653  } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
654  // TODO: Add source information for the code block body.
655  SMRange nameLoc = cDecl->getName().getLoc();
656  SMRange bodyLoc = nameLoc;
657 
658  symbols.emplace_back(cDecl->getName().getName(),
659  llvm::lsp::SymbolKind::Function,
660  llvm::lsp::Range(sourceMgr, bodyLoc),
661  llvm::lsp::Range(sourceMgr, nameLoc));
662  } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
663  // TODO: Add source information for the code block body.
664  SMRange nameLoc = cDecl->getName().getLoc();
665  SMRange bodyLoc = nameLoc;
666 
667  symbols.emplace_back(cDecl->getName().getName(),
668  llvm::lsp::SymbolKind::Function,
669  llvm::lsp::Range(sourceMgr, bodyLoc),
670  llvm::lsp::Range(sourceMgr, nameLoc));
671  }
672  }
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // PDLDocument: Code Completion
677 //===----------------------------------------------------------------------===//
678 
679 namespace {
680 class LSPCodeCompleteContext : public CodeCompleteContext {
681 public:
682  LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
683  llvm::lsp::CompletionList &completionList,
684  ods::Context &odsContext,
685  ArrayRef<std::string> includeDirs)
686  : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
687  completionList(completionList), odsContext(odsContext),
688  includeDirs(includeDirs) {}
689 
690  void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
691  ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
692  ArrayRef<StringRef> elementNames = tupleType.getElementNames();
693  for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
694  // Push back a completion item that uses the result index.
695  llvm::lsp::CompletionItem item;
696  item.label = llvm::formatv("{0} (field #{0})", i).str();
697  item.insertText = Twine(i).str();
698  item.filterText = item.sortText = item.insertText;
699  item.kind = llvm::lsp::CompletionItemKind::Field;
700  item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
701  item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
702  completionList.items.emplace_back(item);
703 
704  // If the element has a name, push back a completion item with that name.
705  if (!elementNames[i].empty()) {
706  item.label =
707  llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
708  item.filterText = item.label;
709  item.insertText = elementNames[i].str();
710  completionList.items.emplace_back(item);
711  }
712  }
713  }
714 
716  const ods::Operation *odsOp = opType.getODSOperation();
717  if (!odsOp)
718  return;
719 
720  ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
721  for (const auto &it : llvm::enumerate(results)) {
722  const ods::OperandOrResult &result = it.value();
723  const ods::TypeConstraint &constraint = result.getConstraint();
724 
725  // Push back a completion item that uses the result index.
726  llvm::lsp::CompletionItem item;
727  item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
728  item.insertText = Twine(it.index()).str();
729  item.filterText = item.sortText = item.insertText;
730  item.kind = llvm::lsp::CompletionItemKind::Field;
731  switch (result.getVariableLengthKind()) {
733  item.detail = llvm::formatv("{0}: Value", it.index()).str();
734  break;
736  item.detail = llvm::formatv("{0}: Value?", it.index()).str();
737  break;
739  item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
740  break;
741  }
742  item.documentation = llvm::lsp::MarkupContent{
743  llvm::lsp::MarkupKind::Markdown,
744  llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
745  constraint.getCppClass())
746  .str()};
747  item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
748  completionList.items.emplace_back(item);
749 
750  // If the result has a name, push back a completion item with the result
751  // name.
752  if (!result.getName().empty()) {
753  item.label =
754  llvm::formatv("{1} (field #{0})", it.index(), result.getName())
755  .str();
756  item.filterText = item.label;
757  item.insertText = result.getName().str();
758  completionList.items.emplace_back(item);
759  }
760  }
761  }
762 
763  void codeCompleteOperationAttributeName(StringRef opName) final {
764  const ods::Operation *odsOp = odsContext.lookupOperation(opName);
765  if (!odsOp)
766  return;
767 
768  for (const ods::Attribute &attr : odsOp->getAttributes()) {
769  const ods::AttributeConstraint &constraint = attr.getConstraint();
770 
771  llvm::lsp::CompletionItem item;
772  item.label = attr.getName().str();
773  item.kind = llvm::lsp::CompletionItemKind::Field;
774  item.detail = attr.isOptional() ? "optional" : "";
775  item.documentation = llvm::lsp::MarkupContent{
776  llvm::lsp::MarkupKind::Markdown,
777  llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
778  constraint.getCppClass())
779  .str()};
780  item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
781  completionList.items.emplace_back(item);
782  }
783  }
784 
785  void codeCompleteConstraintName(ast::Type currentType,
786  bool allowInlineTypeConstraints,
787  const ast::DeclScope *scope) final {
788  auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
789  StringRef snippetText = "") {
790  llvm::lsp::CompletionItem item;
791  item.label = constraint.str();
792  item.kind = llvm::lsp::CompletionItemKind::Class;
793  item.detail = (constraint + " constraint").str();
794  item.documentation = llvm::lsp::MarkupContent{
795  llvm::lsp::MarkupKind::Markdown,
796  ("A single entity core constraint of type `" + mlirType + "`").str()};
797  item.sortText = "0";
798  item.insertText = snippetText.str();
799  item.insertTextFormat = snippetText.empty()
800  ? llvm::lsp::InsertTextFormat::PlainText
801  : llvm::lsp::InsertTextFormat::Snippet;
802  completionList.items.emplace_back(item);
803  };
804 
805  // Insert completions for the core constraints. Some core constraints have
806  // additional characteristics, so we may add then even if a type has been
807  // inferred.
808  if (!currentType) {
809  addCoreConstraint("Attr", "mlir::Attribute");
810  addCoreConstraint("Op", "mlir::Operation *");
811  addCoreConstraint("Value", "mlir::Value");
812  addCoreConstraint("ValueRange", "mlir::ValueRange");
813  addCoreConstraint("Type", "mlir::Type");
814  addCoreConstraint("TypeRange", "mlir::TypeRange");
815  }
816  if (allowInlineTypeConstraints) {
817  /// Attr<Type>.
818  if (!currentType || isa<ast::AttributeType>(currentType))
819  addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
820  /// Value<Type>.
821  if (!currentType || isa<ast::ValueType>(currentType))
822  addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
823  /// ValueRange<TypeRange>.
824  if (!currentType || isa<ast::ValueRangeType>(currentType))
825  addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
826  "ValueRange<$1>");
827  }
828 
829  // If a scope was provided, check it for potential constraints.
830  while (scope) {
831  for (const ast::Decl *decl : scope->getDecls()) {
832  if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
833  llvm::lsp::CompletionItem item;
834  item.label = cst->getName().getName().str();
835  item.kind = llvm::lsp::CompletionItemKind::Interface;
836  item.sortText = "2_" + item.label;
837 
838  // Skip constraints that are not single-arg. We currently only
839  // complete variable constraints.
840  if (cst->getInputs().size() != 1)
841  continue;
842 
843  // Ensure the input type matched the given type.
844  ast::Type constraintType = cst->getInputs()[0]->getType();
845  if (currentType && !currentType.refineWith(constraintType))
846  continue;
847 
848  // Format the constraint signature.
849  {
850  llvm::raw_string_ostream strOS(item.detail);
851  strOS << "(";
852  llvm::interleaveComma(
853  cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
854  strOS << var->getName().getName() << ": " << var->getType();
855  });
856  strOS << ") -> " << cst->getResultType();
857  }
858 
859  // Format the documentation for the constraint.
860  if (std::optional<std::string> doc =
861  getDocumentationFor(sourceMgr, cst)) {
862  item.documentation = llvm::lsp::MarkupContent{
863  llvm::lsp::MarkupKind::Markdown, std::move(*doc)};
864  }
865 
866  completionList.items.emplace_back(item);
867  }
868  }
869 
870  scope = scope->getParentScope();
871  }
872  }
873 
874  void codeCompleteDialectName() final {
875  // Code complete known dialects.
876  for (const ods::Dialect &dialect : odsContext.getDialects()) {
877  llvm::lsp::CompletionItem item;
878  item.label = dialect.getName().str();
879  item.kind = llvm::lsp::CompletionItemKind::Class;
880  item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
881  completionList.items.emplace_back(item);
882  }
883  }
884 
885  void codeCompleteOperationName(StringRef dialectName) final {
886  const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
887  if (!dialect)
888  return;
889 
890  for (const auto &it : dialect->getOperations()) {
891  const ods::Operation &op = *it.second;
892 
893  llvm::lsp::CompletionItem item;
894  item.label = op.getName().drop_front(dialectName.size() + 1).str();
895  item.kind = llvm::lsp::CompletionItemKind::Field;
896  item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
897  completionList.items.emplace_back(item);
898  }
899  }
900 
901  void codeCompletePatternMetadata() final {
902  auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
903  StringRef snippetText = "") {
904  llvm::lsp::CompletionItem item;
905  item.label = constraint.str();
906  item.kind = llvm::lsp::CompletionItemKind::Class;
907  item.detail = "pattern metadata";
908  item.documentation =
909  llvm::lsp::MarkupContent{llvm::lsp::MarkupKind::Markdown, desc.str()};
910  item.insertText = snippetText.str();
911  item.insertTextFormat = snippetText.empty()
912  ? llvm::lsp::InsertTextFormat::PlainText
913  : llvm::lsp::InsertTextFormat::Snippet;
914  completionList.items.emplace_back(item);
915  };
916 
917  addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
918  "benefit($1)");
919  addSimpleConstraint("recursion",
920  "The pattern properly handles recursive application.");
921  }
922 
923  void codeCompleteIncludeFilename(StringRef curPath) final {
924  // Normalize the path to allow for interacting with the file system
925  // utilities.
926  SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath));
927  llvm::sys::path::native(nativeRelDir);
928 
929  // Set of already included completion paths.
930  StringSet<> seenResults;
931 
932  // Functor used to add a single include completion item.
933  auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
934  llvm::lsp::CompletionItem item;
935  item.label = path.str();
936  item.kind = isDirectory ? llvm::lsp::CompletionItemKind::Folder
937  : llvm::lsp::CompletionItemKind::File;
938  if (seenResults.insert(item.label).second)
939  completionList.items.emplace_back(item);
940  };
941 
942  // Process the include directories for this file, adding any potential
943  // nested include files or directories.
944  for (StringRef includeDir : includeDirs) {
945  llvm::SmallString<128> dir = includeDir;
946  if (!nativeRelDir.empty())
947  llvm::sys::path::append(dir, nativeRelDir);
948 
949  std::error_code errorCode;
950  for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
951  e = llvm::sys::fs::directory_iterator();
952  !errorCode && it != e; it.increment(errorCode)) {
953  StringRef filename = llvm::sys::path::filename(it->path());
954 
955  // To know whether a symlink should be treated as file or a directory,
956  // we have to stat it. This should be cheap enough as there shouldn't be
957  // many symlinks.
958  llvm::sys::fs::file_type fileType = it->type();
959  if (fileType == llvm::sys::fs::file_type::symlink_file) {
960  if (auto fileStatus = it->status())
961  fileType = fileStatus->type();
962  }
963 
964  switch (fileType) {
965  case llvm::sys::fs::file_type::directory_file:
966  addIncludeCompletion(filename, /*isDirectory=*/true);
967  break;
968  case llvm::sys::fs::file_type::regular_file: {
969  // Only consider concrete files that can actually be included by PDLL.
970  if (filename.ends_with(".pdll") || filename.ends_with(".td"))
971  addIncludeCompletion(filename, /*isDirectory=*/false);
972  break;
973  }
974  default:
975  break;
976  }
977  }
978  }
979 
980  // Sort the completion results to make sure the output is deterministic in
981  // the face of different iteration schemes for different platforms.
982  llvm::sort(completionList.items, [](const llvm::lsp::CompletionItem &lhs,
983  const llvm::lsp::CompletionItem &rhs) {
984  return lhs.label < rhs.label;
985  });
986  }
987 
988 private:
989  llvm::SourceMgr &sourceMgr;
990  llvm::lsp::CompletionList &completionList;
991  ods::Context &odsContext;
992  ArrayRef<std::string> includeDirs;
993 };
994 } // namespace
995 
996 llvm::lsp::CompletionList
997 PDLDocument::getCodeCompletion(const llvm::lsp::URIForFile &uri,
998  const llvm::lsp::Position &completePos) {
999  SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
1000  if (!posLoc.isValid())
1001  return llvm::lsp::CompletionList();
1002 
1003  // To perform code completion, we run another parse of the module with the
1004  // code completion context provided.
1005  ods::Context tmpODSContext;
1006  llvm::lsp::CompletionList completionList;
1007  LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
1008  tmpODSContext,
1009  sourceMgr.getIncludeDirs());
1010 
1011  ast::Context tmpContext(tmpODSContext);
1012  (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1013  &lspCompleteContext);
1014 
1015  return completionList;
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // PDLDocument: Signature Help
1020 //===----------------------------------------------------------------------===//
1021 
1022 namespace {
1023 class LSPSignatureHelpContext : public CodeCompleteContext {
1024 public:
1025  LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1026  llvm::lsp::SignatureHelp &signatureHelp,
1027  ods::Context &odsContext)
1028  : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1029  signatureHelp(signatureHelp), odsContext(odsContext) {}
1030 
1031  void codeCompleteCallSignature(const ast::CallableDecl *callable,
1032  unsigned currentNumArgs) final {
1033  signatureHelp.activeParameter = currentNumArgs;
1034 
1035  llvm::lsp::SignatureInformation signatureInfo;
1036  {
1037  llvm::raw_string_ostream strOS(signatureInfo.label);
1038  strOS << callable->getName()->getName() << "(";
1039  auto formatParamFn = [&](const ast::VariableDecl *var) {
1040  unsigned paramStart = strOS.str().size();
1041  strOS << var->getName().getName() << ": " << var->getType();
1042  unsigned paramEnd = strOS.str().size();
1043  signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1044  StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1045  std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()});
1046  };
1047  llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1048  strOS << ") -> " << callable->getResultType();
1049  }
1050 
1051  // Format the documentation for the callable.
1052  if (std::optional<std::string> doc =
1053  getDocumentationFor(sourceMgr, callable))
1054  signatureInfo.documentation = std::move(*doc);
1055 
1056  signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1057  }
1058 
1059  void
1060  codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
1061  unsigned currentNumOperands) final {
1062  const ods::Operation *odsOp =
1063  opName ? odsContext.lookupOperation(*opName) : nullptr;
1064  codeCompleteOperationOperandOrResultSignature(
1065  opName, odsOp,
1066  odsOp ? odsOp->getOperands() : ArrayRef<ods::OperandOrResult>(),
1067  currentNumOperands, "operand", "Value");
1068  }
1069 
1070  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
1071  unsigned currentNumResults) final {
1072  const ods::Operation *odsOp =
1073  opName ? odsContext.lookupOperation(*opName) : nullptr;
1074  codeCompleteOperationOperandOrResultSignature(
1075  opName, odsOp,
1076  odsOp ? odsOp->getResults() : ArrayRef<ods::OperandOrResult>(),
1077  currentNumResults, "result", "Type");
1078  }
1079 
1080  void codeCompleteOperationOperandOrResultSignature(
1081  std::optional<StringRef> opName, const ods::Operation *odsOp,
1082  ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1083  StringRef label, StringRef dataType) {
1084  signatureHelp.activeParameter = currentValue;
1085 
1086  // If we have ODS information for the operation, add in the ODS signature
1087  // for the operation. We also verify that the current number of values is
1088  // not more than what is defined in ODS, as this will result in an error
1089  // anyways.
1090  if (odsOp && currentValue < values.size()) {
1091  llvm::lsp::SignatureInformation signatureInfo;
1092 
1093  // Build the signature label.
1094  {
1095  llvm::raw_string_ostream strOS(signatureInfo.label);
1096  strOS << "(";
1097  auto formatFn = [&](const ods::OperandOrResult &value) {
1098  unsigned paramStart = strOS.str().size();
1099 
1100  strOS << value.getName() << ": ";
1101 
1102  StringRef constraintDoc = value.getConstraint().getSummary();
1103  std::string paramDoc;
1104  switch (value.getVariableLengthKind()) {
1106  strOS << dataType;
1107  paramDoc = constraintDoc.str();
1108  break;
1110  strOS << dataType << "?";
1111  paramDoc = ("optional: " + constraintDoc).str();
1112  break;
1114  strOS << dataType << "Range";
1115  paramDoc = ("variadic: " + constraintDoc).str();
1116  break;
1117  }
1118 
1119  unsigned paramEnd = strOS.str().size();
1120  signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1121  StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1122  std::make_pair(paramStart, paramEnd), paramDoc});
1123  };
1124  llvm::interleaveComma(values, strOS, formatFn);
1125  strOS << ")";
1126  }
1127  signatureInfo.documentation =
1128  llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label)
1129  .str();
1130  signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1131  }
1132 
1133  // If there aren't any arguments yet, we also add the generic signature.
1134  if (currentValue == 0 && (!odsOp || !values.empty())) {
1135  llvm::lsp::SignatureInformation signatureInfo;
1136  signatureInfo.label =
1137  llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str();
1138  signatureInfo.documentation =
1139  ("Generic operation " + label + " specification").str();
1140  signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
1141  StringRef(signatureInfo.label).drop_front().drop_back().str(),
1142  std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1143  ("All of the " + label + "s of the operation.").str()});
1144  signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1145  }
1146  }
1147 
1148 private:
1149  llvm::SourceMgr &sourceMgr;
1150  llvm::lsp::SignatureHelp &signatureHelp;
1151  ods::Context &odsContext;
1152 };
1153 } // namespace
1154 
1155 llvm::lsp::SignatureHelp
1156 PDLDocument::getSignatureHelp(const llvm::lsp::URIForFile &uri,
1157  const llvm::lsp::Position &helpPos) {
1158  SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
1159  if (!posLoc.isValid())
1160  return llvm::lsp::SignatureHelp();
1161 
1162  // To perform code completion, we run another parse of the module with the
1163  // code completion context provided.
1164  ods::Context tmpODSContext;
1165  llvm::lsp::SignatureHelp signatureHelp;
1166  LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1167  tmpODSContext);
1168 
1169  ast::Context tmpContext(tmpODSContext);
1170  (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1171  &completeContext);
1172 
1173  return signatureHelp;
1174 }
1175 
1176 //===----------------------------------------------------------------------===//
1177 // PDLDocument: Inlay Hints
1178 //===----------------------------------------------------------------------===//
1179 
1180 /// Returns true if the given name should be added as a hint for `expr`.
1181 static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1182  if (name.empty())
1183  return false;
1184 
1185  // If the argument is a reference of the same name, don't add it as a hint.
1186  if (auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1187  const ast::Name *declName = ref->getDecl()->getName();
1188  if (declName && declName->getName() == name)
1189  return false;
1190  }
1191 
1192  return true;
1193 }
1194 
1195 void PDLDocument::getInlayHints(const llvm::lsp::URIForFile &uri,
1196  const llvm::lsp::Range &range,
1197  std::vector<llvm::lsp::InlayHint> &inlayHints) {
1198  if (failed(astModule))
1199  return;
1200  SMRange rangeLoc = range.getAsSMRange(sourceMgr);
1201  if (!rangeLoc.isValid())
1202  return;
1203  (*astModule)->walk([&](const ast::Node *node) {
1204  SMRange loc = node->getLoc();
1205 
1206  // Check that the location of this node is within the input range.
1207  if (!lsp::contains(rangeLoc, loc.Start) &&
1208  !lsp::contains(rangeLoc, loc.End))
1209  return;
1210 
1211  // Handle hints for various types of nodes.
1214  [&](const auto *node) {
1215  this->getInlayHintsFor(node, uri, inlayHints);
1216  });
1217  });
1218 }
1219 
1220 void PDLDocument::getInlayHintsFor(
1221  const ast::VariableDecl *decl, const llvm::lsp::URIForFile &uri,
1222  std::vector<llvm::lsp::InlayHint> &inlayHints) {
1223  // Check to see if the variable has a constraint list, if it does we don't
1224  // provide initializer hints.
1225  if (!decl->getConstraints().empty())
1226  return;
1227 
1228  // Check to see if the variable has an initializer.
1229  if (const ast::Expr *expr = decl->getInitExpr()) {
1230  // Don't add hints for operation expression initialized variables given that
1231  // the type of the variable is easily inferred by the expression operation
1232  // name.
1233  if (isa<ast::OperationExpr>(expr))
1234  return;
1235  }
1236 
1237  llvm::lsp::InlayHint hint(llvm::lsp::InlayHintKind::Type,
1238  llvm::lsp::Position(sourceMgr, decl->getLoc().End));
1239  {
1240  llvm::raw_string_ostream labelOS(hint.label);
1241  labelOS << ": " << decl->getType();
1242  }
1243 
1244  inlayHints.emplace_back(std::move(hint));
1245 }
1246 
1247 void PDLDocument::getInlayHintsFor(
1248  const ast::CallExpr *expr, const llvm::lsp::URIForFile &uri,
1249  std::vector<llvm::lsp::InlayHint> &inlayHints) {
1250  // Try to extract the callable of this call.
1251  const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
1252  const auto *callable =
1253  callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1254  : nullptr;
1255  if (!callable)
1256  return;
1257 
1258  // Add hints for the arguments to the call.
1259  for (const auto &it : llvm::zip(expr->getArguments(), callable->getInputs()))
1260  addParameterHintFor(inlayHints, std::get<0>(it),
1261  std::get<1>(it)->getName().getName());
1262 }
1263 
1264 void PDLDocument::getInlayHintsFor(
1265  const ast::OperationExpr *expr, const llvm::lsp::URIForFile &uri,
1266  std::vector<llvm::lsp::InlayHint> &inlayHints) {
1267  // Check for ODS information.
1268  ast::OperationType opType = dyn_cast<ast::OperationType>(expr->getType());
1269  const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1270 
1271  auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1272  // If the value expression used the same location as the operation, don't
1273  // add a hint. This expression was materialized during parsing.
1274  if (expr->getLoc().Start == valueExpr->getLoc().Start)
1275  return;
1276  addParameterHintFor(inlayHints, valueExpr, label);
1277  };
1278 
1279  // Functor used to process hints for the operands and results of the
1280  // operation. They effectively have the same format, and thus can be processed
1281  // using the same logic.
1282  auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1284  StringRef allValuesName) {
1285  if (values.empty())
1286  return;
1287 
1288  // The values should either map to a single range, or be equivalent to the
1289  // ODS values.
1290  if (values.size() != odsValues.size()) {
1291  // Handle the case of a single element that covers the full range.
1292  if (values.size() == 1)
1293  return addOpHint(values.front(), allValuesName);
1294  return;
1295  }
1296 
1297  for (const auto &it : llvm::zip(values, odsValues))
1298  addOpHint(std::get<0>(it), std::get<1>(it).getName());
1299  };
1300 
1301  // Add hints for the operands and results of the operation.
1302  addOperandOrResultHints(expr->getOperands(),
1303  odsOp ? odsOp->getOperands()
1305  "operands");
1306  addOperandOrResultHints(expr->getResultTypes(),
1307  odsOp ? odsOp->getResults()
1309  "results");
1310 }
1311 
1312 void PDLDocument::addParameterHintFor(
1313  std::vector<llvm::lsp::InlayHint> &inlayHints, const ast::Expr *expr,
1314  StringRef label) {
1315  if (!shouldAddHintFor(expr, label))
1316  return;
1317 
1318  llvm::lsp::InlayHint hint(
1319  llvm::lsp::InlayHintKind::Parameter,
1320  llvm::lsp::Position(sourceMgr, expr->getLoc().Start));
1321  hint.label = (label + ":").str();
1322  hint.paddingRight = true;
1323  inlayHints.emplace_back(std::move(hint));
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // PDLL ViewOutput
1328 //===----------------------------------------------------------------------===//
1329 
1330 void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1332  if (failed(astModule))
1333  return;
1334  if (kind == lsp::PDLLViewOutputKind::AST) {
1335  (*astModule)->print(os);
1336  return;
1337  }
1338 
1339  // Generate the MLIR for the ast module. We also capture diagnostics here to
1340  // show to the user, which may be useful if PDLL isn't capturing constraints
1341  // expected by PDL.
1342  MLIRContext mlirContext;
1343  SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1344  OwningOpRef<ModuleOp> pdlModule =
1345  codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
1346  if (!pdlModule)
1347  return;
1348  if (kind == lsp::PDLLViewOutputKind::MLIR) {
1349  pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1350  return;
1351  }
1352 
1353  // Otherwise, generate the output for C++.
1354  assert(kind == lsp::PDLLViewOutputKind::CPP &&
1355  "unexpected PDLLViewOutputKind");
1356  codegenPDLLToCPP(**astModule, *pdlModule, os);
1357 }
1358 
1359 //===----------------------------------------------------------------------===//
1360 // PDLTextFileChunk
1361 //===----------------------------------------------------------------------===//
1362 
1363 namespace {
1364 /// This class represents a single chunk of an PDL text file.
1365 struct PDLTextFileChunk {
1366  PDLTextFileChunk(uint64_t lineOffset, const llvm::lsp::URIForFile &uri,
1367  StringRef contents,
1368  const std::vector<std::string> &extraDirs,
1369  std::vector<llvm::lsp::Diagnostic> &diagnostics)
1370  : lineOffset(lineOffset),
1371  document(uri, contents, extraDirs, diagnostics) {}
1372 
1373  /// Adjust the line number of the given range to anchor at the beginning of
1374  /// the file, instead of the beginning of this chunk.
1375  void adjustLocForChunkOffset(llvm::lsp::Range &range) {
1376  adjustLocForChunkOffset(range.start);
1377  adjustLocForChunkOffset(range.end);
1378  }
1379  /// Adjust the line number of the given position to anchor at the beginning of
1380  /// the file, instead of the beginning of this chunk.
1381  void adjustLocForChunkOffset(llvm::lsp::Position &pos) {
1382  pos.line += lineOffset;
1383  }
1384 
1385  /// The line offset of this chunk from the beginning of the file.
1386  uint64_t lineOffset;
1387  /// The document referred to by this chunk.
1388  PDLDocument document;
1389 };
1390 } // namespace
1391 
1392 //===----------------------------------------------------------------------===//
1393 // PDLTextFile
1394 //===----------------------------------------------------------------------===//
1395 
1396 namespace {
1397 /// This class represents a text file containing one or more PDL documents.
1398 class PDLTextFile {
1399 public:
1400  PDLTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents,
1401  int64_t version, const std::vector<std::string> &extraDirs,
1402  std::vector<llvm::lsp::Diagnostic> &diagnostics);
1403 
1404  /// Return the current version of this text file.
1405  int64_t getVersion() const { return version; }
1406 
1407  /// Update the file to the new version using the provided set of content
1408  /// changes. Returns failure if the update was unsuccessful.
1409  LogicalResult
1410  update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
1412  std::vector<llvm::lsp::Diagnostic> &diagnostics);
1413 
1414  //===--------------------------------------------------------------------===//
1415  // LSP Queries
1416  //===--------------------------------------------------------------------===//
1417 
1418  void getLocationsOf(const llvm::lsp::URIForFile &uri,
1419  llvm::lsp::Position defPos,
1420  std::vector<llvm::lsp::Location> &locations);
1421  void findReferencesOf(const llvm::lsp::URIForFile &uri,
1422  llvm::lsp::Position pos,
1423  std::vector<llvm::lsp::Location> &references);
1424  void getDocumentLinks(const llvm::lsp::URIForFile &uri,
1425  std::vector<llvm::lsp::DocumentLink> &links);
1426  std::optional<llvm::lsp::Hover> findHover(const llvm::lsp::URIForFile &uri,
1427  llvm::lsp::Position hoverPos);
1428  void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
1429  llvm::lsp::CompletionList getCodeCompletion(const llvm::lsp::URIForFile &uri,
1430  llvm::lsp::Position completePos);
1431  llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri,
1432  llvm::lsp::Position helpPos);
1433  void getInlayHints(const llvm::lsp::URIForFile &uri, llvm::lsp::Range range,
1434  std::vector<llvm::lsp::InlayHint> &inlayHints);
1436 
1437 private:
1438  using ChunkIterator = llvm::pointee_iterator<
1439  std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1440 
1441  /// Initialize the text file from the given file contents.
1442  void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion,
1443  std::vector<llvm::lsp::Diagnostic> &diagnostics);
1444 
1445  /// Find the PDL document that contains the given position, and update the
1446  /// position to be anchored at the start of the found chunk instead of the
1447  /// beginning of the file.
1448  ChunkIterator getChunkItFor(llvm::lsp::Position &pos);
1449  PDLTextFileChunk &getChunkFor(llvm::lsp::Position &pos) {
1450  return *getChunkItFor(pos);
1451  }
1452 
1453  /// The full string contents of the file.
1454  std::string contents;
1455 
1456  /// The version of this file.
1457  int64_t version = 0;
1458 
1459  /// The number of lines in the file.
1460  int64_t totalNumLines = 0;
1461 
1462  /// The chunks of this file. The order of these chunks is the order in which
1463  /// they appear in the text file.
1464  std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1465 
1466  /// The extra set of include directories for this file.
1467  std::vector<std::string> extraIncludeDirs;
1468 };
1469 } // namespace
1470 
1471 PDLTextFile::PDLTextFile(const llvm::lsp::URIForFile &uri,
1472  StringRef fileContents, int64_t version,
1473  const std::vector<std::string> &extraDirs,
1474  std::vector<llvm::lsp::Diagnostic> &diagnostics)
1475  : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1476  initialize(uri, version, diagnostics);
1477 }
1478 
1479 LogicalResult
1480 PDLTextFile::update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
1482  std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1483  if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes,
1484  contents))) {
1485  llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file());
1486  return failure();
1487  }
1488 
1489  // If the file contents were properly changed, reinitialize the text file.
1490  initialize(uri, newVersion, diagnostics);
1491  return success();
1492 }
1493 
1494 void PDLTextFile::getLocationsOf(const llvm::lsp::URIForFile &uri,
1495  llvm::lsp::Position defPos,
1496  std::vector<llvm::lsp::Location> &locations) {
1497  PDLTextFileChunk &chunk = getChunkFor(defPos);
1498  chunk.document.getLocationsOf(uri, defPos, locations);
1499 
1500  // Adjust any locations within this file for the offset of this chunk.
1501  if (chunk.lineOffset == 0)
1502  return;
1503  for (llvm::lsp::Location &loc : locations)
1504  if (loc.uri == uri)
1505  chunk.adjustLocForChunkOffset(loc.range);
1506 }
1507 
1508 void PDLTextFile::findReferencesOf(
1509  const llvm::lsp::URIForFile &uri, llvm::lsp::Position pos,
1510  std::vector<llvm::lsp::Location> &references) {
1511  PDLTextFileChunk &chunk = getChunkFor(pos);
1512  chunk.document.findReferencesOf(uri, pos, references);
1513 
1514  // Adjust any locations within this file for the offset of this chunk.
1515  if (chunk.lineOffset == 0)
1516  return;
1517  for (llvm::lsp::Location &loc : references)
1518  if (loc.uri == uri)
1519  chunk.adjustLocForChunkOffset(loc.range);
1520 }
1521 
1522 void PDLTextFile::getDocumentLinks(
1523  const llvm::lsp::URIForFile &uri,
1524  std::vector<llvm::lsp::DocumentLink> &links) {
1525  chunks.front()->document.getDocumentLinks(uri, links);
1526  for (const auto &it : llvm::drop_begin(chunks)) {
1527  size_t currentNumLinks = links.size();
1528  it->document.getDocumentLinks(uri, links);
1529 
1530  // Adjust any links within this file to account for the offset of this
1531  // chunk.
1532  for (auto &link : llvm::drop_begin(links, currentNumLinks))
1533  it->adjustLocForChunkOffset(link.range);
1534  }
1535 }
1536 
1537 std::optional<llvm::lsp::Hover>
1538 PDLTextFile::findHover(const llvm::lsp::URIForFile &uri,
1539  llvm::lsp::Position hoverPos) {
1540  PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1541  std::optional<llvm::lsp::Hover> hoverInfo =
1542  chunk.document.findHover(uri, hoverPos);
1543 
1544  // Adjust any locations within this file for the offset of this chunk.
1545  if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1546  chunk.adjustLocForChunkOffset(*hoverInfo->range);
1547  return hoverInfo;
1548 }
1549 
1550 void PDLTextFile::findDocumentSymbols(
1551  std::vector<llvm::lsp::DocumentSymbol> &symbols) {
1552  if (chunks.size() == 1)
1553  return chunks.front()->document.findDocumentSymbols(symbols);
1554 
1555  // If there are multiple chunks in this file, we create top-level symbols for
1556  // each chunk.
1557  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1558  PDLTextFileChunk &chunk = *chunks[i];
1559  llvm::lsp::Position startPos(chunk.lineOffset);
1560  llvm::lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1561  : chunks[i + 1]->lineOffset);
1562  llvm::lsp::DocumentSymbol symbol(
1563  "<file-split-" + Twine(i) + ">", llvm::lsp::SymbolKind::Namespace,
1564  /*range=*/llvm::lsp::Range(startPos, endPos),
1565  /*selectionRange=*/llvm::lsp::Range(startPos));
1566  chunk.document.findDocumentSymbols(symbol.children);
1567 
1568  // Fixup the locations of document symbols within this chunk.
1569  if (i != 0) {
1571  for (llvm::lsp::DocumentSymbol &childSymbol : symbol.children)
1572  symbolsToFix.push_back(&childSymbol);
1573 
1574  while (!symbolsToFix.empty()) {
1575  llvm::lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1576  chunk.adjustLocForChunkOffset(symbol->range);
1577  chunk.adjustLocForChunkOffset(symbol->selectionRange);
1578 
1579  for (llvm::lsp::DocumentSymbol &childSymbol : symbol->children)
1580  symbolsToFix.push_back(&childSymbol);
1581  }
1582  }
1583 
1584  // Push the symbol for this chunk.
1585  symbols.emplace_back(std::move(symbol));
1586  }
1587 }
1588 
1589 llvm::lsp::CompletionList
1590 PDLTextFile::getCodeCompletion(const llvm::lsp::URIForFile &uri,
1591  llvm::lsp::Position completePos) {
1592  PDLTextFileChunk &chunk = getChunkFor(completePos);
1593  llvm::lsp::CompletionList completionList =
1594  chunk.document.getCodeCompletion(uri, completePos);
1595 
1596  // Adjust any completion locations.
1597  for (llvm::lsp::CompletionItem &item : completionList.items) {
1598  if (item.textEdit)
1599  chunk.adjustLocForChunkOffset(item.textEdit->range);
1600  for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1601  chunk.adjustLocForChunkOffset(edit.range);
1602  }
1603  return completionList;
1604 }
1605 
1606 llvm::lsp::SignatureHelp
1607 PDLTextFile::getSignatureHelp(const llvm::lsp::URIForFile &uri,
1608  llvm::lsp::Position helpPos) {
1609  return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1610 }
1611 
1612 void PDLTextFile::getInlayHints(const llvm::lsp::URIForFile &uri,
1613  llvm::lsp::Range range,
1614  std::vector<llvm::lsp::InlayHint> &inlayHints) {
1615  auto startIt = getChunkItFor(range.start);
1616  auto endIt = getChunkItFor(range.end);
1617 
1618  // Functor used to get the chunks for a given file, and fixup any locations
1619  auto getHintsForChunk = [&](ChunkIterator chunkIt, llvm::lsp::Range range) {
1620  size_t currentNumHints = inlayHints.size();
1621  chunkIt->document.getInlayHints(uri, range, inlayHints);
1622 
1623  // If this isn't the first chunk, update any positions to account for line
1624  // number differences.
1625  if (&*chunkIt != &*chunks.front()) {
1626  for (auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1627  chunkIt->adjustLocForChunkOffset(hint.position);
1628  }
1629  };
1630  // Returns the number of lines held by a given chunk.
1631  auto getNumLines = [](ChunkIterator chunkIt) {
1632  return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1633  };
1634 
1635  // Check if the range is fully within a single chunk.
1636  if (startIt == endIt)
1637  return getHintsForChunk(startIt, range);
1638 
1639  // Otherwise, the range is split between multiple chunks. The first chunk
1640  // has the correct range start, but covers the total document.
1641  getHintsForChunk(startIt,
1642  llvm::lsp::Range(range.start, getNumLines(startIt)));
1643 
1644  // Every chunk in between uses the full document.
1645  for (++startIt; startIt != endIt; ++startIt)
1646  getHintsForChunk(startIt, llvm::lsp::Range(0, getNumLines(startIt)));
1647 
1648  // The range for the last chunk starts at the beginning of the document, up
1649  // through the end of the input range.
1650  getHintsForChunk(startIt, llvm::lsp::Range(0, range.end));
1651 }
1652 
1654 PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1656  {
1657  llvm::raw_string_ostream outputOS(result.output);
1658  llvm::interleave(
1659  llvm::make_pointee_range(chunks),
1660  [&](PDLTextFileChunk &chunk) {
1661  chunk.document.getPDLLViewOutput(outputOS, kind);
1662  },
1663  [&] { outputOS << "\n"
1664  << kDefaultSplitMarker << "\n\n"; });
1665  }
1666  return result;
1667 }
1668 
1669 void PDLTextFile::initialize(const llvm::lsp::URIForFile &uri,
1670  int64_t newVersion,
1671  std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1672  version = newVersion;
1673  chunks.clear();
1674 
1675  // Split the file into separate PDL documents.
1676  SmallVector<StringRef, 8> subContents;
1677  StringRef(contents).split(subContents, kDefaultSplitMarker);
1678  chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1679  /*lineOffset=*/0, uri, subContents.front(), extraIncludeDirs,
1680  diagnostics));
1681 
1682  uint64_t lineOffset = subContents.front().count('\n');
1683  for (StringRef docContents : llvm::drop_begin(subContents)) {
1684  unsigned currentNumDiags = diagnostics.size();
1685  auto chunk = std::make_unique<PDLTextFileChunk>(
1686  lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1687  lineOffset += docContents.count('\n');
1688 
1689  // Adjust locations used in diagnostics to account for the offset from the
1690  // beginning of the file.
1691  for (llvm::lsp::Diagnostic &diag :
1692  llvm::drop_begin(diagnostics, currentNumDiags)) {
1693  chunk->adjustLocForChunkOffset(diag.range);
1694 
1695  if (!diag.relatedInformation)
1696  continue;
1697  for (auto &it : *diag.relatedInformation)
1698  if (it.location.uri == uri)
1699  chunk->adjustLocForChunkOffset(it.location.range);
1700  }
1701  chunks.emplace_back(std::move(chunk));
1702  }
1703  totalNumLines = lineOffset;
1704 }
1705 
1706 PDLTextFile::ChunkIterator
1707 PDLTextFile::getChunkItFor(llvm::lsp::Position &pos) {
1708  if (chunks.size() == 1)
1709  return chunks.begin();
1710 
1711  // Search for the first chunk with a greater line offset, the previous chunk
1712  // is the one that contains `pos`.
1713  auto it = llvm::upper_bound(
1714  chunks, pos, [](const llvm::lsp::Position &pos, const auto &chunk) {
1715  return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1716  });
1717  ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1718  pos.line -= chunkIt->lineOffset;
1719  return chunkIt;
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // PDLLServer::Impl
1724 //===----------------------------------------------------------------------===//
1725 
1727  explicit Impl(const Options &options)
1728  : options(options), compilationDatabase(options.compilationDatabases) {}
1729 
1730  /// PDLL LSP options.
1732 
1733  /// The compilation database containing additional information for files
1734  /// passed to the server.
1736 
1737  /// The files held by the server, mapped by their URI file name.
1738  llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1739 };
1740 
1741 //===----------------------------------------------------------------------===//
1742 // PDLLServer
1743 //===----------------------------------------------------------------------===//
1744 
1745 lsp::PDLLServer::PDLLServer(const Options &options)
1746  : impl(std::make_unique<Impl>(options)) {}
1747 lsp::PDLLServer::~PDLLServer() = default;
1748 
1750  const URIForFile &uri, StringRef contents, int64_t version,
1751  std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1752  // Build the set of additional include directories.
1753  std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1754  const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file());
1755  llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1756 
1757  impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1758  uri, contents, version, additionalIncludeDirs, diagnostics);
1759 }
1760 
1762  const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1763  int64_t version, std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1764  // Check that we actually have a document for this uri.
1765  auto it = impl->files.find(uri.file());
1766  if (it == impl->files.end())
1767  return;
1768 
1769  // Try to update the document. If we fail, erase the file from the server. A
1770  // failed updated generally means we've fallen out of sync somewhere.
1771  if (failed(it->second->update(uri, version, changes, diagnostics)))
1772  impl->files.erase(it);
1773 }
1774 
1775 std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1776  auto it = impl->files.find(uri.file());
1777  if (it == impl->files.end())
1778  return std::nullopt;
1779 
1780  int64_t version = it->second->getVersion();
1781  impl->files.erase(it);
1782  return version;
1783 }
1784 
1786  const URIForFile &uri, const Position &defPos,
1787  std::vector<llvm::lsp::Location> &locations) {
1788  auto fileIt = impl->files.find(uri.file());
1789  if (fileIt != impl->files.end())
1790  fileIt->second->getLocationsOf(uri, defPos, locations);
1791 }
1792 
1794  const URIForFile &uri, const Position &pos,
1795  std::vector<llvm::lsp::Location> &references) {
1796  auto fileIt = impl->files.find(uri.file());
1797  if (fileIt != impl->files.end())
1798  fileIt->second->findReferencesOf(uri, pos, references);
1799 }
1800 
1802  const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1803  auto fileIt = impl->files.find(uri.file());
1804  if (fileIt != impl->files.end())
1805  return fileIt->second->getDocumentLinks(uri, documentLinks);
1806 }
1807 
1808 std::optional<llvm::lsp::Hover>
1809 lsp::PDLLServer::findHover(const URIForFile &uri, const Position &hoverPos) {
1810  auto fileIt = impl->files.find(uri.file());
1811  if (fileIt != impl->files.end())
1812  return fileIt->second->findHover(uri, hoverPos);
1813  return std::nullopt;
1814 }
1815 
1817  const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1818  auto fileIt = impl->files.find(uri.file());
1819  if (fileIt != impl->files.end())
1820  fileIt->second->findDocumentSymbols(symbols);
1821 }
1822 
1823 lsp::CompletionList
1825  const Position &completePos) {
1826  auto fileIt = impl->files.find(uri.file());
1827  if (fileIt != impl->files.end())
1828  return fileIt->second->getCodeCompletion(uri, completePos);
1829  return CompletionList();
1830 }
1831 
1832 llvm::lsp::SignatureHelp
1834  const Position &helpPos) {
1835  auto fileIt = impl->files.find(uri.file());
1836  if (fileIt != impl->files.end())
1837  return fileIt->second->getSignatureHelp(uri, helpPos);
1838  return SignatureHelp();
1839 }
1840 
1841 void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1842  std::vector<InlayHint> &inlayHints) {
1843  auto fileIt = impl->files.find(uri.file());
1844  if (fileIt == impl->files.end())
1845  return;
1846  fileIt->second->getInlayHints(uri, range, inlayHints);
1847 
1848  // Drop any duplicated hints that may have cropped up.
1849  llvm::sort(inlayHints);
1850  inlayHints.erase(llvm::unique(inlayHints), inlayHints.end());
1851 }
1852 
1853 std::optional<lsp::PDLLViewOutputResult>
1856  auto fileIt = impl->files.find(uri.file());
1857  if (fileIt != impl->files.end())
1858  return fileIt->second->getPDLLViewOutput(kind);
1859  return std::nullopt;
1860 }
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
Definition: MLIRServer.cpp:111
@ Error
static std::string diag(const llvm::Value &value)
static std::optional< std::string > getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl)
Get or extract the documentation for the given decl.
Definition: PDLLServer.cpp:122
static llvm::lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, const llvm::lsp::URIForFile &uri)
Returns a language server location from the given source range.
Definition: PDLLServer.cpp:65
static bool shouldAddHintFor(const ast::Expr *expr, StringRef name)
Returns true if the given name should be added as a hint for expr.
static std::optional< llvm::lsp::Diagnostic > getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, const llvm::lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
Definition: PDLLServer.cpp:73
static llvm::lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, const llvm::lsp::URIForFile &mainFileURI)
Returns a language server uri for the given source location.
Definition: PDLLServer.cpp:42
static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc)
Returns true if the given location is in the main file of the source manager.
Definition: PDLLServer.cpp:59
static llvm::ManagedStatic< PassManagerOptions > options
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Set of flags used to control the behavior of the various IR print methods (e.g.
This class is a utility diagnostic handler for use with llvm::SourceMgr.
Definition: Diagnostics.h:559
This class contains a collection of compilation information for files provided to the language server...
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
void getDocumentLinks(const URIForFile &uri, std::vector< DocumentLink > &documentLinks)
Return the document links referenced by the given file.
void getInlayHints(const URIForFile &uri, const Range &range, std::vector< InlayHint > &inlayHints)
Get the inlay hints for the range within the given file.
void addDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add the document, with the provided version, at the given URI.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
SignatureHelp getSignatureHelp(const URIForFile &uri, const Position &helpPos)
Get the signature help for the position within the given file.
void updateDocument(const URIForFile &uri, ArrayRef< TextDocumentContentChangeEvent > changes, int64_t version, std::vector< Diagnostic > &diagnostics)
Update the document, with the provided version, at the given URI.
std::optional< PDLLViewOutputResult > getPDLLViewOutput(const URIForFile &uri, PDLLViewOutputKind kind)
Get the output of the given PDLL file, or std::nullopt if there is no valid output.
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
Definition: CodeComplete.h:48
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
Definition: CodeComplete.h:80
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
Definition: CodeComplete.h:62
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
Definition: CodeComplete.h:85
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
Definition: CodeComplete.h:59
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
Definition: CodeComplete.h:65
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
Definition: CodeComplete.h:75
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
Definition: CodeComplete.h:68
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:750
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition: Nodes.h:393
Expr * getCallableExpr() const
Return the callable of this call.
Definition: Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition: Nodes.h:403
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1194
This class represents the main context of the PDLL AST.
Definition: Context.h:25
This class represents the base of all "core" constraints.
Definition: Nodes.h:733
This class represents a scope for named AST decls.
Definition: Nodes.h:64
This class represents the base Decl node.
Definition: Nodes.h:669
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
Definition: Nodes.h:682
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:29
This class represents a base AST Expression node.
Definition: Nodes.h:348
Type getType() const
Return the type of this expression.
Definition: Nodes.h:351
This class represents a top-level AST module.
Definition: Nodes.h:1297
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
The class represents an Operation constraint, and constrains a variable to be an Operation.
Definition: Nodes.h:774
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition: Nodes.cpp:404
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition: Nodes.h:237
This expression represents the structural form of an MLIR Operation.
Definition: Nodes.h:512
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition: Nodes.h:540
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition: Nodes.h:532
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:134
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
This Decl represents a single Pattern.
Definition: Nodes.h:1043
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
Definition: Nodes.h:1060
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition: Nodes.h:1051
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition: Nodes.h:1054
This class represents a PDLL tuple type, i.e.
Definition: Types.h:222
The class represents a Type constraint, and constrains a variable to be a Type.
Definition: Nodes.h:800
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
Definition: Nodes.h:815
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:830
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:853
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1248
const Name & getName() const
Return the name of the decl.
Definition: Nodes.h:1267
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition: Nodes.h:1264
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1255
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1270
This class represents a generic ODS Attribute constraint.
Definition: Constraint.h:63
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
Definition: Constraint.h:66
This class provides an ODS representation of a specific operation attribute.
Definition: Operation.h:39
StringRef getSummary() const
Return the summary of this constraint.
Definition: Constraint.h:44
StringRef getName() const
Return the unique name of this constraint.
Definition: Constraint.h:35
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
auto getDialects() const
Return a range of all of the registered dialects.
Definition: Context.h:57
const Dialect * lookupDialect(StringRef name) const
Lookup a dialect registered with the given name, or null if no dialect with that name was inserted.
Definition: Context.cpp:57
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition: Context.cpp:72
This class represents an ODS dialect, and contains information on the constructs held within the dial...
Definition: Dialect.h:26
const llvm::StringMap< std::unique_ptr< Operation > > & getOperations() const
Return a map of all of the operations registered to this dialect.
Definition: Dialect.h:46
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
const TypeConstraint & getConstraint() const
Return the constraint of this value.
Definition: Operation.h:97
VariableLengthKind getVariableLengthKind() const
Returns the variable length kind of this value.
Definition: Operation.h:92
StringRef getName() const
Return the name of this value.
Definition: Operation.h:77
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
StringRef getDescription() const
Returns the description of the operation.
Definition: Operation.h:156
StringRef getSummary() const
Returns the summary of the operation.
Definition: Operation.h:153
ArrayRef< Attribute > getAttributes() const
Returns the attributes of this operation.
Definition: Operation.h:162
StringRef getName() const
Returns the name of the operation.
Definition: Operation.h:150
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
StringRef getCppClass() const
Return the name of the underlying c++ class of this constraint.
Definition: Constraint.h:87
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
PDLLViewOutputKind
The type of output to view from PDLL.
Definition: Protocol.h:34
void gatherIncludeFiles(llvm::SourceMgr &sourceMgr, SmallVectorImpl< SourceMgrInclude > &includes)
Given a source manager, gather all of the processed include files.
std::optional< std::string > extractSourceDocComment(llvm::SourceMgr &sourceMgr, SMLoc loc)
Extract a documentation comment for the given location within the source manager.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Definition: CPPGen.cpp:251
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3203
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Definition: MLIRGen.cpp:624
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const char *const kDefaultSplitMarker
Impl(const Options &options)
lsp::CompilationDatabase compilationDatabase
The compilation database containing additional information for files passed to the server.
llvm::StringMap< std::unique_ptr< PDLTextFile > > files
The files held by the server, mapped by their URI file name.
const Options & options
PDLL LSP options.
Represents the result of viewing the output of a PDLL file.
Definition: Protocol.h:60
std::string output
The string representation of the output.
Definition: Protocol.h:62
This class represents a single include within a root file.
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44