MLIR 22.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1//===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "MLIRServer.h"
10#include "Protocol.h"
15#include "mlir/IR/Operation.h"
17#include "mlir/Parser/Parser.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Base64.h"
22#include "llvm/Support/LSP/Logging.h"
23#include "llvm/Support/SourceMgr.h"
24#include <optional>
25
26using namespace mlir;
27
28/// Returns the range of a lexical token given a SMLoc corresponding to the
29/// start of an token location. The range is computed heuristically, and
30/// supports identifier-like tokens, strings, etc.
31static SMRange convertTokenLocToRange(SMLoc loc) {
32 return lsp::convertTokenLocToRange(loc, "$-.");
33}
34
35/// Returns a language server location from the given MLIR file location.
36/// `uriScheme` is the scheme to use when building new uris.
37static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38 FileLineColLoc loc) {
40 lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
41 if (!sourceURI) {
42 llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43 loc.getFilename(),
44 llvm::toString(sourceURI.takeError()));
45 return std::nullopt;
46 }
47
48 lsp::Position position;
49 position.line = loc.getLine() - 1;
50 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51 return lsp::Location{*sourceURI, lsp::Range(position)};
52}
53
54/// Returns a language server location from the given MLIR location, or
55/// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56/// when building new uris. `uri` is an optional additional filter that, when
57/// present, is used to filter sub locations that do not share the same uri.
58static std::optional<lsp::Location>
59getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60 StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
61 std::optional<lsp::Location> location;
62 loc->walk([&](Location nestedLoc) {
63 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64 if (!fileLoc)
65 return WalkResult::advance();
66
67 std::optional<lsp::Location> sourceLoc =
68 getLocationFromLoc(uriScheme, fileLoc);
69 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70 location = *sourceLoc;
71 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
73
74 // Use range of potential identifier starting at location, else length 1
75 // range.
76 location->range.end.character += 1;
77 if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78 auto lineCol = sourceMgr.getLineAndColumn(range->End);
79 location->range.end.character =
80 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
81 }
82 return WalkResult::interrupt();
83 }
84 return WalkResult::advance();
85 });
86 return location;
87}
88
89/// Collect all of the locations from the given MLIR location that are not
90/// contained within the given URI.
92 std::vector<lsp::Location> &locations,
93 const lsp::URIForFile &uri) {
94 SetVector<Location> visitedLocs;
95 loc->walk([&](Location nestedLoc) {
96 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97 if (!fileLoc || !visitedLocs.insert(nestedLoc))
98 return WalkResult::advance();
99
100 std::optional<lsp::Location> sourceLoc =
101 getLocationFromLoc(uri.scheme(), fileLoc);
102 if (sourceLoc && sourceLoc->uri != uri)
103 locations.push_back(*sourceLoc);
104 return WalkResult::advance();
105 });
106}
107
108/// Returns true if the given range contains the given source location. Note
109/// that this has slightly different behavior than SMRange because it is
110/// inclusive of the end location.
111static bool contains(SMRange range, SMLoc loc) {
112 return range.Start.getPointer() <= loc.getPointer() &&
113 loc.getPointer() <= range.End.getPointer();
114}
115
116/// Returns true if the given location is contained by the definition or one of
117/// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118/// the range within `def` that the provided `loc` overlapped with.
119static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
120 SMRange *overlappedRange = nullptr) {
121 // Check the main definition.
122 if (contains(def.loc, loc)) {
123 if (overlappedRange)
124 *overlappedRange = def.loc;
125 return true;
126 }
127
128 // Check the uses.
129 const auto *useIt = llvm::find_if(
130 def.uses, [&](const SMRange &range) { return contains(range, loc); });
131 if (useIt != def.uses.end()) {
132 if (overlappedRange)
133 *overlappedRange = *useIt;
134 return true;
135 }
136 return false;
137}
138
139/// Given a location pointing to a result, return the result number it refers
140/// to or std::nullopt if it refers to all of the results.
141static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
142 // Skip all of the identifier characters.
143 auto isIdentifierChar = [](char c) {
144 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
145 c == '-';
146 };
147 const char *curPtr = loc.getPointer();
148 while (isIdentifierChar(*curPtr))
149 ++curPtr;
150
151 // Check to see if this location indexes into the result group, via `#`. If it
152 // doesn't, we can't extract a sub result number.
153 if (*curPtr != '#')
154 return std::nullopt;
155
156 // Compute the sub result number from the remaining portion of the string.
157 const char *numberStart = ++curPtr;
158 while (llvm::isDigit(*curPtr))
159 ++curPtr;
160 StringRef numberStr(numberStart, curPtr - numberStart);
161 unsigned resultNumber = 0;
162 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
163 : resultNumber;
164}
165
166/// Given a source location range, return the text covered by the given range.
167/// If the range is invalid, returns std::nullopt.
168static std::optional<StringRef> getTextFromRange(SMRange range) {
169 if (!range.isValid())
170 return std::nullopt;
171 const char *startPtr = range.Start.getPointer();
172 return StringRef(startPtr, range.End.getPointer() - startPtr);
173}
174
175/// Given a block, return its position in its parent region.
176static unsigned getBlockNumber(Block *block) {
177 return std::distance(block->getParent()->begin(), block->getIterator());
178}
179
180/// Given a block and source location, print the source name of the block to the
181/// given output stream.
182static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
183 // Try to extract a name from the source location.
184 std::optional<StringRef> text = getTextFromRange(loc);
185 if (text && text->starts_with("^")) {
186 os << *text;
187 return;
188 }
189
190 // Otherwise, we don't have a name so print the block number.
191 os << "<Block #" << getBlockNumber(block) << ">";
192}
196}
197
198/// Convert the given MLIR diagnostic to the LSP form.
199static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
201 const lsp::URIForFile &uri) {
202 lsp::Diagnostic lspDiag;
203 lspDiag.source = "mlir";
204
205 // Note: Right now all of the diagnostics are treated as parser issues, but
206 // some are parser and some are verifier.
207 lspDiag.category = "Parse Error";
208
209 // Try to grab a file location for this diagnostic.
210 // TODO: For simplicity, we just grab the first one. It may be likely that we
211 // will need a more interesting heuristic here.'
212 StringRef uriScheme = uri.scheme();
213 std::optional<lsp::Location> lspLocation =
214 getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
215 if (lspLocation)
216 lspDiag.range = lspLocation->range;
217
218 // Convert the severity for the diagnostic.
219 switch (diag.getSeverity()) {
221 llvm_unreachable("expected notes to be handled separately");
223 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
224 break;
226 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
227 break;
229 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
230 break;
231 }
232 lspDiag.message = diag.str();
233
234 // Attach any notes to the main diagnostic as related information.
235 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
236 for (Diagnostic &note : diag.getNotes()) {
237 lsp::Location noteLoc;
238 if (std::optional<lsp::Location> loc =
239 getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
240 noteLoc = *loc;
241 else
242 noteLoc.uri = uri;
243 relatedDiags.emplace_back(noteLoc, note.str());
244 }
245 if (!relatedDiags.empty())
246 lspDiag.relatedInformation = std::move(relatedDiags);
247
248 return lspDiag;
249}
250
251//===----------------------------------------------------------------------===//
252// MLIRDocument
253//===----------------------------------------------------------------------===//
254
255namespace {
256/// This class represents all of the information pertaining to a specific MLIR
257/// document.
258struct MLIRDocument {
259 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
260 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261 MLIRDocument(const MLIRDocument &) = delete;
262 MLIRDocument &operator=(const MLIRDocument &) = delete;
263
264 //===--------------------------------------------------------------------===//
265 // Definitions and References
266 //===--------------------------------------------------------------------===//
267
268 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
269 std::vector<lsp::Location> &locations);
270 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
271 std::vector<lsp::Location> &references);
272
273 //===--------------------------------------------------------------------===//
274 // Hover
275 //===--------------------------------------------------------------------===//
276
277 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
278 const lsp::Position &hoverPos);
279 std::optional<lsp::Hover>
280 buildHoverForOperation(SMRange hoverRange,
281 const AsmParserState::OperationDefinition &op);
282 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
283 unsigned resultStart,
284 unsigned resultEnd, SMLoc posLoc);
285 lsp::Hover buildHoverForBlock(SMRange hoverRange,
286 const AsmParserState::BlockDefinition &block);
287 lsp::Hover
288 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
289 const AsmParserState::BlockDefinition &block);
290
291 lsp::Hover buildHoverForAttributeAlias(
292 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
293 lsp::Hover
294 buildHoverForTypeAlias(SMRange hoverRange,
295 const AsmParserState::TypeAliasDefinition &type);
296
297 //===--------------------------------------------------------------------===//
298 // Document Symbols
299 //===--------------------------------------------------------------------===//
300
301 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302 void findDocumentSymbols(Operation *op,
303 std::vector<lsp::DocumentSymbol> &symbols);
304
305 //===--------------------------------------------------------------------===//
306 // Code Completion
307 //===--------------------------------------------------------------------===//
308
309 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
310 const lsp::Position &completePos,
311 const DialectRegistry &registry);
312
313 //===--------------------------------------------------------------------===//
314 // Code Action
315 //===--------------------------------------------------------------------===//
316
317 void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
318 lsp::Position &pos, StringRef severity,
319 StringRef message,
320 std::vector<llvm::lsp::TextEdit> &edits);
321
322 //===--------------------------------------------------------------------===//
323 // Bytecode
324 //===--------------------------------------------------------------------===//
325
326 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
327
328 //===--------------------------------------------------------------------===//
329 // Fields
330 //===--------------------------------------------------------------------===//
331
332 /// The high level parser state used to find definitions and references within
333 /// the source file.
334 AsmParserState asmState;
335
336 /// The container for the IR parsed from the input file.
337 Block parsedIR;
338
339 /// A collection of external resources, which we want to propagate up to the
340 /// user.
341 FallbackAsmResourceMap fallbackResourceMap;
342
343 /// The source manager containing the contents of the input file.
344 llvm::SourceMgr sourceMgr;
345};
346} // namespace
347
348MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
349 StringRef contents,
350 std::vector<lsp::Diagnostic> &diagnostics) {
351 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
352 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
353 });
354
355 // Try to parsed the given IR string.
356 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
357 if (!memBuffer) {
358 llvm::lsp::Logger::error("Failed to create memory buffer for file",
359 uri.file());
360 return;
361 }
362
363 ParserConfig config(&context, /*verifyAfterParse=*/true,
364 &fallbackResourceMap);
365 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
366 if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
367 // If parsing failed, clear out any of the current state.
368 parsedIR.clear();
369 asmState = AsmParserState();
370 fallbackResourceMap = FallbackAsmResourceMap();
371 return;
372 }
373}
374
375//===----------------------------------------------------------------------===//
376// MLIRDocument: Definitions and References
377//===----------------------------------------------------------------------===//
378
379void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
380 const lsp::Position &defPos,
381 std::vector<lsp::Location> &locations) {
382 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
383
384 // Functor used to check if an SM definition contains the position.
385 auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
386 if (!isDefOrUse(def, posLoc))
387 return false;
388 locations.emplace_back(uri, sourceMgr, def.loc);
389 return true;
390 };
391
392 // Check all definitions related to operations.
393 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
394 if (contains(op.loc, posLoc))
395 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
396 for (const auto &result : op.resultGroups)
397 if (containsPosition(result.definition))
398 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
399 for (const auto &symUse : op.symbolUses) {
400 if (contains(symUse, posLoc)) {
401 locations.emplace_back(uri, sourceMgr, op.loc);
402 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
403 }
404 }
405 }
406
407 // Check all definitions related to blocks.
408 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
409 if (containsPosition(block.definition))
410 return;
411 for (const AsmParserState::SMDefinition &arg : block.arguments)
412 if (containsPosition(arg))
413 return;
414 }
415
416 // Check all alias definitions.
417 for (const AsmParserState::AttributeAliasDefinition &attr :
418 asmState.getAttributeAliasDefs()) {
419 if (containsPosition(attr.definition))
420 return;
421 }
422 for (const AsmParserState::TypeAliasDefinition &type :
423 asmState.getTypeAliasDefs()) {
424 if (containsPosition(type.definition))
425 return;
426 }
427}
428
429void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
430 const lsp::Position &pos,
431 std::vector<lsp::Location> &references) {
432 // Functor used to append all of the definitions/uses of the given SM
433 // definition to the reference list.
434 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
435 references.emplace_back(uri, sourceMgr, def.loc);
436 for (const SMRange &use : def.uses)
437 references.emplace_back(uri, sourceMgr, use);
438 };
439
440 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
441
442 // Check all definitions related to operations.
443 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
444 if (contains(op.loc, posLoc)) {
445 for (const auto &result : op.resultGroups)
446 appendSMDef(result.definition);
447 for (const auto &symUse : op.symbolUses)
448 if (contains(symUse, posLoc))
449 references.emplace_back(uri, sourceMgr, symUse);
450 return;
451 }
452 for (const auto &result : op.resultGroups)
453 if (isDefOrUse(result.definition, posLoc))
454 return appendSMDef(result.definition);
455 for (const auto &symUse : op.symbolUses) {
456 if (!contains(symUse, posLoc))
457 continue;
458 for (const auto &symUse : op.symbolUses)
459 references.emplace_back(uri, sourceMgr, symUse);
460 return;
461 }
462 }
463
464 // Check all definitions related to blocks.
465 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
466 if (isDefOrUse(block.definition, posLoc))
467 return appendSMDef(block.definition);
468
469 for (const AsmParserState::SMDefinition &arg : block.arguments)
470 if (isDefOrUse(arg, posLoc))
471 return appendSMDef(arg);
472 }
473
474 // Check all alias definitions.
475 for (const AsmParserState::AttributeAliasDefinition &attr :
476 asmState.getAttributeAliasDefs()) {
477 if (isDefOrUse(attr.definition, posLoc))
478 return appendSMDef(attr.definition);
479 }
480 for (const AsmParserState::TypeAliasDefinition &type :
481 asmState.getTypeAliasDefs()) {
482 if (isDefOrUse(type.definition, posLoc))
483 return appendSMDef(type.definition);
484 }
485}
486
487//===----------------------------------------------------------------------===//
488// MLIRDocument: Hover
489//===----------------------------------------------------------------------===//
490
491std::optional<lsp::Hover>
492MLIRDocument::findHover(const lsp::URIForFile &uri,
493 const lsp::Position &hoverPos) {
494 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
495 SMRange hoverRange;
496
497 // Check for Hovers on operations and results.
498 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
499 // Check if the position points at this operation.
500 if (contains(op.loc, posLoc))
501 return buildHoverForOperation(op.loc, op);
502
503 // Check if the position points at the symbol name.
504 for (auto &use : op.symbolUses)
505 if (contains(use, posLoc))
506 return buildHoverForOperation(use, op);
507
508 // Check if the position points at a result group.
509 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
510 const auto &result = op.resultGroups[i];
511 if (!isDefOrUse(result.definition, posLoc, &hoverRange))
512 continue;
513
514 // Get the range of results covered by the over position.
515 unsigned resultStart = result.startIndex;
516 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
517 : op.resultGroups[i + 1].startIndex;
518 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
519 resultEnd, posLoc);
520 }
521 }
522
523 // Check to see if the hover is over a block argument.
524 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
525 if (isDefOrUse(block.definition, posLoc, &hoverRange))
526 return buildHoverForBlock(hoverRange, block);
527
528 for (const auto &arg : llvm::enumerate(block.arguments)) {
529 if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
530 continue;
531
532 return buildHoverForBlockArgument(
533 hoverRange, block.block->getArgument(arg.index()), block);
534 }
535 }
536
537 // Check to see if the hover is over an alias.
538 for (const AsmParserState::AttributeAliasDefinition &attr :
539 asmState.getAttributeAliasDefs()) {
540 if (isDefOrUse(attr.definition, posLoc, &hoverRange))
541 return buildHoverForAttributeAlias(hoverRange, attr);
542 }
543 for (const AsmParserState::TypeAliasDefinition &type :
544 asmState.getTypeAliasDefs()) {
545 if (isDefOrUse(type.definition, posLoc, &hoverRange))
546 return buildHoverForTypeAlias(hoverRange, type);
547 }
548
549 return std::nullopt;
550}
551
552std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
553 SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
554 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
555 llvm::raw_string_ostream os(hover.contents.value);
556
557 // Add the operation name to the hover.
558 os << "\"" << op.op->getName() << "\"";
559 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
560 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
561 os << "\n\n";
562
563 os << "Generic Form:\n\n```mlir\n";
564
565 op.op->print(os, OpPrintingFlags()
566 .printGenericOpForm()
567 .elideLargeElementsAttrs()
568 .skipRegions());
569 os << "\n```\n";
570
571 return hover;
572}
573
574lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
575 Operation *op,
576 unsigned resultStart,
577 unsigned resultEnd,
578 SMLoc posLoc) {
579 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
580 llvm::raw_string_ostream os(hover.contents.value);
581
582 // Add the parent operation name to the hover.
583 os << "Operation: \"" << op->getName() << "\"\n\n";
584
585 // Check to see if the location points to a specific result within the
586 // group.
587 if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
588 if ((resultStart + *resultNumber) < resultEnd) {
589 resultStart += *resultNumber;
590 resultEnd = resultStart + 1;
591 }
592 }
593
594 // Add the range of results and their types to the hover info.
595 if ((resultStart + 1) == resultEnd) {
596 os << "Result #" << resultStart << "\n\n"
597 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
598 } else {
599 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
600 << "Types: ";
601 llvm::interleaveComma(
602 op->getResults().slice(resultStart, resultEnd), os,
603 [&](Value result) { os << "`" << result.getType() << "`"; });
604 }
605
606 return hover;
607}
608
609lsp::Hover
610MLIRDocument::buildHoverForBlock(SMRange hoverRange,
611 const AsmParserState::BlockDefinition &block) {
612 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
613 llvm::raw_string_ostream os(hover.contents.value);
614
615 // Print the given block to the hover output stream.
616 auto printBlockToHover = [&](Block *newBlock) {
617 if (const auto *def = asmState.getBlockDef(newBlock))
618 printDefBlockName(os, *def);
619 else
620 printDefBlockName(os, newBlock);
621 };
622
623 // Display the parent operation, block number, predecessors, and successors.
624 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
625 << "Block #" << getBlockNumber(block.block) << "\n\n";
626 if (!block.block->hasNoPredecessors()) {
627 os << "Predecessors: ";
628 llvm::interleaveComma(block.block->getPredecessors(), os,
629 printBlockToHover);
630 os << "\n\n";
631 }
632 if (!block.block->hasNoSuccessors()) {
633 os << "Successors: ";
634 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
635 os << "\n\n";
636 }
637
638 return hover;
639}
640
641lsp::Hover MLIRDocument::buildHoverForBlockArgument(
642 SMRange hoverRange, BlockArgument arg,
643 const AsmParserState::BlockDefinition &block) {
644 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
645 llvm::raw_string_ostream os(hover.contents.value);
646
647 // Display the parent operation, block, the argument number, and the type.
648 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
649 << "Block: ";
650 printDefBlockName(os, block);
651 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
652 << "Type: `" << arg.getType() << "`\n\n";
653
654 return hover;
655}
656
657lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
658 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
659 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
660 llvm::raw_string_ostream os(hover.contents.value);
661
662 os << "Attribute Alias: \"" << attr.name << "\n\n";
663 os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
664
665 return hover;
666}
667
668lsp::Hover MLIRDocument::buildHoverForTypeAlias(
669 SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
670 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
671 llvm::raw_string_ostream os(hover.contents.value);
672
673 os << "Type Alias: \"" << type.name << "\n\n";
674 os << "Value: ```mlir\n" << type.value << "\n```\n\n";
675
676 return hover;
677}
678
679//===----------------------------------------------------------------------===//
680// MLIRDocument: Document Symbols
681//===----------------------------------------------------------------------===//
682
683void MLIRDocument::findDocumentSymbols(
684 std::vector<lsp::DocumentSymbol> &symbols) {
685 for (Operation &op : parsedIR)
686 findDocumentSymbols(&op, symbols);
687}
688
689void MLIRDocument::findDocumentSymbols(
690 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
691 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
692
693 // Check for the source information of this operation.
694 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
695 // If this operation defines a symbol, record it.
696 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
697 symbols.emplace_back(symbol.getName(),
698 isa<FunctionOpInterface>(op)
699 ? llvm::lsp::SymbolKind::Function
700 : llvm::lsp::SymbolKind::Class,
701 lsp::Range(sourceMgr, def->scopeLoc),
702 lsp::Range(sourceMgr, def->loc));
703 childSymbols = &symbols.back().children;
704
705 } else if (op->hasTrait<OpTrait::SymbolTable>()) {
706 // Otherwise, if this is a symbol table push an anonymous document symbol.
707 symbols.emplace_back("<" + op->getName().getStringRef() + ">",
708 llvm::lsp::SymbolKind::Namespace,
709 llvm::lsp::Range(sourceMgr, def->scopeLoc),
710 llvm::lsp::Range(sourceMgr, def->loc));
711 childSymbols = &symbols.back().children;
712 }
713 }
714
715 // Recurse into the regions of this operation.
716 if (!op->getNumRegions())
717 return;
718 for (Region &region : op->getRegions())
719 for (Operation &childOp : region.getOps())
720 findDocumentSymbols(&childOp, *childSymbols);
721}
722
723//===----------------------------------------------------------------------===//
724// MLIRDocument: Code Completion
725//===----------------------------------------------------------------------===//
726
727namespace {
728class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
729public:
730 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
731 MLIRContext *ctx)
732 : AsmParserCodeCompleteContext(completeLoc),
733 completionList(completionList), ctx(ctx) {}
734
735 /// Signal code completion for a dialect name, with an optional prefix.
736 void completeDialectName(StringRef prefix) final {
737 for (StringRef dialect : ctx->getAvailableDialects()) {
738 llvm::lsp::CompletionItem item(prefix + dialect,
739 llvm::lsp::CompletionItemKind::Module,
740 /*sortText=*/"3");
741 item.detail = "dialect";
742 completionList.items.emplace_back(item);
743 }
744 }
746
747 /// Signal code completion for an operation name within the given dialect.
748 void completeOperationName(StringRef dialectName) final {
749 Dialect *dialect = ctx->getOrLoadDialect(dialectName);
750 if (!dialect)
751 return;
752
753 for (const auto &op : ctx->getRegisteredOperations()) {
754 if (&op.getDialect() != dialect)
755 continue;
756
757 llvm::lsp::CompletionItem item(
758 op.getStringRef().drop_front(dialectName.size() + 1),
759 llvm::lsp::CompletionItemKind::Field,
760 /*sortText=*/"1");
761 item.detail = "operation";
762 completionList.items.emplace_back(item);
763 }
764 }
765
766 /// Append the given SSA value as a code completion result for SSA value
767 /// completions.
768 void appendSSAValueCompletion(StringRef name, std::string typeData) final {
769 // Check if we need to insert the `%` or not.
770 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
771
772 llvm::lsp::CompletionItem item(name,
773 llvm::lsp::CompletionItemKind::Variable);
774 if (stripPrefix)
775 item.insertText = name.drop_front(1).str();
776 item.detail = std::move(typeData);
777 completionList.items.emplace_back(item);
778 }
779
780 /// Append the given block as a code completion result for block name
781 /// completions.
782 void appendBlockCompletion(StringRef name) final {
783 // Check if we need to insert the `^` or not.
784 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
785
786 llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
787 if (stripPrefix)
788 item.insertText = name.drop_front(1).str();
789 completionList.items.emplace_back(item);
790 }
791
792 /// Signal a completion for the given expected token.
793 void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
794 for (StringRef token : tokens) {
795 llvm::lsp::CompletionItem item(token,
796 llvm::lsp::CompletionItemKind::Keyword,
797 /*sortText=*/"0");
798 item.detail = optional ? "optional" : "";
799 completionList.items.emplace_back(item);
800 }
801 }
802
803 /// Signal a completion for an attribute.
804 void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
805 appendSimpleCompletions({"affine_set", "affine_map", "dense",
806 "dense_resource", "false", "loc", "sparse", "true",
807 "unit"},
808 llvm::lsp::CompletionItemKind::Field,
809 /*sortText=*/"1");
810
811 completeDialectName("#");
812 completeAliases(aliases, "#");
813 }
814 void completeDialectAttributeOrAlias(
815 const llvm::StringMap<Attribute> &aliases) override {
816 completeDialectName();
817 completeAliases(aliases);
818 }
819
820 /// Signal a completion for a type.
821 void completeType(const llvm::StringMap<Type> &aliases) override {
822 // Handle the various builtin types.
823 appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
824 "bf16", "f16", "f32", "f64", "f80", "f128",
825 "index", "none"},
826 llvm::lsp::CompletionItemKind::Field,
827 /*sortText=*/"1");
828
829 // Handle the builtin integer types.
830 for (StringRef type : {"i", "si", "ui"}) {
831 llvm::lsp::CompletionItem item(type + "<N>",
832 llvm::lsp::CompletionItemKind::Field,
833 /*sortText=*/"1");
834 item.insertText = type.str();
835 completionList.items.emplace_back(item);
836 }
837
838 // Insert completions for dialect types and aliases.
839 completeDialectName("!");
840 completeAliases(aliases, "!");
841 }
842 void
843 completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
844 completeDialectName();
845 completeAliases(aliases);
846 }
847
848 /// Add completion results for the given set of aliases.
849 template <typename T>
850 void completeAliases(const llvm::StringMap<T> &aliases,
851 StringRef prefix = "") {
852 for (const auto &alias : aliases) {
853 llvm::lsp::CompletionItem item(prefix + alias.getKey(),
854 llvm::lsp::CompletionItemKind::Field,
855 /*sortText=*/"2");
856 llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
857 completionList.items.emplace_back(item);
858 }
859 }
860
861 /// Add a set of simple completions that all have the same kind.
862 void appendSimpleCompletions(ArrayRef<StringRef> completions,
863 llvm::lsp::CompletionItemKind kind,
864 StringRef sortText = "") {
865 for (StringRef completion : completions)
866 completionList.items.emplace_back(completion, kind, sortText);
867 }
868
869private:
870 lsp::CompletionList &completionList;
871 MLIRContext *ctx;
872};
873} // namespace
874
875lsp::CompletionList
876MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
877 const lsp::Position &completePos,
878 const DialectRegistry &registry) {
879 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
880 if (!posLoc.isValid())
881 return lsp::CompletionList();
882
883 // To perform code completion, we run another parse of the module with the
884 // code completion context provided.
885 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
886 tmpContext.allowUnregisteredDialects();
887 lsp::CompletionList completionList;
888 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
889 &tmpContext);
890
891 Block tmpIR;
892 AsmParserState tmpState;
893 (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
894 &lspCompleteContext);
895 return completionList;
896}
897
898//===----------------------------------------------------------------------===//
899// MLIRDocument: Code Action
900//===----------------------------------------------------------------------===//
901
902void MLIRDocument::getCodeActionForDiagnostic(
903 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
904 StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
905 // Ignore diagnostics that print the current operation. These are always
906 // enabled for the language server, but not generally during normal
907 // parsing/verification.
908 if (message.starts_with("see current operation: "))
909 return;
910
911 // Get the start of the line containing the diagnostic.
912 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
913 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
914 if (!lineStart)
915 return;
916 StringRef line(lineStart, pos.character);
917
918 // Add a text edit for adding an expected-* diagnostic check for this
919 // diagnostic.
920 llvm::lsp::TextEdit edit;
921 edit.range = lsp::Range(lsp::Position(pos.line, 0));
922
923 // Use the indent of the current line for the expected-* diagnostic.
924 size_t indent = line.find_first_not_of(' ');
925 if (indent == StringRef::npos)
926 indent = line.size();
927
928 edit.newText.append(indent, ' ');
929 llvm::raw_string_ostream(edit.newText)
930 << "// expected-" << severity << " @below {{" << message << "}}\n";
931 edits.emplace_back(std::move(edit));
932}
933
934//===----------------------------------------------------------------------===//
935// MLIRDocument: Bytecode
936//===----------------------------------------------------------------------===//
937
938llvm::Expected<lsp::MLIRConvertBytecodeResult>
939MLIRDocument::convertToBytecode() {
940 // TODO: We currently require a single top-level operation, but this could
941 // conceptually be relaxed.
942 if (!llvm::hasSingleElement(parsedIR)) {
943 if (parsedIR.empty()) {
944 return llvm::make_error<llvm::lsp::LSPError>(
945 "expected a single and valid top-level operation, please ensure "
946 "there are no errors",
947 llvm::lsp::ErrorCode::RequestFailed);
948 }
949 return llvm::make_error<llvm::lsp::LSPError>(
950 "expected a single top-level operation",
951 llvm::lsp::ErrorCode::RequestFailed);
952 }
953
954 lsp::MLIRConvertBytecodeResult result;
955 {
956 BytecodeWriterConfig writerConfig(fallbackResourceMap);
957
958 std::string rawBytecodeBuffer;
959 llvm::raw_string_ostream os(rawBytecodeBuffer);
960 // No desired bytecode version set, so no need to check for error.
961 (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
962 result.output = llvm::encodeBase64(rawBytecodeBuffer);
963 }
964 return result;
965}
966
967//===----------------------------------------------------------------------===//
968// MLIRTextFileChunk
969//===----------------------------------------------------------------------===//
970
971namespace {
972/// This class represents a single chunk of an MLIR text file.
973struct MLIRTextFileChunk {
974 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
975 const lsp::URIForFile &uri, StringRef contents,
976 std::vector<lsp::Diagnostic> &diagnostics)
977 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
978
979 /// Adjust the line number of the given range to anchor at the beginning of
980 /// the file, instead of the beginning of this chunk.
981 void adjustLocForChunkOffset(lsp::Range &range) {
982 adjustLocForChunkOffset(range.start);
983 adjustLocForChunkOffset(range.end);
984 }
985 /// Adjust the line number of the given position to anchor at the beginning of
986 /// the file, instead of the beginning of this chunk.
987 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
988
989 /// The line offset of this chunk from the beginning of the file.
990 uint64_t lineOffset;
991 /// The document referred to by this chunk.
992 MLIRDocument document;
993};
994} // namespace
995
996//===----------------------------------------------------------------------===//
997// MLIRTextFile
998//===----------------------------------------------------------------------===//
999
1000namespace {
1001/// This class represents a text file containing one or more MLIR documents.
1002class MLIRTextFile {
1003public:
1004 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1005 int64_t version, lsp::DialectRegistryFn registry_fn,
1006 std::vector<lsp::Diagnostic> &diagnostics);
1007
1008 /// Return the current version of this text file.
1009 int64_t getVersion() const { return version; }
1010
1011 //===--------------------------------------------------------------------===//
1012 // LSP Queries
1013 //===--------------------------------------------------------------------===//
1014
1015 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1016 std::vector<lsp::Location> &locations);
1017 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1018 std::vector<lsp::Location> &references);
1019 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1020 lsp::Position hoverPos);
1021 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1022 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1023 lsp::Position completePos);
1024 void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1025 const lsp::CodeActionContext &context,
1026 std::vector<lsp::CodeAction> &actions);
1027 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1028
1029private:
1030 /// Find the MLIR document that contains the given position, and update the
1031 /// position to be anchored at the start of the found chunk instead of the
1032 /// beginning of the file.
1033 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1034
1035 /// The context used to hold the state contained by the parsed document.
1036 MLIRContext context;
1037
1038 /// The full string contents of the file.
1039 std::string contents;
1040
1041 /// The version of this file.
1042 int64_t version;
1043
1044 /// The number of lines in the file.
1045 int64_t totalNumLines = 0;
1046
1047 /// The chunks of this file. The order of these chunks is the order in which
1048 /// they appear in the text file.
1049 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1050};
1051} // namespace
1052
1053MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1054 int64_t version, lsp::DialectRegistryFn registry_fn,
1055 std::vector<lsp::Diagnostic> &diagnostics)
1056 : context(registry_fn(uri), MLIRContext::Threading::DISABLED),
1057 contents(fileContents.str()), version(version) {
1058 context.allowUnregisteredDialects();
1059
1060 // Split the file into separate MLIR documents.
1061 SmallVector<StringRef, 8> subContents;
1062 StringRef(contents).split(subContents, kDefaultSplitMarker);
1063 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1064 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
1065
1066 uint64_t lineOffset = subContents.front().count('\n');
1067 for (StringRef docContents : llvm::drop_begin(subContents)) {
1068 unsigned currentNumDiags = diagnostics.size();
1069 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1070 docContents, diagnostics);
1071 lineOffset += docContents.count('\n');
1072
1073 // Adjust locations used in diagnostics to account for the offset from the
1074 // beginning of the file.
1075 for (lsp::Diagnostic &diag :
1076 llvm::drop_begin(diagnostics, currentNumDiags)) {
1077 chunk->adjustLocForChunkOffset(diag.range);
1078
1079 if (!diag.relatedInformation)
1080 continue;
1081 for (auto &it : *diag.relatedInformation)
1082 if (it.location.uri == uri)
1083 chunk->adjustLocForChunkOffset(it.location.range);
1084 }
1085 chunks.emplace_back(std::move(chunk));
1086 }
1087 totalNumLines = lineOffset;
1088}
1089
1090void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1091 lsp::Position defPos,
1092 std::vector<lsp::Location> &locations) {
1093 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1094 chunk.document.getLocationsOf(uri, defPos, locations);
1095
1096 // Adjust any locations within this file for the offset of this chunk.
1097 if (chunk.lineOffset == 0)
1098 return;
1099 for (lsp::Location &loc : locations)
1100 if (loc.uri == uri)
1101 chunk.adjustLocForChunkOffset(loc.range);
1102}
1103
1104void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1105 lsp::Position pos,
1106 std::vector<lsp::Location> &references) {
1107 MLIRTextFileChunk &chunk = getChunkFor(pos);
1108 chunk.document.findReferencesOf(uri, pos, references);
1109
1110 // Adjust any locations within this file for the offset of this chunk.
1111 if (chunk.lineOffset == 0)
1112 return;
1113 for (lsp::Location &loc : references)
1114 if (loc.uri == uri)
1115 chunk.adjustLocForChunkOffset(loc.range);
1116}
1117
1118std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1119 lsp::Position hoverPos) {
1120 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1121 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1122
1123 // Adjust any locations within this file for the offset of this chunk.
1124 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1125 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1126 return hoverInfo;
1127}
1128
1129void MLIRTextFile::findDocumentSymbols(
1130 std::vector<lsp::DocumentSymbol> &symbols) {
1131 if (chunks.size() == 1)
1132 return chunks.front()->document.findDocumentSymbols(symbols);
1133
1134 // If there are multiple chunks in this file, we create top-level symbols for
1135 // each chunk.
1136 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1137 MLIRTextFileChunk &chunk = *chunks[i];
1138 lsp::Position startPos(chunk.lineOffset);
1139 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1140 : chunks[i + 1]->lineOffset);
1141 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1142 llvm::lsp::SymbolKind::Namespace,
1143 /*range=*/lsp::Range(startPos, endPos),
1144 /*selectionRange=*/lsp::Range(startPos));
1145 chunk.document.findDocumentSymbols(symbol.children);
1146
1147 // Fixup the locations of document symbols within this chunk.
1148 if (i != 0) {
1149 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1150 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1151 symbolsToFix.push_back(&childSymbol);
1152
1153 while (!symbolsToFix.empty()) {
1154 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1155 chunk.adjustLocForChunkOffset(symbol->range);
1156 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1157
1158 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1159 symbolsToFix.push_back(&childSymbol);
1160 }
1161 }
1162
1163 // Push the symbol for this chunk.
1164 symbols.emplace_back(std::move(symbol));
1165 }
1166}
1167
1168lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1169 lsp::Position completePos) {
1170 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1171 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1172 uri, completePos, context.getDialectRegistry());
1173
1174 // Adjust any completion locations.
1175 for (llvm::lsp::CompletionItem &item : completionList.items) {
1176 if (item.textEdit)
1177 chunk.adjustLocForChunkOffset(item.textEdit->range);
1178 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1179 chunk.adjustLocForChunkOffset(edit.range);
1180 }
1181 return completionList;
1182}
1183
1184void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1185 const lsp::Range &pos,
1186 const lsp::CodeActionContext &context,
1187 std::vector<lsp::CodeAction> &actions) {
1188 // Create actions for any diagnostics in this file.
1189 for (auto &diag : context.diagnostics) {
1190 if (diag.source != "mlir")
1191 continue;
1192 lsp::Position diagPos = diag.range.start;
1193 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1194
1195 // Add a new code action that inserts a "expected" diagnostic check.
1196 lsp::CodeAction action;
1197 action.title = "Add expected-* diagnostic checks";
1198 action.kind = lsp::CodeAction::kQuickFix.str();
1199
1200 StringRef severity;
1201 switch (diag.severity) {
1202 case llvm::lsp::DiagnosticSeverity::Error:
1203 severity = "error";
1204 break;
1205 case llvm::lsp::DiagnosticSeverity::Warning:
1206 severity = "warning";
1207 break;
1208 default:
1209 continue;
1210 }
1211
1212 // Get edits for the diagnostic.
1213 std::vector<llvm::lsp::TextEdit> edits;
1214 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1215 diag.message, edits);
1216
1217 // Walk the related diagnostics, this is how we encode notes.
1218 if (diag.relatedInformation) {
1219 for (auto &noteDiag : *diag.relatedInformation) {
1220 if (noteDiag.location.uri != uri)
1221 continue;
1222 diagPos = noteDiag.location.range.start;
1223 diagPos.line -= chunk.lineOffset;
1224 chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1225 noteDiag.message, edits);
1226 }
1227 }
1228 // Fixup the locations for any edits.
1229 for (llvm::lsp::TextEdit &edit : edits)
1230 chunk.adjustLocForChunkOffset(edit.range);
1231
1232 action.edit.emplace();
1233 action.edit->changes[uri.uri().str()] = std::move(edits);
1234 action.diagnostics = {diag};
1235
1236 actions.emplace_back(std::move(action));
1237 }
1238}
1239
1240llvm::Expected<lsp::MLIRConvertBytecodeResult>
1241MLIRTextFile::convertToBytecode() {
1242 // Bail out if there is more than one chunk, bytecode wants a single module.
1243 if (chunks.size() != 1) {
1244 return llvm::make_error<llvm::lsp::LSPError>(
1245 "unexpected split file, please remove all `// -----`",
1246 llvm::lsp::ErrorCode::RequestFailed);
1247 }
1248 return chunks.front()->document.convertToBytecode();
1249}
1250
1251MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1252 if (chunks.size() == 1)
1253 return *chunks.front();
1254
1255 // Search for the first chunk with a greater line offset, the previous chunk
1256 // is the one that contains `pos`.
1257 auto it = llvm::upper_bound(
1258 chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1259 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1260 });
1261 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1262 pos.line -= chunk.lineOffset;
1263 return chunk;
1264}
1265
1266//===----------------------------------------------------------------------===//
1267// MLIRServer::Impl
1268//===----------------------------------------------------------------------===//
1269
1272
1273 /// The registry factory for containing dialects that can be recognized in
1274 /// parsed .mlir files.
1276
1277 /// The files held by the server, mapped by their URI file name.
1278 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1279};
1280
1281//===----------------------------------------------------------------------===//
1282// MLIRServer
1283//===----------------------------------------------------------------------===//
1284
1286 : impl(std::make_unique<Impl>(registry_fn)) {}
1288
1290 const URIForFile &uri, StringRef contents, int64_t version,
1291 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1292 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1293 uri, contents, version, impl->registry_fn, diagnostics);
1294}
1295
1296std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1297 auto it = impl->files.find(uri.file());
1298 if (it == impl->files.end())
1299 return std::nullopt;
1300
1301 int64_t version = it->second->getVersion();
1302 impl->files.erase(it);
1303 return version;
1304}
1305
1307 const URIForFile &uri, const Position &defPos,
1308 std::vector<llvm::lsp::Location> &locations) {
1309 auto fileIt = impl->files.find(uri.file());
1310 if (fileIt != impl->files.end())
1311 fileIt->second->getLocationsOf(uri, defPos, locations);
1312}
1313
1315 const URIForFile &uri, const Position &pos,
1316 std::vector<llvm::lsp::Location> &references) {
1317 auto fileIt = impl->files.find(uri.file());
1318 if (fileIt != impl->files.end())
1319 fileIt->second->findReferencesOf(uri, pos, references);
1320}
1321
1322std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1323 const Position &hoverPos) {
1324 auto fileIt = impl->files.find(uri.file());
1325 if (fileIt != impl->files.end())
1326 return fileIt->second->findHover(uri, hoverPos);
1327 return std::nullopt;
1328}
1329
1331 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1332 auto fileIt = impl->files.find(uri.file());
1333 if (fileIt != impl->files.end())
1334 fileIt->second->findDocumentSymbols(symbols);
1335}
1336
1337lsp::CompletionList
1339 const Position &completePos) {
1340 auto fileIt = impl->files.find(uri.file());
1341 if (fileIt != impl->files.end())
1342 return fileIt->second->getCodeCompletion(uri, completePos);
1343 return CompletionList();
1344}
1345
1346void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1347 const CodeActionContext &context,
1348 std::vector<CodeAction> &actions) {
1349 auto fileIt = impl->files.find(uri.file());
1350 if (fileIt != impl->files.end())
1351 fileIt->second->getCodeActions(uri, pos, context, actions);
1352}
1353
1356 MLIRContext tempContext(impl->registry_fn(uri));
1357 tempContext.allowUnregisteredDialects();
1358
1359 // Collect any errors during parsing.
1360 std::string errorMsg;
1361 ScopedDiagnosticHandler diagHandler(
1362 &tempContext,
1363 [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1364
1365 // Handling for external resources, which we want to propagate up to the user.
1366 FallbackAsmResourceMap fallbackResourceMap;
1367
1368 // Setup the parser config.
1369 ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1370 &fallbackResourceMap);
1371
1372 // Try to parse the given source file.
1373 Block parsedBlock;
1374 if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1375 return llvm::make_error<llvm::lsp::LSPError>(
1376 "failed to parse bytecode source file: " + errorMsg,
1377 llvm::lsp::ErrorCode::RequestFailed);
1378 }
1379
1380 // TODO: We currently expect a single top-level operation, but this could
1381 // conceptually be relaxed.
1382 if (!llvm::hasSingleElement(parsedBlock)) {
1383 return llvm::make_error<llvm::lsp::LSPError>(
1384 "expected bytecode to contain a single top-level operation",
1385 llvm::lsp::ErrorCode::RequestFailed);
1386 }
1387
1388 // Print the module to a buffer.
1390 {
1391 // Extract the top-level op so that aliases get printed.
1392 // FIXME: We should be able to enable aliases without having to do this!
1393 OwningOpRef<Operation *> topOp = &parsedBlock.front();
1394 topOp->remove();
1395
1396 AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1397 /*locationMap=*/nullptr, &fallbackResourceMap);
1398
1399 llvm::raw_string_ostream os(result.output);
1400 topOp->print(os, state);
1401 }
1402 return std::move(result);
1403}
1404
1407 auto fileIt = impl->files.find(uri.file());
1408 if (fileIt == impl->files.end()) {
1409 return llvm::make_error<llvm::lsp::LSPError>(
1410 "language server does not contain an entry for this source file",
1411 llvm::lsp::ErrorCode::RequestFailed);
1412 }
1413 return fileIt->second->convertToBytecode();
1414}
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
static std::string diag(const llvm::Value &value)
iterator_range< AttributeDefIterator > getAttributeAliasDefs() const
Return a range of the AttributeAliasDefinitions held by the current parser state.
iterator_range< BlockDefIterator > getBlockDefs() const
Return a range of the BlockDefinitions held by the current parser state.
const OperationDefinition * getOpDef(Operation *op) const
Return the definition for the given operation, or nullptr if the given operation does not have a defi...
const BlockDefinition * getBlockDef(Block *block) const
Return the definition for the given block, or nullptr if the given block does not have a definition.
iterator_range< OperationDefIterator > getOpDefs() const
Return a range of the OperationDefinitions held by the current parser state.
iterator_range< TypeDefIterator > getTypeAliasDefs() const
Return a range of the TypeAliasDefinitions held by the current parser state.
This class provides management for the lifetime of the state used when printing the IR.
Definition AsmState.h:542
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition Block.h:248
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:240
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
SuccessorRange getSuccessors()
Definition Block.h:270
void clear()
Definition Block.h:38
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition Block.h:245
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition AsmState.h:421
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
unsigned getLine() const
Definition Location.cpp:173
StringAttr getFilename() const
Definition Location.cpp:169
unsigned getColumn() const
Definition Location.cpp:175
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
Definition Location.cpp:124
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
std::vector< StringRef > getAvailableDialects()
Return information about all available dialects in the registry in this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_range getResults()
Definition Operation.h:415
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This class represents a configuration for the MLIR assembly parser.
Definition AsmState.h:469
iterator begin()
Definition Region.h:55
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
MLIRServer(DialectRegistryFn registry_fn)
Construct a new server with the given dialect registry function.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
llvm::function_ref< DialectRegistry &(const llvm::lsp::URIForFile &uri)> DialectRegistryFn
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition Parser.cpp:2918
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition Parser.cpp:38
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
Definition Protocol.h:48
Impl(lsp::DialectRegistryFn registry_fn)
lsp::DialectRegistryFn registry_fn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
StringRef name
The name of the attribute alias.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...