MLIR  19.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "Protocol.h"
15 #include "mlir/IR/Operation.h"
17 #include "mlir/Parser/Parser.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Base64.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include <optional>
25 
26 using namespace mlir;
27 
28 /// Returns the range of a lexical token given a SMLoc corresponding to the
29 /// start of an token location. The range is computed heuristically, and
30 /// supports identifier-like tokens, strings, etc.
31 static SMRange convertTokenLocToRange(SMLoc loc) {
32  return lsp::convertTokenLocToRange(loc, "$-.");
33 }
34 
35 /// Returns a language server location from the given MLIR file location.
36 /// `uriScheme` is the scheme to use when building new uris.
37 static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38  FileLineColLoc loc) {
40  lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
41  if (!sourceURI) {
42  lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43  loc.getFilename(),
44  llvm::toString(sourceURI.takeError()));
45  return std::nullopt;
46  }
47 
48  lsp::Position position;
49  position.line = loc.getLine() - 1;
50  position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51  return lsp::Location{*sourceURI, lsp::Range(position)};
52 }
53 
54 /// Returns a language server location from the given MLIR location, or
55 /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56 /// when building new uris. `uri` is an optional additional filter that, when
57 /// present, is used to filter sub locations that do not share the same uri.
58 static std::optional<lsp::Location>
59 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60  StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
61  std::optional<lsp::Location> location;
62  loc->walk([&](Location nestedLoc) {
63  FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64  if (!fileLoc)
65  return WalkResult::advance();
66 
67  std::optional<lsp::Location> sourceLoc =
68  getLocationFromLoc(uriScheme, fileLoc);
69  if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70  location = *sourceLoc;
71  SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72  sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
73 
74  // Use range of potential identifier starting at location, else length 1
75  // range.
76  location->range.end.character += 1;
77  if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78  auto lineCol = sourceMgr.getLineAndColumn(range->End);
79  location->range.end.character =
80  std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
81  }
82  return WalkResult::interrupt();
83  }
84  return WalkResult::advance();
85  });
86  return location;
87 }
88 
89 /// Collect all of the locations from the given MLIR location that are not
90 /// contained within the given URI.
92  std::vector<lsp::Location> &locations,
93  const lsp::URIForFile &uri) {
94  SetVector<Location> visitedLocs;
95  loc->walk([&](Location nestedLoc) {
96  FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97  if (!fileLoc || !visitedLocs.insert(nestedLoc))
98  return WalkResult::advance();
99 
100  std::optional<lsp::Location> sourceLoc =
101  getLocationFromLoc(uri.scheme(), fileLoc);
102  if (sourceLoc && sourceLoc->uri != uri)
103  locations.push_back(*sourceLoc);
104  return WalkResult::advance();
105  });
106 }
107 
108 /// Returns true if the given range contains the given source location. Note
109 /// that this has slightly different behavior than SMRange because it is
110 /// inclusive of the end location.
111 static bool contains(SMRange range, SMLoc loc) {
112  return range.Start.getPointer() <= loc.getPointer() &&
113  loc.getPointer() <= range.End.getPointer();
114 }
115 
116 /// Returns true if the given location is contained by the definition or one of
117 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118 /// the range within `def` that the provided `loc` overlapped with.
119 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
120  SMRange *overlappedRange = nullptr) {
121  // Check the main definition.
122  if (contains(def.loc, loc)) {
123  if (overlappedRange)
124  *overlappedRange = def.loc;
125  return true;
126  }
127 
128  // Check the uses.
129  const auto *useIt = llvm::find_if(
130  def.uses, [&](const SMRange &range) { return contains(range, loc); });
131  if (useIt != def.uses.end()) {
132  if (overlappedRange)
133  *overlappedRange = *useIt;
134  return true;
135  }
136  return false;
137 }
138 
139 /// Given a location pointing to a result, return the result number it refers
140 /// to or std::nullopt if it refers to all of the results.
141 static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
142  // Skip all of the identifier characters.
143  auto isIdentifierChar = [](char c) {
144  return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
145  c == '-';
146  };
147  const char *curPtr = loc.getPointer();
148  while (isIdentifierChar(*curPtr))
149  ++curPtr;
150 
151  // Check to see if this location indexes into the result group, via `#`. If it
152  // doesn't, we can't extract a sub result number.
153  if (*curPtr != '#')
154  return std::nullopt;
155 
156  // Compute the sub result number from the remaining portion of the string.
157  const char *numberStart = ++curPtr;
158  while (llvm::isDigit(*curPtr))
159  ++curPtr;
160  StringRef numberStr(numberStart, curPtr - numberStart);
161  unsigned resultNumber = 0;
162  return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
163  : resultNumber;
164 }
165 
166 /// Given a source location range, return the text covered by the given range.
167 /// If the range is invalid, returns std::nullopt.
168 static std::optional<StringRef> getTextFromRange(SMRange range) {
169  if (!range.isValid())
170  return std::nullopt;
171  const char *startPtr = range.Start.getPointer();
172  return StringRef(startPtr, range.End.getPointer() - startPtr);
173 }
174 
175 /// Given a block, return its position in its parent region.
176 static unsigned getBlockNumber(Block *block) {
177  return std::distance(block->getParent()->begin(), block->getIterator());
178 }
179 
180 /// Given a block and source location, print the source name of the block to the
181 /// given output stream.
182 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
183  // Try to extract a name from the source location.
184  std::optional<StringRef> text = getTextFromRange(loc);
185  if (text && text->starts_with("^")) {
186  os << *text;
187  return;
188  }
189 
190  // Otherwise, we don't have a name so print the block number.
191  os << "<Block #" << getBlockNumber(block) << ">";
192 }
193 static void printDefBlockName(raw_ostream &os,
194  const AsmParserState::BlockDefinition &def) {
195  printDefBlockName(os, def.block, def.definition.loc);
196 }
197 
198 /// Convert the given MLIR diagnostic to the LSP form.
199 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
200  Diagnostic &diag,
201  const lsp::URIForFile &uri) {
202  lsp::Diagnostic lspDiag;
203  lspDiag.source = "mlir";
204 
205  // Note: Right now all of the diagnostics are treated as parser issues, but
206  // some are parser and some are verifier.
207  lspDiag.category = "Parse Error";
208 
209  // Try to grab a file location for this diagnostic.
210  // TODO: For simplicity, we just grab the first one. It may be likely that we
211  // will need a more interesting heuristic here.'
212  StringRef uriScheme = uri.scheme();
213  std::optional<lsp::Location> lspLocation =
214  getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
215  if (lspLocation)
216  lspDiag.range = lspLocation->range;
217 
218  // Convert the severity for the diagnostic.
219  switch (diag.getSeverity()) {
221  llvm_unreachable("expected notes to be handled separately");
224  break;
227  break;
230  break;
231  }
232  lspDiag.message = diag.str();
233 
234  // Attach any notes to the main diagnostic as related information.
235  std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
236  for (Diagnostic &note : diag.getNotes()) {
237  lsp::Location noteLoc;
238  if (std::optional<lsp::Location> loc =
239  getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
240  noteLoc = *loc;
241  else
242  noteLoc.uri = uri;
243  relatedDiags.emplace_back(noteLoc, note.str());
244  }
245  if (!relatedDiags.empty())
246  lspDiag.relatedInformation = std::move(relatedDiags);
247 
248  return lspDiag;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // MLIRDocument
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 /// This class represents all of the information pertaining to a specific MLIR
257 /// document.
258 struct MLIRDocument {
259  MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
260  StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261  MLIRDocument(const MLIRDocument &) = delete;
262  MLIRDocument &operator=(const MLIRDocument &) = delete;
263 
264  //===--------------------------------------------------------------------===//
265  // Definitions and References
266  //===--------------------------------------------------------------------===//
267 
268  void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
269  std::vector<lsp::Location> &locations);
270  void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
271  std::vector<lsp::Location> &references);
272 
273  //===--------------------------------------------------------------------===//
274  // Hover
275  //===--------------------------------------------------------------------===//
276 
277  std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
278  const lsp::Position &hoverPos);
279  std::optional<lsp::Hover>
280  buildHoverForOperation(SMRange hoverRange,
282  lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
283  unsigned resultStart,
284  unsigned resultEnd, SMLoc posLoc);
285  lsp::Hover buildHoverForBlock(SMRange hoverRange,
286  const AsmParserState::BlockDefinition &block);
287  lsp::Hover
288  buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
289  const AsmParserState::BlockDefinition &block);
290 
291  lsp::Hover buildHoverForAttributeAlias(
292  SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
293  lsp::Hover
294  buildHoverForTypeAlias(SMRange hoverRange,
296 
297  //===--------------------------------------------------------------------===//
298  // Document Symbols
299  //===--------------------------------------------------------------------===//
300 
301  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302  void findDocumentSymbols(Operation *op,
303  std::vector<lsp::DocumentSymbol> &symbols);
304 
305  //===--------------------------------------------------------------------===//
306  // Code Completion
307  //===--------------------------------------------------------------------===//
308 
309  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
310  const lsp::Position &completePos,
311  const DialectRegistry &registry);
312 
313  //===--------------------------------------------------------------------===//
314  // Code Action
315  //===--------------------------------------------------------------------===//
316 
317  void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
318  lsp::Position &pos, StringRef severity,
319  StringRef message,
320  std::vector<lsp::TextEdit> &edits);
321 
322  //===--------------------------------------------------------------------===//
323  // Bytecode
324  //===--------------------------------------------------------------------===//
325 
327 
328  //===--------------------------------------------------------------------===//
329  // Fields
330  //===--------------------------------------------------------------------===//
331 
332  /// The high level parser state used to find definitions and references within
333  /// the source file.
334  AsmParserState asmState;
335 
336  /// The container for the IR parsed from the input file.
337  Block parsedIR;
338 
339  /// A collection of external resources, which we want to propagate up to the
340  /// user.
341  FallbackAsmResourceMap fallbackResourceMap;
342 
343  /// The source manager containing the contents of the input file.
344  llvm::SourceMgr sourceMgr;
345 };
346 } // namespace
347 
348 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
349  StringRef contents,
350  std::vector<lsp::Diagnostic> &diagnostics) {
351  ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
352  diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
353  });
354 
355  // Try to parsed the given IR string.
356  auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
357  if (!memBuffer) {
358  lsp::Logger::error("Failed to create memory buffer for file", uri.file());
359  return;
360  }
361 
362  ParserConfig config(&context, /*verifyAfterParse=*/true,
363  &fallbackResourceMap);
364  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
365  if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
366  // If parsing failed, clear out any of the current state.
367  parsedIR.clear();
368  asmState = AsmParserState();
369  fallbackResourceMap = FallbackAsmResourceMap();
370  return;
371  }
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // MLIRDocument: Definitions and References
376 //===----------------------------------------------------------------------===//
377 
378 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
379  const lsp::Position &defPos,
380  std::vector<lsp::Location> &locations) {
381  SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
382 
383  // Functor used to check if an SM definition contains the position.
384  auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
385  if (!isDefOrUse(def, posLoc))
386  return false;
387  locations.emplace_back(uri, sourceMgr, def.loc);
388  return true;
389  };
390 
391  // Check all definitions related to operations.
392  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
393  if (contains(op.loc, posLoc))
394  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
395  for (const auto &result : op.resultGroups)
396  if (containsPosition(result.definition))
397  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
398  for (const auto &symUse : op.symbolUses) {
399  if (contains(symUse, posLoc)) {
400  locations.emplace_back(uri, sourceMgr, op.loc);
401  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
402  }
403  }
404  }
405 
406  // Check all definitions related to blocks.
407  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
408  if (containsPosition(block.definition))
409  return;
410  for (const AsmParserState::SMDefinition &arg : block.arguments)
411  if (containsPosition(arg))
412  return;
413  }
414 
415  // Check all alias definitions.
417  asmState.getAttributeAliasDefs()) {
418  if (containsPosition(attr.definition))
419  return;
420  }
421  for (const AsmParserState::TypeAliasDefinition &type :
422  asmState.getTypeAliasDefs()) {
423  if (containsPosition(type.definition))
424  return;
425  }
426 }
427 
428 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
429  const lsp::Position &pos,
430  std::vector<lsp::Location> &references) {
431  // Functor used to append all of the definitions/uses of the given SM
432  // definition to the reference list.
433  auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
434  references.emplace_back(uri, sourceMgr, def.loc);
435  for (const SMRange &use : def.uses)
436  references.emplace_back(uri, sourceMgr, use);
437  };
438 
439  SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
440 
441  // Check all definitions related to operations.
442  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
443  if (contains(op.loc, posLoc)) {
444  for (const auto &result : op.resultGroups)
445  appendSMDef(result.definition);
446  for (const auto &symUse : op.symbolUses)
447  if (contains(symUse, posLoc))
448  references.emplace_back(uri, sourceMgr, symUse);
449  return;
450  }
451  for (const auto &result : op.resultGroups)
452  if (isDefOrUse(result.definition, posLoc))
453  return appendSMDef(result.definition);
454  for (const auto &symUse : op.symbolUses) {
455  if (!contains(symUse, posLoc))
456  continue;
457  for (const auto &symUse : op.symbolUses)
458  references.emplace_back(uri, sourceMgr, symUse);
459  return;
460  }
461  }
462 
463  // Check all definitions related to blocks.
464  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
465  if (isDefOrUse(block.definition, posLoc))
466  return appendSMDef(block.definition);
467 
468  for (const AsmParserState::SMDefinition &arg : block.arguments)
469  if (isDefOrUse(arg, posLoc))
470  return appendSMDef(arg);
471  }
472 
473  // Check all alias definitions.
475  asmState.getAttributeAliasDefs()) {
476  if (isDefOrUse(attr.definition, posLoc))
477  return appendSMDef(attr.definition);
478  }
479  for (const AsmParserState::TypeAliasDefinition &type :
480  asmState.getTypeAliasDefs()) {
481  if (isDefOrUse(type.definition, posLoc))
482  return appendSMDef(type.definition);
483  }
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // MLIRDocument: Hover
488 //===----------------------------------------------------------------------===//
489 
490 std::optional<lsp::Hover>
491 MLIRDocument::findHover(const lsp::URIForFile &uri,
492  const lsp::Position &hoverPos) {
493  SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
494  SMRange hoverRange;
495 
496  // Check for Hovers on operations and results.
497  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
498  // Check if the position points at this operation.
499  if (contains(op.loc, posLoc))
500  return buildHoverForOperation(op.loc, op);
501 
502  // Check if the position points at the symbol name.
503  for (auto &use : op.symbolUses)
504  if (contains(use, posLoc))
505  return buildHoverForOperation(use, op);
506 
507  // Check if the position points at a result group.
508  for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
509  const auto &result = op.resultGroups[i];
510  if (!isDefOrUse(result.definition, posLoc, &hoverRange))
511  continue;
512 
513  // Get the range of results covered by the over position.
514  unsigned resultStart = result.startIndex;
515  unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
516  : op.resultGroups[i + 1].startIndex;
517  return buildHoverForOperationResult(hoverRange, op.op, resultStart,
518  resultEnd, posLoc);
519  }
520  }
521 
522  // Check to see if the hover is over a block argument.
523  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
524  if (isDefOrUse(block.definition, posLoc, &hoverRange))
525  return buildHoverForBlock(hoverRange, block);
526 
527  for (const auto &arg : llvm::enumerate(block.arguments)) {
528  if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
529  continue;
530 
531  return buildHoverForBlockArgument(
532  hoverRange, block.block->getArgument(arg.index()), block);
533  }
534  }
535 
536  // Check to see if the hover is over an alias.
538  asmState.getAttributeAliasDefs()) {
539  if (isDefOrUse(attr.definition, posLoc, &hoverRange))
540  return buildHoverForAttributeAlias(hoverRange, attr);
541  }
542  for (const AsmParserState::TypeAliasDefinition &type :
543  asmState.getTypeAliasDefs()) {
544  if (isDefOrUse(type.definition, posLoc, &hoverRange))
545  return buildHoverForTypeAlias(hoverRange, type);
546  }
547 
548  return std::nullopt;
549 }
550 
551 std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
552  SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
553  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
554  llvm::raw_string_ostream os(hover.contents.value);
555 
556  // Add the operation name to the hover.
557  os << "\"" << op.op->getName() << "\"";
558  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
559  os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
560  os << "\n\n";
561 
562  os << "Generic Form:\n\n```mlir\n";
563 
564  op.op->print(os, OpPrintingFlags()
565  .printGenericOpForm()
566  .elideLargeElementsAttrs()
567  .skipRegions());
568  os << "\n```\n";
569 
570  return hover;
571 }
572 
573 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
574  Operation *op,
575  unsigned resultStart,
576  unsigned resultEnd,
577  SMLoc posLoc) {
578  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579  llvm::raw_string_ostream os(hover.contents.value);
580 
581  // Add the parent operation name to the hover.
582  os << "Operation: \"" << op->getName() << "\"\n\n";
583 
584  // Check to see if the location points to a specific result within the
585  // group.
586  if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
587  if ((resultStart + *resultNumber) < resultEnd) {
588  resultStart += *resultNumber;
589  resultEnd = resultStart + 1;
590  }
591  }
592 
593  // Add the range of results and their types to the hover info.
594  if ((resultStart + 1) == resultEnd) {
595  os << "Result #" << resultStart << "\n\n"
596  << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
597  } else {
598  os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
599  << "Types: ";
600  llvm::interleaveComma(
601  op->getResults().slice(resultStart, resultEnd), os,
602  [&](Value result) { os << "`" << result.getType() << "`"; });
603  }
604 
605  return hover;
606 }
607 
609 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
610  const AsmParserState::BlockDefinition &block) {
611  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
612  llvm::raw_string_ostream os(hover.contents.value);
613 
614  // Print the given block to the hover output stream.
615  auto printBlockToHover = [&](Block *newBlock) {
616  if (const auto *def = asmState.getBlockDef(newBlock))
617  printDefBlockName(os, *def);
618  else
619  printDefBlockName(os, newBlock);
620  };
621 
622  // Display the parent operation, block number, predecessors, and successors.
623  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
624  << "Block #" << getBlockNumber(block.block) << "\n\n";
625  if (!block.block->hasNoPredecessors()) {
626  os << "Predecessors: ";
627  llvm::interleaveComma(block.block->getPredecessors(), os,
628  printBlockToHover);
629  os << "\n\n";
630  }
631  if (!block.block->hasNoSuccessors()) {
632  os << "Successors: ";
633  llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
634  os << "\n\n";
635  }
636 
637  return hover;
638 }
639 
640 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
641  SMRange hoverRange, BlockArgument arg,
642  const AsmParserState::BlockDefinition &block) {
643  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
644  llvm::raw_string_ostream os(hover.contents.value);
645 
646  // Display the parent operation, block, the argument number, and the type.
647  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
648  << "Block: ";
649  printDefBlockName(os, block);
650  os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
651  << "Type: `" << arg.getType() << "`\n\n";
652 
653  return hover;
654 }
655 
656 lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
657  SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
658  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
659  llvm::raw_string_ostream os(hover.contents.value);
660 
661  os << "Attribute Alias: \"" << attr.name << "\n\n";
662  os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
663 
664  return hover;
665 }
666 
667 lsp::Hover MLIRDocument::buildHoverForTypeAlias(
668  SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
669  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
670  llvm::raw_string_ostream os(hover.contents.value);
671 
672  os << "Type Alias: \"" << type.name << "\n\n";
673  os << "Value: ```mlir\n" << type.value << "\n```\n\n";
674 
675  return hover;
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // MLIRDocument: Document Symbols
680 //===----------------------------------------------------------------------===//
681 
682 void MLIRDocument::findDocumentSymbols(
683  std::vector<lsp::DocumentSymbol> &symbols) {
684  for (Operation &op : parsedIR)
685  findDocumentSymbols(&op, symbols);
686 }
687 
688 void MLIRDocument::findDocumentSymbols(
689  Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
690  std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
691 
692  // Check for the source information of this operation.
693  if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
694  // If this operation defines a symbol, record it.
695  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
696  symbols.emplace_back(symbol.getName(),
697  isa<FunctionOpInterface>(op)
700  lsp::Range(sourceMgr, def->scopeLoc),
701  lsp::Range(sourceMgr, def->loc));
702  childSymbols = &symbols.back().children;
703 
704  } else if (op->hasTrait<OpTrait::SymbolTable>()) {
705  // Otherwise, if this is a symbol table push an anonymous document symbol.
706  symbols.emplace_back("<" + op->getName().getStringRef() + ">",
708  lsp::Range(sourceMgr, def->scopeLoc),
709  lsp::Range(sourceMgr, def->loc));
710  childSymbols = &symbols.back().children;
711  }
712  }
713 
714  // Recurse into the regions of this operation.
715  if (!op->getNumRegions())
716  return;
717  for (Region &region : op->getRegions())
718  for (Operation &childOp : region.getOps())
719  findDocumentSymbols(&childOp, *childSymbols);
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // MLIRDocument: Code Completion
724 //===----------------------------------------------------------------------===//
725 
726 namespace {
727 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
728 public:
729  LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
730  MLIRContext *ctx)
731  : AsmParserCodeCompleteContext(completeLoc),
732  completionList(completionList), ctx(ctx) {}
733 
734  /// Signal code completion for a dialect name, with an optional prefix.
735  void completeDialectName(StringRef prefix) final {
736  for (StringRef dialect : ctx->getAvailableDialects()) {
737  lsp::CompletionItem item(prefix + dialect,
739  /*sortText=*/"3");
740  item.detail = "dialect";
741  completionList.items.emplace_back(item);
742  }
743  }
745 
746  /// Signal code completion for an operation name within the given dialect.
747  void completeOperationName(StringRef dialectName) final {
748  Dialect *dialect = ctx->getOrLoadDialect(dialectName);
749  if (!dialect)
750  return;
751 
752  for (const auto &op : ctx->getRegisteredOperations()) {
753  if (&op.getDialect() != dialect)
754  continue;
755 
756  lsp::CompletionItem item(
757  op.getStringRef().drop_front(dialectName.size() + 1),
759  /*sortText=*/"1");
760  item.detail = "operation";
761  completionList.items.emplace_back(item);
762  }
763  }
764 
765  /// Append the given SSA value as a code completion result for SSA value
766  /// completions.
767  void appendSSAValueCompletion(StringRef name, std::string typeData) final {
768  // Check if we need to insert the `%` or not.
769  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
770 
772  if (stripPrefix)
773  item.insertText = name.drop_front(1).str();
774  item.detail = std::move(typeData);
775  completionList.items.emplace_back(item);
776  }
777 
778  /// Append the given block as a code completion result for block name
779  /// completions.
780  void appendBlockCompletion(StringRef name) final {
781  // Check if we need to insert the `^` or not.
782  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
783 
785  if (stripPrefix)
786  item.insertText = name.drop_front(1).str();
787  completionList.items.emplace_back(item);
788  }
789 
790  /// Signal a completion for the given expected token.
791  void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
792  for (StringRef token : tokens) {
794  /*sortText=*/"0");
795  item.detail = optional ? "optional" : "";
796  completionList.items.emplace_back(item);
797  }
798  }
799 
800  /// Signal a completion for an attribute.
801  void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
802  appendSimpleCompletions({"affine_set", "affine_map", "dense",
803  "dense_resource", "false", "loc", "sparse", "true",
804  "unit"},
806  /*sortText=*/"1");
807 
808  completeDialectName("#");
809  completeAliases(aliases, "#");
810  }
811  void completeDialectAttributeOrAlias(
812  const llvm::StringMap<Attribute> &aliases) override {
813  completeDialectName();
814  completeAliases(aliases);
815  }
816 
817  /// Signal a completion for a type.
818  void completeType(const llvm::StringMap<Type> &aliases) override {
819  // Handle the various builtin types.
820  appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
821  "bf16", "f16", "f32", "f64", "f80", "f128",
822  "index", "none"},
824  /*sortText=*/"1");
825 
826  // Handle the builtin integer types.
827  for (StringRef type : {"i", "si", "ui"}) {
829  /*sortText=*/"1");
830  item.insertText = type.str();
831  completionList.items.emplace_back(item);
832  }
833 
834  // Insert completions for dialect types and aliases.
835  completeDialectName("!");
836  completeAliases(aliases, "!");
837  }
838  void
839  completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
840  completeDialectName();
841  completeAliases(aliases);
842  }
843 
844  /// Add completion results for the given set of aliases.
845  template <typename T>
846  void completeAliases(const llvm::StringMap<T> &aliases,
847  StringRef prefix = "") {
848  for (const auto &alias : aliases) {
849  lsp::CompletionItem item(prefix + alias.getKey(),
851  /*sortText=*/"2");
852  llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
853  completionList.items.emplace_back(item);
854  }
855  }
856 
857  /// Add a set of simple completions that all have the same kind.
858  void appendSimpleCompletions(ArrayRef<StringRef> completions,
860  StringRef sortText = "") {
861  for (StringRef completion : completions)
862  completionList.items.emplace_back(completion, kind, sortText);
863  }
864 
865 private:
866  lsp::CompletionList &completionList;
867  MLIRContext *ctx;
868 };
869 } // namespace
870 
872 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
873  const lsp::Position &completePos,
874  const DialectRegistry &registry) {
875  SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
876  if (!posLoc.isValid())
877  return lsp::CompletionList();
878 
879  // To perform code completion, we run another parse of the module with the
880  // code completion context provided.
881  MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
882  tmpContext.allowUnregisteredDialects();
883  lsp::CompletionList completionList;
884  LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
885  &tmpContext);
886 
887  Block tmpIR;
888  AsmParserState tmpState;
889  (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
890  &lspCompleteContext);
891  return completionList;
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // MLIRDocument: Code Action
896 //===----------------------------------------------------------------------===//
897 
898 void MLIRDocument::getCodeActionForDiagnostic(
899  const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
900  StringRef message, std::vector<lsp::TextEdit> &edits) {
901  // Ignore diagnostics that print the current operation. These are always
902  // enabled for the language server, but not generally during normal
903  // parsing/verification.
904  if (message.starts_with("see current operation: "))
905  return;
906 
907  // Get the start of the line containing the diagnostic.
908  const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
909  const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
910  if (!lineStart)
911  return;
912  StringRef line(lineStart, pos.character);
913 
914  // Add a text edit for adding an expected-* diagnostic check for this
915  // diagnostic.
916  lsp::TextEdit edit;
917  edit.range = lsp::Range(lsp::Position(pos.line, 0));
918 
919  // Use the indent of the current line for the expected-* diagnostic.
920  size_t indent = line.find_first_not_of(" ");
921  if (indent == StringRef::npos)
922  indent = line.size();
923 
924  edit.newText.append(indent, ' ');
925  llvm::raw_string_ostream(edit.newText)
926  << "// expected-" << severity << " @below {{" << message << "}}\n";
927  edits.emplace_back(std::move(edit));
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // MLIRDocument: Bytecode
932 //===----------------------------------------------------------------------===//
933 
935 MLIRDocument::convertToBytecode() {
936  // TODO: We currently require a single top-level operation, but this could
937  // conceptually be relaxed.
938  if (!llvm::hasSingleElement(parsedIR)) {
939  if (parsedIR.empty()) {
940  return llvm::make_error<lsp::LSPError>(
941  "expected a single and valid top-level operation, please ensure "
942  "there are no errors",
944  }
945  return llvm::make_error<lsp::LSPError>(
946  "expected a single top-level operation", lsp::ErrorCode::RequestFailed);
947  }
948 
950  {
951  BytecodeWriterConfig writerConfig(fallbackResourceMap);
952 
953  std::string rawBytecodeBuffer;
954  llvm::raw_string_ostream os(rawBytecodeBuffer);
955  // No desired bytecode version set, so no need to check for error.
956  (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
957  result.output = llvm::encodeBase64(rawBytecodeBuffer);
958  }
959  return result;
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // MLIRTextFileChunk
964 //===----------------------------------------------------------------------===//
965 
966 namespace {
967 /// This class represents a single chunk of an MLIR text file.
968 struct MLIRTextFileChunk {
969  MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
970  const lsp::URIForFile &uri, StringRef contents,
971  std::vector<lsp::Diagnostic> &diagnostics)
972  : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
973 
974  /// Adjust the line number of the given range to anchor at the beginning of
975  /// the file, instead of the beginning of this chunk.
976  void adjustLocForChunkOffset(lsp::Range &range) {
977  adjustLocForChunkOffset(range.start);
978  adjustLocForChunkOffset(range.end);
979  }
980  /// Adjust the line number of the given position to anchor at the beginning of
981  /// the file, instead of the beginning of this chunk.
982  void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
983 
984  /// The line offset of this chunk from the beginning of the file.
985  uint64_t lineOffset;
986  /// The document referred to by this chunk.
987  MLIRDocument document;
988 };
989 } // namespace
990 
991 //===----------------------------------------------------------------------===//
992 // MLIRTextFile
993 //===----------------------------------------------------------------------===//
994 
995 namespace {
996 /// This class represents a text file containing one or more MLIR documents.
997 class MLIRTextFile {
998 public:
999  MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1000  int64_t version, DialectRegistry &registry,
1001  std::vector<lsp::Diagnostic> &diagnostics);
1002 
1003  /// Return the current version of this text file.
1004  int64_t getVersion() const { return version; }
1005 
1006  //===--------------------------------------------------------------------===//
1007  // LSP Queries
1008  //===--------------------------------------------------------------------===//
1009 
1010  void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1011  std::vector<lsp::Location> &locations);
1012  void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1013  std::vector<lsp::Location> &references);
1014  std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1015  lsp::Position hoverPos);
1016  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1017  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1018  lsp::Position completePos);
1019  void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1020  const lsp::CodeActionContext &context,
1021  std::vector<lsp::CodeAction> &actions);
1023 
1024 private:
1025  /// Find the MLIR document that contains the given position, and update the
1026  /// position to be anchored at the start of the found chunk instead of the
1027  /// beginning of the file.
1028  MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1029 
1030  /// The context used to hold the state contained by the parsed document.
1031  MLIRContext context;
1032 
1033  /// The full string contents of the file.
1034  std::string contents;
1035 
1036  /// The version of this file.
1037  int64_t version;
1038 
1039  /// The number of lines in the file.
1040  int64_t totalNumLines = 0;
1041 
1042  /// The chunks of this file. The order of these chunks is the order in which
1043  /// they appear in the text file.
1044  std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1045 };
1046 } // namespace
1047 
1048 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1049  int64_t version, DialectRegistry &registry,
1050  std::vector<lsp::Diagnostic> &diagnostics)
1051  : context(registry, MLIRContext::Threading::DISABLED),
1052  contents(fileContents.str()), version(version) {
1053  context.allowUnregisteredDialects();
1054 
1055  // Split the file into separate MLIR documents.
1056  SmallVector<StringRef, 8> subContents;
1057  StringRef(contents).split(subContents, kDefaultSplitMarker);
1058  chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1059  context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
1060 
1061  uint64_t lineOffset = subContents.front().count('\n');
1062  for (StringRef docContents : llvm::drop_begin(subContents)) {
1063  unsigned currentNumDiags = diagnostics.size();
1064  auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1065  docContents, diagnostics);
1066  lineOffset += docContents.count('\n');
1067 
1068  // Adjust locations used in diagnostics to account for the offset from the
1069  // beginning of the file.
1070  for (lsp::Diagnostic &diag :
1071  llvm::drop_begin(diagnostics, currentNumDiags)) {
1072  chunk->adjustLocForChunkOffset(diag.range);
1073 
1074  if (!diag.relatedInformation)
1075  continue;
1076  for (auto &it : *diag.relatedInformation)
1077  if (it.location.uri == uri)
1078  chunk->adjustLocForChunkOffset(it.location.range);
1079  }
1080  chunks.emplace_back(std::move(chunk));
1081  }
1082  totalNumLines = lineOffset;
1083 }
1084 
1085 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1086  lsp::Position defPos,
1087  std::vector<lsp::Location> &locations) {
1088  MLIRTextFileChunk &chunk = getChunkFor(defPos);
1089  chunk.document.getLocationsOf(uri, defPos, locations);
1090 
1091  // Adjust any locations within this file for the offset of this chunk.
1092  if (chunk.lineOffset == 0)
1093  return;
1094  for (lsp::Location &loc : locations)
1095  if (loc.uri == uri)
1096  chunk.adjustLocForChunkOffset(loc.range);
1097 }
1098 
1099 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1100  lsp::Position pos,
1101  std::vector<lsp::Location> &references) {
1102  MLIRTextFileChunk &chunk = getChunkFor(pos);
1103  chunk.document.findReferencesOf(uri, pos, references);
1104 
1105  // Adjust any locations within this file for the offset of this chunk.
1106  if (chunk.lineOffset == 0)
1107  return;
1108  for (lsp::Location &loc : references)
1109  if (loc.uri == uri)
1110  chunk.adjustLocForChunkOffset(loc.range);
1111 }
1112 
1113 std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1114  lsp::Position hoverPos) {
1115  MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1116  std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1117 
1118  // Adjust any locations within this file for the offset of this chunk.
1119  if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1120  chunk.adjustLocForChunkOffset(*hoverInfo->range);
1121  return hoverInfo;
1122 }
1123 
1124 void MLIRTextFile::findDocumentSymbols(
1125  std::vector<lsp::DocumentSymbol> &symbols) {
1126  if (chunks.size() == 1)
1127  return chunks.front()->document.findDocumentSymbols(symbols);
1128 
1129  // If there are multiple chunks in this file, we create top-level symbols for
1130  // each chunk.
1131  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1132  MLIRTextFileChunk &chunk = *chunks[i];
1133  lsp::Position startPos(chunk.lineOffset);
1134  lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1135  : chunks[i + 1]->lineOffset);
1136  lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1137  lsp::SymbolKind::Namespace,
1138  /*range=*/lsp::Range(startPos, endPos),
1139  /*selectionRange=*/lsp::Range(startPos));
1140  chunk.document.findDocumentSymbols(symbol.children);
1141 
1142  // Fixup the locations of document symbols within this chunk.
1143  if (i != 0) {
1145  for (lsp::DocumentSymbol &childSymbol : symbol.children)
1146  symbolsToFix.push_back(&childSymbol);
1147 
1148  while (!symbolsToFix.empty()) {
1149  lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1150  chunk.adjustLocForChunkOffset(symbol->range);
1151  chunk.adjustLocForChunkOffset(symbol->selectionRange);
1152 
1153  for (lsp::DocumentSymbol &childSymbol : symbol->children)
1154  symbolsToFix.push_back(&childSymbol);
1155  }
1156  }
1157 
1158  // Push the symbol for this chunk.
1159  symbols.emplace_back(std::move(symbol));
1160  }
1161 }
1162 
1163 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1164  lsp::Position completePos) {
1165  MLIRTextFileChunk &chunk = getChunkFor(completePos);
1166  lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1167  uri, completePos, context.getDialectRegistry());
1168 
1169  // Adjust any completion locations.
1170  for (lsp::CompletionItem &item : completionList.items) {
1171  if (item.textEdit)
1172  chunk.adjustLocForChunkOffset(item.textEdit->range);
1173  for (lsp::TextEdit &edit : item.additionalTextEdits)
1174  chunk.adjustLocForChunkOffset(edit.range);
1175  }
1176  return completionList;
1177 }
1178 
1179 void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1180  const lsp::Range &pos,
1181  const lsp::CodeActionContext &context,
1182  std::vector<lsp::CodeAction> &actions) {
1183  // Create actions for any diagnostics in this file.
1184  for (auto &diag : context.diagnostics) {
1185  if (diag.source != "mlir")
1186  continue;
1187  lsp::Position diagPos = diag.range.start;
1188  MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1189 
1190  // Add a new code action that inserts a "expected" diagnostic check.
1191  lsp::CodeAction action;
1192  action.title = "Add expected-* diagnostic checks";
1193  action.kind = lsp::CodeAction::kQuickFix.str();
1194 
1195  StringRef severity;
1196  switch (diag.severity) {
1198  severity = "error";
1199  break;
1200  case lsp::DiagnosticSeverity::Warning:
1201  severity = "warning";
1202  break;
1203  default:
1204  continue;
1205  }
1206 
1207  // Get edits for the diagnostic.
1208  std::vector<lsp::TextEdit> edits;
1209  chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1210  diag.message, edits);
1211 
1212  // Walk the related diagnostics, this is how we encode notes.
1213  if (diag.relatedInformation) {
1214  for (auto &noteDiag : *diag.relatedInformation) {
1215  if (noteDiag.location.uri != uri)
1216  continue;
1217  diagPos = noteDiag.location.range.start;
1218  diagPos.line -= chunk.lineOffset;
1219  chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1220  noteDiag.message, edits);
1221  }
1222  }
1223  // Fixup the locations for any edits.
1224  for (lsp::TextEdit &edit : edits)
1225  chunk.adjustLocForChunkOffset(edit.range);
1226 
1227  action.edit.emplace();
1228  action.edit->changes[uri.uri().str()] = std::move(edits);
1229  action.diagnostics = {diag};
1230 
1231  actions.emplace_back(std::move(action));
1232  }
1233 }
1234 
1236 MLIRTextFile::convertToBytecode() {
1237  // Bail out if there is more than one chunk, bytecode wants a single module.
1238  if (chunks.size() != 1) {
1239  return llvm::make_error<lsp::LSPError>(
1240  "unexpected split file, please remove all `// -----`",
1241  lsp::ErrorCode::RequestFailed);
1242  }
1243  return chunks.front()->document.convertToBytecode();
1244 }
1245 
1246 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1247  if (chunks.size() == 1)
1248  return *chunks.front();
1249 
1250  // Search for the first chunk with a greater line offset, the previous chunk
1251  // is the one that contains `pos`.
1252  auto it = llvm::upper_bound(
1253  chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1254  return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1255  });
1256  MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1257  pos.line -= chunk.lineOffset;
1258  return chunk;
1259 }
1260 
1261 //===----------------------------------------------------------------------===//
1262 // MLIRServer::Impl
1263 //===----------------------------------------------------------------------===//
1264 
1267 
1268  /// The registry containing dialects that can be recognized in parsed .mlir
1269  /// files.
1271 
1272  /// The files held by the server, mapped by their URI file name.
1273  llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1274 };
1275 
1276 //===----------------------------------------------------------------------===//
1277 // MLIRServer
1278 //===----------------------------------------------------------------------===//
1279 
1280 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1281  : impl(std::make_unique<Impl>(registry)) {}
1282 lsp::MLIRServer::~MLIRServer() = default;
1283 
1285  const URIForFile &uri, StringRef contents, int64_t version,
1286  std::vector<Diagnostic> &diagnostics) {
1287  impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288  uri, contents, version, impl->registry, diagnostics);
1289 }
1290 
1291 std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1292  auto it = impl->files.find(uri.file());
1293  if (it == impl->files.end())
1294  return std::nullopt;
1295 
1296  int64_t version = it->second->getVersion();
1297  impl->files.erase(it);
1298  return version;
1299 }
1300 
1302  const Position &defPos,
1303  std::vector<Location> &locations) {
1304  auto fileIt = impl->files.find(uri.file());
1305  if (fileIt != impl->files.end())
1306  fileIt->second->getLocationsOf(uri, defPos, locations);
1307 }
1308 
1310  const Position &pos,
1311  std::vector<Location> &references) {
1312  auto fileIt = impl->files.find(uri.file());
1313  if (fileIt != impl->files.end())
1314  fileIt->second->findReferencesOf(uri, pos, references);
1315 }
1316 
1317 std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1318  const Position &hoverPos) {
1319  auto fileIt = impl->files.find(uri.file());
1320  if (fileIt != impl->files.end())
1321  return fileIt->second->findHover(uri, hoverPos);
1322  return std::nullopt;
1323 }
1324 
1326  const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1327  auto fileIt = impl->files.find(uri.file());
1328  if (fileIt != impl->files.end())
1329  fileIt->second->findDocumentSymbols(symbols);
1330 }
1331 
1334  const Position &completePos) {
1335  auto fileIt = impl->files.find(uri.file());
1336  if (fileIt != impl->files.end())
1337  return fileIt->second->getCodeCompletion(uri, completePos);
1338  return CompletionList();
1339 }
1340 
1342  const CodeActionContext &context,
1343  std::vector<CodeAction> &actions) {
1344  auto fileIt = impl->files.find(uri.file());
1345  if (fileIt != impl->files.end())
1346  fileIt->second->getCodeActions(uri, pos, context, actions);
1347 }
1348 
1351  MLIRContext tempContext(impl->registry);
1352  tempContext.allowUnregisteredDialects();
1353 
1354  // Collect any errors during parsing.
1355  std::string errorMsg;
1356  ScopedDiagnosticHandler diagHandler(
1357  &tempContext,
1358  [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1359 
1360  // Handling for external resources, which we want to propagate up to the user.
1361  FallbackAsmResourceMap fallbackResourceMap;
1362 
1363  // Setup the parser config.
1364  ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1365  &fallbackResourceMap);
1366 
1367  // Try to parse the given source file.
1368  Block parsedBlock;
1369  if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1370  return llvm::make_error<lsp::LSPError>(
1371  "failed to parse bytecode source file: " + errorMsg,
1373  }
1374 
1375  // TODO: We currently expect a single top-level operation, but this could
1376  // conceptually be relaxed.
1377  if (!llvm::hasSingleElement(parsedBlock)) {
1378  return llvm::make_error<lsp::LSPError>(
1379  "expected bytecode to contain a single top-level operation",
1381  }
1382 
1383  // Print the module to a buffer.
1385  {
1386  // Extract the top-level op so that aliases get printed.
1387  // FIXME: We should be able to enable aliases without having to do this!
1388  OwningOpRef<Operation *> topOp = &parsedBlock.front();
1389  topOp->remove();
1390 
1391  AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1392  /*locationMap=*/nullptr, &fallbackResourceMap);
1393 
1394  llvm::raw_string_ostream os(result.output);
1395  topOp->print(os, state);
1396  }
1397  return std::move(result);
1398 }
1399 
1402  auto fileIt = impl->files.find(uri.file());
1403  if (fileIt == impl->files.end()) {
1404  return llvm::make_error<lsp::LSPError>(
1405  "language server does not contain an entry for this source file",
1407  }
1408  return fileIt->second->convertToBytecode();
1409 }
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
Definition: MLIRServer.cpp:91
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
Definition: MLIRServer.cpp:176
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
Definition: MLIRServer.cpp:37
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
Definition: MLIRServer.cpp:119
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
Definition: MLIRServer.cpp:111
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
Definition: MLIRServer.cpp:199
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
Definition: MLIRServer.cpp:182
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
Definition: MLIRServer.cpp:141
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
Definition: MLIRServer.cpp:31
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
Definition: MLIRServer.cpp:168
@ Error
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:24
This class represents state from a parsed MLIR textual format string.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:533
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:242
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
SuccessorRange getSuccessors()
Definition: Block.h:264
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:234
Operation & front()
Definition: Block.h:150
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:239
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class contains the configuration used for the bytecode writer.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition: AsmState.h:412
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested under, and including, the current.
Definition: Location.cpp:40
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:460
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Definition: Diagnostics.h:516
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
static void error(const char *fmt, Ts &&...vals)
Definition: Logging.h:42
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
URI in "file" scheme for a file.
Definition: Protocol.h:100
StringRef uri() const
Returns the original uri of the file.
Definition: Protocol.h:116
static llvm::Expected< URIForFile > fromFile(StringRef absoluteFilepath, StringRef scheme="file")
Try to build a URIForFile from the given absolute file path and optional scheme.
Definition: Protocol.cpp:233
StringRef scheme() const
Return the scheme of the uri.
Definition: Protocol.cpp:242
StringRef file() const
Returns the absolute path to the file.
Definition: Protocol.h:113
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
CompletionItemKind
The kind of a completion entry.
Definition: Protocol.h:758
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:2778
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:20
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Impl(DialectRegistry &registry)
DialectRegistry & registry
The registry containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
This class represents the information for an attribute alias definition within the input file.
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
This class represents the information for an operation definition within an input file.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
This class represents the information for type definition within the input file.
StringRef name
The name of the attribute alias.
std::vector< Diagnostic > diagnostics
An array of diagnostics known on the client side overlapping the range provided to the textDocument/c...
Definition: Protocol.h:1141
A code action represents a change that can be performed in code, e.g.
Definition: Protocol.h:1199
std::optional< WorkspaceEdit > edit
The workspace edit this code action performs.
Definition: Protocol.h:1221
std::optional< std::string > kind
The kind of the code action.
Definition: Protocol.h:1205
std::optional< std::vector< Diagnostic > > diagnostics
The diagnostics that this code action resolves.
Definition: Protocol.h:1211
std::string title
A short, human-readable, title for this code action.
Definition: Protocol.h:1201
std::optional< TextEdit > textEdit
An edit which is applied to a document when selecting this completion.
Definition: Protocol.h:866
std::vector< TextEdit > additionalTextEdits
An optional array of additional text edits that are applied when selecting this completion.
Definition: Protocol.h:871
Represents a collection of completion items to be presented in the editor.
Definition: Protocol.h:887
std::vector< CompletionItem > items
The completion items.
Definition: Protocol.h:893
std::string source
A human-readable string describing the source of this diagnostic, e.g.
Definition: Protocol.h:690
DiagnosticSeverity severity
The diagnostic's severity.
Definition: Protocol.h:686
Range range
The source range where the message applies.
Definition: Protocol.h:682
std::optional< std::vector< DiagnosticRelatedInformation > > relatedInformation
An array of related diagnostic information, e.g.
Definition: Protocol.h:697
std::string message
The diagnostic's message.
Definition: Protocol.h:693
std::optional< std::string > category
The diagnostic's category.
Definition: Protocol.h:703
Represents programming constructs like variables, classes, interfaces etc.
Definition: Protocol.h:596
Range range
The range enclosing this symbol not including leading/trailing whitespace but everything else like co...
Definition: Protocol.h:617
Range selectionRange
The range that should be selected and revealed when this symbol is being picked, e....
Definition: Protocol.h:621
std::vector< DocumentSymbol > children
Children of this symbol, e.g. properties of a class.
Definition: Protocol.h:624
URIForFile uri
The text document's URI.
Definition: Protocol.h:394
This class represents the result of converting between MLIR's bytecode and textual format.
Definition: Protocol.h:48
std::string output
The resultant output of the conversion.
Definition: Protocol.h:50
int line
Line position in a document (zero-based).
Definition: Protocol.h:293
int character
Character offset on a line in a document (zero-based).
Definition: Protocol.h:296
SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const
Convert this position into a source location in the main file of the given source manager.
Definition: Protocol.h:316
Position end
The range's end position.
Definition: Protocol.h:345
Position start
The range's start position.
Definition: Protocol.h:342
std::string newText
The string to be inserted.
Definition: Protocol.h:741
Range range
The range of the text document to be manipulated.
Definition: Protocol.h:737