MLIR  18.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "Protocol.h"
15 #include "mlir/IR/Operation.h"
17 #include "mlir/Parser/Parser.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Base64.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include <optional>
24 
25 using namespace mlir;
26 
27 /// Returns the range of a lexical token given a SMLoc corresponding to the
28 /// start of an token location. The range is computed heuristically, and
29 /// supports identifier-like tokens, strings, etc.
30 static SMRange convertTokenLocToRange(SMLoc loc) {
31  return lsp::convertTokenLocToRange(loc, "$-.");
32 }
33 
34 /// Returns a language server location from the given MLIR file location.
35 /// `uriScheme` is the scheme to use when building new uris.
36 static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
37  FileLineColLoc loc) {
39  lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
40  if (!sourceURI) {
41  lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
42  loc.getFilename(),
43  llvm::toString(sourceURI.takeError()));
44  return std::nullopt;
45  }
46 
47  lsp::Position position;
48  position.line = loc.getLine() - 1;
49  position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
50  return lsp::Location{*sourceURI, lsp::Range(position)};
51 }
52 
53 /// Returns a language server location from the given MLIR location, or
54 /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
55 /// when building new uris. `uri` is an optional additional filter that, when
56 /// present, is used to filter sub locations that do not share the same uri.
57 static std::optional<lsp::Location>
58 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
59  StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
60  std::optional<lsp::Location> location;
61  loc->walk([&](Location nestedLoc) {
62  FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
63  if (!fileLoc)
64  return WalkResult::advance();
65 
66  std::optional<lsp::Location> sourceLoc =
67  getLocationFromLoc(uriScheme, fileLoc);
68  if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
69  location = *sourceLoc;
70  SMLoc loc = sourceMgr.FindLocForLineAndColumn(
71  sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
72 
73  // Use range of potential identifier starting at location, else length 1
74  // range.
75  location->range.end.character += 1;
76  if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
77  auto lineCol = sourceMgr.getLineAndColumn(range->End);
78  location->range.end.character =
79  std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
80  }
81  return WalkResult::interrupt();
82  }
83  return WalkResult::advance();
84  });
85  return location;
86 }
87 
88 /// Collect all of the locations from the given MLIR location that are not
89 /// contained within the given URI.
91  std::vector<lsp::Location> &locations,
92  const lsp::URIForFile &uri) {
93  SetVector<Location> visitedLocs;
94  loc->walk([&](Location nestedLoc) {
95  FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
96  if (!fileLoc || !visitedLocs.insert(nestedLoc))
97  return WalkResult::advance();
98 
99  std::optional<lsp::Location> sourceLoc =
100  getLocationFromLoc(uri.scheme(), fileLoc);
101  if (sourceLoc && sourceLoc->uri != uri)
102  locations.push_back(*sourceLoc);
103  return WalkResult::advance();
104  });
105 }
106 
107 /// Returns true if the given range contains the given source location. Note
108 /// that this has slightly different behavior than SMRange because it is
109 /// inclusive of the end location.
110 static bool contains(SMRange range, SMLoc loc) {
111  return range.Start.getPointer() <= loc.getPointer() &&
112  loc.getPointer() <= range.End.getPointer();
113 }
114 
115 /// Returns true if the given location is contained by the definition or one of
116 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
117 /// the range within `def` that the provided `loc` overlapped with.
118 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
119  SMRange *overlappedRange = nullptr) {
120  // Check the main definition.
121  if (contains(def.loc, loc)) {
122  if (overlappedRange)
123  *overlappedRange = def.loc;
124  return true;
125  }
126 
127  // Check the uses.
128  const auto *useIt = llvm::find_if(
129  def.uses, [&](const SMRange &range) { return contains(range, loc); });
130  if (useIt != def.uses.end()) {
131  if (overlappedRange)
132  *overlappedRange = *useIt;
133  return true;
134  }
135  return false;
136 }
137 
138 /// Given a location pointing to a result, return the result number it refers
139 /// to or std::nullopt if it refers to all of the results.
140 static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
141  // Skip all of the identifier characters.
142  auto isIdentifierChar = [](char c) {
143  return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
144  c == '-';
145  };
146  const char *curPtr = loc.getPointer();
147  while (isIdentifierChar(*curPtr))
148  ++curPtr;
149 
150  // Check to see if this location indexes into the result group, via `#`. If it
151  // doesn't, we can't extract a sub result number.
152  if (*curPtr != '#')
153  return std::nullopt;
154 
155  // Compute the sub result number from the remaining portion of the string.
156  const char *numberStart = ++curPtr;
157  while (llvm::isDigit(*curPtr))
158  ++curPtr;
159  StringRef numberStr(numberStart, curPtr - numberStart);
160  unsigned resultNumber = 0;
161  return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
162  : resultNumber;
163 }
164 
165 /// Given a source location range, return the text covered by the given range.
166 /// If the range is invalid, returns std::nullopt.
167 static std::optional<StringRef> getTextFromRange(SMRange range) {
168  if (!range.isValid())
169  return std::nullopt;
170  const char *startPtr = range.Start.getPointer();
171  return StringRef(startPtr, range.End.getPointer() - startPtr);
172 }
173 
174 /// Given a block, return its position in its parent region.
175 static unsigned getBlockNumber(Block *block) {
176  return std::distance(block->getParent()->begin(), block->getIterator());
177 }
178 
179 /// Given a block and source location, print the source name of the block to the
180 /// given output stream.
181 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
182  // Try to extract a name from the source location.
183  std::optional<StringRef> text = getTextFromRange(loc);
184  if (text && text->startswith("^")) {
185  os << *text;
186  return;
187  }
188 
189  // Otherwise, we don't have a name so print the block number.
190  os << "<Block #" << getBlockNumber(block) << ">";
191 }
192 static void printDefBlockName(raw_ostream &os,
193  const AsmParserState::BlockDefinition &def) {
194  printDefBlockName(os, def.block, def.definition.loc);
195 }
196 
197 /// Convert the given MLIR diagnostic to the LSP form.
198 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
199  Diagnostic &diag,
200  const lsp::URIForFile &uri) {
201  lsp::Diagnostic lspDiag;
202  lspDiag.source = "mlir";
203 
204  // Note: Right now all of the diagnostics are treated as parser issues, but
205  // some are parser and some are verifier.
206  lspDiag.category = "Parse Error";
207 
208  // Try to grab a file location for this diagnostic.
209  // TODO: For simplicity, we just grab the first one. It may be likely that we
210  // will need a more interesting heuristic here.'
211  StringRef uriScheme = uri.scheme();
212  std::optional<lsp::Location> lspLocation =
213  getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
214  if (lspLocation)
215  lspDiag.range = lspLocation->range;
216 
217  // Convert the severity for the diagnostic.
218  switch (diag.getSeverity()) {
220  llvm_unreachable("expected notes to be handled separately");
223  break;
226  break;
229  break;
230  }
231  lspDiag.message = diag.str();
232 
233  // Attach any notes to the main diagnostic as related information.
234  std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
235  for (Diagnostic &note : diag.getNotes()) {
236  lsp::Location noteLoc;
237  if (std::optional<lsp::Location> loc =
238  getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
239  noteLoc = *loc;
240  else
241  noteLoc.uri = uri;
242  relatedDiags.emplace_back(noteLoc, note.str());
243  }
244  if (!relatedDiags.empty())
245  lspDiag.relatedInformation = std::move(relatedDiags);
246 
247  return lspDiag;
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // MLIRDocument
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 /// This class represents all of the information pertaining to a specific MLIR
256 /// document.
257 struct MLIRDocument {
258  MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
259  StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
260  MLIRDocument(const MLIRDocument &) = delete;
261  MLIRDocument &operator=(const MLIRDocument &) = delete;
262 
263  //===--------------------------------------------------------------------===//
264  // Definitions and References
265  //===--------------------------------------------------------------------===//
266 
267  void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
268  std::vector<lsp::Location> &locations);
269  void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
270  std::vector<lsp::Location> &references);
271 
272  //===--------------------------------------------------------------------===//
273  // Hover
274  //===--------------------------------------------------------------------===//
275 
276  std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
277  const lsp::Position &hoverPos);
278  std::optional<lsp::Hover>
279  buildHoverForOperation(SMRange hoverRange,
281  lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
282  unsigned resultStart,
283  unsigned resultEnd, SMLoc posLoc);
284  lsp::Hover buildHoverForBlock(SMRange hoverRange,
285  const AsmParserState::BlockDefinition &block);
286  lsp::Hover
287  buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
288  const AsmParserState::BlockDefinition &block);
289 
290  lsp::Hover buildHoverForAttributeAlias(
291  SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
292  lsp::Hover
293  buildHoverForTypeAlias(SMRange hoverRange,
295 
296  //===--------------------------------------------------------------------===//
297  // Document Symbols
298  //===--------------------------------------------------------------------===//
299 
300  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
301  void findDocumentSymbols(Operation *op,
302  std::vector<lsp::DocumentSymbol> &symbols);
303 
304  //===--------------------------------------------------------------------===//
305  // Code Completion
306  //===--------------------------------------------------------------------===//
307 
308  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
309  const lsp::Position &completePos,
310  const DialectRegistry &registry);
311 
312  //===--------------------------------------------------------------------===//
313  // Code Action
314  //===--------------------------------------------------------------------===//
315 
316  void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
317  lsp::Position &pos, StringRef severity,
318  StringRef message,
319  std::vector<lsp::TextEdit> &edits);
320 
321  //===--------------------------------------------------------------------===//
322  // Bytecode
323  //===--------------------------------------------------------------------===//
324 
326 
327  //===--------------------------------------------------------------------===//
328  // Fields
329  //===--------------------------------------------------------------------===//
330 
331  /// The high level parser state used to find definitions and references within
332  /// the source file.
333  AsmParserState asmState;
334 
335  /// The container for the IR parsed from the input file.
336  Block parsedIR;
337 
338  /// A collection of external resources, which we want to propagate up to the
339  /// user.
340  FallbackAsmResourceMap fallbackResourceMap;
341 
342  /// The source manager containing the contents of the input file.
343  llvm::SourceMgr sourceMgr;
344 };
345 } // namespace
346 
347 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
348  StringRef contents,
349  std::vector<lsp::Diagnostic> &diagnostics) {
350  ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
351  diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
352  });
353 
354  // Try to parsed the given IR string.
355  auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
356  if (!memBuffer) {
357  lsp::Logger::error("Failed to create memory buffer for file", uri.file());
358  return;
359  }
360 
361  ParserConfig config(&context, /*verifyAfterParse=*/true,
362  &fallbackResourceMap);
363  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
364  if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
365  // If parsing failed, clear out any of the current state.
366  parsedIR.clear();
367  asmState = AsmParserState();
368  fallbackResourceMap = FallbackAsmResourceMap();
369  return;
370  }
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // MLIRDocument: Definitions and References
375 //===----------------------------------------------------------------------===//
376 
377 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
378  const lsp::Position &defPos,
379  std::vector<lsp::Location> &locations) {
380  SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
381 
382  // Functor used to check if an SM definition contains the position.
383  auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
384  if (!isDefOrUse(def, posLoc))
385  return false;
386  locations.emplace_back(uri, sourceMgr, def.loc);
387  return true;
388  };
389 
390  // Check all definitions related to operations.
391  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
392  if (contains(op.loc, posLoc))
393  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
394  for (const auto &result : op.resultGroups)
395  if (containsPosition(result.definition))
396  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
397  for (const auto &symUse : op.symbolUses) {
398  if (contains(symUse, posLoc)) {
399  locations.emplace_back(uri, sourceMgr, op.loc);
400  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
401  }
402  }
403  }
404 
405  // Check all definitions related to blocks.
406  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
407  if (containsPosition(block.definition))
408  return;
409  for (const AsmParserState::SMDefinition &arg : block.arguments)
410  if (containsPosition(arg))
411  return;
412  }
413 
414  // Check all alias definitions.
416  asmState.getAttributeAliasDefs()) {
417  if (containsPosition(attr.definition))
418  return;
419  }
420  for (const AsmParserState::TypeAliasDefinition &type :
421  asmState.getTypeAliasDefs()) {
422  if (containsPosition(type.definition))
423  return;
424  }
425 }
426 
427 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
428  const lsp::Position &pos,
429  std::vector<lsp::Location> &references) {
430  // Functor used to append all of the definitions/uses of the given SM
431  // definition to the reference list.
432  auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
433  references.emplace_back(uri, sourceMgr, def.loc);
434  for (const SMRange &use : def.uses)
435  references.emplace_back(uri, sourceMgr, use);
436  };
437 
438  SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
439 
440  // Check all definitions related to operations.
441  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
442  if (contains(op.loc, posLoc)) {
443  for (const auto &result : op.resultGroups)
444  appendSMDef(result.definition);
445  for (const auto &symUse : op.symbolUses)
446  if (contains(symUse, posLoc))
447  references.emplace_back(uri, sourceMgr, symUse);
448  return;
449  }
450  for (const auto &result : op.resultGroups)
451  if (isDefOrUse(result.definition, posLoc))
452  return appendSMDef(result.definition);
453  for (const auto &symUse : op.symbolUses) {
454  if (!contains(symUse, posLoc))
455  continue;
456  for (const auto &symUse : op.symbolUses)
457  references.emplace_back(uri, sourceMgr, symUse);
458  return;
459  }
460  }
461 
462  // Check all definitions related to blocks.
463  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
464  if (isDefOrUse(block.definition, posLoc))
465  return appendSMDef(block.definition);
466 
467  for (const AsmParserState::SMDefinition &arg : block.arguments)
468  if (isDefOrUse(arg, posLoc))
469  return appendSMDef(arg);
470  }
471 
472  // Check all alias definitions.
474  asmState.getAttributeAliasDefs()) {
475  if (isDefOrUse(attr.definition, posLoc))
476  return appendSMDef(attr.definition);
477  }
478  for (const AsmParserState::TypeAliasDefinition &type :
479  asmState.getTypeAliasDefs()) {
480  if (isDefOrUse(type.definition, posLoc))
481  return appendSMDef(type.definition);
482  }
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // MLIRDocument: Hover
487 //===----------------------------------------------------------------------===//
488 
489 std::optional<lsp::Hover>
490 MLIRDocument::findHover(const lsp::URIForFile &uri,
491  const lsp::Position &hoverPos) {
492  SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
493  SMRange hoverRange;
494 
495  // Check for Hovers on operations and results.
496  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
497  // Check if the position points at this operation.
498  if (contains(op.loc, posLoc))
499  return buildHoverForOperation(op.loc, op);
500 
501  // Check if the position points at the symbol name.
502  for (auto &use : op.symbolUses)
503  if (contains(use, posLoc))
504  return buildHoverForOperation(use, op);
505 
506  // Check if the position points at a result group.
507  for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
508  const auto &result = op.resultGroups[i];
509  if (!isDefOrUse(result.definition, posLoc, &hoverRange))
510  continue;
511 
512  // Get the range of results covered by the over position.
513  unsigned resultStart = result.startIndex;
514  unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
515  : op.resultGroups[i + 1].startIndex;
516  return buildHoverForOperationResult(hoverRange, op.op, resultStart,
517  resultEnd, posLoc);
518  }
519  }
520 
521  // Check to see if the hover is over a block argument.
522  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
523  if (isDefOrUse(block.definition, posLoc, &hoverRange))
524  return buildHoverForBlock(hoverRange, block);
525 
526  for (const auto &arg : llvm::enumerate(block.arguments)) {
527  if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
528  continue;
529 
530  return buildHoverForBlockArgument(
531  hoverRange, block.block->getArgument(arg.index()), block);
532  }
533  }
534 
535  // Check to see if the hover is over an alias.
537  asmState.getAttributeAliasDefs()) {
538  if (isDefOrUse(attr.definition, posLoc, &hoverRange))
539  return buildHoverForAttributeAlias(hoverRange, attr);
540  }
541  for (const AsmParserState::TypeAliasDefinition &type :
542  asmState.getTypeAliasDefs()) {
543  if (isDefOrUse(type.definition, posLoc, &hoverRange))
544  return buildHoverForTypeAlias(hoverRange, type);
545  }
546 
547  return std::nullopt;
548 }
549 
550 std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
551  SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
552  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
553  llvm::raw_string_ostream os(hover.contents.value);
554 
555  // Add the operation name to the hover.
556  os << "\"" << op.op->getName() << "\"";
557  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
558  os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
559  os << "\n\n";
560 
561  os << "Generic Form:\n\n```mlir\n";
562 
563  op.op->print(os, OpPrintingFlags()
564  .printGenericOpForm()
565  .elideLargeElementsAttrs()
566  .skipRegions());
567  os << "\n```\n";
568 
569  return hover;
570 }
571 
572 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
573  Operation *op,
574  unsigned resultStart,
575  unsigned resultEnd,
576  SMLoc posLoc) {
577  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
578  llvm::raw_string_ostream os(hover.contents.value);
579 
580  // Add the parent operation name to the hover.
581  os << "Operation: \"" << op->getName() << "\"\n\n";
582 
583  // Check to see if the location points to a specific result within the
584  // group.
585  if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
586  if ((resultStart + *resultNumber) < resultEnd) {
587  resultStart += *resultNumber;
588  resultEnd = resultStart + 1;
589  }
590  }
591 
592  // Add the range of results and their types to the hover info.
593  if ((resultStart + 1) == resultEnd) {
594  os << "Result #" << resultStart << "\n\n"
595  << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
596  } else {
597  os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
598  << "Types: ";
599  llvm::interleaveComma(
600  op->getResults().slice(resultStart, resultEnd), os,
601  [&](Value result) { os << "`" << result.getType() << "`"; });
602  }
603 
604  return hover;
605 }
606 
608 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
609  const AsmParserState::BlockDefinition &block) {
610  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
611  llvm::raw_string_ostream os(hover.contents.value);
612 
613  // Print the given block to the hover output stream.
614  auto printBlockToHover = [&](Block *newBlock) {
615  if (const auto *def = asmState.getBlockDef(newBlock))
616  printDefBlockName(os, *def);
617  else
618  printDefBlockName(os, newBlock);
619  };
620 
621  // Display the parent operation, block number, predecessors, and successors.
622  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
623  << "Block #" << getBlockNumber(block.block) << "\n\n";
624  if (!block.block->hasNoPredecessors()) {
625  os << "Predecessors: ";
626  llvm::interleaveComma(block.block->getPredecessors(), os,
627  printBlockToHover);
628  os << "\n\n";
629  }
630  if (!block.block->hasNoSuccessors()) {
631  os << "Successors: ";
632  llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
633  os << "\n\n";
634  }
635 
636  return hover;
637 }
638 
639 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
640  SMRange hoverRange, BlockArgument arg,
641  const AsmParserState::BlockDefinition &block) {
642  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
643  llvm::raw_string_ostream os(hover.contents.value);
644 
645  // Display the parent operation, block, the argument number, and the type.
646  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
647  << "Block: ";
648  printDefBlockName(os, block);
649  os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
650  << "Type: `" << arg.getType() << "`\n\n";
651 
652  return hover;
653 }
654 
655 lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
656  SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
657  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
658  llvm::raw_string_ostream os(hover.contents.value);
659 
660  os << "Attribute Alias: \"" << attr.name << "\n\n";
661  os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
662 
663  return hover;
664 }
665 
666 lsp::Hover MLIRDocument::buildHoverForTypeAlias(
667  SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
668  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
669  llvm::raw_string_ostream os(hover.contents.value);
670 
671  os << "Type Alias: \"" << type.name << "\n\n";
672  os << "Value: ```mlir\n" << type.value << "\n```\n\n";
673 
674  return hover;
675 }
676 
677 //===----------------------------------------------------------------------===//
678 // MLIRDocument: Document Symbols
679 //===----------------------------------------------------------------------===//
680 
681 void MLIRDocument::findDocumentSymbols(
682  std::vector<lsp::DocumentSymbol> &symbols) {
683  for (Operation &op : parsedIR)
684  findDocumentSymbols(&op, symbols);
685 }
686 
687 void MLIRDocument::findDocumentSymbols(
688  Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
689  std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
690 
691  // Check for the source information of this operation.
692  if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
693  // If this operation defines a symbol, record it.
694  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
695  symbols.emplace_back(symbol.getName(),
696  isa<FunctionOpInterface>(op)
699  lsp::Range(sourceMgr, def->scopeLoc),
700  lsp::Range(sourceMgr, def->loc));
701  childSymbols = &symbols.back().children;
702 
703  } else if (op->hasTrait<OpTrait::SymbolTable>()) {
704  // Otherwise, if this is a symbol table push an anonymous document symbol.
705  symbols.emplace_back("<" + op->getName().getStringRef() + ">",
707  lsp::Range(sourceMgr, def->scopeLoc),
708  lsp::Range(sourceMgr, def->loc));
709  childSymbols = &symbols.back().children;
710  }
711  }
712 
713  // Recurse into the regions of this operation.
714  if (!op->getNumRegions())
715  return;
716  for (Region &region : op->getRegions())
717  for (Operation &childOp : region.getOps())
718  findDocumentSymbols(&childOp, *childSymbols);
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // MLIRDocument: Code Completion
723 //===----------------------------------------------------------------------===//
724 
725 namespace {
726 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
727 public:
728  LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
729  MLIRContext *ctx)
730  : AsmParserCodeCompleteContext(completeLoc),
731  completionList(completionList), ctx(ctx) {}
732 
733  /// Signal code completion for a dialect name, with an optional prefix.
734  void completeDialectName(StringRef prefix) final {
735  for (StringRef dialect : ctx->getAvailableDialects()) {
736  lsp::CompletionItem item(prefix + dialect,
738  /*sortText=*/"3");
739  item.detail = "dialect";
740  completionList.items.emplace_back(item);
741  }
742  }
744 
745  /// Signal code completion for an operation name within the given dialect.
746  void completeOperationName(StringRef dialectName) final {
747  Dialect *dialect = ctx->getOrLoadDialect(dialectName);
748  if (!dialect)
749  return;
750 
751  for (const auto &op : ctx->getRegisteredOperations()) {
752  if (&op.getDialect() != dialect)
753  continue;
754 
755  lsp::CompletionItem item(
756  op.getStringRef().drop_front(dialectName.size() + 1),
758  /*sortText=*/"1");
759  item.detail = "operation";
760  completionList.items.emplace_back(item);
761  }
762  }
763 
764  /// Append the given SSA value as a code completion result for SSA value
765  /// completions.
766  void appendSSAValueCompletion(StringRef name, std::string typeData) final {
767  // Check if we need to insert the `%` or not.
768  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
769 
771  if (stripPrefix)
772  item.insertText = name.drop_front(1).str();
773  item.detail = std::move(typeData);
774  completionList.items.emplace_back(item);
775  }
776 
777  /// Append the given block as a code completion result for block name
778  /// completions.
779  void appendBlockCompletion(StringRef name) final {
780  // Check if we need to insert the `^` or not.
781  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
782 
784  if (stripPrefix)
785  item.insertText = name.drop_front(1).str();
786  completionList.items.emplace_back(item);
787  }
788 
789  /// Signal a completion for the given expected token.
790  void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
791  for (StringRef token : tokens) {
793  /*sortText=*/"0");
794  item.detail = optional ? "optional" : "";
795  completionList.items.emplace_back(item);
796  }
797  }
798 
799  /// Signal a completion for an attribute.
800  void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
801  appendSimpleCompletions({"affine_set", "affine_map", "dense",
802  "dense_resource", "false", "loc", "sparse", "true",
803  "unit"},
805  /*sortText=*/"1");
806 
807  completeDialectName("#");
808  completeAliases(aliases, "#");
809  }
810  void completeDialectAttributeOrAlias(
811  const llvm::StringMap<Attribute> &aliases) override {
812  completeDialectName();
813  completeAliases(aliases);
814  }
815 
816  /// Signal a completion for a type.
817  void completeType(const llvm::StringMap<Type> &aliases) override {
818  // Handle the various builtin types.
819  appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
820  "bf16", "f16", "f32", "f64", "f80", "f128",
821  "index", "none"},
823  /*sortText=*/"1");
824 
825  // Handle the builtin integer types.
826  for (StringRef type : {"i", "si", "ui"}) {
828  /*sortText=*/"1");
829  item.insertText = type.str();
830  completionList.items.emplace_back(item);
831  }
832 
833  // Insert completions for dialect types and aliases.
834  completeDialectName("!");
835  completeAliases(aliases, "!");
836  }
837  void
838  completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
839  completeDialectName();
840  completeAliases(aliases);
841  }
842 
843  /// Add completion results for the given set of aliases.
844  template <typename T>
845  void completeAliases(const llvm::StringMap<T> &aliases,
846  StringRef prefix = "") {
847  for (const auto &alias : aliases) {
848  lsp::CompletionItem item(prefix + alias.getKey(),
850  /*sortText=*/"2");
851  llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
852  completionList.items.emplace_back(item);
853  }
854  }
855 
856  /// Add a set of simple completions that all have the same kind.
857  void appendSimpleCompletions(ArrayRef<StringRef> completions,
859  StringRef sortText = "") {
860  for (StringRef completion : completions)
861  completionList.items.emplace_back(completion, kind, sortText);
862  }
863 
864 private:
865  lsp::CompletionList &completionList;
866  MLIRContext *ctx;
867 };
868 } // namespace
869 
871 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
872  const lsp::Position &completePos,
873  const DialectRegistry &registry) {
874  SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
875  if (!posLoc.isValid())
876  return lsp::CompletionList();
877 
878  // To perform code completion, we run another parse of the module with the
879  // code completion context provided.
880  MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
881  tmpContext.allowUnregisteredDialects();
882  lsp::CompletionList completionList;
883  LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
884  &tmpContext);
885 
886  Block tmpIR;
887  AsmParserState tmpState;
888  (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
889  &lspCompleteContext);
890  return completionList;
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // MLIRDocument: Code Action
895 //===----------------------------------------------------------------------===//
896 
897 void MLIRDocument::getCodeActionForDiagnostic(
898  const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
899  StringRef message, std::vector<lsp::TextEdit> &edits) {
900  // Ignore diagnostics that print the current operation. These are always
901  // enabled for the language server, but not generally during normal
902  // parsing/verification.
903  if (message.startswith("see current operation: "))
904  return;
905 
906  // Get the start of the line containing the diagnostic.
907  const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
908  const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
909  if (!lineStart)
910  return;
911  StringRef line(lineStart, pos.character);
912 
913  // Add a text edit for adding an expected-* diagnostic check for this
914  // diagnostic.
915  lsp::TextEdit edit;
916  edit.range = lsp::Range(lsp::Position(pos.line, 0));
917 
918  // Use the indent of the current line for the expected-* diagnostic.
919  size_t indent = line.find_first_not_of(" ");
920  if (indent == StringRef::npos)
921  indent = line.size();
922 
923  edit.newText.append(indent, ' ');
924  llvm::raw_string_ostream(edit.newText)
925  << "// expected-" << severity << " @below {{" << message << "}}\n";
926  edits.emplace_back(std::move(edit));
927 }
928 
929 //===----------------------------------------------------------------------===//
930 // MLIRDocument: Bytecode
931 //===----------------------------------------------------------------------===//
932 
934 MLIRDocument::convertToBytecode() {
935  // TODO: We currently require a single top-level operation, but this could
936  // conceptually be relaxed.
937  if (!llvm::hasSingleElement(parsedIR)) {
938  if (parsedIR.empty()) {
939  return llvm::make_error<lsp::LSPError>(
940  "expected a single and valid top-level operation, please ensure "
941  "there are no errors",
943  }
944  return llvm::make_error<lsp::LSPError>(
945  "expected a single top-level operation", lsp::ErrorCode::RequestFailed);
946  }
947 
949  {
950  BytecodeWriterConfig writerConfig(fallbackResourceMap);
951 
952  std::string rawBytecodeBuffer;
953  llvm::raw_string_ostream os(rawBytecodeBuffer);
954  // No desired bytecode version set, so no need to check for error.
955  (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
956  result.output = llvm::encodeBase64(rawBytecodeBuffer);
957  }
958  return result;
959 }
960 
961 //===----------------------------------------------------------------------===//
962 // MLIRTextFileChunk
963 //===----------------------------------------------------------------------===//
964 
965 namespace {
966 /// This class represents a single chunk of an MLIR text file.
967 struct MLIRTextFileChunk {
968  MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
969  const lsp::URIForFile &uri, StringRef contents,
970  std::vector<lsp::Diagnostic> &diagnostics)
971  : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
972 
973  /// Adjust the line number of the given range to anchor at the beginning of
974  /// the file, instead of the beginning of this chunk.
975  void adjustLocForChunkOffset(lsp::Range &range) {
976  adjustLocForChunkOffset(range.start);
977  adjustLocForChunkOffset(range.end);
978  }
979  /// Adjust the line number of the given position to anchor at the beginning of
980  /// the file, instead of the beginning of this chunk.
981  void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
982 
983  /// The line offset of this chunk from the beginning of the file.
984  uint64_t lineOffset;
985  /// The document referred to by this chunk.
986  MLIRDocument document;
987 };
988 } // namespace
989 
990 //===----------------------------------------------------------------------===//
991 // MLIRTextFile
992 //===----------------------------------------------------------------------===//
993 
994 namespace {
995 /// This class represents a text file containing one or more MLIR documents.
996 class MLIRTextFile {
997 public:
998  MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
999  int64_t version, DialectRegistry &registry,
1000  std::vector<lsp::Diagnostic> &diagnostics);
1001 
1002  /// Return the current version of this text file.
1003  int64_t getVersion() const { return version; }
1004 
1005  //===--------------------------------------------------------------------===//
1006  // LSP Queries
1007  //===--------------------------------------------------------------------===//
1008 
1009  void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1010  std::vector<lsp::Location> &locations);
1011  void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1012  std::vector<lsp::Location> &references);
1013  std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1014  lsp::Position hoverPos);
1015  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1016  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1017  lsp::Position completePos);
1018  void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1019  const lsp::CodeActionContext &context,
1020  std::vector<lsp::CodeAction> &actions);
1022 
1023 private:
1024  /// Find the MLIR document that contains the given position, and update the
1025  /// position to be anchored at the start of the found chunk instead of the
1026  /// beginning of the file.
1027  MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1028 
1029  /// The context used to hold the state contained by the parsed document.
1030  MLIRContext context;
1031 
1032  /// The full string contents of the file.
1033  std::string contents;
1034 
1035  /// The version of this file.
1036  int64_t version;
1037 
1038  /// The number of lines in the file.
1039  int64_t totalNumLines = 0;
1040 
1041  /// The chunks of this file. The order of these chunks is the order in which
1042  /// they appear in the text file.
1043  std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1044 };
1045 } // namespace
1046 
1047 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1048  int64_t version, DialectRegistry &registry,
1049  std::vector<lsp::Diagnostic> &diagnostics)
1050  : context(registry, MLIRContext::Threading::DISABLED),
1051  contents(fileContents.str()), version(version) {
1052  context.allowUnregisteredDialects();
1053 
1054  // Split the file into separate MLIR documents.
1055  // TODO: Find a way to share the split file marker with other tools. We don't
1056  // want to use `splitAndProcessBuffer` here, but we do want to make sure this
1057  // marker doesn't go out of sync.
1058  SmallVector<StringRef, 8> subContents;
1059  StringRef(contents).split(subContents, "// -----");
1060  chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1061  context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
1062 
1063  uint64_t lineOffset = subContents.front().count('\n');
1064  for (StringRef docContents : llvm::drop_begin(subContents)) {
1065  unsigned currentNumDiags = diagnostics.size();
1066  auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1067  docContents, diagnostics);
1068  lineOffset += docContents.count('\n');
1069 
1070  // Adjust locations used in diagnostics to account for the offset from the
1071  // beginning of the file.
1072  for (lsp::Diagnostic &diag :
1073  llvm::drop_begin(diagnostics, currentNumDiags)) {
1074  chunk->adjustLocForChunkOffset(diag.range);
1075 
1076  if (!diag.relatedInformation)
1077  continue;
1078  for (auto &it : *diag.relatedInformation)
1079  if (it.location.uri == uri)
1080  chunk->adjustLocForChunkOffset(it.location.range);
1081  }
1082  chunks.emplace_back(std::move(chunk));
1083  }
1084  totalNumLines = lineOffset;
1085 }
1086 
1087 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1088  lsp::Position defPos,
1089  std::vector<lsp::Location> &locations) {
1090  MLIRTextFileChunk &chunk = getChunkFor(defPos);
1091  chunk.document.getLocationsOf(uri, defPos, locations);
1092 
1093  // Adjust any locations within this file for the offset of this chunk.
1094  if (chunk.lineOffset == 0)
1095  return;
1096  for (lsp::Location &loc : locations)
1097  if (loc.uri == uri)
1098  chunk.adjustLocForChunkOffset(loc.range);
1099 }
1100 
1101 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1102  lsp::Position pos,
1103  std::vector<lsp::Location> &references) {
1104  MLIRTextFileChunk &chunk = getChunkFor(pos);
1105  chunk.document.findReferencesOf(uri, pos, references);
1106 
1107  // Adjust any locations within this file for the offset of this chunk.
1108  if (chunk.lineOffset == 0)
1109  return;
1110  for (lsp::Location &loc : references)
1111  if (loc.uri == uri)
1112  chunk.adjustLocForChunkOffset(loc.range);
1113 }
1114 
1115 std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1116  lsp::Position hoverPos) {
1117  MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1118  std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1119 
1120  // Adjust any locations within this file for the offset of this chunk.
1121  if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1122  chunk.adjustLocForChunkOffset(*hoverInfo->range);
1123  return hoverInfo;
1124 }
1125 
1126 void MLIRTextFile::findDocumentSymbols(
1127  std::vector<lsp::DocumentSymbol> &symbols) {
1128  if (chunks.size() == 1)
1129  return chunks.front()->document.findDocumentSymbols(symbols);
1130 
1131  // If there are multiple chunks in this file, we create top-level symbols for
1132  // each chunk.
1133  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1134  MLIRTextFileChunk &chunk = *chunks[i];
1135  lsp::Position startPos(chunk.lineOffset);
1136  lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1137  : chunks[i + 1]->lineOffset);
1138  lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1139  lsp::SymbolKind::Namespace,
1140  /*range=*/lsp::Range(startPos, endPos),
1141  /*selectionRange=*/lsp::Range(startPos));
1142  chunk.document.findDocumentSymbols(symbol.children);
1143 
1144  // Fixup the locations of document symbols within this chunk.
1145  if (i != 0) {
1147  for (lsp::DocumentSymbol &childSymbol : symbol.children)
1148  symbolsToFix.push_back(&childSymbol);
1149 
1150  while (!symbolsToFix.empty()) {
1151  lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1152  chunk.adjustLocForChunkOffset(symbol->range);
1153  chunk.adjustLocForChunkOffset(symbol->selectionRange);
1154 
1155  for (lsp::DocumentSymbol &childSymbol : symbol->children)
1156  symbolsToFix.push_back(&childSymbol);
1157  }
1158  }
1159 
1160  // Push the symbol for this chunk.
1161  symbols.emplace_back(std::move(symbol));
1162  }
1163 }
1164 
1165 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1166  lsp::Position completePos) {
1167  MLIRTextFileChunk &chunk = getChunkFor(completePos);
1168  lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1169  uri, completePos, context.getDialectRegistry());
1170 
1171  // Adjust any completion locations.
1172  for (lsp::CompletionItem &item : completionList.items) {
1173  if (item.textEdit)
1174  chunk.adjustLocForChunkOffset(item.textEdit->range);
1175  for (lsp::TextEdit &edit : item.additionalTextEdits)
1176  chunk.adjustLocForChunkOffset(edit.range);
1177  }
1178  return completionList;
1179 }
1180 
1181 void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1182  const lsp::Range &pos,
1183  const lsp::CodeActionContext &context,
1184  std::vector<lsp::CodeAction> &actions) {
1185  // Create actions for any diagnostics in this file.
1186  for (auto &diag : context.diagnostics) {
1187  if (diag.source != "mlir")
1188  continue;
1189  lsp::Position diagPos = diag.range.start;
1190  MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1191 
1192  // Add a new code action that inserts a "expected" diagnostic check.
1193  lsp::CodeAction action;
1194  action.title = "Add expected-* diagnostic checks";
1195  action.kind = lsp::CodeAction::kQuickFix.str();
1196 
1197  StringRef severity;
1198  switch (diag.severity) {
1200  severity = "error";
1201  break;
1202  case lsp::DiagnosticSeverity::Warning:
1203  severity = "warning";
1204  break;
1205  default:
1206  continue;
1207  }
1208 
1209  // Get edits for the diagnostic.
1210  std::vector<lsp::TextEdit> edits;
1211  chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1212  diag.message, edits);
1213 
1214  // Walk the related diagnostics, this is how we encode notes.
1215  if (diag.relatedInformation) {
1216  for (auto &noteDiag : *diag.relatedInformation) {
1217  if (noteDiag.location.uri != uri)
1218  continue;
1219  diagPos = noteDiag.location.range.start;
1220  diagPos.line -= chunk.lineOffset;
1221  chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1222  noteDiag.message, edits);
1223  }
1224  }
1225  // Fixup the locations for any edits.
1226  for (lsp::TextEdit &edit : edits)
1227  chunk.adjustLocForChunkOffset(edit.range);
1228 
1229  action.edit.emplace();
1230  action.edit->changes[uri.uri().str()] = std::move(edits);
1231  action.diagnostics = {diag};
1232 
1233  actions.emplace_back(std::move(action));
1234  }
1235 }
1236 
1238 MLIRTextFile::convertToBytecode() {
1239  // Bail out if there is more than one chunk, bytecode wants a single module.
1240  if (chunks.size() != 1) {
1241  return llvm::make_error<lsp::LSPError>(
1242  "unexpected split file, please remove all `// -----`",
1243  lsp::ErrorCode::RequestFailed);
1244  }
1245  return chunks.front()->document.convertToBytecode();
1246 }
1247 
1248 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1249  if (chunks.size() == 1)
1250  return *chunks.front();
1251 
1252  // Search for the first chunk with a greater line offset, the previous chunk
1253  // is the one that contains `pos`.
1254  auto it = llvm::upper_bound(
1255  chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1256  return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1257  });
1258  MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1259  pos.line -= chunk.lineOffset;
1260  return chunk;
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // MLIRServer::Impl
1265 //===----------------------------------------------------------------------===//
1266 
1269 
1270  /// The registry containing dialects that can be recognized in parsed .mlir
1271  /// files.
1273 
1274  /// The files held by the server, mapped by their URI file name.
1275  llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1276 };
1277 
1278 //===----------------------------------------------------------------------===//
1279 // MLIRServer
1280 //===----------------------------------------------------------------------===//
1281 
1282 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1283  : impl(std::make_unique<Impl>(registry)) {}
1284 lsp::MLIRServer::~MLIRServer() = default;
1285 
1287  const URIForFile &uri, StringRef contents, int64_t version,
1288  std::vector<Diagnostic> &diagnostics) {
1289  impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1290  uri, contents, version, impl->registry, diagnostics);
1291 }
1292 
1293 std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1294  auto it = impl->files.find(uri.file());
1295  if (it == impl->files.end())
1296  return std::nullopt;
1297 
1298  int64_t version = it->second->getVersion();
1299  impl->files.erase(it);
1300  return version;
1301 }
1302 
1304  const Position &defPos,
1305  std::vector<Location> &locations) {
1306  auto fileIt = impl->files.find(uri.file());
1307  if (fileIt != impl->files.end())
1308  fileIt->second->getLocationsOf(uri, defPos, locations);
1309 }
1310 
1312  const Position &pos,
1313  std::vector<Location> &references) {
1314  auto fileIt = impl->files.find(uri.file());
1315  if (fileIt != impl->files.end())
1316  fileIt->second->findReferencesOf(uri, pos, references);
1317 }
1318 
1319 std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1320  const Position &hoverPos) {
1321  auto fileIt = impl->files.find(uri.file());
1322  if (fileIt != impl->files.end())
1323  return fileIt->second->findHover(uri, hoverPos);
1324  return std::nullopt;
1325 }
1326 
1328  const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1329  auto fileIt = impl->files.find(uri.file());
1330  if (fileIt != impl->files.end())
1331  fileIt->second->findDocumentSymbols(symbols);
1332 }
1333 
1336  const Position &completePos) {
1337  auto fileIt = impl->files.find(uri.file());
1338  if (fileIt != impl->files.end())
1339  return fileIt->second->getCodeCompletion(uri, completePos);
1340  return CompletionList();
1341 }
1342 
1344  const CodeActionContext &context,
1345  std::vector<CodeAction> &actions) {
1346  auto fileIt = impl->files.find(uri.file());
1347  if (fileIt != impl->files.end())
1348  fileIt->second->getCodeActions(uri, pos, context, actions);
1349 }
1350 
1353  MLIRContext tempContext(impl->registry);
1354  tempContext.allowUnregisteredDialects();
1355 
1356  // Collect any errors during parsing.
1357  std::string errorMsg;
1358  ScopedDiagnosticHandler diagHandler(
1359  &tempContext,
1360  [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1361 
1362  // Handling for external resources, which we want to propagate up to the user.
1363  FallbackAsmResourceMap fallbackResourceMap;
1364 
1365  // Setup the parser config.
1366  ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1367  &fallbackResourceMap);
1368 
1369  // Try to parse the given source file.
1370  Block parsedBlock;
1371  if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1372  return llvm::make_error<lsp::LSPError>(
1373  "failed to parse bytecode source file: " + errorMsg,
1375  }
1376 
1377  // TODO: We currently expect a single top-level operation, but this could
1378  // conceptually be relaxed.
1379  if (!llvm::hasSingleElement(parsedBlock)) {
1380  return llvm::make_error<lsp::LSPError>(
1381  "expected bytecode to contain a single top-level operation",
1383  }
1384 
1385  // Print the module to a buffer.
1387  {
1388  // Extract the top-level op so that aliases get printed.
1389  // FIXME: We should be able to enable aliases without having to do this!
1390  OwningOpRef<Operation *> topOp = &parsedBlock.front();
1391  topOp->remove();
1392 
1393  AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1394  /*locationMap=*/nullptr, &fallbackResourceMap);
1395 
1396  llvm::raw_string_ostream os(result.output);
1397  topOp->print(os, state);
1398  }
1399  return std::move(result);
1400 }
1401 
1404  auto fileIt = impl->files.find(uri.file());
1405  if (fileIt == impl->files.end()) {
1406  return llvm::make_error<lsp::LSPError>(
1407  "language server does not contain an entry for this source file",
1409  }
1410  return fileIt->second->convertToBytecode();
1411 }
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
Definition: MLIRServer.cpp:90
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
Definition: MLIRServer.cpp:175
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
Definition: MLIRServer.cpp:36
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
Definition: MLIRServer.cpp:118
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
Definition: MLIRServer.cpp:110
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
Definition: MLIRServer.cpp:198
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
Definition: MLIRServer.cpp:181
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
Definition: MLIRServer.cpp:140
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
Definition: MLIRServer.cpp:30
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
Definition: MLIRServer.cpp:167
@ Error
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:24
This class represents state from a parsed MLIR textual format string.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:533
This class represents an argument of a Block.
Definition: Value.h:315
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:238
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
SuccessorRange getSuccessors()
Definition: Block.h:260
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:230
Operation & front()
Definition: Block.h:146
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:235
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class contains the configuration used for the bytecode writer.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition: AsmState.h:412
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested under, and including, the current.
Definition: Location.cpp:40
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:28
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:460
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Definition: Diagnostics.h:516
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
static void error(const char *fmt, Ts &&...vals)
Definition: Logging.h:42
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
URI in "file" scheme for a file.
Definition: Protocol.h:100
StringRef uri() const
Returns the original uri of the file.
Definition: Protocol.h:116
static llvm::Expected< URIForFile > fromFile(StringRef absoluteFilepath, StringRef scheme="file")
Try to build a URIForFile from the given absolute file path and optional scheme.
Definition: Protocol.cpp:233
StringRef scheme() const
Return the scheme of the uri.
Definition: Protocol.cpp:242
StringRef file() const
Returns the absolute path to the file.
Definition: Protocol.h:113
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
CompletionItemKind
The kind of a completion entry.
Definition: Protocol.h:758
Include the generated interface declarations.
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:2776
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:20
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Impl(DialectRegistry &registry)
DialectRegistry & registry
The registry containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
This class represents the information for an attribute alias definition within the input file.
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
This class represents the information for an operation definition within an input file.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
This class represents the information for type definition within the input file.
StringRef name
The name of the attribute alias.
std::vector< Diagnostic > diagnostics
An array of diagnostics known on the client side overlapping the range provided to the textDocument/c...
Definition: Protocol.h:1141
A code action represents a change that can be performed in code, e.g.
Definition: Protocol.h:1199
std::optional< WorkspaceEdit > edit
The workspace edit this code action performs.
Definition: Protocol.h:1221
std::optional< std::string > kind
The kind of the code action.
Definition: Protocol.h:1205
std::optional< std::vector< Diagnostic > > diagnostics
The diagnostics that this code action resolves.
Definition: Protocol.h:1211
std::string title
A short, human-readable, title for this code action.
Definition: Protocol.h:1201
std::optional< TextEdit > textEdit
An edit which is applied to a document when selecting this completion.
Definition: Protocol.h:866
std::vector< TextEdit > additionalTextEdits
An optional array of additional text edits that are applied when selecting this completion.
Definition: Protocol.h:871
Represents a collection of completion items to be presented in the editor.
Definition: Protocol.h:887
std::vector< CompletionItem > items
The completion items.
Definition: Protocol.h:893
std::string source
A human-readable string describing the source of this diagnostic, e.g.
Definition: Protocol.h:690
DiagnosticSeverity severity
The diagnostic's severity.
Definition: Protocol.h:686
Range range
The source range where the message applies.
Definition: Protocol.h:682
std::optional< std::vector< DiagnosticRelatedInformation > > relatedInformation
An array of related diagnostic information, e.g.
Definition: Protocol.h:697
std::string message
The diagnostic's message.
Definition: Protocol.h:693
std::optional< std::string > category
The diagnostic's category.
Definition: Protocol.h:703
Represents programming constructs like variables, classes, interfaces etc.
Definition: Protocol.h:596
Range range
The range enclosing this symbol not including leading/trailing whitespace but everything else like co...
Definition: Protocol.h:617
Range selectionRange
The range that should be selected and revealed when this symbol is being picked, e....
Definition: Protocol.h:621
std::vector< DocumentSymbol > children
Children of this symbol, e.g. properties of a class.
Definition: Protocol.h:624
URIForFile uri
The text document's URI.
Definition: Protocol.h:394
This class represents the result of converting between MLIR's bytecode and textual format.
Definition: Protocol.h:48
std::string output
The resultant output of the conversion.
Definition: Protocol.h:50
int line
Line position in a document (zero-based).
Definition: Protocol.h:293
int character
Character offset on a line in a document (zero-based).
Definition: Protocol.h:296
SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const
Convert this position into a source location in the main file of the given source manager.
Definition: Protocol.h:316
Position end
The range's end position.
Definition: Protocol.h:345
Position start
The range's start position.
Definition: Protocol.h:342
std::string newText
The string to be inserted.
Definition: Protocol.h:741
Range range
The range of the text document to be manipulated.
Definition: Protocol.h:737