MLIR 23.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1//===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "MLIRServer.h"
10#include "Protocol.h"
15#include "mlir/IR/Operation.h"
17#include "mlir/Parser/Parser.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Base64.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/LSP/Logging.h"
24#include "llvm/Support/Path.h"
25#include "llvm/Support/SourceMgr.h"
26#include <optional>
27
28using namespace mlir;
29
30/// Returns the range of a lexical token given a SMLoc corresponding to the
31/// start of an token location. The range is computed heuristically, and
32/// supports identifier-like tokens, strings, etc.
33static SMRange convertTokenLocToRange(SMLoc loc) {
34 return lsp::convertTokenLocToRange(loc, "$-.");
35}
36
37/// Returns a language server location from the given MLIR file location.
38/// `uriScheme` is the scheme to use when building new uris.
39static std::optional<lsp::Location>
40getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc,
41 StringRef workspaceRoot) {
42 StringRef filename = loc.getFilename();
43 SmallString<128> absPath;
44 // Always make the path absolute. Skip paths that start with a separator:
45 // prevents incorrect resolution of virtual paths used in tests on Windows.
46 if (!llvm::sys::path::is_absolute(filename) && !filename.starts_with("/") &&
47 !filename.starts_with("\\")) {
48 if (!workspaceRoot.empty())
49 llvm::sys::path::append(absPath, workspaceRoot, filename);
50 else
51 absPath = filename;
52 llvm::sys::fs::make_absolute(absPath);
53 filename = absPath;
54 }
55
57 lsp::URIForFile::fromFile(filename, uriScheme);
58 if (!sourceURI) {
59 llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
60 filename, llvm::toString(sourceURI.takeError()));
61 return std::nullopt;
62 }
63
64 lsp::Position position;
65 position.line = loc.getLine() - 1;
66 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
67 return lsp::Location{*sourceURI, lsp::Range(position)};
68}
69
70/// Returns a language server location from the given MLIR location, or
71/// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
72/// when building new uris. `uri` is an optional additional filter that, when
73/// present, is used to filter sub locations that do not share the same uri.
74static std::optional<lsp::Location>
75getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
76 StringRef uriScheme, StringRef workspaceRoot,
77 const lsp::URIForFile *uri = nullptr) {
78 std::optional<lsp::Location> location;
79 loc->walk([&](Location nestedLoc) {
80 auto fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
81 if (!fileLoc)
82 return WalkResult::advance();
83
84 std::optional<lsp::Location> sourceLoc =
85 getLocationFromLoc(uriScheme, fileLoc, workspaceRoot);
86 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
87 location = *sourceLoc;
88 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
89 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
90
91 // Use range of potential identifier starting at location, else length 1
92 // range.
93 location->range.end.character += 1;
94 if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
95 auto lineCol = sourceMgr.getLineAndColumn(range->End);
96 location->range.end.character =
97 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
98 }
99 return WalkResult::interrupt();
100 }
101 return WalkResult::advance();
102 });
103 return location;
104}
105
106/// Collect all of the locations from the given MLIR location that are not
107/// contained within the given URI.
109 std::vector<lsp::Location> &locations,
110 const lsp::URIForFile &uri,
111 StringRef workspaceRoot) {
112 SetVector<Location> visitedLocs;
113 loc->walk([&](Location nestedLoc) {
114 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
115 if (!fileLoc || !visitedLocs.insert(nestedLoc))
116 return WalkResult::advance();
117
118 std::optional<lsp::Location> sourceLoc =
119 getLocationFromLoc(uri.scheme(), fileLoc, workspaceRoot);
120 if (sourceLoc && sourceLoc->uri != uri)
121 locations.push_back(*sourceLoc);
122 return WalkResult::advance();
123 });
124}
125
126/// Returns true if the given range contains the given source location. Note
127/// that this has slightly different behavior than SMRange because it is
128/// inclusive of the end location.
129static bool contains(SMRange range, SMLoc loc) {
130 return range.Start.getPointer() <= loc.getPointer() &&
131 loc.getPointer() <= range.End.getPointer();
132}
133
134/// Returns true if the given location is contained by the definition or one of
135/// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
136/// the range within `def` that the provided `loc` overlapped with.
137static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
138 SMRange *overlappedRange = nullptr) {
139 // Check the main definition.
140 if (contains(def.loc, loc)) {
141 if (overlappedRange)
142 *overlappedRange = def.loc;
143 return true;
144 }
145
146 // Check the uses.
147 const auto *useIt = llvm::find_if(
148 def.uses, [&](const SMRange &range) { return contains(range, loc); });
149 if (useIt != def.uses.end()) {
150 if (overlappedRange)
151 *overlappedRange = *useIt;
152 return true;
153 }
154 return false;
155}
156
157/// Given a location pointing to a result, return the result number it refers
158/// to or std::nullopt if it refers to all of the results.
159static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
160 // Skip all of the identifier characters.
161 auto isIdentifierChar = [](char c) {
162 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
163 c == '-';
164 };
165 const char *curPtr = loc.getPointer();
166 while (isIdentifierChar(*curPtr))
167 ++curPtr;
168
169 // Check to see if this location indexes into the result group, via `#`. If it
170 // doesn't, we can't extract a sub result number.
171 if (*curPtr != '#')
172 return std::nullopt;
173
174 // Compute the sub result number from the remaining portion of the string.
175 const char *numberStart = ++curPtr;
176 while (llvm::isDigit(*curPtr))
177 ++curPtr;
178 StringRef numberStr(numberStart, curPtr - numberStart);
179 unsigned resultNumber = 0;
180 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
181 : resultNumber;
182}
183
184/// Given a source location range, return the text covered by the given range.
185/// If the range is invalid, returns std::nullopt.
186static std::optional<StringRef> getTextFromRange(SMRange range) {
187 if (!range.isValid())
188 return std::nullopt;
189 const char *startPtr = range.Start.getPointer();
190 return StringRef(startPtr, range.End.getPointer() - startPtr);
191}
192
193/// Given a block and source location, print the source name of the block to the
194/// given output stream.
195static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
196 // Try to extract a name from the source location.
197 std::optional<StringRef> text = getTextFromRange(loc);
198 if (text && text->starts_with("^")) {
199 os << *text;
200 return;
201 }
202
203 // Otherwise, we don't have a name so print the block number.
204 os << "<Block #" << block->computeBlockNumber() << ">";
205}
209}
210
211/// Convert the given MLIR diagnostic to the LSP form.
212static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
214 const lsp::URIForFile &uri,
215 StringRef workspaceRoot) {
216 lsp::Diagnostic lspDiag;
217 lspDiag.source = "mlir";
218
219 // Note: Right now all of the diagnostics are treated as parser issues, but
220 // some are parser and some are verifier.
221 lspDiag.category = "Parse Error";
222
223 // Try to grab a file location for this diagnostic.
224 // TODO: For simplicity, we just grab the first one. It may be likely that we
225 // will need a more interesting heuristic here.'
226 StringRef uriScheme = uri.scheme();
227 std::optional<lsp::Location> lspLocation = getLocationFromLoc(
228 sourceMgr, diag.getLocation(), uriScheme, workspaceRoot, &uri);
229 if (lspLocation)
230 lspDiag.range = lspLocation->range;
231
232 // Convert the severity for the diagnostic.
233 switch (diag.getSeverity()) {
235 llvm_unreachable("expected notes to be handled separately");
237 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
238 break;
240 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
241 break;
243 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
244 break;
245 }
246 lspDiag.message = diag.str();
247
248 // Attach any notes to the main diagnostic as related information.
249 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
250 for (Diagnostic &note : diag.getNotes()) {
251 lsp::Location noteLoc;
252 if (std::optional<lsp::Location> loc = getLocationFromLoc(
253 sourceMgr, note.getLocation(), uriScheme, workspaceRoot))
254 noteLoc = *loc;
255 else
256 noteLoc.uri = uri;
257 relatedDiags.emplace_back(noteLoc, note.str());
258 }
259 if (!relatedDiags.empty())
260 lspDiag.relatedInformation = std::move(relatedDiags);
261
262 return lspDiag;
263}
264
265//===----------------------------------------------------------------------===//
266// MLIRDocument
267//===----------------------------------------------------------------------===//
268
269namespace {
270/// This class represents all of the information pertaining to a specific MLIR
271/// document.
272struct MLIRDocument {
273 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
274 StringRef contents, StringRef workspaceRoot,
275 std::vector<lsp::Diagnostic> &diagnostics);
276 MLIRDocument(const MLIRDocument &) = delete;
277 MLIRDocument &operator=(const MLIRDocument &) = delete;
278
279 //===--------------------------------------------------------------------===//
280 // Definitions and References
281 //===--------------------------------------------------------------------===//
282
283 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
284 std::vector<lsp::Location> &locations);
285 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
286 std::vector<lsp::Location> &references);
287
288 //===--------------------------------------------------------------------===//
289 // Hover
290 //===--------------------------------------------------------------------===//
291
292 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
293 const lsp::Position &hoverPos);
294 std::optional<lsp::Hover>
295 buildHoverForOperation(SMRange hoverRange,
296 const AsmParserState::OperationDefinition &op);
297 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
298 unsigned resultStart,
299 unsigned resultEnd, SMLoc posLoc);
300 lsp::Hover buildHoverForBlock(SMRange hoverRange,
301 const AsmParserState::BlockDefinition &block);
302 lsp::Hover
303 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
304 const AsmParserState::BlockDefinition &block);
305
306 lsp::Hover buildHoverForAttributeAlias(
307 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
308 lsp::Hover
309 buildHoverForTypeAlias(SMRange hoverRange,
310 const AsmParserState::TypeAliasDefinition &type);
311
312 //===--------------------------------------------------------------------===//
313 // Document Symbols
314 //===--------------------------------------------------------------------===//
315
316 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
317 void findDocumentSymbols(Operation *op,
318 std::vector<lsp::DocumentSymbol> &symbols);
319
320 //===--------------------------------------------------------------------===//
321 // Code Completion
322 //===--------------------------------------------------------------------===//
323
324 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
325 const lsp::Position &completePos,
326 const DialectRegistry &registry);
327
328 //===--------------------------------------------------------------------===//
329 // Code Action
330 //===--------------------------------------------------------------------===//
331
332 void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
333 lsp::Position &pos, StringRef severity,
334 StringRef message,
335 std::vector<llvm::lsp::TextEdit> &edits);
336
337 //===--------------------------------------------------------------------===//
338 // Bytecode
339 //===--------------------------------------------------------------------===//
340
341 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
342
343 //===--------------------------------------------------------------------===//
344 // Fields
345 //===--------------------------------------------------------------------===//
346
347 /// The high level parser state used to find definitions and references within
348 /// the source file.
349 AsmParserState asmState;
350
351 /// The container for the IR parsed from the input file.
352 Block parsedIR;
353
354 /// A collection of external resources, which we want to propagate up to the
355 /// user.
356 FallbackAsmResourceMap fallbackResourceMap;
357
358 /// The source manager containing the contents of the input file.
359 llvm::SourceMgr sourceMgr;
360
361 /// The workspace root of the server.
362 std::string workspaceRoot;
363};
364} // namespace
365
366MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
367 StringRef contents, StringRef workspaceRoot,
368 std::vector<lsp::Diagnostic> &diagnostics)
369 : workspaceRoot(workspaceRoot.str()) {
370 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
371 diagnostics.push_back(
372 getLspDiagnoticFromDiag(sourceMgr, diag, uri, workspaceRoot));
373 });
374
375 // Try to parsed the given IR string.
376 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
377 if (!memBuffer) {
378 llvm::lsp::Logger::error("Failed to create memory buffer for file",
379 uri.file());
380 return;
381 }
382
383 ParserConfig config(&context, /*verifyAfterParse=*/true,
384 &fallbackResourceMap);
385 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
386 if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
387 // If parsing failed, clear out any of the current state.
388 parsedIR.clear();
389 asmState = AsmParserState();
390 fallbackResourceMap = FallbackAsmResourceMap();
391 return;
392 }
393}
394
395//===----------------------------------------------------------------------===//
396// MLIRDocument: Definitions and References
397//===----------------------------------------------------------------------===//
398
399void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
400 const lsp::Position &defPos,
401 std::vector<lsp::Location> &locations) {
402 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
403
404 // Functor used to check if an SM definition contains the position.
405 auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
406 if (!isDefOrUse(def, posLoc))
407 return false;
408 locations.emplace_back(uri, sourceMgr, def.loc);
409 return true;
410 };
411
412 // Check all definitions related to operations.
413 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
414 if (contains(op.loc, posLoc))
415 return collectLocationsFromLoc(op.op->getLoc(), locations, uri,
416 workspaceRoot);
417 for (const auto &result : op.resultGroups)
418 if (containsPosition(result.definition))
419 return collectLocationsFromLoc(op.op->getLoc(), locations, uri,
420 workspaceRoot);
421 for (const auto &symUse : op.symbolUses) {
422 if (contains(symUse, posLoc)) {
423 locations.emplace_back(uri, sourceMgr, op.loc);
424 return collectLocationsFromLoc(op.op->getLoc(), locations, uri,
425 workspaceRoot);
426 }
427 }
428 }
429
430 // Check all definitions related to blocks.
431 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
432 if (containsPosition(block.definition))
433 return;
434 for (const AsmParserState::SMDefinition &arg : block.arguments)
435 if (containsPosition(arg))
436 return;
437 }
438
439 // Check all alias definitions.
440 for (const AsmParserState::AttributeAliasDefinition &attr :
441 asmState.getAttributeAliasDefs()) {
442 if (containsPosition(attr.definition))
443 return;
444 }
445 for (const AsmParserState::TypeAliasDefinition &type :
446 asmState.getTypeAliasDefs()) {
447 if (containsPosition(type.definition))
448 return;
449 }
450}
451
452void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
453 const lsp::Position &pos,
454 std::vector<lsp::Location> &references) {
455 // Functor used to append all of the definitions/uses of the given SM
456 // definition to the reference list.
457 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
458 references.emplace_back(uri, sourceMgr, def.loc);
459 for (const SMRange &use : def.uses)
460 references.emplace_back(uri, sourceMgr, use);
461 };
462
463 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
464
465 // Check all definitions related to operations.
466 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
467 if (contains(op.loc, posLoc)) {
468 for (const auto &result : op.resultGroups)
469 appendSMDef(result.definition);
470 for (const auto &symUse : op.symbolUses)
471 if (contains(symUse, posLoc))
472 references.emplace_back(uri, sourceMgr, symUse);
473 return;
474 }
475 for (const auto &result : op.resultGroups)
476 if (isDefOrUse(result.definition, posLoc))
477 return appendSMDef(result.definition);
478 for (const auto &symUse : op.symbolUses) {
479 if (!contains(symUse, posLoc))
480 continue;
481 for (const auto &symUse : op.symbolUses)
482 references.emplace_back(uri, sourceMgr, symUse);
483 return;
484 }
485 }
486
487 // Check all definitions related to blocks.
488 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
489 if (isDefOrUse(block.definition, posLoc))
490 return appendSMDef(block.definition);
491
492 for (const AsmParserState::SMDefinition &arg : block.arguments)
493 if (isDefOrUse(arg, posLoc))
494 return appendSMDef(arg);
495 }
496
497 // Check all alias definitions.
498 for (const AsmParserState::AttributeAliasDefinition &attr :
499 asmState.getAttributeAliasDefs()) {
500 if (isDefOrUse(attr.definition, posLoc))
501 return appendSMDef(attr.definition);
502 }
503 for (const AsmParserState::TypeAliasDefinition &type :
504 asmState.getTypeAliasDefs()) {
505 if (isDefOrUse(type.definition, posLoc))
506 return appendSMDef(type.definition);
507 }
508}
509
510//===----------------------------------------------------------------------===//
511// MLIRDocument: Hover
512//===----------------------------------------------------------------------===//
513
514std::optional<lsp::Hover>
515MLIRDocument::findHover(const lsp::URIForFile &uri,
516 const lsp::Position &hoverPos) {
517 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
518 SMRange hoverRange;
519
520 // Check for Hovers on operations and results.
521 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
522 // Check if the position points at this operation.
523 if (contains(op.loc, posLoc))
524 return buildHoverForOperation(op.loc, op);
525
526 // Check if the position points at the symbol name.
527 for (auto &use : op.symbolUses)
528 if (contains(use, posLoc))
529 return buildHoverForOperation(use, op);
530
531 // Check if the position points at a result group.
532 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
533 const auto &result = op.resultGroups[i];
534 if (!isDefOrUse(result.definition, posLoc, &hoverRange))
535 continue;
536
537 // Get the range of results covered by the over position.
538 unsigned resultStart = result.startIndex;
539 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
540 : op.resultGroups[i + 1].startIndex;
541 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
542 resultEnd, posLoc);
543 }
544 }
545
546 // Check to see if the hover is over a block argument.
547 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
548 if (isDefOrUse(block.definition, posLoc, &hoverRange))
549 return buildHoverForBlock(hoverRange, block);
550
551 for (const auto &arg : llvm::enumerate(block.arguments)) {
552 if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
553 continue;
554
555 return buildHoverForBlockArgument(
556 hoverRange, block.block->getArgument(arg.index()), block);
557 }
558 }
559
560 // Check to see if the hover is over an alias.
561 for (const AsmParserState::AttributeAliasDefinition &attr :
562 asmState.getAttributeAliasDefs()) {
563 if (isDefOrUse(attr.definition, posLoc, &hoverRange))
564 return buildHoverForAttributeAlias(hoverRange, attr);
565 }
566 for (const AsmParserState::TypeAliasDefinition &type :
567 asmState.getTypeAliasDefs()) {
568 if (isDefOrUse(type.definition, posLoc, &hoverRange))
569 return buildHoverForTypeAlias(hoverRange, type);
570 }
571
572 return std::nullopt;
573}
574
575std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
576 SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
577 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
578 llvm::raw_string_ostream os(hover.contents.value);
579
580 // Add the operation name to the hover.
581 os << "\"" << op.op->getName() << "\"";
582 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
583 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
584 os << "\n\n";
585
586 os << "Generic Form:\n\n```mlir\n";
587
588 op.op->print(os, OpPrintingFlags()
589 .printGenericOpForm()
590 .elideLargeElementsAttrs()
591 .skipRegions());
592 os << "\n```\n";
593
594 return hover;
595}
596
597lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
598 Operation *op,
599 unsigned resultStart,
600 unsigned resultEnd,
601 SMLoc posLoc) {
602 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
603 llvm::raw_string_ostream os(hover.contents.value);
604
605 // Add the parent operation name to the hover.
606 os << "Operation: \"" << op->getName() << "\"\n\n";
607
608 // Check to see if the location points to a specific result within the
609 // group.
610 if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
611 if ((resultStart + *resultNumber) < resultEnd) {
612 resultStart += *resultNumber;
613 resultEnd = resultStart + 1;
614 }
615 }
616
617 // Add the range of results and their types to the hover info.
618 if ((resultStart + 1) == resultEnd) {
619 os << "Result #" << resultStart << "\n\n"
620 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
621 } else {
622 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
623 << "Types: ";
624 llvm::interleaveComma(
625 op->getResults().slice(resultStart, resultEnd), os,
626 [&](Value result) { os << "`" << result.getType() << "`"; });
627 }
628
629 return hover;
630}
631
632lsp::Hover
633MLIRDocument::buildHoverForBlock(SMRange hoverRange,
634 const AsmParserState::BlockDefinition &block) {
635 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
636 llvm::raw_string_ostream os(hover.contents.value);
637
638 // Print the given block to the hover output stream.
639 auto printBlockToHover = [&](Block *newBlock) {
640 if (const auto *def = asmState.getBlockDef(newBlock))
641 printDefBlockName(os, *def);
642 else
643 printDefBlockName(os, newBlock);
644 };
645
646 // Display the parent operation, block number, predecessors, and successors.
647 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
648 << "Block #" << block.block->computeBlockNumber() << "\n\n";
649 if (!block.block->hasNoPredecessors()) {
650 os << "Predecessors: ";
651 llvm::interleaveComma(block.block->getPredecessors(), os,
652 printBlockToHover);
653 os << "\n\n";
654 }
655 if (!block.block->hasNoSuccessors()) {
656 os << "Successors: ";
657 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
658 os << "\n\n";
659 }
660
661 return hover;
662}
663
664lsp::Hover MLIRDocument::buildHoverForBlockArgument(
665 SMRange hoverRange, BlockArgument arg,
666 const AsmParserState::BlockDefinition &block) {
667 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
668 llvm::raw_string_ostream os(hover.contents.value);
669
670 // Display the parent operation, block, the argument number, and the type.
671 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
672 << "Block: ";
673 printDefBlockName(os, block);
674 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
675 << "Type: `" << arg.getType() << "`\n\n";
676
677 return hover;
678}
679
680lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
681 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
682 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
683 llvm::raw_string_ostream os(hover.contents.value);
684
685 os << "Attribute Alias: \"" << attr.name << "\n\n";
686 os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
687
688 return hover;
689}
690
691lsp::Hover MLIRDocument::buildHoverForTypeAlias(
692 SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
693 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
694 llvm::raw_string_ostream os(hover.contents.value);
695
696 os << "Type Alias: \"" << type.name << "\n\n";
697 os << "Value: ```mlir\n" << type.value << "\n```\n\n";
698
699 return hover;
700}
701
702//===----------------------------------------------------------------------===//
703// MLIRDocument: Document Symbols
704//===----------------------------------------------------------------------===//
705
706void MLIRDocument::findDocumentSymbols(
707 std::vector<lsp::DocumentSymbol> &symbols) {
708 for (Operation &op : parsedIR)
709 findDocumentSymbols(&op, symbols);
710}
711
712void MLIRDocument::findDocumentSymbols(
713 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
714 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
715
716 // Check for the source information of this operation.
717 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
718 // If this operation defines a symbol, record it.
719 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
720 symbols.emplace_back(symbol.getName(),
721 isa<FunctionOpInterface>(op)
722 ? llvm::lsp::SymbolKind::Function
723 : llvm::lsp::SymbolKind::Class,
724 lsp::Range(sourceMgr, def->scopeLoc),
725 lsp::Range(sourceMgr, def->loc));
726 childSymbols = &symbols.back().children;
727
728 } else if (op->hasTrait<OpTrait::SymbolTable>()) {
729 // Otherwise, if this is a symbol table push an anonymous document symbol.
730 symbols.emplace_back("<" + op->getName().getStringRef() + ">",
731 llvm::lsp::SymbolKind::Namespace,
732 llvm::lsp::Range(sourceMgr, def->scopeLoc),
733 llvm::lsp::Range(sourceMgr, def->loc));
734 childSymbols = &symbols.back().children;
735 }
736 }
737
738 // Recurse into the regions of this operation.
739 if (!op->getNumRegions())
740 return;
741 for (Region &region : op->getRegions())
742 for (Operation &childOp : region.getOps())
743 findDocumentSymbols(&childOp, *childSymbols);
744}
745
746//===----------------------------------------------------------------------===//
747// MLIRDocument: Code Completion
748//===----------------------------------------------------------------------===//
749
750namespace {
751class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
752public:
753 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
754 MLIRContext *ctx)
755 : AsmParserCodeCompleteContext(completeLoc),
756 completionList(completionList), ctx(ctx) {}
757
758 /// Signal code completion for a dialect name, with an optional prefix.
759 void completeDialectName(StringRef prefix) final {
760 for (StringRef dialect : ctx->getAvailableDialects()) {
761 llvm::lsp::CompletionItem item(prefix + dialect,
762 llvm::lsp::CompletionItemKind::Module,
763 /*sortText=*/"3");
764 item.detail = "dialect";
765 completionList.items.emplace_back(item);
766 }
767 }
769
770 /// Signal code completion for an operation name within the given dialect.
771 void completeOperationName(StringRef dialectName) final {
772 Dialect *dialect = ctx->getOrLoadDialect(dialectName);
773 if (!dialect)
774 return;
775
776 for (const auto &op : ctx->getRegisteredOperations()) {
777 if (&op.getDialect() != dialect)
778 continue;
779
780 llvm::lsp::CompletionItem item(
781 op.getStringRef().drop_front(dialectName.size() + 1),
782 llvm::lsp::CompletionItemKind::Field,
783 /*sortText=*/"1");
784 item.detail = "operation";
785 completionList.items.emplace_back(item);
786 }
787 }
788
789 /// Append the given SSA value as a code completion result for SSA value
790 /// completions.
791 void appendSSAValueCompletion(StringRef name, std::string typeData) final {
792 // Check if we need to insert the `%` or not.
793 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
794
795 llvm::lsp::CompletionItem item(name,
796 llvm::lsp::CompletionItemKind::Variable);
797 if (stripPrefix)
798 item.insertText = name.drop_front(1).str();
799 item.detail = std::move(typeData);
800 completionList.items.emplace_back(item);
801 }
802
803 /// Append the given block as a code completion result for block name
804 /// completions.
805 void appendBlockCompletion(StringRef name) final {
806 // Check if we need to insert the `^` or not.
807 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
808
809 llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
810 if (stripPrefix)
811 item.insertText = name.drop_front(1).str();
812 completionList.items.emplace_back(item);
813 }
814
815 /// Signal a completion for the given expected token.
816 void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
817 for (StringRef token : tokens) {
818 llvm::lsp::CompletionItem item(token,
819 llvm::lsp::CompletionItemKind::Keyword,
820 /*sortText=*/"0");
821 item.detail = optional ? "optional" : "";
822 completionList.items.emplace_back(item);
823 }
824 }
825
826 /// Signal a completion for an attribute.
827 void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
828 appendSimpleCompletions({"affine_set", "affine_map", "dense",
829 "dense_resource", "false", "loc", "sparse", "true",
830 "unit"},
831 llvm::lsp::CompletionItemKind::Field,
832 /*sortText=*/"1");
833
834 completeDialectName("#");
835 completeAliases(aliases, "#");
836 }
837 void completeDialectAttributeOrAlias(
838 const llvm::StringMap<Attribute> &aliases) override {
839 completeDialectName();
840 completeAliases(aliases);
841 }
842
843 /// Signal a completion for a type.
844 void completeType(const llvm::StringMap<Type> &aliases) override {
845 // Handle the various builtin types.
846 appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
847 "bf16", "f16", "f32", "f64", "f80", "f128",
848 "index", "none"},
849 llvm::lsp::CompletionItemKind::Field,
850 /*sortText=*/"1");
851
852 // Handle the builtin integer types.
853 for (StringRef type : {"i", "si", "ui"}) {
854 llvm::lsp::CompletionItem item(type + "<N>",
855 llvm::lsp::CompletionItemKind::Field,
856 /*sortText=*/"1");
857 item.insertText = type.str();
858 completionList.items.emplace_back(item);
859 }
860
861 // Insert completions for dialect types and aliases.
862 completeDialectName("!");
863 completeAliases(aliases, "!");
864 }
865 void
866 completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
867 completeDialectName();
868 completeAliases(aliases);
869 }
870
871 /// Add completion results for the given set of aliases.
872 template <typename T>
873 void completeAliases(const llvm::StringMap<T> &aliases,
874 StringRef prefix = "") {
875 for (const auto &alias : aliases) {
876 llvm::lsp::CompletionItem item(prefix + alias.getKey(),
877 llvm::lsp::CompletionItemKind::Field,
878 /*sortText=*/"2");
879 llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
880 completionList.items.emplace_back(item);
881 }
882 }
883
884 /// Add a set of simple completions that all have the same kind.
885 void appendSimpleCompletions(ArrayRef<StringRef> completions,
886 llvm::lsp::CompletionItemKind kind,
887 StringRef sortText = "") {
888 for (StringRef completion : completions)
889 completionList.items.emplace_back(completion, kind, sortText);
890 }
891
892private:
893 lsp::CompletionList &completionList;
894 MLIRContext *ctx;
895};
896} // namespace
897
898lsp::CompletionList
899MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
900 const lsp::Position &completePos,
901 const DialectRegistry &registry) {
902 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
903 if (!posLoc.isValid())
904 return lsp::CompletionList();
905
906 // To perform code completion, we run another parse of the module with the
907 // code completion context provided.
908 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
909 tmpContext.allowUnregisteredDialects();
910 lsp::CompletionList completionList;
911 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
912 &tmpContext);
913
914 Block tmpIR;
915 AsmParserState tmpState;
916 (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
917 &lspCompleteContext);
918 return completionList;
919}
920
921//===----------------------------------------------------------------------===//
922// MLIRDocument: Code Action
923//===----------------------------------------------------------------------===//
924
925void MLIRDocument::getCodeActionForDiagnostic(
926 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
927 StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
928 // Ignore diagnostics that print the current operation. These are always
929 // enabled for the language server, but not generally during normal
930 // parsing/verification.
931 if (message.starts_with("see current operation: "))
932 return;
933
934 // Get the start of the line containing the diagnostic.
935 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
936 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
937 if (!lineStart)
938 return;
939 StringRef line(lineStart, pos.character);
940
941 // Add a text edit for adding an expected-* diagnostic check for this
942 // diagnostic.
943 llvm::lsp::TextEdit edit;
944 edit.range = lsp::Range(lsp::Position(pos.line, 0));
945
946 // Use the indent of the current line for the expected-* diagnostic.
947 size_t indent = line.find_first_not_of(' ');
948 if (indent == StringRef::npos)
949 indent = line.size();
950
951 edit.newText.append(indent, ' ');
952 llvm::raw_string_ostream(edit.newText)
953 << "// expected-" << severity << " @below {{" << message << "}}\n";
954 edits.emplace_back(std::move(edit));
955}
956
957//===----------------------------------------------------------------------===//
958// MLIRDocument: Bytecode
959//===----------------------------------------------------------------------===//
960
961llvm::Expected<lsp::MLIRConvertBytecodeResult>
962MLIRDocument::convertToBytecode() {
963 // TODO: We currently require a single top-level operation, but this could
964 // conceptually be relaxed.
965 if (!llvm::hasSingleElement(parsedIR)) {
966 if (parsedIR.empty()) {
967 return llvm::make_error<llvm::lsp::LSPError>(
968 "expected a single and valid top-level operation, please ensure "
969 "there are no errors",
970 llvm::lsp::ErrorCode::RequestFailed);
971 }
972 return llvm::make_error<llvm::lsp::LSPError>(
973 "expected a single top-level operation",
974 llvm::lsp::ErrorCode::RequestFailed);
975 }
976
977 lsp::MLIRConvertBytecodeResult result;
978 {
979 BytecodeWriterConfig writerConfig(fallbackResourceMap);
980
981 std::string rawBytecodeBuffer;
982 llvm::raw_string_ostream os(rawBytecodeBuffer);
983 // No desired bytecode version set, so no need to check for error.
984 (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
985 result.output = llvm::encodeBase64(rawBytecodeBuffer);
986 }
987 return result;
988}
989
990//===----------------------------------------------------------------------===//
991// MLIRTextFileChunk
992//===----------------------------------------------------------------------===//
993
994namespace {
995/// This class represents a single chunk of an MLIR text file.
996struct MLIRTextFileChunk {
997 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
998 const lsp::URIForFile &uri, StringRef contents,
999 StringRef workspaceRoot,
1000 std::vector<lsp::Diagnostic> &diagnostics)
1001 : lineOffset(lineOffset),
1002 document(context, uri, contents, workspaceRoot, diagnostics) {}
1003
1004 /// Adjust the line number of the given range to anchor at the beginning of
1005 /// the file, instead of the beginning of this chunk.
1006 void adjustLocForChunkOffset(lsp::Range &range) {
1007 adjustLocForChunkOffset(range.start);
1008 adjustLocForChunkOffset(range.end);
1009 }
1010 /// Adjust the line number of the given position to anchor at the beginning of
1011 /// the file, instead of the beginning of this chunk.
1012 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1013
1014 /// The line offset of this chunk from the beginning of the file.
1015 uint64_t lineOffset;
1016 /// The document referred to by this chunk.
1017 MLIRDocument document;
1018};
1019} // namespace
1020
1021//===----------------------------------------------------------------------===//
1022// MLIRTextFile
1023//===----------------------------------------------------------------------===//
1024
1025namespace {
1026/// This class represents a text file containing one or more MLIR documents.
1027class MLIRTextFile {
1028public:
1029 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1030 int64_t version, lsp::DialectRegistryFn registryFn,
1031 StringRef workspaceRoot,
1032 std::vector<lsp::Diagnostic> &diagnostics);
1033
1034 /// Return the current version of this text file.
1035 int64_t getVersion() const { return version; }
1036
1037 //===--------------------------------------------------------------------===//
1038 // LSP Queries
1039 //===--------------------------------------------------------------------===//
1040
1041 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1042 std::vector<lsp::Location> &locations);
1043 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1044 std::vector<lsp::Location> &references);
1045 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1046 lsp::Position hoverPos);
1047 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1048 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1049 lsp::Position completePos);
1050 void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1051 const lsp::CodeActionContext &context,
1052 std::vector<lsp::CodeAction> &actions);
1053 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1054
1055private:
1056 /// Find the MLIR document that contains the given position, and update the
1057 /// position to be anchored at the start of the found chunk instead of the
1058 /// beginning of the file.
1059 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1060
1061 /// The context used to hold the state contained by the parsed document.
1062 MLIRContext context;
1063
1064 /// The full string contents of the file.
1065 std::string contents;
1066
1067 /// The version of this file.
1068 int64_t version;
1069
1070 /// The number of lines in the file.
1071 int64_t totalNumLines = 0;
1072
1073 /// The chunks of this file. The order of these chunks is the order in which
1074 /// they appear in the text file.
1075 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1076};
1077} // namespace
1078
1079MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1080 int64_t version, lsp::DialectRegistryFn registryFn,
1081 StringRef workspaceRoot,
1082 std::vector<lsp::Diagnostic> &diagnostics)
1083 : context(registryFn(uri), MLIRContext::Threading::DISABLED),
1084 contents(fileContents.str()), version(version) {
1085 context.allowUnregisteredDialects();
1086
1087 // Split the file into separate MLIR documents.
1088 SmallVector<StringRef, 8> subContents;
1089 StringRef(contents).split(subContents, kDefaultSplitMarker);
1090 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1091 context, /*lineOffset=*/0, uri, subContents.front(), workspaceRoot,
1092 diagnostics));
1093
1094 uint64_t lineOffset = subContents.front().count('\n');
1095 for (StringRef docContents : llvm::drop_begin(subContents)) {
1096 unsigned currentNumDiags = diagnostics.size();
1097 auto chunk = std::make_unique<MLIRTextFileChunk>(
1098 context, lineOffset, uri, docContents, workspaceRoot, diagnostics);
1099 lineOffset += docContents.count('\n');
1100
1101 // Adjust locations used in diagnostics to account for the offset from the
1102 // beginning of the file.
1103 for (lsp::Diagnostic &diag :
1104 llvm::drop_begin(diagnostics, currentNumDiags)) {
1105 chunk->adjustLocForChunkOffset(diag.range);
1106
1107 if (!diag.relatedInformation)
1108 continue;
1109 for (auto &it : *diag.relatedInformation)
1110 if (it.location.uri == uri)
1111 chunk->adjustLocForChunkOffset(it.location.range);
1112 }
1113 chunks.emplace_back(std::move(chunk));
1114 }
1115 totalNumLines = lineOffset;
1116}
1117
1118void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1119 lsp::Position defPos,
1120 std::vector<lsp::Location> &locations) {
1121 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1122 chunk.document.getLocationsOf(uri, defPos, locations);
1123
1124 // Adjust any locations within this file for the offset of this chunk.
1125 if (chunk.lineOffset == 0)
1126 return;
1127 for (lsp::Location &loc : locations)
1128 if (loc.uri == uri)
1129 chunk.adjustLocForChunkOffset(loc.range);
1130}
1131
1132void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1133 lsp::Position pos,
1134 std::vector<lsp::Location> &references) {
1135 MLIRTextFileChunk &chunk = getChunkFor(pos);
1136 chunk.document.findReferencesOf(uri, pos, references);
1137
1138 // Adjust any locations within this file for the offset of this chunk.
1139 if (chunk.lineOffset == 0)
1140 return;
1141 for (lsp::Location &loc : references)
1142 if (loc.uri == uri)
1143 chunk.adjustLocForChunkOffset(loc.range);
1144}
1145
1146std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1147 lsp::Position hoverPos) {
1148 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1149 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1150
1151 // Adjust any locations within this file for the offset of this chunk.
1152 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1153 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1154 return hoverInfo;
1155}
1156
1157void MLIRTextFile::findDocumentSymbols(
1158 std::vector<lsp::DocumentSymbol> &symbols) {
1159 if (chunks.size() == 1)
1160 return chunks.front()->document.findDocumentSymbols(symbols);
1161
1162 // If there are multiple chunks in this file, we create top-level symbols for
1163 // each chunk.
1164 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1165 MLIRTextFileChunk &chunk = *chunks[i];
1166 lsp::Position startPos(chunk.lineOffset);
1167 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1168 : chunks[i + 1]->lineOffset);
1169 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1170 llvm::lsp::SymbolKind::Namespace,
1171 /*range=*/lsp::Range(startPos, endPos),
1172 /*selectionRange=*/lsp::Range(startPos));
1173 chunk.document.findDocumentSymbols(symbol.children);
1174
1175 // Fixup the locations of document symbols within this chunk.
1176 if (i != 0) {
1177 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1178 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1179 symbolsToFix.push_back(&childSymbol);
1180
1181 while (!symbolsToFix.empty()) {
1182 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1183 chunk.adjustLocForChunkOffset(symbol->range);
1184 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1185
1186 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1187 symbolsToFix.push_back(&childSymbol);
1188 }
1189 }
1190
1191 // Push the symbol for this chunk.
1192 symbols.emplace_back(std::move(symbol));
1193 }
1194}
1195
1196lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1197 lsp::Position completePos) {
1198 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1199 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1200 uri, completePos, context.getDialectRegistry());
1201
1202 // Adjust any completion locations.
1203 for (llvm::lsp::CompletionItem &item : completionList.items) {
1204 if (item.textEdit)
1205 chunk.adjustLocForChunkOffset(item.textEdit->range);
1206 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1207 chunk.adjustLocForChunkOffset(edit.range);
1208 }
1209 return completionList;
1210}
1211
1212void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1213 const lsp::Range &pos,
1214 const lsp::CodeActionContext &context,
1215 std::vector<lsp::CodeAction> &actions) {
1216 // Create actions for any diagnostics in this file.
1217 for (auto &diag : context.diagnostics) {
1218 if (diag.source != "mlir")
1219 continue;
1220 lsp::Position diagPos = diag.range.start;
1221 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1222
1223 // Add a new code action that inserts a "expected" diagnostic check.
1224 lsp::CodeAction action;
1225 action.title = "Add expected-* diagnostic checks";
1226 action.kind = lsp::CodeAction::kQuickFix.str();
1227
1228 StringRef severity;
1229 switch (diag.severity) {
1230 case llvm::lsp::DiagnosticSeverity::Error:
1231 severity = "error";
1232 break;
1233 case llvm::lsp::DiagnosticSeverity::Warning:
1234 severity = "warning";
1235 break;
1236 default:
1237 continue;
1238 }
1239
1240 // Get edits for the diagnostic.
1241 std::vector<llvm::lsp::TextEdit> edits;
1242 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1243 diag.message, edits);
1244
1245 // Walk the related diagnostics, this is how we encode notes.
1246 if (diag.relatedInformation) {
1247 for (auto &noteDiag : *diag.relatedInformation) {
1248 if (noteDiag.location.uri != uri)
1249 continue;
1250 diagPos = noteDiag.location.range.start;
1251 diagPos.line -= chunk.lineOffset;
1252 chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1253 noteDiag.message, edits);
1254 }
1255 }
1256 // Fixup the locations for any edits.
1257 for (llvm::lsp::TextEdit &edit : edits)
1258 chunk.adjustLocForChunkOffset(edit.range);
1259
1260 action.edit.emplace();
1261 action.edit->changes[uri.uri().str()] = std::move(edits);
1262 action.diagnostics = {diag};
1263
1264 actions.emplace_back(std::move(action));
1265 }
1266}
1267
1268llvm::Expected<lsp::MLIRConvertBytecodeResult>
1269MLIRTextFile::convertToBytecode() {
1270 // Bail out if there is more than one chunk, bytecode wants a single module.
1271 if (chunks.size() != 1) {
1272 return llvm::make_error<llvm::lsp::LSPError>(
1273 "unexpected split file, please remove all `// -----`",
1274 llvm::lsp::ErrorCode::RequestFailed);
1275 }
1276 return chunks.front()->document.convertToBytecode();
1277}
1278
1279MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1280 if (chunks.size() == 1)
1281 return *chunks.front();
1282
1283 // Search for the first chunk with a greater line offset, the previous chunk
1284 // is the one that contains `pos`.
1285 auto it = llvm::upper_bound(
1286 chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1287 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1288 });
1289 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1290 pos.line -= chunk.lineOffset;
1291 return chunk;
1292}
1293
1294//===----------------------------------------------------------------------===//
1295// MLIRServer::Impl
1296//===----------------------------------------------------------------------===//
1297
1300
1301 /// The registry factory for containing dialects that can be recognized in
1302 /// parsed .mlir files.
1304
1305 /// The files held by the server, mapped by their URI file name.
1306 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1307
1308 /// The workspace root of the server.
1309 std::string workspaceRoot;
1310};
1311
1312//===----------------------------------------------------------------------===//
1313// MLIRServer
1314//===----------------------------------------------------------------------===//
1315
1317 : impl(std::make_unique<Impl>(registryFn)) {}
1319
1321 const URIForFile &uri, StringRef contents, int64_t version,
1322 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1323 impl->files[uri.file()] =
1324 std::make_unique<MLIRTextFile>(uri, contents, version, impl->registryFn,
1325 impl->workspaceRoot, diagnostics);
1326}
1327
1328std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1329 auto it = impl->files.find(uri.file());
1330 if (it == impl->files.end())
1331 return std::nullopt;
1332
1333 int64_t version = it->second->getVersion();
1334 impl->files.erase(it);
1335 return version;
1336}
1337
1339 const URIForFile &uri, const Position &defPos,
1340 std::vector<llvm::lsp::Location> &locations) {
1341 auto fileIt = impl->files.find(uri.file());
1342 if (fileIt != impl->files.end())
1343 fileIt->second->getLocationsOf(uri, defPos, locations);
1344}
1345
1347 const URIForFile &uri, const Position &pos,
1348 std::vector<llvm::lsp::Location> &references) {
1349 auto fileIt = impl->files.find(uri.file());
1350 if (fileIt != impl->files.end())
1351 fileIt->second->findReferencesOf(uri, pos, references);
1352}
1353
1354std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1355 const Position &hoverPos) {
1356 auto fileIt = impl->files.find(uri.file());
1357 if (fileIt != impl->files.end())
1358 return fileIt->second->findHover(uri, hoverPos);
1359 return std::nullopt;
1360}
1361
1363 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1364 auto fileIt = impl->files.find(uri.file());
1365 if (fileIt != impl->files.end())
1366 fileIt->second->findDocumentSymbols(symbols);
1367}
1368
1369lsp::CompletionList
1371 const Position &completePos) {
1372 auto fileIt = impl->files.find(uri.file());
1373 if (fileIt != impl->files.end())
1374 return fileIt->second->getCodeCompletion(uri, completePos);
1375 return CompletionList();
1376}
1377
1378void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1379 const CodeActionContext &context,
1380 std::vector<CodeAction> &actions) {
1381 auto fileIt = impl->files.find(uri.file());
1382 if (fileIt != impl->files.end())
1383 fileIt->second->getCodeActions(uri, pos, context, actions);
1384}
1385
1388 MLIRContext tempContext(impl->registryFn(uri));
1389 tempContext.allowUnregisteredDialects();
1390
1391 // Collect any errors during parsing.
1392 std::string errorMsg;
1393 ScopedDiagnosticHandler diagHandler(
1394 &tempContext,
1395 [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1396
1397 // Handling for external resources, which we want to propagate up to the user.
1398 FallbackAsmResourceMap fallbackResourceMap;
1399
1400 // Setup the parser config.
1401 ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1402 &fallbackResourceMap);
1403
1404 // Try to parse the given source file.
1405 Block parsedBlock;
1406 if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1407 return llvm::make_error<llvm::lsp::LSPError>(
1408 "failed to parse bytecode source file: " + errorMsg,
1409 llvm::lsp::ErrorCode::RequestFailed);
1410 }
1411
1412 // TODO: We currently expect a single top-level operation, but this could
1413 // conceptually be relaxed.
1414 if (!llvm::hasSingleElement(parsedBlock)) {
1415 return llvm::make_error<llvm::lsp::LSPError>(
1416 "expected bytecode to contain a single top-level operation",
1417 llvm::lsp::ErrorCode::RequestFailed);
1418 }
1419
1420 // Print the module to a buffer.
1422 {
1423 // Extract the top-level op so that aliases get printed.
1424 // FIXME: We should be able to enable aliases without having to do this!
1425 OwningOpRef<Operation *> topOp = &parsedBlock.front();
1426 topOp->remove();
1427
1428 AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1429 /*locationMap=*/nullptr, &fallbackResourceMap);
1430
1431 llvm::raw_string_ostream os(result.output);
1432 topOp->print(os, state);
1433 }
1434 return std::move(result);
1435}
1436
1439 auto fileIt = impl->files.find(uri.file());
1440 if (fileIt == impl->files.end()) {
1441 return llvm::make_error<llvm::lsp::LSPError>(
1442 "language server does not contain an entry for this source file",
1443 llvm::lsp::ErrorCode::RequestFailed);
1444 }
1445 return fileIt->second->convertToBytecode();
1446}
1447
1449 impl->workspaceRoot = root.str();
1450}
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc, StringRef workspaceRoot)
Returns a language server location from the given MLIR file location.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri, StringRef workspaceRoot)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri, StringRef workspaceRoot)
Convert the given MLIR diagnostic to the LSP form.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
static std::string diag(const llvm::Value &value)
This class represents state from a parsed MLIR textual format string.
iterator_range< AttributeDefIterator > getAttributeAliasDefs() const
Return a range of the AttributeAliasDefinitions held by the current parser state.
iterator_range< BlockDefIterator > getBlockDefs() const
Return a range of the BlockDefinitions held by the current parser state.
const OperationDefinition * getOpDef(Operation *op) const
Return the definition for the given operation, or nullptr if the given operation does not have a defi...
const BlockDefinition * getBlockDef(Block *block) const
Return the definition for the given block, or nullptr if the given block does not have a definition.
iterator_range< OperationDefIterator > getOpDefs() const
Return a range of the OperationDefinitions held by the current parser state.
iterator_range< TypeDefIterator > getTypeAliasDefs() const
Return a range of the TypeAliasDefinitions held by the current parser state.
This class provides management for the lifetime of the state used when printing the IR.
Definition AsmState.h:542
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition Block.h:258
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition Block.h:255
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
Definition Block.cpp:144
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition AsmState.h:421
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
unsigned getLine() const
Definition Location.cpp:173
StringAttr getFilename() const
Definition Location.cpp:169
unsigned getColumn() const
Definition Location.cpp:175
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
Definition Location.cpp:124
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
std::vector< StringRef > getAvailableDialects()
Return information about all available dialects in the registry in this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
result_range getResults()
Definition Operation.h:441
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This class represents a configuration for the MLIR assembly parser.
Definition AsmState.h:469
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
void setWorkspaceRoot(StringRef root)
Set the workspace root for the server.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
MLIRServer(DialectRegistryFn registry_fn)
Construct a new server with the given dialect registry function.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
llvm::function_ref< DialectRegistry &(const llvm::lsp::URIForFile &uri)> DialectRegistryFn
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition Parser.cpp:2961
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition Parser.cpp:38
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
Definition Protocol.h:48
std::string workspaceRoot
The workspace root of the server.
lsp::DialectRegistryFn registryFn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
Impl(lsp::DialectRegistryFn registryFn)
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
StringRef name
The name of the attribute alias.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...