MLIR  16.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/SourceMgrUtils.h"
12 #include "Protocol.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Parser/Parser.h"
20 #include "llvm/Support/Base64.h"
21 #include "llvm/Support/SourceMgr.h"
22 
23 using namespace mlir;
24 
25 /// Returns a language server location from the given MLIR file location.
26 /// `uriScheme` is the scheme to use when building new uris.
27 static Optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
28  FileLineColLoc loc) {
30  lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
31  if (!sourceURI) {
32  lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
33  loc.getFilename(),
34  llvm::toString(sourceURI.takeError()));
35  return llvm::None;
36  }
37 
38  lsp::Position position;
39  position.line = loc.getLine() - 1;
40  position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
41  return lsp::Location{*sourceURI, lsp::Range(position)};
42 }
43 
44 /// Returns a language server location from the given MLIR location, or None if
45 /// one couldn't be created. `uriScheme` is the scheme to use when building new
46 /// uris. `uri` is an optional additional filter that, when present, is used to
47 /// filter sub locations that do not share the same uri.
49 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
50  StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
51  Optional<lsp::Location> location;
52  loc->walk([&](Location nestedLoc) {
53  FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
54  if (!fileLoc)
55  return WalkResult::advance();
56 
57  Optional<lsp::Location> sourceLoc = getLocationFromLoc(uriScheme, fileLoc);
58  if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
59  location = *sourceLoc;
60  SMLoc loc = sourceMgr.FindLocForLineAndColumn(
61  sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
62 
63  // Use range of potential identifier starting at location, else length 1
64  // range.
65  location->range.end.character += 1;
67  auto lineCol = sourceMgr.getLineAndColumn(range->End);
68  location->range.end.character =
69  std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
70  }
71  return WalkResult::interrupt();
72  }
73  return WalkResult::advance();
74  });
75  return location;
76 }
77 
78 /// Collect all of the locations from the given MLIR location that are not
79 /// contained within the given URI.
81  std::vector<lsp::Location> &locations,
82  const lsp::URIForFile &uri) {
83  SetVector<Location> visitedLocs;
84  loc->walk([&](Location nestedLoc) {
85  FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
86  if (!fileLoc || !visitedLocs.insert(nestedLoc))
87  return WalkResult::advance();
88 
89  Optional<lsp::Location> sourceLoc =
90  getLocationFromLoc(uri.scheme(), fileLoc);
91  if (sourceLoc && sourceLoc->uri != uri)
92  locations.push_back(*sourceLoc);
93  return WalkResult::advance();
94  });
95 }
96 
97 /// Returns true if the given range contains the given source location. Note
98 /// that this has slightly different behavior than SMRange because it is
99 /// inclusive of the end location.
100 static bool contains(SMRange range, SMLoc loc) {
101  return range.Start.getPointer() <= loc.getPointer() &&
102  loc.getPointer() <= range.End.getPointer();
103 }
104 
105 /// Returns true if the given location is contained by the definition or one of
106 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
107 /// the range within `def` that the provided `loc` overlapped with.
108 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
109  SMRange *overlappedRange = nullptr) {
110  // Check the main definition.
111  if (contains(def.loc, loc)) {
112  if (overlappedRange)
113  *overlappedRange = def.loc;
114  return true;
115  }
116 
117  // Check the uses.
118  const auto *useIt = llvm::find_if(
119  def.uses, [&](const SMRange &range) { return contains(range, loc); });
120  if (useIt != def.uses.end()) {
121  if (overlappedRange)
122  *overlappedRange = *useIt;
123  return true;
124  }
125  return false;
126 }
127 
128 /// Given a location pointing to a result, return the result number it refers
129 /// to or None if it refers to all of the results.
131  // Skip all of the identifier characters.
132  auto isIdentifierChar = [](char c) {
133  return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
134  c == '-';
135  };
136  const char *curPtr = loc.getPointer();
137  while (isIdentifierChar(*curPtr))
138  ++curPtr;
139 
140  // Check to see if this location indexes into the result group, via `#`. If it
141  // doesn't, we can't extract a sub result number.
142  if (*curPtr != '#')
143  return llvm::None;
144 
145  // Compute the sub result number from the remaining portion of the string.
146  const char *numberStart = ++curPtr;
147  while (llvm::isDigit(*curPtr))
148  ++curPtr;
149  StringRef numberStr(numberStart, curPtr - numberStart);
150  unsigned resultNumber = 0;
151  return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
152  : resultNumber;
153 }
154 
155 /// Given a source location range, return the text covered by the given range.
156 /// If the range is invalid, returns None.
157 static Optional<StringRef> getTextFromRange(SMRange range) {
158  if (!range.isValid())
159  return None;
160  const char *startPtr = range.Start.getPointer();
161  return StringRef(startPtr, range.End.getPointer() - startPtr);
162 }
163 
164 /// Given a block, return its position in its parent region.
165 static unsigned getBlockNumber(Block *block) {
166  return std::distance(block->getParent()->begin(), block->getIterator());
167 }
168 
169 /// Given a block and source location, print the source name of the block to the
170 /// given output stream.
171 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
172  // Try to extract a name from the source location.
174  if (text && text->startswith("^")) {
175  os << *text;
176  return;
177  }
178 
179  // Otherwise, we don't have a name so print the block number.
180  os << "<Block #" << getBlockNumber(block) << ">";
181 }
182 static void printDefBlockName(raw_ostream &os,
183  const AsmParserState::BlockDefinition &def) {
184  printDefBlockName(os, def.block, def.definition.loc);
185 }
186 
187 /// Convert the given MLIR diagnostic to the LSP form.
188 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
189  Diagnostic &diag,
190  const lsp::URIForFile &uri) {
191  lsp::Diagnostic lspDiag;
192  lspDiag.source = "mlir";
193 
194  // Note: Right now all of the diagnostics are treated as parser issues, but
195  // some are parser and some are verifier.
196  lspDiag.category = "Parse Error";
197 
198  // Try to grab a file location for this diagnostic.
199  // TODO: For simplicity, we just grab the first one. It may be likely that we
200  // will need a more interesting heuristic here.'
201  StringRef uriScheme = uri.scheme();
202  Optional<lsp::Location> lspLocation =
203  getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
204  if (lspLocation)
205  lspDiag.range = lspLocation->range;
206 
207  // Convert the severity for the diagnostic.
208  switch (diag.getSeverity()) {
210  llvm_unreachable("expected notes to be handled separately");
213  break;
216  break;
219  break;
220  }
221  lspDiag.message = diag.str();
222 
223  // Attach any notes to the main diagnostic as related information.
224  std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
225  for (Diagnostic &note : diag.getNotes()) {
226  lsp::Location noteLoc;
227  if (Optional<lsp::Location> loc =
228  getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
229  noteLoc = *loc;
230  else
231  noteLoc.uri = uri;
232  relatedDiags.emplace_back(noteLoc, note.str());
233  }
234  if (!relatedDiags.empty())
235  lspDiag.relatedInformation = std::move(relatedDiags);
236 
237  return lspDiag;
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // MLIRDocument
242 //===----------------------------------------------------------------------===//
243 
244 namespace {
245 /// This class represents all of the information pertaining to a specific MLIR
246 /// document.
247 struct MLIRDocument {
248  MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
249  StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
250  MLIRDocument(const MLIRDocument &) = delete;
251  MLIRDocument &operator=(const MLIRDocument &) = delete;
252 
253  //===--------------------------------------------------------------------===//
254  // Definitions and References
255  //===--------------------------------------------------------------------===//
256 
257  void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
258  std::vector<lsp::Location> &locations);
259  void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
260  std::vector<lsp::Location> &references);
261 
262  //===--------------------------------------------------------------------===//
263  // Hover
264  //===--------------------------------------------------------------------===//
265 
266  Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
267  const lsp::Position &hoverPos);
269  buildHoverForOperation(SMRange hoverRange,
271  lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
272  unsigned resultStart,
273  unsigned resultEnd, SMLoc posLoc);
274  lsp::Hover buildHoverForBlock(SMRange hoverRange,
275  const AsmParserState::BlockDefinition &block);
276  lsp::Hover
277  buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
278  const AsmParserState::BlockDefinition &block);
279 
280  //===--------------------------------------------------------------------===//
281  // Document Symbols
282  //===--------------------------------------------------------------------===//
283 
284  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
285  void findDocumentSymbols(Operation *op,
286  std::vector<lsp::DocumentSymbol> &symbols);
287 
288  //===--------------------------------------------------------------------===//
289  // Code Completion
290  //===--------------------------------------------------------------------===//
291 
292  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
293  const lsp::Position &completePos,
294  const DialectRegistry &registry);
295 
296  //===--------------------------------------------------------------------===//
297  // Code Action
298  //===--------------------------------------------------------------------===//
299 
300  void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
301  lsp::Position &pos, StringRef severity,
302  StringRef message,
303  std::vector<lsp::TextEdit> &edits);
304 
305  //===--------------------------------------------------------------------===//
306  // Bytecode
307  //===--------------------------------------------------------------------===//
308 
310 
311  //===--------------------------------------------------------------------===//
312  // Fields
313  //===--------------------------------------------------------------------===//
314 
315  /// The high level parser state used to find definitions and references within
316  /// the source file.
317  AsmParserState asmState;
318 
319  /// The container for the IR parsed from the input file.
320  Block parsedIR;
321 
322  /// A collection of external resources, which we want to propagate up to the
323  /// user.
324  FallbackAsmResourceMap fallbackResourceMap;
325 
326  /// The source manager containing the contents of the input file.
327  llvm::SourceMgr sourceMgr;
328 };
329 } // namespace
330 
331 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
332  StringRef contents,
333  std::vector<lsp::Diagnostic> &diagnostics) {
334  ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
335  diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
336  });
337 
338  // Try to parsed the given IR string.
339  auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
340  if (!memBuffer) {
341  lsp::Logger::error("Failed to create memory buffer for file", uri.file());
342  return;
343  }
344 
345  ParserConfig config(&context, /*verifyAfterParse=*/true,
346  &fallbackResourceMap);
347  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
348  if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
349  // If parsing failed, clear out any of the current state.
350  parsedIR.clear();
351  asmState = AsmParserState();
352  fallbackResourceMap = FallbackAsmResourceMap();
353  return;
354  }
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // MLIRDocument: Definitions and References
359 //===----------------------------------------------------------------------===//
360 
361 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
362  const lsp::Position &defPos,
363  std::vector<lsp::Location> &locations) {
364  SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
365 
366  // Functor used to check if an SM definition contains the position.
367  auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
368  if (!isDefOrUse(def, posLoc))
369  return false;
370  locations.emplace_back(uri, sourceMgr, def.loc);
371  return true;
372  };
373 
374  // Check all definitions related to operations.
375  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
376  if (contains(op.loc, posLoc))
377  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
378  for (const auto &result : op.resultGroups)
379  if (containsPosition(result.definition))
380  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
381  for (const auto &symUse : op.symbolUses) {
382  if (contains(symUse, posLoc)) {
383  locations.emplace_back(uri, sourceMgr, op.loc);
384  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
385  }
386  }
387  }
388 
389  // Check all definitions related to blocks.
390  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
391  if (containsPosition(block.definition))
392  return;
393  for (const AsmParserState::SMDefinition &arg : block.arguments)
394  if (containsPosition(arg))
395  return;
396  }
397 }
398 
399 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
400  const lsp::Position &pos,
401  std::vector<lsp::Location> &references) {
402  // Functor used to append all of the definitions/uses of the given SM
403  // definition to the reference list.
404  auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
405  references.emplace_back(uri, sourceMgr, def.loc);
406  for (const SMRange &use : def.uses)
407  references.emplace_back(uri, sourceMgr, use);
408  };
409 
410  SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
411 
412  // Check all definitions related to operations.
413  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
414  if (contains(op.loc, posLoc)) {
415  for (const auto &result : op.resultGroups)
416  appendSMDef(result.definition);
417  for (const auto &symUse : op.symbolUses)
418  if (contains(symUse, posLoc))
419  references.emplace_back(uri, sourceMgr, symUse);
420  return;
421  }
422  for (const auto &result : op.resultGroups)
423  if (isDefOrUse(result.definition, posLoc))
424  return appendSMDef(result.definition);
425  for (const auto &symUse : op.symbolUses) {
426  if (!contains(symUse, posLoc))
427  continue;
428  for (const auto &symUse : op.symbolUses)
429  references.emplace_back(uri, sourceMgr, symUse);
430  return;
431  }
432  }
433 
434  // Check all definitions related to blocks.
435  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
436  if (isDefOrUse(block.definition, posLoc))
437  return appendSMDef(block.definition);
438 
439  for (const AsmParserState::SMDefinition &arg : block.arguments)
440  if (isDefOrUse(arg, posLoc))
441  return appendSMDef(arg);
442  }
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // MLIRDocument: Hover
447 //===----------------------------------------------------------------------===//
448 
449 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
450  const lsp::Position &hoverPos) {
451  SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
452  SMRange hoverRange;
453 
454  // Check for Hovers on operations and results.
455  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
456  // Check if the position points at this operation.
457  if (contains(op.loc, posLoc))
458  return buildHoverForOperation(op.loc, op);
459 
460  // Check if the position points at the symbol name.
461  for (auto &use : op.symbolUses)
462  if (contains(use, posLoc))
463  return buildHoverForOperation(use, op);
464 
465  // Check if the position points at a result group.
466  for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
467  const auto &result = op.resultGroups[i];
468  if (!isDefOrUse(result.definition, posLoc, &hoverRange))
469  continue;
470 
471  // Get the range of results covered by the over position.
472  unsigned resultStart = result.startIndex;
473  unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
474  : op.resultGroups[i + 1].startIndex;
475  return buildHoverForOperationResult(hoverRange, op.op, resultStart,
476  resultEnd, posLoc);
477  }
478  }
479 
480  // Check to see if the hover is over a block argument.
481  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
482  if (isDefOrUse(block.definition, posLoc, &hoverRange))
483  return buildHoverForBlock(hoverRange, block);
484 
485  for (const auto &arg : llvm::enumerate(block.arguments)) {
486  if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
487  continue;
488 
489  return buildHoverForBlockArgument(
490  hoverRange, block.block->getArgument(arg.index()), block);
491  }
492  }
493  return llvm::None;
494 }
495 
496 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
497  SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
498  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
499  llvm::raw_string_ostream os(hover.contents.value);
500 
501  // Add the operation name to the hover.
502  os << "\"" << op.op->getName() << "\"";
503  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
504  os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
505  os << "\n\n";
506 
507  os << "Generic Form:\n\n```mlir\n";
508 
509  // Temporary drop the regions of this operation so that they don't get
510  // printed in the output. This helps keeps the size of the output hover
511  // small.
513  for (Region &region : op.op->getRegions()) {
514  regions.emplace_back(std::make_unique<Region>());
515  regions.back()->takeBody(region);
516  }
517 
518  op.op->print(
519  os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
520  os << "\n```\n";
521 
522  // Move the regions back to the current operation.
523  for (Region &region : op.op->getRegions())
524  region.takeBody(*regions.back());
525 
526  return hover;
527 }
528 
529 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
530  Operation *op,
531  unsigned resultStart,
532  unsigned resultEnd,
533  SMLoc posLoc) {
534  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
535  llvm::raw_string_ostream os(hover.contents.value);
536 
537  // Add the parent operation name to the hover.
538  os << "Operation: \"" << op->getName() << "\"\n\n";
539 
540  // Check to see if the location points to a specific result within the
541  // group.
542  if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
543  if ((resultStart + *resultNumber) < resultEnd) {
544  resultStart += *resultNumber;
545  resultEnd = resultStart + 1;
546  }
547  }
548 
549  // Add the range of results and their types to the hover info.
550  if ((resultStart + 1) == resultEnd) {
551  os << "Result #" << resultStart << "\n\n"
552  << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
553  } else {
554  os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
555  << "Types: ";
556  llvm::interleaveComma(
557  op->getResults().slice(resultStart, resultEnd), os,
558  [&](Value result) { os << "`" << result.getType() << "`"; });
559  }
560 
561  return hover;
562 }
563 
565 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
566  const AsmParserState::BlockDefinition &block) {
567  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
568  llvm::raw_string_ostream os(hover.contents.value);
569 
570  // Print the given block to the hover output stream.
571  auto printBlockToHover = [&](Block *newBlock) {
572  if (const auto *def = asmState.getBlockDef(newBlock))
573  printDefBlockName(os, *def);
574  else
575  printDefBlockName(os, newBlock);
576  };
577 
578  // Display the parent operation, block number, predecessors, and successors.
579  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
580  << "Block #" << getBlockNumber(block.block) << "\n\n";
581  if (!block.block->hasNoPredecessors()) {
582  os << "Predecessors: ";
583  llvm::interleaveComma(block.block->getPredecessors(), os,
584  printBlockToHover);
585  os << "\n\n";
586  }
587  if (!block.block->hasNoSuccessors()) {
588  os << "Successors: ";
589  llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
590  os << "\n\n";
591  }
592 
593  return hover;
594 }
595 
596 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
597  SMRange hoverRange, BlockArgument arg,
598  const AsmParserState::BlockDefinition &block) {
599  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
600  llvm::raw_string_ostream os(hover.contents.value);
601 
602  // Display the parent operation, block, the argument number, and the type.
603  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
604  << "Block: ";
605  printDefBlockName(os, block);
606  os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
607  << "Type: `" << arg.getType() << "`\n\n";
608 
609  return hover;
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // MLIRDocument: Document Symbols
614 //===----------------------------------------------------------------------===//
615 
616 void MLIRDocument::findDocumentSymbols(
617  std::vector<lsp::DocumentSymbol> &symbols) {
618  for (Operation &op : parsedIR)
619  findDocumentSymbols(&op, symbols);
620 }
621 
622 void MLIRDocument::findDocumentSymbols(
623  Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
624  std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
625 
626  // Check for the source information of this operation.
627  if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
628  // If this operation defines a symbol, record it.
629  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
630  symbols.emplace_back(symbol.getName(),
631  isa<FunctionOpInterface>(op)
634  lsp::Range(sourceMgr, def->scopeLoc),
635  lsp::Range(sourceMgr, def->loc));
636  childSymbols = &symbols.back().children;
637 
638  } else if (op->hasTrait<OpTrait::SymbolTable>()) {
639  // Otherwise, if this is a symbol table push an anonymous document symbol.
640  symbols.emplace_back("<" + op->getName().getStringRef() + ">",
642  lsp::Range(sourceMgr, def->scopeLoc),
643  lsp::Range(sourceMgr, def->loc));
644  childSymbols = &symbols.back().children;
645  }
646  }
647 
648  // Recurse into the regions of this operation.
649  if (!op->getNumRegions())
650  return;
651  for (Region &region : op->getRegions())
652  for (Operation &childOp : region.getOps())
653  findDocumentSymbols(&childOp, *childSymbols);
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // MLIRDocument: Code Completion
658 //===----------------------------------------------------------------------===//
659 
660 namespace {
661 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
662 public:
663  LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
664  MLIRContext *ctx)
665  : AsmParserCodeCompleteContext(completeLoc),
666  completionList(completionList), ctx(ctx) {}
667 
668  /// Signal code completion for a dialect name, with an optional prefix.
669  void completeDialectName(StringRef prefix) final {
670  for (StringRef dialect : ctx->getAvailableDialects()) {
671  lsp::CompletionItem item(prefix + dialect,
673  /*sortText=*/"3");
674  item.detail = "dialect";
675  completionList.items.emplace_back(item);
676  }
677  }
679 
680  /// Signal code completion for an operation name within the given dialect.
681  void completeOperationName(StringRef dialectName) final {
682  Dialect *dialect = ctx->getOrLoadDialect(dialectName);
683  if (!dialect)
684  return;
685 
686  for (const auto &op : ctx->getRegisteredOperations()) {
687  if (&op.getDialect() != dialect)
688  continue;
689 
690  lsp::CompletionItem item(
691  op.getStringRef().drop_front(dialectName.size() + 1),
693  /*sortText=*/"1");
694  item.detail = "operation";
695  completionList.items.emplace_back(item);
696  }
697  }
698 
699  /// Append the given SSA value as a code completion result for SSA value
700  /// completions.
701  void appendSSAValueCompletion(StringRef name, std::string typeData) final {
702  // Check if we need to insert the `%` or not.
703  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
704 
706  if (stripPrefix)
707  item.insertText = name.drop_front(1).str();
708  item.detail = std::move(typeData);
709  completionList.items.emplace_back(item);
710  }
711 
712  /// Append the given block as a code completion result for block name
713  /// completions.
714  void appendBlockCompletion(StringRef name) final {
715  // Check if we need to insert the `^` or not.
716  bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
717 
719  if (stripPrefix)
720  item.insertText = name.drop_front(1).str();
721  completionList.items.emplace_back(item);
722  }
723 
724  /// Signal a completion for the given expected token.
725  void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
726  for (StringRef token : tokens) {
728  /*sortText=*/"0");
729  item.detail = optional ? "optional" : "";
730  completionList.items.emplace_back(item);
731  }
732  }
733 
734  /// Signal a completion for an attribute.
735  void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
736  appendSimpleCompletions({"affine_set", "affine_map", "dense",
737  "dense_resource", "false", "loc", "sparse", "true",
738  "unit"},
740  /*sortText=*/"1");
741 
742  completeDialectName("#");
743  completeAliases(aliases, "#");
744  }
745  void completeDialectAttributeOrAlias(
746  const llvm::StringMap<Attribute> &aliases) override {
747  completeDialectName();
748  completeAliases(aliases);
749  }
750 
751  /// Signal a completion for a type.
752  void completeType(const llvm::StringMap<Type> &aliases) override {
753  // Handle the various builtin types.
754  appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
755  "bf16", "f16", "f32", "f64", "f80", "f128",
756  "index", "none"},
758  /*sortText=*/"1");
759 
760  // Handle the builtin integer types.
761  for (StringRef type : {"i", "si", "ui"}) {
763  /*sortText=*/"1");
764  item.insertText = type.str();
765  completionList.items.emplace_back(item);
766  }
767 
768  // Insert completions for dialect types and aliases.
769  completeDialectName("!");
770  completeAliases(aliases, "!");
771  }
772  void
773  completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
774  completeDialectName();
775  completeAliases(aliases);
776  }
777 
778  /// Add completion results for the given set of aliases.
779  template <typename T>
780  void completeAliases(const llvm::StringMap<T> &aliases,
781  StringRef prefix = "") {
782  for (const auto &alias : aliases) {
783  lsp::CompletionItem item(prefix + alias.getKey(),
785  /*sortText=*/"2");
786  llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
787  completionList.items.emplace_back(item);
788  }
789  }
790 
791  /// Add a set of simple completions that all have the same kind.
792  void appendSimpleCompletions(ArrayRef<StringRef> completions,
794  StringRef sortText = "") {
795  for (StringRef completion : completions)
796  completionList.items.emplace_back(completion, kind, sortText);
797  }
798 
799 private:
800  lsp::CompletionList &completionList;
801  MLIRContext *ctx;
802 };
803 } // namespace
804 
806 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
807  const lsp::Position &completePos,
808  const DialectRegistry &registry) {
809  SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
810  if (!posLoc.isValid())
811  return lsp::CompletionList();
812 
813  // To perform code completion, we run another parse of the module with the
814  // code completion context provided.
815  MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
816  tmpContext.allowUnregisteredDialects();
817  lsp::CompletionList completionList;
818  LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
819  &tmpContext);
820 
821  Block tmpIR;
822  AsmParserState tmpState;
823  (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
824  &lspCompleteContext);
825  return completionList;
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // MLIRDocument: Code Action
830 //===----------------------------------------------------------------------===//
831 
832 void MLIRDocument::getCodeActionForDiagnostic(
833  const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
834  StringRef message, std::vector<lsp::TextEdit> &edits) {
835  // Ignore diagnostics that print the current operation. These are always
836  // enabled for the language server, but not generally during normal
837  // parsing/verification.
838  if (message.startswith("see current operation: "))
839  return;
840 
841  // Get the start of the line containing the diagnostic.
842  const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
843  const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
844  if (!lineStart)
845  return;
846  StringRef line(lineStart, pos.character);
847 
848  // Add a text edit for adding an expected-* diagnostic check for this
849  // diagnostic.
850  lsp::TextEdit edit;
851  edit.range = lsp::Range(lsp::Position(pos.line, 0));
852 
853  // Use the indent of the current line for the expected-* diagnostic.
854  size_t indent = line.find_first_not_of(" ");
855  if (indent == StringRef::npos)
856  indent = line.size();
857 
858  edit.newText.append(indent, ' ');
859  llvm::raw_string_ostream(edit.newText)
860  << "// expected-" << severity << " @below {{" << message << "}}\n";
861  edits.emplace_back(std::move(edit));
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // MLIRDocument: Bytecode
866 //===----------------------------------------------------------------------===//
867 
869 MLIRDocument::convertToBytecode() {
870  // TODO: We currently require a single top-level operation, but this could
871  // conceptually be relaxed.
872  if (!llvm::hasSingleElement(parsedIR)) {
873  if (parsedIR.empty()) {
874  return llvm::make_error<lsp::LSPError>(
875  "expected a single and valid top-level operation, please ensure "
876  "there are no errors",
878  }
879  return llvm::make_error<lsp::LSPError>(
880  "expected a single top-level operation", lsp::ErrorCode::RequestFailed);
881  }
882 
884  {
885  BytecodeWriterConfig writerConfig(fallbackResourceMap);
886 
887  std::string rawBytecodeBuffer;
888  llvm::raw_string_ostream os(rawBytecodeBuffer);
889  writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
890  result.output = llvm::encodeBase64(rawBytecodeBuffer);
891  }
892  return result;
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // MLIRTextFileChunk
897 //===----------------------------------------------------------------------===//
898 
899 namespace {
900 /// This class represents a single chunk of an MLIR text file.
901 struct MLIRTextFileChunk {
902  MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
903  const lsp::URIForFile &uri, StringRef contents,
904  std::vector<lsp::Diagnostic> &diagnostics)
905  : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
906 
907  /// Adjust the line number of the given range to anchor at the beginning of
908  /// the file, instead of the beginning of this chunk.
909  void adjustLocForChunkOffset(lsp::Range &range) {
910  adjustLocForChunkOffset(range.start);
911  adjustLocForChunkOffset(range.end);
912  }
913  /// Adjust the line number of the given position to anchor at the beginning of
914  /// the file, instead of the beginning of this chunk.
915  void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
916 
917  /// The line offset of this chunk from the beginning of the file.
918  uint64_t lineOffset;
919  /// The document referred to by this chunk.
920  MLIRDocument document;
921 };
922 } // namespace
923 
924 //===----------------------------------------------------------------------===//
925 // MLIRTextFile
926 //===----------------------------------------------------------------------===//
927 
928 namespace {
929 /// This class represents a text file containing one or more MLIR documents.
930 class MLIRTextFile {
931 public:
932  MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
933  int64_t version, DialectRegistry &registry,
934  std::vector<lsp::Diagnostic> &diagnostics);
935 
936  /// Return the current version of this text file.
937  int64_t getVersion() const { return version; }
938 
939  //===--------------------------------------------------------------------===//
940  // LSP Queries
941  //===--------------------------------------------------------------------===//
942 
943  void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
944  std::vector<lsp::Location> &locations);
945  void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
946  std::vector<lsp::Location> &references);
947  Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
948  lsp::Position hoverPos);
949  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
950  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
951  lsp::Position completePos);
952  void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
953  const lsp::CodeActionContext &context,
954  std::vector<lsp::CodeAction> &actions);
956 
957 private:
958  /// Find the MLIR document that contains the given position, and update the
959  /// position to be anchored at the start of the found chunk instead of the
960  /// beginning of the file.
961  MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
962 
963  /// The context used to hold the state contained by the parsed document.
964  MLIRContext context;
965 
966  /// The full string contents of the file.
967  std::string contents;
968 
969  /// The version of this file.
970  int64_t version;
971 
972  /// The number of lines in the file.
973  int64_t totalNumLines = 0;
974 
975  /// The chunks of this file. The order of these chunks is the order in which
976  /// they appear in the text file.
977  std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
978 };
979 } // namespace
980 
981 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
982  int64_t version, DialectRegistry &registry,
983  std::vector<lsp::Diagnostic> &diagnostics)
984  : context(registry, MLIRContext::Threading::DISABLED),
985  contents(fileContents.str()), version(version) {
986  context.allowUnregisteredDialects();
987 
988  // Split the file into separate MLIR documents.
989  // TODO: Find a way to share the split file marker with other tools. We don't
990  // want to use `splitAndProcessBuffer` here, but we do want to make sure this
991  // marker doesn't go out of sync.
992  SmallVector<StringRef, 8> subContents;
993  StringRef(contents).split(subContents, "// -----");
994  chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
995  context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
996 
997  uint64_t lineOffset = subContents.front().count('\n');
998  for (StringRef docContents : llvm::drop_begin(subContents)) {
999  unsigned currentNumDiags = diagnostics.size();
1000  auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1001  docContents, diagnostics);
1002  lineOffset += docContents.count('\n');
1003 
1004  // Adjust locations used in diagnostics to account for the offset from the
1005  // beginning of the file.
1006  for (lsp::Diagnostic &diag :
1007  llvm::drop_begin(diagnostics, currentNumDiags)) {
1008  chunk->adjustLocForChunkOffset(diag.range);
1009 
1010  if (!diag.relatedInformation)
1011  continue;
1012  for (auto &it : *diag.relatedInformation)
1013  if (it.location.uri == uri)
1014  chunk->adjustLocForChunkOffset(it.location.range);
1015  }
1016  chunks.emplace_back(std::move(chunk));
1017  }
1018  totalNumLines = lineOffset;
1019 }
1020 
1021 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1022  lsp::Position defPos,
1023  std::vector<lsp::Location> &locations) {
1024  MLIRTextFileChunk &chunk = getChunkFor(defPos);
1025  chunk.document.getLocationsOf(uri, defPos, locations);
1026 
1027  // Adjust any locations within this file for the offset of this chunk.
1028  if (chunk.lineOffset == 0)
1029  return;
1030  for (lsp::Location &loc : locations)
1031  if (loc.uri == uri)
1032  chunk.adjustLocForChunkOffset(loc.range);
1033 }
1034 
1035 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1036  lsp::Position pos,
1037  std::vector<lsp::Location> &references) {
1038  MLIRTextFileChunk &chunk = getChunkFor(pos);
1039  chunk.document.findReferencesOf(uri, pos, references);
1040 
1041  // Adjust any locations within this file for the offset of this chunk.
1042  if (chunk.lineOffset == 0)
1043  return;
1044  for (lsp::Location &loc : references)
1045  if (loc.uri == uri)
1046  chunk.adjustLocForChunkOffset(loc.range);
1047 }
1048 
1049 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1050  lsp::Position hoverPos) {
1051  MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1052  Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1053 
1054  // Adjust any locations within this file for the offset of this chunk.
1055  if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1056  chunk.adjustLocForChunkOffset(*hoverInfo->range);
1057  return hoverInfo;
1058 }
1059 
1060 void MLIRTextFile::findDocumentSymbols(
1061  std::vector<lsp::DocumentSymbol> &symbols) {
1062  if (chunks.size() == 1)
1063  return chunks.front()->document.findDocumentSymbols(symbols);
1064 
1065  // If there are multiple chunks in this file, we create top-level symbols for
1066  // each chunk.
1067  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1068  MLIRTextFileChunk &chunk = *chunks[i];
1069  lsp::Position startPos(chunk.lineOffset);
1070  lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1071  : chunks[i + 1]->lineOffset);
1072  lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1073  lsp::SymbolKind::Namespace,
1074  /*range=*/lsp::Range(startPos, endPos),
1075  /*selectionRange=*/lsp::Range(startPos));
1076  chunk.document.findDocumentSymbols(symbol.children);
1077 
1078  // Fixup the locations of document symbols within this chunk.
1079  if (i != 0) {
1081  for (lsp::DocumentSymbol &childSymbol : symbol.children)
1082  symbolsToFix.push_back(&childSymbol);
1083 
1084  while (!symbolsToFix.empty()) {
1085  lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1086  chunk.adjustLocForChunkOffset(symbol->range);
1087  chunk.adjustLocForChunkOffset(symbol->selectionRange);
1088 
1089  for (lsp::DocumentSymbol &childSymbol : symbol->children)
1090  symbolsToFix.push_back(&childSymbol);
1091  }
1092  }
1093 
1094  // Push the symbol for this chunk.
1095  symbols.emplace_back(std::move(symbol));
1096  }
1097 }
1098 
1099 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1100  lsp::Position completePos) {
1101  MLIRTextFileChunk &chunk = getChunkFor(completePos);
1102  lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1103  uri, completePos, context.getDialectRegistry());
1104 
1105  // Adjust any completion locations.
1106  for (lsp::CompletionItem &item : completionList.items) {
1107  if (item.textEdit)
1108  chunk.adjustLocForChunkOffset(item.textEdit->range);
1109  for (lsp::TextEdit &edit : item.additionalTextEdits)
1110  chunk.adjustLocForChunkOffset(edit.range);
1111  }
1112  return completionList;
1113 }
1114 
1115 void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1116  const lsp::Range &pos,
1117  const lsp::CodeActionContext &context,
1118  std::vector<lsp::CodeAction> &actions) {
1119  // Create actions for any diagnostics in this file.
1120  for (auto &diag : context.diagnostics) {
1121  if (diag.source != "mlir")
1122  continue;
1123  lsp::Position diagPos = diag.range.start;
1124  MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1125 
1126  // Add a new code action that inserts a "expected" diagnostic check.
1127  lsp::CodeAction action;
1128  action.title = "Add expected-* diagnostic checks";
1129  action.kind = lsp::CodeAction::kQuickFix.str();
1130 
1131  StringRef severity;
1132  switch (diag.severity) {
1134  severity = "error";
1135  break;
1136  case lsp::DiagnosticSeverity::Warning:
1137  severity = "warning";
1138  break;
1139  default:
1140  continue;
1141  }
1142 
1143  // Get edits for the diagnostic.
1144  std::vector<lsp::TextEdit> edits;
1145  chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1146  diag.message, edits);
1147 
1148  // Walk the related diagnostics, this is how we encode notes.
1149  if (diag.relatedInformation) {
1150  for (auto &noteDiag : *diag.relatedInformation) {
1151  if (noteDiag.location.uri != uri)
1152  continue;
1153  diagPos = noteDiag.location.range.start;
1154  diagPos.line -= chunk.lineOffset;
1155  chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1156  noteDiag.message, edits);
1157  }
1158  }
1159  // Fixup the locations for any edits.
1160  for (lsp::TextEdit &edit : edits)
1161  chunk.adjustLocForChunkOffset(edit.range);
1162 
1163  action.edit.emplace();
1164  action.edit->changes[uri.uri().str()] = std::move(edits);
1165  action.diagnostics = {diag};
1166 
1167  actions.emplace_back(std::move(action));
1168  }
1169 }
1170 
1172 MLIRTextFile::convertToBytecode() {
1173  // Bail out if there is more than one chunk, bytecode wants a single module.
1174  if (chunks.size() != 1) {
1175  return llvm::make_error<lsp::LSPError>(
1176  "unexpected split file, please remove all `// -----`",
1177  lsp::ErrorCode::RequestFailed);
1178  }
1179  return chunks.front()->document.convertToBytecode();
1180 }
1181 
1182 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1183  if (chunks.size() == 1)
1184  return *chunks.front();
1185 
1186  // Search for the first chunk with a greater line offset, the previous chunk
1187  // is the one that contains `pos`.
1188  auto it = llvm::upper_bound(
1189  chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1190  return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1191  });
1192  MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1193  pos.line -= chunk.lineOffset;
1194  return chunk;
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // MLIRServer::Impl
1199 //===----------------------------------------------------------------------===//
1200 
1203 
1204  /// The registry containing dialects that can be recognized in parsed .mlir
1205  /// files.
1207 
1208  /// The files held by the server, mapped by their URI file name.
1209  llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1210 };
1211 
1212 //===----------------------------------------------------------------------===//
1213 // MLIRServer
1214 //===----------------------------------------------------------------------===//
1215 
1216 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1217  : impl(std::make_unique<Impl>(registry)) {}
1218 lsp::MLIRServer::~MLIRServer() = default;
1219 
1221  const URIForFile &uri, StringRef contents, int64_t version,
1222  std::vector<Diagnostic> &diagnostics) {
1223  impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1224  uri, contents, version, impl->registry, diagnostics);
1225 }
1226 
1228  auto it = impl->files.find(uri.file());
1229  if (it == impl->files.end())
1230  return llvm::None;
1231 
1232  int64_t version = it->second->getVersion();
1233  impl->files.erase(it);
1234  return version;
1235 }
1236 
1238  const Position &defPos,
1239  std::vector<Location> &locations) {
1240  auto fileIt = impl->files.find(uri.file());
1241  if (fileIt != impl->files.end())
1242  fileIt->second->getLocationsOf(uri, defPos, locations);
1243 }
1244 
1246  const Position &pos,
1247  std::vector<Location> &references) {
1248  auto fileIt = impl->files.find(uri.file());
1249  if (fileIt != impl->files.end())
1250  fileIt->second->findReferencesOf(uri, pos, references);
1251 }
1252 
1254  const Position &hoverPos) {
1255  auto fileIt = impl->files.find(uri.file());
1256  if (fileIt != impl->files.end())
1257  return fileIt->second->findHover(uri, hoverPos);
1258  return llvm::None;
1259 }
1260 
1262  const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1263  auto fileIt = impl->files.find(uri.file());
1264  if (fileIt != impl->files.end())
1265  fileIt->second->findDocumentSymbols(symbols);
1266 }
1267 
1270  const Position &completePos) {
1271  auto fileIt = impl->files.find(uri.file());
1272  if (fileIt != impl->files.end())
1273  return fileIt->second->getCodeCompletion(uri, completePos);
1274  return CompletionList();
1275 }
1276 
1278  const CodeActionContext &context,
1279  std::vector<CodeAction> &actions) {
1280  auto fileIt = impl->files.find(uri.file());
1281  if (fileIt != impl->files.end())
1282  fileIt->second->getCodeActions(uri, pos, context, actions);
1283 }
1284 
1287  MLIRContext tempContext(impl->registry);
1288  tempContext.allowUnregisteredDialects();
1289 
1290  // Collect any errors during parsing.
1291  std::string errorMsg;
1292  ScopedDiagnosticHandler diagHandler(
1293  &tempContext,
1294  [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1295 
1296  // Handling for external resources, which we want to propagate up to the user.
1297  FallbackAsmResourceMap fallbackResourceMap;
1298 
1299  // Setup the parser config.
1300  ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1301  &fallbackResourceMap);
1302 
1303  // Try to parse the given source file.
1304  Block parsedBlock;
1305  if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1306  return llvm::make_error<lsp::LSPError>(
1307  "failed to parse bytecode source file: " + errorMsg,
1309  }
1310 
1311  // TODO: We currently expect a single top-level operation, but this could
1312  // conceptually be relaxed.
1313  if (!llvm::hasSingleElement(parsedBlock)) {
1314  return llvm::make_error<lsp::LSPError>(
1315  "expected bytecode to contain a single top-level operation",
1317  }
1318 
1319  // Print the module to a buffer.
1321  {
1322  // Extract the top-level op so that aliases get printed.
1323  // FIXME: We should be able to enable aliases without having to do this!
1324  OwningOpRef<Operation *> topOp = &parsedBlock.front();
1325  (*topOp)->remove();
1326 
1327  AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1328  /*locationMap=*/nullptr, &fallbackResourceMap);
1329 
1330  llvm::raw_string_ostream os(result.output);
1331  (*topOp)->print(os, state);
1332  }
1333  return std::move(result);
1334 }
1335 
1338  auto fileIt = impl->files.find(uri.file());
1339  if (fileIt == impl->files.end()) {
1340  return llvm::make_error<lsp::LSPError>(
1341  "language server does not contain an entry for this source file",
1343  }
1344  return fileIt->second->convertToBytecode();
1345 }
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static std::string diag(llvm::Value &value)
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
Definition: MLIRServer.cpp:80
static Optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or None if it refers to ...
Definition: MLIRServer.cpp:130
static Optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
Definition: MLIRServer.cpp:27
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
Definition: MLIRServer.cpp:165
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
Definition: MLIRServer.cpp:108
static Optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
Definition: MLIRServer.cpp:157
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
Definition: MLIRServer.cpp:100
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
Definition: MLIRServer.cpp:188
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
Definition: MLIRServer.cpp:171
@ Error
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:24
This class represents state from a parsed MLIR textual format string.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:524
This class represents an argument of a Block.
Definition: Value.h:296
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:308
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:231
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
SuccessorRange getSuccessors()
Definition: Block.h:253
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:223
Operation & front()
Definition: Block.h:142
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:228
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class contains the configuration used for the bytecode writer.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition: AsmState.h:409
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested under, and including, the current.
Definition: Location.cpp:40
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
U dyn_cast() const
Definition: Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:341
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:151
void print(raw_ostream &os, const OpPrintingFlags &flags=llvm::None)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
void remove()
Remove the operation from its parent block, but don't delete it.
Definition: Operation.cpp:426
result_range getResults()
Definition: Operation.h:332
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:28
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:457
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Definition: Diagnostics.h:515
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
static void error(const char *fmt, Ts &&...vals)
Definition: Logging.h:42
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
Optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
Optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or None if one couldn't be found.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
URI in "file" scheme for a file.
Definition: Protocol.h:100
StringRef uri() const
Returns the original uri of the file.
Definition: Protocol.h:116
static llvm::Expected< URIForFile > fromFile(StringRef absoluteFilepath, StringRef scheme="file")
Try to build a URIForFile from the given absolute file path and optional scheme.
Definition: Protocol.cpp:232
StringRef scheme() const
Return the scheme of the uri.
Definition: Protocol.cpp:241
StringRef file() const
Returns the absolute path to the file.
Definition: Protocol.h:113
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
CompletionItemKind
The kind of a completion entry.
Definition: Protocol.h:739
Include the generated interface declarations.
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:2647
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:20
void writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Impl(DialectRegistry &registry)
DialectRegistry & registry
The registry containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
This class represents the information for an operation definition within an input file.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
std::vector< Diagnostic > diagnostics
An array of diagnostics known on the client side overlapping the range provided to the textDocument/c...
Definition: Protocol.h:1122
A code action represents a change that can be performed in code, e.g.
Definition: Protocol.h:1180
Optional< std::string > kind
The kind of the code action.
Definition: Protocol.h:1186
Optional< std::vector< Diagnostic > > diagnostics
The diagnostics that this code action resolves.
Definition: Protocol.h:1192
Optional< WorkspaceEdit > edit
The workspace edit this code action performs.
Definition: Protocol.h:1202
std::string title
A short, human-readable, title for this code action.
Definition: Protocol.h:1182
Optional< TextEdit > textEdit
An edit which is applied to a document when selecting this completion.
Definition: Protocol.h:847
std::vector< TextEdit > additionalTextEdits
An optional array of additional text edits that are applied when selecting this completion.
Definition: Protocol.h:852
Represents a collection of completion items to be presented in the editor.
Definition: Protocol.h:868
std::vector< CompletionItem > items
The completion items.
Definition: Protocol.h:874
std::string source
A human-readable string describing the source of this diagnostic, e.g.
Definition: Protocol.h:671
Optional< std::string > category
The diagnostic's category.
Definition: Protocol.h:684
DiagnosticSeverity severity
The diagnostic's severity.
Definition: Protocol.h:667
Optional< std::vector< DiagnosticRelatedInformation > > relatedInformation
An array of related diagnostic information, e.g.
Definition: Protocol.h:678
Range range
The source range where the message applies.
Definition: Protocol.h:663
std::string message
The diagnostic's message.
Definition: Protocol.h:674
Represents programming constructs like variables, classes, interfaces etc.
Definition: Protocol.h:577
Range range
The range enclosing this symbol not including leading/trailing whitespace but everything else like co...
Definition: Protocol.h:598
Range selectionRange
The range that should be selected and revealed when this symbol is being picked, e....
Definition: Protocol.h:602
std::vector< DocumentSymbol > children
Children of this symbol, e.g. properties of a class.
Definition: Protocol.h:605
URIForFile uri
The text document's URI.
Definition: Protocol.h:375
This class represents the result of converting between MLIR's bytecode and textual format.
Definition: Protocol.h:48
std::string output
The resultant output of the conversion.
Definition: Protocol.h:50
int line
Line position in a document (zero-based).
Definition: Protocol.h:274
int character
Character offset on a line in a document (zero-based).
Definition: Protocol.h:277
SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const
Convert this position into a source location in the main file of the given source manager.
Definition: Protocol.h:297
Position end
The range's end position.
Definition: Protocol.h:326
Position start
The range's start position.
Definition: Protocol.h:323
std::string newText
The string to be inserted.
Definition: Protocol.h:722
Range range
The range of the text document to be manipulated.
Definition: Protocol.h:718