MLIR  22.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "Protocol.h"
15 #include "mlir/IR/Operation.h"
17 #include "mlir/Parser/Parser.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Base64.h"
22 #include "llvm/Support/LSP/Logging.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include <optional>
25 
26 using namespace mlir;
27 
28 /// Returns the range of a lexical token given a SMLoc corresponding to the
29 /// start of an token location. The range is computed heuristically, and
30 /// supports identifier-like tokens, strings, etc.
31 static SMRange convertTokenLocToRange(SMLoc loc) {
32  return lsp::convertTokenLocToRange(loc, "$-.");
33 }
34 
35 /// Returns a language server location from the given MLIR file location.
36 /// `uriScheme` is the scheme to use when building new uris.
37 static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38  FileLineColLoc loc) {
40  lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
41  if (!sourceURI) {
42  llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43  loc.getFilename(),
44  llvm::toString(sourceURI.takeError()));
45  return std::nullopt;
46  }
47 
48  lsp::Position position;
49  position.line = loc.getLine() - 1;
50  position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51  return lsp::Location{*sourceURI, lsp::Range(position)};
52 }
53 
54 /// Returns a language server location from the given MLIR location, or
55 /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56 /// when building new uris. `uri` is an optional additional filter that, when
57 /// present, is used to filter sub locations that do not share the same uri.
58 static std::optional<lsp::Location>
59 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60  StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
61  std::optional<lsp::Location> location;
62  loc->walk([&](Location nestedLoc) {
63  FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64  if (!fileLoc)
65  return WalkResult::advance();
66 
67  std::optional<lsp::Location> sourceLoc =
68  getLocationFromLoc(uriScheme, fileLoc);
69  if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70  location = *sourceLoc;
71  SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72  sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
73 
74  // Use range of potential identifier starting at location, else length 1
75  // range.
76  location->range.end.character += 1;
77  if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78  auto lineCol = sourceMgr.getLineAndColumn(range->End);
79  location->range.end.character =
80  std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
81  }
82  return WalkResult::interrupt();
83  }
84  return WalkResult::advance();
85  });
86  return location;
87 }
88 
89 /// Collect all of the locations from the given MLIR location that are not
90 /// contained within the given URI.
92  std::vector<lsp::Location> &locations,
93  const lsp::URIForFile &uri) {
94  SetVector<Location> visitedLocs;
95  loc->walk([&](Location nestedLoc) {
96  FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97  if (!fileLoc || !visitedLocs.insert(nestedLoc))
98  return WalkResult::advance();
99 
100  std::optional<lsp::Location> sourceLoc =
101  getLocationFromLoc(uri.scheme(), fileLoc);
102  if (sourceLoc && sourceLoc->uri != uri)
103  locations.push_back(*sourceLoc);
104  return WalkResult::advance();
105  });
106 }
107 
108 /// Returns true if the given range contains the given source location. Note
109 /// that this has slightly different behavior than SMRange because it is
110 /// inclusive of the end location.
111 static bool contains(SMRange range, SMLoc loc) {
112  return range.Start.getPointer() <= loc.getPointer() &&
113  loc.getPointer() <= range.End.getPointer();
114 }
115 
116 /// Returns true if the given location is contained by the definition or one of
117 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118 /// the range within `def` that the provided `loc` overlapped with.
119 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
120  SMRange *overlappedRange = nullptr) {
121  // Check the main definition.
122  if (contains(def.loc, loc)) {
123  if (overlappedRange)
124  *overlappedRange = def.loc;
125  return true;
126  }
127 
128  // Check the uses.
129  const auto *useIt = llvm::find_if(
130  def.uses, [&](const SMRange &range) { return contains(range, loc); });
131  if (useIt != def.uses.end()) {
132  if (overlappedRange)
133  *overlappedRange = *useIt;
134  return true;
135  }
136  return false;
137 }
138 
139 /// Given a location pointing to a result, return the result number it refers
140 /// to or std::nullopt if it refers to all of the results.
141 static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
142  // Skip all of the identifier characters.
143  auto isIdentifierChar = [](char c) {
144  return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
145  c == '-';
146  };
147  const char *curPtr = loc.getPointer();
148  while (isIdentifierChar(*curPtr))
149  ++curPtr;
150 
151  // Check to see if this location indexes into the result group, via `#`. If it
152  // doesn't, we can't extract a sub result number.
153  if (*curPtr != '#')
154  return std::nullopt;
155 
156  // Compute the sub result number from the remaining portion of the string.
157  const char *numberStart = ++curPtr;
158  while (llvm::isDigit(*curPtr))
159  ++curPtr;
160  StringRef numberStr(numberStart, curPtr - numberStart);
161  unsigned resultNumber = 0;
162  return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
163  : resultNumber;
164 }
165 
166 /// Given a source location range, return the text covered by the given range.
167 /// If the range is invalid, returns std::nullopt.
168 static std::optional<StringRef> getTextFromRange(SMRange range) {
169  if (!range.isValid())
170  return std::nullopt;
171  const char *startPtr = range.Start.getPointer();
172  return StringRef(startPtr, range.End.getPointer() - startPtr);
173 }
174 
175 /// Given a block, return its position in its parent region.
176 static unsigned getBlockNumber(Block *block) {
177  return std::distance(block->getParent()->begin(), block->getIterator());
178 }
179 
180 /// Given a block and source location, print the source name of the block to the
181 /// given output stream.
182 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
183  // Try to extract a name from the source location.
184  std::optional<StringRef> text = getTextFromRange(loc);
185  if (text && text->starts_with("^")) {
186  os << *text;
187  return;
188  }
189 
190  // Otherwise, we don't have a name so print the block number.
191  os << "<Block #" << getBlockNumber(block) << ">";
192 }
193 static void printDefBlockName(raw_ostream &os,
194  const AsmParserState::BlockDefinition &def) {
195  printDefBlockName(os, def.block, def.definition.loc);
196 }
197 
198 /// Convert the given MLIR diagnostic to the LSP form.
199 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
200  Diagnostic &diag,
201  const lsp::URIForFile &uri) {
202  lsp::Diagnostic lspDiag;
203  lspDiag.source = "mlir";
204 
205  // Note: Right now all of the diagnostics are treated as parser issues, but
206  // some are parser and some are verifier.
207  lspDiag.category = "Parse Error";
208 
209  // Try to grab a file location for this diagnostic.
210  // TODO: For simplicity, we just grab the first one. It may be likely that we
211  // will need a more interesting heuristic here.'
212  StringRef uriScheme = uri.scheme();
213  std::optional<lsp::Location> lspLocation =
214  getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
215  if (lspLocation)
216  lspDiag.range = lspLocation->range;
217 
218  // Convert the severity for the diagnostic.
219  switch (diag.getSeverity()) {
221  llvm_unreachable("expected notes to be handled separately");
223  lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
224  break;
226  lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
227  break;
229  lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
230  break;
231  }
232  lspDiag.message = diag.str();
233 
234  // Attach any notes to the main diagnostic as related information.
235  std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
236  for (Diagnostic &note : diag.getNotes()) {
237  lsp::Location noteLoc;
238  if (std::optional<lsp::Location> loc =
239  getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
240  noteLoc = *loc;
241  else
242  noteLoc.uri = uri;
243  relatedDiags.emplace_back(noteLoc, note.str());
244  }
245  if (!relatedDiags.empty())
246  lspDiag.relatedInformation = std::move(relatedDiags);
247 
248  return lspDiag;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // MLIRDocument
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 /// This class represents all of the information pertaining to a specific MLIR
257 /// document.
258 struct MLIRDocument {
259  MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
260  StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261  MLIRDocument(const MLIRDocument &) = delete;
262  MLIRDocument &operator=(const MLIRDocument &) = delete;
263 
264  //===--------------------------------------------------------------------===//
265  // Definitions and References
266  //===--------------------------------------------------------------------===//
267 
268  void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
269  std::vector<lsp::Location> &locations);
270  void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
271  std::vector<lsp::Location> &references);
272 
273  //===--------------------------------------------------------------------===//
274  // Hover
275  //===--------------------------------------------------------------------===//
276 
277  std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
278  const lsp::Position &hoverPos);
279  std::optional<lsp::Hover>
280  buildHoverForOperation(SMRange hoverRange,
282  lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
283  unsigned resultStart,
284  unsigned resultEnd, SMLoc posLoc);
285  lsp::Hover buildHoverForBlock(SMRange hoverRange,
286  const AsmParserState::BlockDefinition &block);
287  lsp::Hover
288  buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
289  const AsmParserState::BlockDefinition &block);
290 
291  lsp::Hover buildHoverForAttributeAlias(
292  SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
293  lsp::Hover
294  buildHoverForTypeAlias(SMRange hoverRange,
296 
297  //===--------------------------------------------------------------------===//
298  // Document Symbols
299  //===--------------------------------------------------------------------===//
300 
301  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302  void findDocumentSymbols(Operation *op,
303  std::vector<lsp::DocumentSymbol> &symbols);
304 
305  //===--------------------------------------------------------------------===//
306  // Code Completion
307  //===--------------------------------------------------------------------===//
308 
309  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
310  const lsp::Position &completePos,
311  const DialectRegistry &registry);
312 
313  //===--------------------------------------------------------------------===//
314  // Code Action
315  //===--------------------------------------------------------------------===//
316 
317  void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
318  lsp::Position &pos, StringRef severity,
319  StringRef message,
320  std::vector<llvm::lsp::TextEdit> &edits);
321 
322  //===--------------------------------------------------------------------===//
323  // Bytecode
324  //===--------------------------------------------------------------------===//
325 
327 
328  //===--------------------------------------------------------------------===//
329  // Fields
330  //===--------------------------------------------------------------------===//
331 
332  /// The high level parser state used to find definitions and references within
333  /// the source file.
334  AsmParserState asmState;
335 
336  /// The container for the IR parsed from the input file.
337  Block parsedIR;
338 
339  /// A collection of external resources, which we want to propagate up to the
340  /// user.
341  FallbackAsmResourceMap fallbackResourceMap;
342 
343  /// The source manager containing the contents of the input file.
344  llvm::SourceMgr sourceMgr;
345 };
346 } // namespace
347 
348 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
349  StringRef contents,
350  std::vector<lsp::Diagnostic> &diagnostics) {
351  ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
352  diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
353  });
354 
355  // Try to parsed the given IR string.
356  auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
357  if (!memBuffer) {
358  llvm::lsp::Logger::error("Failed to create memory buffer for file",
359  uri.file());
360  return;
361  }
362 
363  ParserConfig config(&context, /*verifyAfterParse=*/true,
364  &fallbackResourceMap);
365  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
366  if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
367  // If parsing failed, clear out any of the current state.
368  parsedIR.clear();
369  asmState = AsmParserState();
370  fallbackResourceMap = FallbackAsmResourceMap();
371  return;
372  }
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // MLIRDocument: Definitions and References
377 //===----------------------------------------------------------------------===//
378 
379 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
380  const lsp::Position &defPos,
381  std::vector<lsp::Location> &locations) {
382  SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
383 
384  // Functor used to check if an SM definition contains the position.
385  auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
386  if (!isDefOrUse(def, posLoc))
387  return false;
388  locations.emplace_back(uri, sourceMgr, def.loc);
389  return true;
390  };
391 
392  // Check all definitions related to operations.
393  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
394  if (contains(op.loc, posLoc))
395  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
396  for (const auto &result : op.resultGroups)
397  if (containsPosition(result.definition))
398  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
399  for (const auto &symUse : op.symbolUses) {
400  if (contains(symUse, posLoc)) {
401  locations.emplace_back(uri, sourceMgr, op.loc);
402  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
403  }
404  }
405  }
406 
407  // Check all definitions related to blocks.
408  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
409  if (containsPosition(block.definition))
410  return;
411  for (const AsmParserState::SMDefinition &arg : block.arguments)
412  if (containsPosition(arg))
413  return;
414  }
415 
416  // Check all alias definitions.
418  asmState.getAttributeAliasDefs()) {
419  if (containsPosition(attr.definition))
420  return;
421  }
422  for (const AsmParserState::TypeAliasDefinition &type :
423  asmState.getTypeAliasDefs()) {
424  if (containsPosition(type.definition))
425  return;
426  }
427 }
428 
429 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
430  const lsp::Position &pos,
431  std::vector<lsp::Location> &references) {
432  // Functor used to append all of the definitions/uses of the given SM
433  // definition to the reference list.
434  auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
435  references.emplace_back(uri, sourceMgr, def.loc);
436  for (const SMRange &use : def.uses)
437  references.emplace_back(uri, sourceMgr, use);
438  };
439 
440  SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
441 
442  // Check all definitions related to operations.
443  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
444  if (contains(op.loc, posLoc)) {
445  for (const auto &result : op.resultGroups)
446  appendSMDef(result.definition);
447  for (const auto &symUse : op.symbolUses)
448  if (contains(symUse, posLoc))
449  references.emplace_back(uri, sourceMgr, symUse);
450  return;
451  }
452  for (const auto &result : op.resultGroups)
453  if (isDefOrUse(result.definition, posLoc))
454  return appendSMDef(result.definition);
455  for (const auto &symUse : op.symbolUses) {
456  if (!contains(symUse, posLoc))
457  continue;
458  for (const auto &symUse : op.symbolUses)
459  references.emplace_back(uri, sourceMgr, symUse);
460  return;
461  }
462  }
463 
464  // Check all definitions related to blocks.
465  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
466  if (isDefOrUse(block.definition, posLoc))
467  return appendSMDef(block.definition);
468 
469  for (const AsmParserState::SMDefinition &arg : block.arguments)
470  if (isDefOrUse(arg, posLoc))
471  return appendSMDef(arg);
472  }
473 
474  // Check all alias definitions.
476  asmState.getAttributeAliasDefs()) {
477  if (isDefOrUse(attr.definition, posLoc))
478  return appendSMDef(attr.definition);
479  }
480  for (const AsmParserState::TypeAliasDefinition &type :
481  asmState.getTypeAliasDefs()) {
482  if (isDefOrUse(type.definition, posLoc))
483  return appendSMDef(type.definition);
484  }
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // MLIRDocument: Hover
489 //===----------------------------------------------------------------------===//
490 
491 std::optional<lsp::Hover>
492 MLIRDocument::findHover(const lsp::URIForFile &uri,
493  const lsp::Position &hoverPos) {
494  SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
495  SMRange hoverRange;
496 
497  // Check for Hovers on operations and results.
498  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
499  // Check if the position points at this operation.
500  if (contains(op.loc, posLoc))
501  return buildHoverForOperation(op.loc, op);
502 
503  // Check if the position points at the symbol name.
504  for (auto &use : op.symbolUses)
505  if (contains(use, posLoc))
506  return buildHoverForOperation(use, op);
507 
508  // Check if the position points at a result group.
509  for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
510  const auto &result = op.resultGroups[i];
511  if (!isDefOrUse(result.definition, posLoc, &hoverRange))
512  continue;
513 
514  // Get the range of results covered by the over position.
515  unsigned resultStart = result.startIndex;
516  unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
517  : op.resultGroups[i + 1].startIndex;
518  return buildHoverForOperationResult(hoverRange, op.op, resultStart,
519  resultEnd, posLoc);
520  }
521  }
522 
523  // Check to see if the hover is over a block argument.
524  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
525  if (isDefOrUse(block.definition, posLoc, &hoverRange))
526  return buildHoverForBlock(hoverRange, block);
527 
528  for (const auto &arg : llvm::enumerate(block.arguments)) {
529  if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
530  continue;
531 
532  return buildHoverForBlockArgument(
533  hoverRange, block.block->getArgument(arg.index()), block);
534  }
535  }
536 
537  // Check to see if the hover is over an alias.
539  asmState.getAttributeAliasDefs()) {
540  if (isDefOrUse(attr.definition, posLoc, &hoverRange))
541  return buildHoverForAttributeAlias(hoverRange, attr);
542  }
543  for (const AsmParserState::TypeAliasDefinition &type :
544  asmState.getTypeAliasDefs()) {
545  if (isDefOrUse(type.definition, posLoc, &hoverRange))
546  return buildHoverForTypeAlias(hoverRange, type);
547  }
548 
549  return std::nullopt;
550 }
551 
552 std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
553  SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
554  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
555  llvm::raw_string_ostream os(hover.contents.value);
556 
557  // Add the operation name to the hover.
558  os << "\"" << op.op->getName() << "\"";
559  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
560  os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
561  os << "\n\n";
562 
563  os << "Generic Form:\n\n```mlir\n";
564 
565  op.op->print(os, OpPrintingFlags()
566  .printGenericOpForm()
567  .elideLargeElementsAttrs()
568  .skipRegions());
569  os << "\n```\n";
570 
571  return hover;
572 }
573 
574 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
575  Operation *op,
576  unsigned resultStart,
577  unsigned resultEnd,
578  SMLoc posLoc) {
579  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
580  llvm::raw_string_ostream os(hover.contents.value);
581 
582  // Add the parent operation name to the hover.
583  os << "Operation: \"" << op->getName() << "\"\n\n";
584 
585  // Check to see if the location points to a specific result within the
586  // group.
587  if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
588  if ((resultStart + *resultNumber) < resultEnd) {
589  resultStart += *resultNumber;
590  resultEnd = resultStart + 1;
591  }
592  }
593 
594  // Add the range of results and their types to the hover info.
595  if ((resultStart + 1) == resultEnd) {
596  os << "Result #" << resultStart << "\n\n"
597  << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
598  } else {
599  os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
600  << "Types: ";
601  llvm::interleaveComma(
602  op->getResults().slice(resultStart, resultEnd), os,
603  [&](Value result) { os << "`" << result.getType() << "`"; });
604  }
605 
606  return hover;
607 }
608 
609 lsp::Hover
610 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
611  const AsmParserState::BlockDefinition &block) {
612  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
613  llvm::raw_string_ostream os(hover.contents.value);
614 
615  // Print the given block to the hover output stream.
616  auto printBlockToHover = [&](Block *newBlock) {
617  if (const auto *def = asmState.getBlockDef(newBlock))
618  printDefBlockName(os, *def);
619  else
620  printDefBlockName(os, newBlock);
621  };
622 
623  // Display the parent operation, block number, predecessors, and successors.
624  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
625  << "Block #" << getBlockNumber(block.block) << "\n\n";
626  if (!block.block->hasNoPredecessors()) {
627  os << "Predecessors: ";
628  llvm::interleaveComma(block.block->getPredecessors(), os,
629  printBlockToHover);
630  os << "\n\n";
631  }
632  if (!block.block->hasNoSuccessors()) {
633  os << "Successors: ";
634  llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
635  os << "\n\n";
636  }
637 
638  return hover;
639 }
640 
641 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
642  SMRange hoverRange, BlockArgument arg,
643  const AsmParserState::BlockDefinition &block) {
644  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
645  llvm::raw_string_ostream os(hover.contents.value);
646 
647  // Display the parent operation, block, the argument number, and the type.
648  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
649  << "Block: ";
650  printDefBlockName(os, block);
651  os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
652  << "Type: `" << arg.getType() << "`\n\n";
653 
654  return hover;
655 }
656 
657 lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
658  SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
659  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
660  llvm::raw_string_ostream os(hover.contents.value);
661 
662  os << "Attribute Alias: \"" << attr.name << "\n\n";
663  os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
664 
665  return hover;
666 }
667 
668 lsp::Hover MLIRDocument::buildHoverForTypeAlias(
669  SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
670  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
671  llvm::raw_string_ostream os(hover.contents.value);
672 
673  os << "Type Alias: \"" << type.name << "\n\n";
674  os << "Value: ```mlir\n" << type.value << "\n```\n\n";
675 
676  return hover;
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // MLIRDocument: Document Symbols
681 //===----------------------------------------------------------------------===//
682 
683 void MLIRDocument::findDocumentSymbols(
684  std::vector<lsp::DocumentSymbol> &symbols) {
685  for (Operation &op : parsedIR)
686  findDocumentSymbols(&op, symbols);
687 }
688 
689 void MLIRDocument::findDocumentSymbols(
690  Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
691  std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
692 
693  // Check for the source information of this operation.
694  if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
695  // If this operation defines a symbol, record it.
696  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
697  symbols.emplace_back(symbol.getName(),
698  isa<FunctionOpInterface>(op)
699  ? llvm::lsp::SymbolKind::Function
700  : llvm::lsp::SymbolKind::Class,
701  lsp::Range(sourceMgr, def->scopeLoc),
702  lsp::Range(sourceMgr, def->loc));
703  childSymbols = &symbols.back().children;
704 
705  } else if (op->hasTrait<OpTrait::SymbolTable>()) {
706  // Otherwise, if this is a symbol table push an anonymous document symbol.
707  symbols.emplace_back("<" + op->getName().getStringRef() + ">",
708  llvm::lsp::SymbolKind::Namespace,
709  llvm::lsp::Range(sourceMgr, def->scopeLoc),
710  llvm::lsp::Range(sourceMgr, def->loc));
711  childSymbols = &symbols.back().children;
712  }
713  }
714 
715  // Recurse into the regions of this operation.
716  if (!op->getNumRegions())
717  return;
718  for (Region &region : op->getRegions())
719  for (Operation &childOp : region.getOps())
720  findDocumentSymbols(&childOp, *childSymbols);
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // MLIRDocument: Code Completion
725 //===----------------------------------------------------------------------===//
726 
727 namespace {
728 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
729 public:
730  LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
731  MLIRContext *ctx)
732  : AsmParserCodeCompleteContext(completeLoc),
733  completionList(completionList), ctx(ctx) {}
734 
735  /// Signal code completion for a dialect name, with an optional prefix.
736  void completeDialectName(StringRef prefix) final {
737  for (StringRef dialect : ctx->getAvailableDialects()) {
738  llvm::lsp::CompletionItem item(prefix + dialect,
739  llvm::lsp::CompletionItemKind::Module,
740  /*sortText=*/"3");
741  item.detail = "dialect";
742  completionList.items.emplace_back(item);
743  }
744  }
746 
747  /// Signal code completion for an operation name within the given dialect.
748  void completeOperationName(StringRef dialectName) final {
749  Dialect *dialect = ctx->getOrLoadDialect(dialectName);
750  if (!dialect)
751  return;
752 
753  for (const auto &op : ctx->getRegisteredOperations()) {
754  if (&op.getDialect() != dialect)
755  continue;
756 
757  llvm::lsp::CompletionItem item(
758  op.getStringRef().drop_front(dialectName.size() + 1),
759  llvm::lsp::CompletionItemKind::Field,
760  /*sortText=*/"1");
761  item.detail = "operation";
762  completionList.items.emplace_back(item);
763  }
764  }
765 
766  /// Append the given SSA value as a code completion result for SSA value
767  /// completions.
768  void appendSSAValueCompletion(StringRef name, std::string typeData) final {
769  // Check if we need to insert the `%` or not.
770  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
771 
772  llvm::lsp::CompletionItem item(name,
773  llvm::lsp::CompletionItemKind::Variable);
774  if (stripPrefix)
775  item.insertText = name.drop_front(1).str();
776  item.detail = std::move(typeData);
777  completionList.items.emplace_back(item);
778  }
779 
780  /// Append the given block as a code completion result for block name
781  /// completions.
782  void appendBlockCompletion(StringRef name) final {
783  // Check if we need to insert the `^` or not.
784  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
785 
786  llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
787  if (stripPrefix)
788  item.insertText = name.drop_front(1).str();
789  completionList.items.emplace_back(item);
790  }
791 
792  /// Signal a completion for the given expected token.
793  void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
794  for (StringRef token : tokens) {
795  llvm::lsp::CompletionItem item(token,
796  llvm::lsp::CompletionItemKind::Keyword,
797  /*sortText=*/"0");
798  item.detail = optional ? "optional" : "";
799  completionList.items.emplace_back(item);
800  }
801  }
802 
803  /// Signal a completion for an attribute.
804  void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
805  appendSimpleCompletions({"affine_set", "affine_map", "dense",
806  "dense_resource", "false", "loc", "sparse", "true",
807  "unit"},
808  llvm::lsp::CompletionItemKind::Field,
809  /*sortText=*/"1");
810 
811  completeDialectName("#");
812  completeAliases(aliases, "#");
813  }
814  void completeDialectAttributeOrAlias(
815  const llvm::StringMap<Attribute> &aliases) override {
816  completeDialectName();
817  completeAliases(aliases);
818  }
819 
820  /// Signal a completion for a type.
821  void completeType(const llvm::StringMap<Type> &aliases) override {
822  // Handle the various builtin types.
823  appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
824  "bf16", "f16", "f32", "f64", "f80", "f128",
825  "index", "none"},
826  llvm::lsp::CompletionItemKind::Field,
827  /*sortText=*/"1");
828 
829  // Handle the builtin integer types.
830  for (StringRef type : {"i", "si", "ui"}) {
831  llvm::lsp::CompletionItem item(type + "<N>",
832  llvm::lsp::CompletionItemKind::Field,
833  /*sortText=*/"1");
834  item.insertText = type.str();
835  completionList.items.emplace_back(item);
836  }
837 
838  // Insert completions for dialect types and aliases.
839  completeDialectName("!");
840  completeAliases(aliases, "!");
841  }
842  void
843  completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
844  completeDialectName();
845  completeAliases(aliases);
846  }
847 
848  /// Add completion results for the given set of aliases.
849  template <typename T>
850  void completeAliases(const llvm::StringMap<T> &aliases,
851  StringRef prefix = "") {
852  for (const auto &alias : aliases) {
853  llvm::lsp::CompletionItem item(prefix + alias.getKey(),
854  llvm::lsp::CompletionItemKind::Field,
855  /*sortText=*/"2");
856  llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
857  completionList.items.emplace_back(item);
858  }
859  }
860 
861  /// Add a set of simple completions that all have the same kind.
862  void appendSimpleCompletions(ArrayRef<StringRef> completions,
863  llvm::lsp::CompletionItemKind kind,
864  StringRef sortText = "") {
865  for (StringRef completion : completions)
866  completionList.items.emplace_back(completion, kind, sortText);
867  }
868 
869 private:
870  lsp::CompletionList &completionList;
871  MLIRContext *ctx;
872 };
873 } // namespace
874 
875 lsp::CompletionList
876 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
877  const lsp::Position &completePos,
878  const DialectRegistry &registry) {
879  SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
880  if (!posLoc.isValid())
881  return lsp::CompletionList();
882 
883  // To perform code completion, we run another parse of the module with the
884  // code completion context provided.
885  MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
886  tmpContext.allowUnregisteredDialects();
887  lsp::CompletionList completionList;
888  LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
889  &tmpContext);
890 
891  Block tmpIR;
892  AsmParserState tmpState;
893  (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
894  &lspCompleteContext);
895  return completionList;
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // MLIRDocument: Code Action
900 //===----------------------------------------------------------------------===//
901 
902 void MLIRDocument::getCodeActionForDiagnostic(
903  const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
904  StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
905  // Ignore diagnostics that print the current operation. These are always
906  // enabled for the language server, but not generally during normal
907  // parsing/verification.
908  if (message.starts_with("see current operation: "))
909  return;
910 
911  // Get the start of the line containing the diagnostic.
912  const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
913  const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
914  if (!lineStart)
915  return;
916  StringRef line(lineStart, pos.character);
917 
918  // Add a text edit for adding an expected-* diagnostic check for this
919  // diagnostic.
920  llvm::lsp::TextEdit edit;
921  edit.range = lsp::Range(lsp::Position(pos.line, 0));
922 
923  // Use the indent of the current line for the expected-* diagnostic.
924  size_t indent = line.find_first_not_of(' ');
925  if (indent == StringRef::npos)
926  indent = line.size();
927 
928  edit.newText.append(indent, ' ');
929  llvm::raw_string_ostream(edit.newText)
930  << "// expected-" << severity << " @below {{" << message << "}}\n";
931  edits.emplace_back(std::move(edit));
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // MLIRDocument: Bytecode
936 //===----------------------------------------------------------------------===//
937 
939 MLIRDocument::convertToBytecode() {
940  // TODO: We currently require a single top-level operation, but this could
941  // conceptually be relaxed.
942  if (!llvm::hasSingleElement(parsedIR)) {
943  if (parsedIR.empty()) {
944  return llvm::make_error<llvm::lsp::LSPError>(
945  "expected a single and valid top-level operation, please ensure "
946  "there are no errors",
947  llvm::lsp::ErrorCode::RequestFailed);
948  }
949  return llvm::make_error<llvm::lsp::LSPError>(
950  "expected a single top-level operation",
951  llvm::lsp::ErrorCode::RequestFailed);
952  }
953 
954  lsp::MLIRConvertBytecodeResult result;
955  {
956  BytecodeWriterConfig writerConfig(fallbackResourceMap);
957 
958  std::string rawBytecodeBuffer;
959  llvm::raw_string_ostream os(rawBytecodeBuffer);
960  // No desired bytecode version set, so no need to check for error.
961  (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
962  result.output = llvm::encodeBase64(rawBytecodeBuffer);
963  }
964  return result;
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // MLIRTextFileChunk
969 //===----------------------------------------------------------------------===//
970 
971 namespace {
972 /// This class represents a single chunk of an MLIR text file.
973 struct MLIRTextFileChunk {
974  MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
975  const lsp::URIForFile &uri, StringRef contents,
976  std::vector<lsp::Diagnostic> &diagnostics)
977  : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
978 
979  /// Adjust the line number of the given range to anchor at the beginning of
980  /// the file, instead of the beginning of this chunk.
981  void adjustLocForChunkOffset(lsp::Range &range) {
982  adjustLocForChunkOffset(range.start);
983  adjustLocForChunkOffset(range.end);
984  }
985  /// Adjust the line number of the given position to anchor at the beginning of
986  /// the file, instead of the beginning of this chunk.
987  void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
988 
989  /// The line offset of this chunk from the beginning of the file.
990  uint64_t lineOffset;
991  /// The document referred to by this chunk.
992  MLIRDocument document;
993 };
994 } // namespace
995 
996 //===----------------------------------------------------------------------===//
997 // MLIRTextFile
998 //===----------------------------------------------------------------------===//
999 
1000 namespace {
1001 /// This class represents a text file containing one or more MLIR documents.
1002 class MLIRTextFile {
1003 public:
1004  MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1005  int64_t version, lsp::DialectRegistryFn registry_fn,
1006  std::vector<lsp::Diagnostic> &diagnostics);
1007 
1008  /// Return the current version of this text file.
1009  int64_t getVersion() const { return version; }
1010 
1011  //===--------------------------------------------------------------------===//
1012  // LSP Queries
1013  //===--------------------------------------------------------------------===//
1014 
1015  void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1016  std::vector<lsp::Location> &locations);
1017  void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1018  std::vector<lsp::Location> &references);
1019  std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1020  lsp::Position hoverPos);
1021  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1022  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1023  lsp::Position completePos);
1024  void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1025  const lsp::CodeActionContext &context,
1026  std::vector<lsp::CodeAction> &actions);
1028 
1029 private:
1030  /// Find the MLIR document that contains the given position, and update the
1031  /// position to be anchored at the start of the found chunk instead of the
1032  /// beginning of the file.
1033  MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1034 
1035  /// The context used to hold the state contained by the parsed document.
1036  MLIRContext context;
1037 
1038  /// The full string contents of the file.
1039  std::string contents;
1040 
1041  /// The version of this file.
1042  int64_t version;
1043 
1044  /// The number of lines in the file.
1045  int64_t totalNumLines = 0;
1046 
1047  /// The chunks of this file. The order of these chunks is the order in which
1048  /// they appear in the text file.
1049  std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1050 };
1051 } // namespace
1052 
1053 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1054  int64_t version, lsp::DialectRegistryFn registry_fn,
1055  std::vector<lsp::Diagnostic> &diagnostics)
1056  : context(registry_fn(uri), MLIRContext::Threading::DISABLED),
1057  contents(fileContents.str()), version(version) {
1058  context.allowUnregisteredDialects();
1059 
1060  // Split the file into separate MLIR documents.
1061  SmallVector<StringRef, 8> subContents;
1062  StringRef(contents).split(subContents, kDefaultSplitMarker);
1063  chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1064  context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
1065 
1066  uint64_t lineOffset = subContents.front().count('\n');
1067  for (StringRef docContents : llvm::drop_begin(subContents)) {
1068  unsigned currentNumDiags = diagnostics.size();
1069  auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1070  docContents, diagnostics);
1071  lineOffset += docContents.count('\n');
1072 
1073  // Adjust locations used in diagnostics to account for the offset from the
1074  // beginning of the file.
1075  for (lsp::Diagnostic &diag :
1076  llvm::drop_begin(diagnostics, currentNumDiags)) {
1077  chunk->adjustLocForChunkOffset(diag.range);
1078 
1079  if (!diag.relatedInformation)
1080  continue;
1081  for (auto &it : *diag.relatedInformation)
1082  if (it.location.uri == uri)
1083  chunk->adjustLocForChunkOffset(it.location.range);
1084  }
1085  chunks.emplace_back(std::move(chunk));
1086  }
1087  totalNumLines = lineOffset;
1088 }
1089 
1090 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1091  lsp::Position defPos,
1092  std::vector<lsp::Location> &locations) {
1093  MLIRTextFileChunk &chunk = getChunkFor(defPos);
1094  chunk.document.getLocationsOf(uri, defPos, locations);
1095 
1096  // Adjust any locations within this file for the offset of this chunk.
1097  if (chunk.lineOffset == 0)
1098  return;
1099  for (lsp::Location &loc : locations)
1100  if (loc.uri == uri)
1101  chunk.adjustLocForChunkOffset(loc.range);
1102 }
1103 
1104 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1105  lsp::Position pos,
1106  std::vector<lsp::Location> &references) {
1107  MLIRTextFileChunk &chunk = getChunkFor(pos);
1108  chunk.document.findReferencesOf(uri, pos, references);
1109 
1110  // Adjust any locations within this file for the offset of this chunk.
1111  if (chunk.lineOffset == 0)
1112  return;
1113  for (lsp::Location &loc : references)
1114  if (loc.uri == uri)
1115  chunk.adjustLocForChunkOffset(loc.range);
1116 }
1117 
1118 std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1119  lsp::Position hoverPos) {
1120  MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1121  std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1122 
1123  // Adjust any locations within this file for the offset of this chunk.
1124  if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1125  chunk.adjustLocForChunkOffset(*hoverInfo->range);
1126  return hoverInfo;
1127 }
1128 
1129 void MLIRTextFile::findDocumentSymbols(
1130  std::vector<lsp::DocumentSymbol> &symbols) {
1131  if (chunks.size() == 1)
1132  return chunks.front()->document.findDocumentSymbols(symbols);
1133 
1134  // If there are multiple chunks in this file, we create top-level symbols for
1135  // each chunk.
1136  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1137  MLIRTextFileChunk &chunk = *chunks[i];
1138  lsp::Position startPos(chunk.lineOffset);
1139  lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1140  : chunks[i + 1]->lineOffset);
1141  lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1142  llvm::lsp::SymbolKind::Namespace,
1143  /*range=*/lsp::Range(startPos, endPos),
1144  /*selectionRange=*/lsp::Range(startPos));
1145  chunk.document.findDocumentSymbols(symbol.children);
1146 
1147  // Fixup the locations of document symbols within this chunk.
1148  if (i != 0) {
1150  for (lsp::DocumentSymbol &childSymbol : symbol.children)
1151  symbolsToFix.push_back(&childSymbol);
1152 
1153  while (!symbolsToFix.empty()) {
1154  lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1155  chunk.adjustLocForChunkOffset(symbol->range);
1156  chunk.adjustLocForChunkOffset(symbol->selectionRange);
1157 
1158  for (lsp::DocumentSymbol &childSymbol : symbol->children)
1159  symbolsToFix.push_back(&childSymbol);
1160  }
1161  }
1162 
1163  // Push the symbol for this chunk.
1164  symbols.emplace_back(std::move(symbol));
1165  }
1166 }
1167 
1168 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1169  lsp::Position completePos) {
1170  MLIRTextFileChunk &chunk = getChunkFor(completePos);
1171  lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1172  uri, completePos, context.getDialectRegistry());
1173 
1174  // Adjust any completion locations.
1175  for (llvm::lsp::CompletionItem &item : completionList.items) {
1176  if (item.textEdit)
1177  chunk.adjustLocForChunkOffset(item.textEdit->range);
1178  for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1179  chunk.adjustLocForChunkOffset(edit.range);
1180  }
1181  return completionList;
1182 }
1183 
1184 void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1185  const lsp::Range &pos,
1186  const lsp::CodeActionContext &context,
1187  std::vector<lsp::CodeAction> &actions) {
1188  // Create actions for any diagnostics in this file.
1189  for (auto &diag : context.diagnostics) {
1190  if (diag.source != "mlir")
1191  continue;
1192  lsp::Position diagPos = diag.range.start;
1193  MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1194 
1195  // Add a new code action that inserts a "expected" diagnostic check.
1196  lsp::CodeAction action;
1197  action.title = "Add expected-* diagnostic checks";
1198  action.kind = lsp::CodeAction::kQuickFix.str();
1199 
1200  StringRef severity;
1201  switch (diag.severity) {
1203  severity = "error";
1204  break;
1205  case llvm::lsp::DiagnosticSeverity::Warning:
1206  severity = "warning";
1207  break;
1208  default:
1209  continue;
1210  }
1211 
1212  // Get edits for the diagnostic.
1213  std::vector<llvm::lsp::TextEdit> edits;
1214  chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1215  diag.message, edits);
1216 
1217  // Walk the related diagnostics, this is how we encode notes.
1218  if (diag.relatedInformation) {
1219  for (auto &noteDiag : *diag.relatedInformation) {
1220  if (noteDiag.location.uri != uri)
1221  continue;
1222  diagPos = noteDiag.location.range.start;
1223  diagPos.line -= chunk.lineOffset;
1224  chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1225  noteDiag.message, edits);
1226  }
1227  }
1228  // Fixup the locations for any edits.
1229  for (llvm::lsp::TextEdit &edit : edits)
1230  chunk.adjustLocForChunkOffset(edit.range);
1231 
1232  action.edit.emplace();
1233  action.edit->changes[uri.uri().str()] = std::move(edits);
1234  action.diagnostics = {diag};
1235 
1236  actions.emplace_back(std::move(action));
1237  }
1238 }
1239 
1241 MLIRTextFile::convertToBytecode() {
1242  // Bail out if there is more than one chunk, bytecode wants a single module.
1243  if (chunks.size() != 1) {
1244  return llvm::make_error<llvm::lsp::LSPError>(
1245  "unexpected split file, please remove all `// -----`",
1246  llvm::lsp::ErrorCode::RequestFailed);
1247  }
1248  return chunks.front()->document.convertToBytecode();
1249 }
1250 
1251 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1252  if (chunks.size() == 1)
1253  return *chunks.front();
1254 
1255  // Search for the first chunk with a greater line offset, the previous chunk
1256  // is the one that contains `pos`.
1257  auto it = llvm::upper_bound(
1258  chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1259  return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1260  });
1261  MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1262  pos.line -= chunk.lineOffset;
1263  return chunk;
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // MLIRServer::Impl
1268 //===----------------------------------------------------------------------===//
1269 
1272 
1273  /// The registry factory for containing dialects that can be recognized in
1274  /// parsed .mlir files.
1276 
1277  /// The files held by the server, mapped by their URI file name.
1278  llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1279 };
1280 
1281 //===----------------------------------------------------------------------===//
1282 // MLIRServer
1283 //===----------------------------------------------------------------------===//
1284 
1285 lsp::MLIRServer::MLIRServer(lsp::DialectRegistryFn registry_fn)
1286  : impl(std::make_unique<Impl>(registry_fn)) {}
1287 lsp::MLIRServer::~MLIRServer() = default;
1288 
1290  const URIForFile &uri, StringRef contents, int64_t version,
1291  std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1292  impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1293  uri, contents, version, impl->registry_fn, diagnostics);
1294 }
1295 
1296 std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1297  auto it = impl->files.find(uri.file());
1298  if (it == impl->files.end())
1299  return std::nullopt;
1300 
1301  int64_t version = it->second->getVersion();
1302  impl->files.erase(it);
1303  return version;
1304 }
1305 
1307  const URIForFile &uri, const Position &defPos,
1308  std::vector<llvm::lsp::Location> &locations) {
1309  auto fileIt = impl->files.find(uri.file());
1310  if (fileIt != impl->files.end())
1311  fileIt->second->getLocationsOf(uri, defPos, locations);
1312 }
1313 
1315  const URIForFile &uri, const Position &pos,
1316  std::vector<llvm::lsp::Location> &references) {
1317  auto fileIt = impl->files.find(uri.file());
1318  if (fileIt != impl->files.end())
1319  fileIt->second->findReferencesOf(uri, pos, references);
1320 }
1321 
1322 std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1323  const Position &hoverPos) {
1324  auto fileIt = impl->files.find(uri.file());
1325  if (fileIt != impl->files.end())
1326  return fileIt->second->findHover(uri, hoverPos);
1327  return std::nullopt;
1328 }
1329 
1331  const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1332  auto fileIt = impl->files.find(uri.file());
1333  if (fileIt != impl->files.end())
1334  fileIt->second->findDocumentSymbols(symbols);
1335 }
1336 
1337 lsp::CompletionList
1339  const Position &completePos) {
1340  auto fileIt = impl->files.find(uri.file());
1341  if (fileIt != impl->files.end())
1342  return fileIt->second->getCodeCompletion(uri, completePos);
1343  return CompletionList();
1344 }
1345 
1346 void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1347  const CodeActionContext &context,
1348  std::vector<CodeAction> &actions) {
1349  auto fileIt = impl->files.find(uri.file());
1350  if (fileIt != impl->files.end())
1351  fileIt->second->getCodeActions(uri, pos, context, actions);
1352 }
1353 
1356  MLIRContext tempContext(impl->registry_fn(uri));
1357  tempContext.allowUnregisteredDialects();
1358 
1359  // Collect any errors during parsing.
1360  std::string errorMsg;
1361  ScopedDiagnosticHandler diagHandler(
1362  &tempContext,
1363  [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1364 
1365  // Handling for external resources, which we want to propagate up to the user.
1366  FallbackAsmResourceMap fallbackResourceMap;
1367 
1368  // Setup the parser config.
1369  ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1370  &fallbackResourceMap);
1371 
1372  // Try to parse the given source file.
1373  Block parsedBlock;
1374  if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1375  return llvm::make_error<llvm::lsp::LSPError>(
1376  "failed to parse bytecode source file: " + errorMsg,
1377  llvm::lsp::ErrorCode::RequestFailed);
1378  }
1379 
1380  // TODO: We currently expect a single top-level operation, but this could
1381  // conceptually be relaxed.
1382  if (!llvm::hasSingleElement(parsedBlock)) {
1383  return llvm::make_error<llvm::lsp::LSPError>(
1384  "expected bytecode to contain a single top-level operation",
1385  llvm::lsp::ErrorCode::RequestFailed);
1386  }
1387 
1388  // Print the module to a buffer.
1390  {
1391  // Extract the top-level op so that aliases get printed.
1392  // FIXME: We should be able to enable aliases without having to do this!
1393  OwningOpRef<Operation *> topOp = &parsedBlock.front();
1394  topOp->remove();
1395 
1396  AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1397  /*locationMap=*/nullptr, &fallbackResourceMap);
1398 
1399  llvm::raw_string_ostream os(result.output);
1400  topOp->print(os, state);
1401  }
1402  return std::move(result);
1403 }
1404 
1406 lsp::MLIRServer::convertToBytecode(const URIForFile &uri) {
1407  auto fileIt = impl->files.find(uri.file());
1408  if (fileIt == impl->files.end()) {
1409  return llvm::make_error<llvm::lsp::LSPError>(
1410  "language server does not contain an entry for this source file",
1411  llvm::lsp::ErrorCode::RequestFailed);
1412  }
1413  return fileIt->second->convertToBytecode();
1414 }
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
Definition: MLIRServer.cpp:91
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
Definition: MLIRServer.cpp:176
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
Definition: MLIRServer.cpp:37
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
Definition: MLIRServer.cpp:119
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
Definition: MLIRServer.cpp:111
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
Definition: MLIRServer.cpp:199
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
Definition: MLIRServer.cpp:182
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
Definition: MLIRServer.cpp:141
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
Definition: MLIRServer.cpp:31
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
Definition: MLIRServer.cpp:168
@ Error
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:24
This class represents state from a parsed MLIR textual format string.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:542
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:248
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
SuccessorRange getSuccessors()
Definition: Block.h:270
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:240
Operation & front()
Definition: Block.h:153
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:245
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class contains the configuration used for the bytecode writer.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition: AsmState.h:421
An instance of this location represents a tuple of file, line number, and column number.
Definition: Location.h:174
unsigned getLine() const
Definition: Location.cpp:173
StringAttr getFilename() const
Definition: Location.cpp:169
unsigned getColumn() const
Definition: Location.cpp:175
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
Definition: Location.cpp:124
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
Definition: Operation.h:415
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:469
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Definition: Diagnostics.h:522
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:2918
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:38
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
Definition: Protocol.h:48
std::string output
The resultant output of the conversion.
Definition: Protocol.h:50
Impl(lsp::DialectRegistryFn registry_fn)
lsp::DialectRegistryFn registry_fn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
This class represents the information for an attribute alias definition within the input file.
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
This class represents the information for an operation definition within an input file.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
This class represents the information for type definition within the input file.
StringRef name
The name of the attribute alias.