MLIR  14.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "lsp/Logging.h"
11 #include "lsp/Protocol.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/Parser.h"
15 #include "llvm/Support/SourceMgr.h"
16 
17 using namespace mlir;
18 
19 /// Returns a language server position for the given source location.
20 static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) {
21  std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
22  lsp::Position pos;
23  pos.line = lineAndCol.first - 1;
24  pos.character = lineAndCol.second - 1;
25  return pos;
26 }
27 
28 /// Returns a source location from the given language server position.
29 static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) {
30  return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1,
31  pos.character);
32 }
33 
34 /// Returns a language server range for the given source range.
35 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) {
36  return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)};
37 }
38 
39 /// Returns a language server location from the given source range.
40 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr,
41  llvm::SMRange range,
42  const lsp::URIForFile &uri) {
43  return lsp::Location{uri, getRangeFromLoc(mgr, range)};
44 }
45 
46 /// Returns a language server location from the given MLIR file location.
47 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
49  lsp::URIForFile::fromFile(loc.getFilename());
50  if (!sourceURI) {
51  lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
52  loc.getFilename(),
53  llvm::toString(sourceURI.takeError()));
54  return llvm::None;
55  }
56 
57  lsp::Position position;
58  position.line = loc.getLine() - 1;
59  position.character = loc.getColumn();
60  return lsp::Location{*sourceURI, lsp::Range(position)};
61 }
62 
63 /// Returns a language server location from the given MLIR location, or None if
64 /// one couldn't be created. `uri` is an optional additional filter that, when
65 /// present, is used to filter sub locations that do not share the same uri.
67 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
68  const lsp::URIForFile *uri = nullptr) {
69  Optional<lsp::Location> location;
70  loc->walk([&](Location nestedLoc) {
71  FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
72  if (!fileLoc)
73  return WalkResult::advance();
74 
75  Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
76  if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
77  location = *sourceLoc;
78  llvm::SMLoc loc = sourceMgr.FindLocForLineAndColumn(
79  sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
80 
81  // Use range of potential identifier starting at location, else length 1
82  // range.
83  location->range.end.character += 1;
84  if (Optional<llvm::SMRange> range =
86  auto lineCol = sourceMgr.getLineAndColumn(range->End);
87  location->range.end.character =
88  std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
89  }
90  return WalkResult::interrupt();
91  }
92  return WalkResult::advance();
93  });
94  return location;
95 }
96 
97 /// Collect all of the locations from the given MLIR location that are not
98 /// contained within the given URI.
100  std::vector<lsp::Location> &locations,
101  const lsp::URIForFile &uri) {
102  SetVector<Location> visitedLocs;
103  loc->walk([&](Location nestedLoc) {
104  FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
105  if (!fileLoc || !visitedLocs.insert(nestedLoc))
106  return WalkResult::advance();
107 
108  Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
109  if (sourceLoc && sourceLoc->uri != uri)
110  locations.push_back(*sourceLoc);
111  return WalkResult::advance();
112  });
113 }
114 
115 /// Returns true if the given range contains the given source location. Note
116 /// that this has slightly different behavior than SMRange because it is
117 /// inclusive of the end location.
118 static bool contains(llvm::SMRange range, llvm::SMLoc loc) {
119  return range.Start.getPointer() <= loc.getPointer() &&
120  loc.getPointer() <= range.End.getPointer();
121 }
122 
123 /// Returns true if the given location is contained by the definition or one of
124 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
125 /// the range within `def` that the provided `loc` overlapped with.
126 static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc,
127  llvm::SMRange *overlappedRange = nullptr) {
128  // Check the main definition.
129  if (contains(def.loc, loc)) {
130  if (overlappedRange)
131  *overlappedRange = def.loc;
132  return true;
133  }
134 
135  // Check the uses.
136  const auto *useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) {
137  return contains(range, loc);
138  });
139  if (useIt != def.uses.end()) {
140  if (overlappedRange)
141  *overlappedRange = *useIt;
142  return true;
143  }
144  return false;
145 }
146 
147 /// Given a location pointing to a result, return the result number it refers
148 /// to or None if it refers to all of the results.
150  // Skip all of the identifier characters.
151  auto isIdentifierChar = [](char c) {
152  return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
153  c == '-';
154  };
155  const char *curPtr = loc.getPointer();
156  while (isIdentifierChar(*curPtr))
157  ++curPtr;
158 
159  // Check to see if this location indexes into the result group, via `#`. If it
160  // doesn't, we can't extract a sub result number.
161  if (*curPtr != '#')
162  return llvm::None;
163 
164  // Compute the sub result number from the remaining portion of the string.
165  const char *numberStart = ++curPtr;
166  while (llvm::isDigit(*curPtr))
167  ++curPtr;
168  StringRef numberStr(numberStart, curPtr - numberStart);
169  unsigned resultNumber = 0;
170  return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
171  : resultNumber;
172 }
173 
174 /// Given a source location range, return the text covered by the given range.
175 /// If the range is invalid, returns None.
176 static Optional<StringRef> getTextFromRange(llvm::SMRange range) {
177  if (!range.isValid())
178  return None;
179  const char *startPtr = range.Start.getPointer();
180  return StringRef(startPtr, range.End.getPointer() - startPtr);
181 }
182 
183 /// Given a block, return its position in its parent region.
184 static unsigned getBlockNumber(Block *block) {
185  return std::distance(block->getParent()->begin(), block->getIterator());
186 }
187 
188 /// Given a block and source location, print the source name of the block to the
189 /// given output stream.
190 static void printDefBlockName(raw_ostream &os, Block *block,
191  llvm::SMRange loc = {}) {
192  // Try to extract a name from the source location.
194  if (text && text->startswith("^")) {
195  os << *text;
196  return;
197  }
198 
199  // Otherwise, we don't have a name so print the block number.
200  os << "<Block #" << getBlockNumber(block) << ">";
201 }
202 static void printDefBlockName(raw_ostream &os,
203  const AsmParserState::BlockDefinition &def) {
204  printDefBlockName(os, def.block, def.definition.loc);
205 }
206 
207 /// Convert the given MLIR diagnostic to the LSP form.
208 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
209  Diagnostic &diag,
210  const lsp::URIForFile &uri) {
211  lsp::Diagnostic lspDiag;
212  lspDiag.source = "mlir";
213 
214  // Note: Right now all of the diagnostics are treated as parser issues, but
215  // some are parser and some are verifier.
216  lspDiag.category = "Parse Error";
217 
218  // Try to grab a file location for this diagnostic.
219  // TODO: For simplicity, we just grab the first one. It may be likely that we
220  // will need a more interesting heuristic here.'
221  Optional<lsp::Location> lspLocation =
222  getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
223  if (lspLocation)
224  lspDiag.range = lspLocation->range;
225 
226  // Convert the severity for the diagnostic.
227  switch (diag.getSeverity()) {
229  llvm_unreachable("expected notes to be handled separately");
232  break;
235  break;
238  break;
239  }
240  lspDiag.message = diag.str();
241 
242  // Attach any notes to the main diagnostic as related information.
243  std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
244  for (Diagnostic &note : diag.getNotes()) {
245  lsp::Location noteLoc;
246  if (Optional<lsp::Location> loc =
247  getLocationFromLoc(sourceMgr, note.getLocation()))
248  noteLoc = *loc;
249  else
250  noteLoc.uri = uri;
251  relatedDiags.emplace_back(noteLoc, note.str());
252  }
253  if (!relatedDiags.empty())
254  lspDiag.relatedInformation = std::move(relatedDiags);
255 
256  return lspDiag;
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // MLIRDocument
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 /// This class represents all of the information pertaining to a specific MLIR
265 /// document.
266 struct MLIRDocument {
267  MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
268  StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
269  MLIRDocument(const MLIRDocument &) = delete;
270  MLIRDocument &operator=(const MLIRDocument &) = delete;
271 
272  //===--------------------------------------------------------------------===//
273  // Definitions and References
274  //===--------------------------------------------------------------------===//
275 
276  void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
277  std::vector<lsp::Location> &locations);
278  void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
279  std::vector<lsp::Location> &references);
280 
281  //===--------------------------------------------------------------------===//
282  // Hover
283  //===--------------------------------------------------------------------===//
284 
285  Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
286  const lsp::Position &hoverPos);
288  buildHoverForOperation(llvm::SMRange hoverRange,
290  lsp::Hover buildHoverForOperationResult(llvm::SMRange hoverRange,
291  Operation *op, unsigned resultStart,
292  unsigned resultEnd,
293  llvm::SMLoc posLoc);
294  lsp::Hover buildHoverForBlock(llvm::SMRange hoverRange,
295  const AsmParserState::BlockDefinition &block);
296  lsp::Hover
297  buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg,
298  const AsmParserState::BlockDefinition &block);
299 
300  //===--------------------------------------------------------------------===//
301  // Document Symbols
302  //===--------------------------------------------------------------------===//
303 
304  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
305  void findDocumentSymbols(Operation *op,
306  std::vector<lsp::DocumentSymbol> &symbols);
307 
308  //===--------------------------------------------------------------------===//
309  // Fields
310  //===--------------------------------------------------------------------===//
311 
312  /// The high level parser state used to find definitions and references within
313  /// the source file.
314  AsmParserState asmState;
315 
316  /// The container for the IR parsed from the input file.
317  Block parsedIR;
318 
319  /// The source manager containing the contents of the input file.
320  llvm::SourceMgr sourceMgr;
321 };
322 } // namespace
323 
324 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
325  StringRef contents,
326  std::vector<lsp::Diagnostic> &diagnostics) {
327  ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
328  diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
329  });
330 
331  // Try to parsed the given IR string.
332  auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
333  if (!memBuffer) {
334  lsp::Logger::error("Failed to create memory buffer for file", uri.file());
335  return;
336  }
337 
338  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc());
339  if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr,
340  &asmState))) {
341  // If parsing failed, clear out any of the current state.
342  parsedIR.clear();
343  asmState = AsmParserState();
344  return;
345  }
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // MLIRDocument: Definitions and References
350 //===----------------------------------------------------------------------===//
351 
352 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
353  const lsp::Position &defPos,
354  std::vector<lsp::Location> &locations) {
355  llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, defPos);
356 
357  // Functor used to check if an SM definition contains the position.
358  auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
359  if (!isDefOrUse(def, posLoc))
360  return false;
361  locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
362  return true;
363  };
364 
365  // Check all definitions related to operations.
366  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
367  if (contains(op.loc, posLoc))
368  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
369  for (const auto &result : op.resultGroups)
370  if (containsPosition(result.definition))
371  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
372  for (const auto &symUse : op.symbolUses) {
373  if (contains(symUse, posLoc)) {
374  locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri));
375  return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
376  }
377  }
378  }
379 
380  // Check all definitions related to blocks.
381  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
382  if (containsPosition(block.definition))
383  return;
384  for (const AsmParserState::SMDefinition &arg : block.arguments)
385  if (containsPosition(arg))
386  return;
387  }
388 }
389 
390 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
391  const lsp::Position &pos,
392  std::vector<lsp::Location> &references) {
393  // Functor used to append all of the definitions/uses of the given SM
394  // definition to the reference list.
395  auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
396  references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
397  for (const llvm::SMRange &use : def.uses)
398  references.push_back(getLocationFromLoc(sourceMgr, use, uri));
399  };
400 
401  llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, pos);
402 
403  // Check all definitions related to operations.
404  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
405  if (contains(op.loc, posLoc)) {
406  for (const auto &result : op.resultGroups)
407  appendSMDef(result.definition);
408  for (const auto &symUse : op.symbolUses)
409  if (contains(symUse, posLoc))
410  references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
411  return;
412  }
413  for (const auto &result : op.resultGroups)
414  if (isDefOrUse(result.definition, posLoc))
415  return appendSMDef(result.definition);
416  for (const auto &symUse : op.symbolUses) {
417  if (!contains(symUse, posLoc))
418  continue;
419  for (const auto &symUse : op.symbolUses)
420  references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
421  return;
422  }
423  }
424 
425  // Check all definitions related to blocks.
426  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
427  if (isDefOrUse(block.definition, posLoc))
428  return appendSMDef(block.definition);
429 
430  for (const AsmParserState::SMDefinition &arg : block.arguments)
431  if (isDefOrUse(arg, posLoc))
432  return appendSMDef(arg);
433  }
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // MLIRDocument: Hover
438 //===----------------------------------------------------------------------===//
439 
440 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
441  const lsp::Position &hoverPos) {
442  llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos);
443  llvm::SMRange hoverRange;
444 
445  // Check for Hovers on operations and results.
446  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
447  // Check if the position points at this operation.
448  if (contains(op.loc, posLoc))
449  return buildHoverForOperation(op.loc, op);
450 
451  // Check if the position points at the symbol name.
452  for (auto &use : op.symbolUses)
453  if (contains(use, posLoc))
454  return buildHoverForOperation(use, op);
455 
456  // Check if the position points at a result group.
457  for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
458  const auto &result = op.resultGroups[i];
459  if (!isDefOrUse(result.definition, posLoc, &hoverRange))
460  continue;
461 
462  // Get the range of results covered by the over position.
463  unsigned resultStart = result.startIndex;
464  unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
465  : op.resultGroups[i + 1].startIndex;
466  return buildHoverForOperationResult(hoverRange, op.op, resultStart,
467  resultEnd, posLoc);
468  }
469  }
470 
471  // Check to see if the hover is over a block argument.
472  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
473  if (isDefOrUse(block.definition, posLoc, &hoverRange))
474  return buildHoverForBlock(hoverRange, block);
475 
476  for (const auto &arg : llvm::enumerate(block.arguments)) {
477  if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
478  continue;
479 
480  return buildHoverForBlockArgument(
481  hoverRange, block.block->getArgument(arg.index()), block);
482  }
483  }
484  return llvm::None;
485 }
486 
487 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
488  llvm::SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
489  lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
490  llvm::raw_string_ostream os(hover.contents.value);
491 
492  // Add the operation name to the hover.
493  os << "\"" << op.op->getName() << "\"";
494  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
495  os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
496  os << "\n\n";
497 
498  os << "Generic Form:\n\n```mlir\n";
499 
500  // Temporary drop the regions of this operation so that they don't get
501  // printed in the output. This helps keeps the size of the output hover
502  // small.
504  for (Region &region : op.op->getRegions()) {
505  regions.emplace_back(std::make_unique<Region>());
506  regions.back()->takeBody(region);
507  }
508 
509  op.op->print(
510  os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
511  os << "\n```\n";
512 
513  // Move the regions back to the current operation.
514  for (Region &region : op.op->getRegions())
515  region.takeBody(*regions.back());
516 
517  return hover;
518 }
519 
520 lsp::Hover MLIRDocument::buildHoverForOperationResult(llvm::SMRange hoverRange,
521  Operation *op,
522  unsigned resultStart,
523  unsigned resultEnd,
524  llvm::SMLoc posLoc) {
525  lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
526  llvm::raw_string_ostream os(hover.contents.value);
527 
528  // Add the parent operation name to the hover.
529  os << "Operation: \"" << op->getName() << "\"\n\n";
530 
531  // Check to see if the location points to a specific result within the
532  // group.
533  if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
534  if ((resultStart + *resultNumber) < resultEnd) {
535  resultStart += *resultNumber;
536  resultEnd = resultStart + 1;
537  }
538  }
539 
540  // Add the range of results and their types to the hover info.
541  if ((resultStart + 1) == resultEnd) {
542  os << "Result #" << resultStart << "\n\n"
543  << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
544  } else {
545  os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
546  << "Types: ";
547  llvm::interleaveComma(
548  op->getResults().slice(resultStart, resultEnd), os,
549  [&](Value result) { os << "`" << result.getType() << "`"; });
550  }
551 
552  return hover;
553 }
554 
556 MLIRDocument::buildHoverForBlock(llvm::SMRange hoverRange,
557  const AsmParserState::BlockDefinition &block) {
558  lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
559  llvm::raw_string_ostream os(hover.contents.value);
560 
561  // Print the given block to the hover output stream.
562  auto printBlockToHover = [&](Block *newBlock) {
563  if (const auto *def = asmState.getBlockDef(newBlock))
564  printDefBlockName(os, *def);
565  else
566  printDefBlockName(os, newBlock);
567  };
568 
569  // Display the parent operation, block number, predecessors, and successors.
570  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
571  << "Block #" << getBlockNumber(block.block) << "\n\n";
572  if (!block.block->hasNoPredecessors()) {
573  os << "Predecessors: ";
574  llvm::interleaveComma(block.block->getPredecessors(), os,
575  printBlockToHover);
576  os << "\n\n";
577  }
578  if (!block.block->hasNoSuccessors()) {
579  os << "Successors: ";
580  llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
581  os << "\n\n";
582  }
583 
584  return hover;
585 }
586 
587 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
588  llvm::SMRange hoverRange, BlockArgument arg,
589  const AsmParserState::BlockDefinition &block) {
590  lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
591  llvm::raw_string_ostream os(hover.contents.value);
592 
593  // Display the parent operation, block, the argument number, and the type.
594  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
595  << "Block: ";
596  printDefBlockName(os, block);
597  os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
598  << "Type: `" << arg.getType() << "`\n\n";
599 
600  return hover;
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // MLIRDocument: Document Symbols
605 //===----------------------------------------------------------------------===//
606 
607 void MLIRDocument::findDocumentSymbols(
608  std::vector<lsp::DocumentSymbol> &symbols) {
609  for (Operation &op : parsedIR)
610  findDocumentSymbols(&op, symbols);
611 }
612 
613 void MLIRDocument::findDocumentSymbols(
614  Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
615  std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
616 
617  // Check for the source information of this operation.
618  if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
619  // If this operation defines a symbol, record it.
620  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
621  symbols.emplace_back(symbol.getName(),
622  isa<FunctionOpInterface>(op)
625  getRangeFromLoc(sourceMgr, def->scopeLoc),
626  getRangeFromLoc(sourceMgr, def->loc));
627  childSymbols = &symbols.back().children;
628 
629  } else if (op->hasTrait<OpTrait::SymbolTable>()) {
630  // Otherwise, if this is a symbol table push an anonymous document symbol.
631  symbols.emplace_back("<" + op->getName().getStringRef() + ">",
633  getRangeFromLoc(sourceMgr, def->scopeLoc),
634  getRangeFromLoc(sourceMgr, def->loc));
635  childSymbols = &symbols.back().children;
636  }
637  }
638 
639  // Recurse into the regions of this operation.
640  if (!op->getNumRegions())
641  return;
642  for (Region &region : op->getRegions())
643  for (Operation &childOp : region.getOps())
644  findDocumentSymbols(&childOp, *childSymbols);
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // MLIRTextFileChunk
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 /// This class represents a single chunk of an MLIR text file.
653 struct MLIRTextFileChunk {
654  MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
655  const lsp::URIForFile &uri, StringRef contents,
656  std::vector<lsp::Diagnostic> &diagnostics)
657  : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
658 
659  /// Adjust the line number of the given range to anchor at the beginning of
660  /// the file, instead of the beginning of this chunk.
661  void adjustLocForChunkOffset(lsp::Range &range) {
662  adjustLocForChunkOffset(range.start);
663  adjustLocForChunkOffset(range.end);
664  }
665  /// Adjust the line number of the given position to anchor at the beginning of
666  /// the file, instead of the beginning of this chunk.
667  void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
668 
669  /// The line offset of this chunk from the beginning of the file.
670  uint64_t lineOffset;
671  /// The document referred to by this chunk.
672  MLIRDocument document;
673 };
674 } // namespace
675 
676 //===----------------------------------------------------------------------===//
677 // MLIRTextFile
678 //===----------------------------------------------------------------------===//
679 
680 namespace {
681 /// This class represents a text file containing one or more MLIR documents.
682 class MLIRTextFile {
683 public:
684  MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
685  int64_t version, DialectRegistry &registry,
686  std::vector<lsp::Diagnostic> &diagnostics);
687 
688  /// Return the current version of this text file.
689  int64_t getVersion() const { return version; }
690 
691  //===--------------------------------------------------------------------===//
692  // LSP Queries
693  //===--------------------------------------------------------------------===//
694 
695  void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
696  std::vector<lsp::Location> &locations);
697  void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
698  std::vector<lsp::Location> &references);
699  Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
700  lsp::Position hoverPos);
701  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
702 
703 private:
704  /// Find the MLIR document that contains the given position, and update the
705  /// position to be anchored at the start of the found chunk instead of the
706  /// beginning of the file.
707  MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
708 
709  /// The context used to hold the state contained by the parsed document.
710  MLIRContext context;
711 
712  /// The full string contents of the file.
713  std::string contents;
714 
715  /// The version of this file.
716  int64_t version;
717 
718  /// The number of lines in the file.
719  int64_t totalNumLines;
720 
721  /// The chunks of this file. The order of these chunks is the order in which
722  /// they appear in the text file.
723  std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
724 };
725 } // namespace
726 
727 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
728  int64_t version, DialectRegistry &registry,
729  std::vector<lsp::Diagnostic> &diagnostics)
730  : context(registry, MLIRContext::Threading::DISABLED),
731  contents(fileContents.str()), version(version), totalNumLines(0) {
732  context.allowUnregisteredDialects();
733 
734  // Split the file into separate MLIR documents.
735  // TODO: Find a way to share the split file marker with other tools. We don't
736  // want to use `splitAndProcessBuffer` here, but we do want to make sure this
737  // marker doesn't go out of sync.
738  SmallVector<StringRef, 8> subContents;
739  StringRef(contents).split(subContents, "// -----");
740  chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
741  context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
742 
743  uint64_t lineOffset = subContents.front().count('\n');
744  for (StringRef docContents : llvm::drop_begin(subContents)) {
745  unsigned currentNumDiags = diagnostics.size();
746  auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
747  docContents, diagnostics);
748  lineOffset += docContents.count('\n');
749 
750  // Adjust locations used in diagnostics to account for the offset from the
751  // beginning of the file.
752  for (lsp::Diagnostic &diag :
753  llvm::drop_begin(diagnostics, currentNumDiags)) {
754  chunk->adjustLocForChunkOffset(diag.range);
755 
756  if (!diag.relatedInformation)
757  continue;
758  for (auto &it : *diag.relatedInformation)
759  if (it.location.uri == uri)
760  chunk->adjustLocForChunkOffset(it.location.range);
761  }
762  chunks.emplace_back(std::move(chunk));
763  }
764  totalNumLines = lineOffset;
765 }
766 
767 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
768  lsp::Position defPos,
769  std::vector<lsp::Location> &locations) {
770  MLIRTextFileChunk &chunk = getChunkFor(defPos);
771  chunk.document.getLocationsOf(uri, defPos, locations);
772 
773  // Adjust any locations within this file for the offset of this chunk.
774  if (chunk.lineOffset == 0)
775  return;
776  for (lsp::Location &loc : locations)
777  if (loc.uri == uri)
778  chunk.adjustLocForChunkOffset(loc.range);
779 }
780 
781 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
782  lsp::Position pos,
783  std::vector<lsp::Location> &references) {
784  MLIRTextFileChunk &chunk = getChunkFor(pos);
785  chunk.document.findReferencesOf(uri, pos, references);
786 
787  // Adjust any locations within this file for the offset of this chunk.
788  if (chunk.lineOffset == 0)
789  return;
790  for (lsp::Location &loc : references)
791  if (loc.uri == uri)
792  chunk.adjustLocForChunkOffset(loc.range);
793 }
794 
795 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
796  lsp::Position hoverPos) {
797  MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
798  Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
799 
800  // Adjust any locations within this file for the offset of this chunk.
801  if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
802  chunk.adjustLocForChunkOffset(*hoverInfo->range);
803  return hoverInfo;
804 }
805 
806 void MLIRTextFile::findDocumentSymbols(
807  std::vector<lsp::DocumentSymbol> &symbols) {
808  if (chunks.size() == 1)
809  return chunks.front()->document.findDocumentSymbols(symbols);
810 
811  // If there are multiple chunks in this file, we create top-level symbols for
812  // each chunk.
813  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
814  MLIRTextFileChunk &chunk = *chunks[i];
815  lsp::Position startPos(chunk.lineOffset);
816  lsp::Position endPos((i == e - 1) ? totalNumLines - 1
817  : chunks[i + 1]->lineOffset);
818  lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
820  /*range=*/lsp::Range(startPos, endPos),
821  /*selectionRange=*/lsp::Range(startPos));
822  chunk.document.findDocumentSymbols(symbol.children);
823 
824  // Fixup the locations of document symbols within this chunk.
825  if (i != 0) {
827  for (lsp::DocumentSymbol &childSymbol : symbol.children)
828  symbolsToFix.push_back(&childSymbol);
829 
830  while (!symbolsToFix.empty()) {
831  lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
832  chunk.adjustLocForChunkOffset(symbol->range);
833  chunk.adjustLocForChunkOffset(symbol->selectionRange);
834 
835  for (lsp::DocumentSymbol &childSymbol : symbol->children)
836  symbolsToFix.push_back(&childSymbol);
837  }
838  }
839 
840  // Push the symbol for this chunk.
841  symbols.emplace_back(std::move(symbol));
842  }
843 }
844 
845 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
846  if (chunks.size() == 1)
847  return *chunks.front();
848 
849  // Search for the first chunk with a greater line offset, the previous chunk
850  // is the one that contains `pos`.
851  auto it = llvm::upper_bound(
852  chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
853  return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
854  });
855  MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
856  pos.line -= chunk.lineOffset;
857  return chunk;
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // MLIRServer::Impl
862 //===----------------------------------------------------------------------===//
863 
865  Impl(DialectRegistry &registry) : registry(registry) {}
866 
867  /// The registry containing dialects that can be recognized in parsed .mlir
868  /// files.
870 
871  /// The files held by the server, mapped by their URI file name.
872  llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
873 };
874 
875 //===----------------------------------------------------------------------===//
876 // MLIRServer
877 //===----------------------------------------------------------------------===//
878 
880  : impl(std::make_unique<Impl>(registry)) {}
881 lsp::MLIRServer::~MLIRServer() = default;
882 
884  const URIForFile &uri, StringRef contents, int64_t version,
885  std::vector<Diagnostic> &diagnostics) {
886  impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
887  uri, contents, version, impl->registry, diagnostics);
888 }
889 
891  auto it = impl->files.find(uri.file());
892  if (it == impl->files.end())
893  return llvm::None;
894 
895  int64_t version = it->second->getVersion();
896  impl->files.erase(it);
897  return version;
898 }
899 
901  const Position &defPos,
902  std::vector<Location> &locations) {
903  auto fileIt = impl->files.find(uri.file());
904  if (fileIt != impl->files.end())
905  fileIt->second->getLocationsOf(uri, defPos, locations);
906 }
907 
909  const Position &pos,
910  std::vector<Location> &references) {
911  auto fileIt = impl->files.find(uri.file());
912  if (fileIt != impl->files.end())
913  fileIt->second->findReferencesOf(uri, pos, references);
914 }
915 
917  const Position &hoverPos) {
918  auto fileIt = impl->files.find(uri.file());
919  if (fileIt != impl->files.end())
920  return fileIt->second->findHover(uri, hoverPos);
921  return llvm::None;
922 }
923 
925  const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
926  auto fileIt = impl->files.find(uri.file());
927  if (fileIt != impl->files.end())
928  fileIt->second->findDocumentSymbols(symbols);
929 }
StringRef file() const
Returns the absolute path to the file.
Definition: Protocol.h:107
Include the generated interface declarations.
Documents should not be synced at all.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static std::string diag(llvm::Value &v)
Optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
Definition: MLIRServer.cpp:890
This class represents a definition within the source manager, containing it&#39;s defining location and l...
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
Definition: MLIRServer.cpp:900
This class represents state from a parsed MLIR textual format string.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:423
static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc, llvm::SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
Definition: MLIRServer.cpp:126
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range, const lsp::URIForFile &uri)
Returns a language server location from the given source range.
Definition: MLIRServer.cpp:40
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:420
llvm::SMRange loc
The source location of the definition.
Block represents an ordered list of Operations.
Definition: Block.h:29
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
Definition: MLIRServer.cpp:184
static void error(const char *fmt, Ts &&... vals)
Definition: Logging.h:39
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Definition: Diagnostics.h:489
static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc)
Returns a language server position for the given source location.
Definition: MLIRServer.cpp:20
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
Location getLocation() const
Returns the source location for this diagnostic.
Definition: Diagnostics.h:170
DiagnosticSeverity severity
The diagnostic&#39;s severity.
Definition: Protocol.h:601
Optional< std::string > category
The diagnostic&#39;s category.
Definition: Protocol.h:618
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:225
URI in "file" scheme for a file.
Definition: Protocol.h:96
Range range
The range enclosing this symbol not including leading/trailing whitespace but everything else like co...
Definition: Protocol.h:535
DiagnosticSeverity getSeverity() const
Returns the severity of this diagnostic.
Definition: Diagnostics.h:167
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:310
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
Definition: MLIRServer.cpp:99
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range)
Returns a language server range for the given source range.
Definition: MLIRServer.cpp:35
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
Definition: MLIRServer.cpp:208
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, MLIRContext *context, LocationAttr *sourceFileLoc=nullptr, AsmParserState *asmState=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition: Parser.cpp:2216
MLIRServer(DialectRegistry &registry)
Construct a new server with the given dialect regitstry.
Definition: MLIRServer.cpp:879
static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc)
Returns (heuristically) the range of an identifier given a SMLoc corresponding to the start of an ide...
static Optional< unsigned > getResultNumberFromLoc(llvm::SMLoc loc)
Given a location pointing to a result, return the result number it refers to or None if it refers to ...
Definition: MLIRServer.cpp:149
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::string source
A human-readable string describing the source of this diagnostic, e.g.
Definition: Protocol.h:605
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
iterator begin()
Definition: Region.h:55
SmallVector< llvm::SMRange > uses
The source location of all uses of the definition.
int character
Character offset on a line in a document (zero-based).
Definition: Protocol.h:252
Block * block
The block representing this definition.
Optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or None if one couldn&#39;t be found...
Definition: MLIRServer.cpp:916
static bool contains(llvm::SMRange range, llvm::SMLoc loc)
Returns true if the given range contains the given source location.
Definition: MLIRServer.cpp:118
int line
Line position in a document (zero-based).
Definition: Protocol.h:249
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
static WalkResult advance()
Definition: Visitors.h:51
Impl(DialectRegistry &registry)
Definition: MLIRServer.cpp:865
auto getType() const
Position end
The range&#39;s end position.
Definition: Protocol.h:290
static WalkResult interrupt()
Definition: Visitors.h:50
This class represents an argument of a Block.
Definition: Value.h:298
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
Definition: MLIRServer.cpp:883
MarkupContent contents
The hover&#39;s content.
Definition: Protocol.h:463
This class represents the information for an operation definition within an input file...
std::vector< DocumentSymbol > children
Children of this symbol, e.g. properties of a class.
Definition: Protocol.h:542
std::string message
The diagnostic&#39;s message.
Definition: Protocol.h:608
void print(raw_ostream &os, const OpPrintingFlags &flags=llvm::None)
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
Definition: MLIRServer.cpp:924
This class represents the information for a block definition within the input file.
DialectRegistry & registry
The registry containing dialects that can be recognized in parsed .mlir files.
Definition: MLIRServer.cpp:869
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Position start
The range&#39;s start position.
Definition: Protocol.h:287
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
Definition: MLIRServer.cpp:872
static void printDefBlockName(raw_ostream &os, Block *block, llvm::SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream...
Definition: MLIRServer.cpp:190
Optional< std::vector< DiagnosticRelatedInformation > > relatedInformation
An array of related diagnostic information, e.g.
Definition: Protocol.h:612
URIForFile uri
The text document&#39;s URI.
Definition: Protocol.h:320
U dyn_cast() const
Definition: Location.h:66
Set of flags used to control the behavior of the various IR print methods (e.g.
iterator_range< note_iterator > getNotes()
Returns the notes held by this diagnostic.
Definition: Diagnostics.h:258
Type getType() const
Return the type of this value.
Definition: Value.h:117
Range range
The source range where the message applies.
Definition: Protocol.h:597
SMDefinition definition
The source location for the block, i.e.
Range selectionRange
The range that should be selected and revealed when this symbol is being picked, e.g the name of a function.
Definition: Protocol.h:539
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Definition: Dialect.h:282
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition: Block.h:233
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition: Block.h:230
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
StringRef toString(IteratorType t)
Operation * op
The operation representing this definition.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
Definition: MLIRServer.cpp:908
static llvm::Expected< URIForFile > fromFile(StringRef absoluteFilepath)
Try to build a URIForFile from the given absolute file path.
Definition: Protocol.cpp:219
SuccessorRange getSuccessors()
Definition: Block.h:255
Represents programming constructs like variables, classes, interfaces etc.
Definition: Protocol.h:514
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
result_range getResults()
Definition: Operation.h:284
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested under, and including, the current.
Definition: Location.cpp:40
std::string str() const
Converts the diagnostic to a string.
static Optional< StringRef > getTextFromRange(llvm::SMRange range)
Given a source location range, return the text covered by the given range.
Definition: MLIRServer.cpp:176
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)