MLIR 22.0.0git
MLIRServer.cpp
Go to the documentation of this file.
1//===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "MLIRServer.h"
10#include "Protocol.h"
15#include "mlir/IR/Operation.h"
17#include "mlir/Parser/Parser.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Base64.h"
22#include "llvm/Support/LSP/Logging.h"
23#include "llvm/Support/SourceMgr.h"
24#include <optional>
25
26using namespace mlir;
27
28/// Returns the range of a lexical token given a SMLoc corresponding to the
29/// start of an token location. The range is computed heuristically, and
30/// supports identifier-like tokens, strings, etc.
31static SMRange convertTokenLocToRange(SMLoc loc) {
32 return lsp::convertTokenLocToRange(loc, "$-.");
33}
34
35/// Returns a language server location from the given MLIR file location.
36/// `uriScheme` is the scheme to use when building new uris.
37static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38 FileLineColLoc loc) {
40 lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
41 if (!sourceURI) {
42 llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43 loc.getFilename(),
44 llvm::toString(sourceURI.takeError()));
45 return std::nullopt;
46 }
47
48 lsp::Position position;
49 position.line = loc.getLine() - 1;
50 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51 return lsp::Location{*sourceURI, lsp::Range(position)};
52}
53
54/// Returns a language server location from the given MLIR location, or
55/// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56/// when building new uris. `uri` is an optional additional filter that, when
57/// present, is used to filter sub locations that do not share the same uri.
58static std::optional<lsp::Location>
59getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60 StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
61 std::optional<lsp::Location> location;
62 loc->walk([&](Location nestedLoc) {
63 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64 if (!fileLoc)
65 return WalkResult::advance();
66
67 std::optional<lsp::Location> sourceLoc =
68 getLocationFromLoc(uriScheme, fileLoc);
69 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70 location = *sourceLoc;
71 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
73
74 // Use range of potential identifier starting at location, else length 1
75 // range.
76 location->range.end.character += 1;
77 if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78 auto lineCol = sourceMgr.getLineAndColumn(range->End);
79 location->range.end.character =
80 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
81 }
82 return WalkResult::interrupt();
83 }
84 return WalkResult::advance();
85 });
86 return location;
87}
88
89/// Collect all of the locations from the given MLIR location that are not
90/// contained within the given URI.
92 std::vector<lsp::Location> &locations,
93 const lsp::URIForFile &uri) {
94 SetVector<Location> visitedLocs;
95 loc->walk([&](Location nestedLoc) {
96 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97 if (!fileLoc || !visitedLocs.insert(nestedLoc))
98 return WalkResult::advance();
99
100 std::optional<lsp::Location> sourceLoc =
101 getLocationFromLoc(uri.scheme(), fileLoc);
102 if (sourceLoc && sourceLoc->uri != uri)
103 locations.push_back(*sourceLoc);
104 return WalkResult::advance();
105 });
106}
107
108/// Returns true if the given range contains the given source location. Note
109/// that this has slightly different behavior than SMRange because it is
110/// inclusive of the end location.
111static bool contains(SMRange range, SMLoc loc) {
112 return range.Start.getPointer() <= loc.getPointer() &&
113 loc.getPointer() <= range.End.getPointer();
114}
115
116/// Returns true if the given location is contained by the definition or one of
117/// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118/// the range within `def` that the provided `loc` overlapped with.
119static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
120 SMRange *overlappedRange = nullptr) {
121 // Check the main definition.
122 if (contains(def.loc, loc)) {
123 if (overlappedRange)
124 *overlappedRange = def.loc;
125 return true;
126 }
127
128 // Check the uses.
129 const auto *useIt = llvm::find_if(
130 def.uses, [&](const SMRange &range) { return contains(range, loc); });
131 if (useIt != def.uses.end()) {
132 if (overlappedRange)
133 *overlappedRange = *useIt;
134 return true;
135 }
136 return false;
137}
138
139/// Given a location pointing to a result, return the result number it refers
140/// to or std::nullopt if it refers to all of the results.
141static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
142 // Skip all of the identifier characters.
143 auto isIdentifierChar = [](char c) {
144 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
145 c == '-';
146 };
147 const char *curPtr = loc.getPointer();
148 while (isIdentifierChar(*curPtr))
149 ++curPtr;
150
151 // Check to see if this location indexes into the result group, via `#`. If it
152 // doesn't, we can't extract a sub result number.
153 if (*curPtr != '#')
154 return std::nullopt;
155
156 // Compute the sub result number from the remaining portion of the string.
157 const char *numberStart = ++curPtr;
158 while (llvm::isDigit(*curPtr))
159 ++curPtr;
160 StringRef numberStr(numberStart, curPtr - numberStart);
161 unsigned resultNumber = 0;
162 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
163 : resultNumber;
164}
165
166/// Given a source location range, return the text covered by the given range.
167/// If the range is invalid, returns std::nullopt.
168static std::optional<StringRef> getTextFromRange(SMRange range) {
169 if (!range.isValid())
170 return std::nullopt;
171 const char *startPtr = range.Start.getPointer();
172 return StringRef(startPtr, range.End.getPointer() - startPtr);
173}
174
175/// Given a block and source location, print the source name of the block to the
176/// given output stream.
177static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
178 // Try to extract a name from the source location.
179 std::optional<StringRef> text = getTextFromRange(loc);
180 if (text && text->starts_with("^")) {
181 os << *text;
182 return;
183 }
184
185 // Otherwise, we don't have a name so print the block number.
186 os << "<Block #" << block->computeBlockNumber() << ">";
187}
191}
192
193/// Convert the given MLIR diagnostic to the LSP form.
194static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
196 const lsp::URIForFile &uri) {
197 lsp::Diagnostic lspDiag;
198 lspDiag.source = "mlir";
199
200 // Note: Right now all of the diagnostics are treated as parser issues, but
201 // some are parser and some are verifier.
202 lspDiag.category = "Parse Error";
203
204 // Try to grab a file location for this diagnostic.
205 // TODO: For simplicity, we just grab the first one. It may be likely that we
206 // will need a more interesting heuristic here.'
207 StringRef uriScheme = uri.scheme();
208 std::optional<lsp::Location> lspLocation =
209 getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
210 if (lspLocation)
211 lspDiag.range = lspLocation->range;
212
213 // Convert the severity for the diagnostic.
214 switch (diag.getSeverity()) {
216 llvm_unreachable("expected notes to be handled separately");
218 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
219 break;
221 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
222 break;
224 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
225 break;
226 }
227 lspDiag.message = diag.str();
228
229 // Attach any notes to the main diagnostic as related information.
230 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
231 for (Diagnostic &note : diag.getNotes()) {
232 lsp::Location noteLoc;
233 if (std::optional<lsp::Location> loc =
234 getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
235 noteLoc = *loc;
236 else
237 noteLoc.uri = uri;
238 relatedDiags.emplace_back(noteLoc, note.str());
239 }
240 if (!relatedDiags.empty())
241 lspDiag.relatedInformation = std::move(relatedDiags);
242
243 return lspDiag;
244}
245
246//===----------------------------------------------------------------------===//
247// MLIRDocument
248//===----------------------------------------------------------------------===//
249
250namespace {
251/// This class represents all of the information pertaining to a specific MLIR
252/// document.
253struct MLIRDocument {
254 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
255 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
256 MLIRDocument(const MLIRDocument &) = delete;
257 MLIRDocument &operator=(const MLIRDocument &) = delete;
258
259 //===--------------------------------------------------------------------===//
260 // Definitions and References
261 //===--------------------------------------------------------------------===//
262
263 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
264 std::vector<lsp::Location> &locations);
265 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
266 std::vector<lsp::Location> &references);
267
268 //===--------------------------------------------------------------------===//
269 // Hover
270 //===--------------------------------------------------------------------===//
271
272 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
273 const lsp::Position &hoverPos);
274 std::optional<lsp::Hover>
275 buildHoverForOperation(SMRange hoverRange,
276 const AsmParserState::OperationDefinition &op);
277 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
278 unsigned resultStart,
279 unsigned resultEnd, SMLoc posLoc);
280 lsp::Hover buildHoverForBlock(SMRange hoverRange,
281 const AsmParserState::BlockDefinition &block);
282 lsp::Hover
283 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
284 const AsmParserState::BlockDefinition &block);
285
286 lsp::Hover buildHoverForAttributeAlias(
287 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
288 lsp::Hover
289 buildHoverForTypeAlias(SMRange hoverRange,
290 const AsmParserState::TypeAliasDefinition &type);
291
292 //===--------------------------------------------------------------------===//
293 // Document Symbols
294 //===--------------------------------------------------------------------===//
295
296 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
297 void findDocumentSymbols(Operation *op,
298 std::vector<lsp::DocumentSymbol> &symbols);
299
300 //===--------------------------------------------------------------------===//
301 // Code Completion
302 //===--------------------------------------------------------------------===//
303
304 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
305 const lsp::Position &completePos,
306 const DialectRegistry &registry);
307
308 //===--------------------------------------------------------------------===//
309 // Code Action
310 //===--------------------------------------------------------------------===//
311
312 void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
313 lsp::Position &pos, StringRef severity,
314 StringRef message,
315 std::vector<llvm::lsp::TextEdit> &edits);
316
317 //===--------------------------------------------------------------------===//
318 // Bytecode
319 //===--------------------------------------------------------------------===//
320
321 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
322
323 //===--------------------------------------------------------------------===//
324 // Fields
325 //===--------------------------------------------------------------------===//
326
327 /// The high level parser state used to find definitions and references within
328 /// the source file.
329 AsmParserState asmState;
330
331 /// The container for the IR parsed from the input file.
332 Block parsedIR;
333
334 /// A collection of external resources, which we want to propagate up to the
335 /// user.
336 FallbackAsmResourceMap fallbackResourceMap;
337
338 /// The source manager containing the contents of the input file.
339 llvm::SourceMgr sourceMgr;
340};
341} // namespace
342
343MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
344 StringRef contents,
345 std::vector<lsp::Diagnostic> &diagnostics) {
346 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
347 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
348 });
349
350 // Try to parsed the given IR string.
351 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
352 if (!memBuffer) {
353 llvm::lsp::Logger::error("Failed to create memory buffer for file",
354 uri.file());
355 return;
356 }
357
358 ParserConfig config(&context, /*verifyAfterParse=*/true,
359 &fallbackResourceMap);
360 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
361 if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
362 // If parsing failed, clear out any of the current state.
363 parsedIR.clear();
364 asmState = AsmParserState();
365 fallbackResourceMap = FallbackAsmResourceMap();
366 return;
367 }
368}
369
370//===----------------------------------------------------------------------===//
371// MLIRDocument: Definitions and References
372//===----------------------------------------------------------------------===//
373
374void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
375 const lsp::Position &defPos,
376 std::vector<lsp::Location> &locations) {
377 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
378
379 // Functor used to check if an SM definition contains the position.
380 auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
381 if (!isDefOrUse(def, posLoc))
382 return false;
383 locations.emplace_back(uri, sourceMgr, def.loc);
384 return true;
385 };
386
387 // Check all definitions related to operations.
388 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
389 if (contains(op.loc, posLoc))
390 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
391 for (const auto &result : op.resultGroups)
392 if (containsPosition(result.definition))
393 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
394 for (const auto &symUse : op.symbolUses) {
395 if (contains(symUse, posLoc)) {
396 locations.emplace_back(uri, sourceMgr, op.loc);
397 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
398 }
399 }
400 }
401
402 // Check all definitions related to blocks.
403 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
404 if (containsPosition(block.definition))
405 return;
406 for (const AsmParserState::SMDefinition &arg : block.arguments)
407 if (containsPosition(arg))
408 return;
409 }
410
411 // Check all alias definitions.
412 for (const AsmParserState::AttributeAliasDefinition &attr :
413 asmState.getAttributeAliasDefs()) {
414 if (containsPosition(attr.definition))
415 return;
416 }
417 for (const AsmParserState::TypeAliasDefinition &type :
418 asmState.getTypeAliasDefs()) {
419 if (containsPosition(type.definition))
420 return;
421 }
422}
423
424void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
425 const lsp::Position &pos,
426 std::vector<lsp::Location> &references) {
427 // Functor used to append all of the definitions/uses of the given SM
428 // definition to the reference list.
429 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
430 references.emplace_back(uri, sourceMgr, def.loc);
431 for (const SMRange &use : def.uses)
432 references.emplace_back(uri, sourceMgr, use);
433 };
434
435 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
436
437 // Check all definitions related to operations.
438 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
439 if (contains(op.loc, posLoc)) {
440 for (const auto &result : op.resultGroups)
441 appendSMDef(result.definition);
442 for (const auto &symUse : op.symbolUses)
443 if (contains(symUse, posLoc))
444 references.emplace_back(uri, sourceMgr, symUse);
445 return;
446 }
447 for (const auto &result : op.resultGroups)
448 if (isDefOrUse(result.definition, posLoc))
449 return appendSMDef(result.definition);
450 for (const auto &symUse : op.symbolUses) {
451 if (!contains(symUse, posLoc))
452 continue;
453 for (const auto &symUse : op.symbolUses)
454 references.emplace_back(uri, sourceMgr, symUse);
455 return;
456 }
457 }
458
459 // Check all definitions related to blocks.
460 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
461 if (isDefOrUse(block.definition, posLoc))
462 return appendSMDef(block.definition);
463
464 for (const AsmParserState::SMDefinition &arg : block.arguments)
465 if (isDefOrUse(arg, posLoc))
466 return appendSMDef(arg);
467 }
468
469 // Check all alias definitions.
470 for (const AsmParserState::AttributeAliasDefinition &attr :
471 asmState.getAttributeAliasDefs()) {
472 if (isDefOrUse(attr.definition, posLoc))
473 return appendSMDef(attr.definition);
474 }
475 for (const AsmParserState::TypeAliasDefinition &type :
476 asmState.getTypeAliasDefs()) {
477 if (isDefOrUse(type.definition, posLoc))
478 return appendSMDef(type.definition);
479 }
480}
481
482//===----------------------------------------------------------------------===//
483// MLIRDocument: Hover
484//===----------------------------------------------------------------------===//
485
486std::optional<lsp::Hover>
487MLIRDocument::findHover(const lsp::URIForFile &uri,
488 const lsp::Position &hoverPos) {
489 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
490 SMRange hoverRange;
491
492 // Check for Hovers on operations and results.
493 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
494 // Check if the position points at this operation.
495 if (contains(op.loc, posLoc))
496 return buildHoverForOperation(op.loc, op);
497
498 // Check if the position points at the symbol name.
499 for (auto &use : op.symbolUses)
500 if (contains(use, posLoc))
501 return buildHoverForOperation(use, op);
502
503 // Check if the position points at a result group.
504 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
505 const auto &result = op.resultGroups[i];
506 if (!isDefOrUse(result.definition, posLoc, &hoverRange))
507 continue;
508
509 // Get the range of results covered by the over position.
510 unsigned resultStart = result.startIndex;
511 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
512 : op.resultGroups[i + 1].startIndex;
513 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
514 resultEnd, posLoc);
515 }
516 }
517
518 // Check to see if the hover is over a block argument.
519 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
520 if (isDefOrUse(block.definition, posLoc, &hoverRange))
521 return buildHoverForBlock(hoverRange, block);
522
523 for (const auto &arg : llvm::enumerate(block.arguments)) {
524 if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
525 continue;
526
527 return buildHoverForBlockArgument(
528 hoverRange, block.block->getArgument(arg.index()), block);
529 }
530 }
531
532 // Check to see if the hover is over an alias.
533 for (const AsmParserState::AttributeAliasDefinition &attr :
534 asmState.getAttributeAliasDefs()) {
535 if (isDefOrUse(attr.definition, posLoc, &hoverRange))
536 return buildHoverForAttributeAlias(hoverRange, attr);
537 }
538 for (const AsmParserState::TypeAliasDefinition &type :
539 asmState.getTypeAliasDefs()) {
540 if (isDefOrUse(type.definition, posLoc, &hoverRange))
541 return buildHoverForTypeAlias(hoverRange, type);
542 }
543
544 return std::nullopt;
545}
546
547std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
548 SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
549 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
550 llvm::raw_string_ostream os(hover.contents.value);
551
552 // Add the operation name to the hover.
553 os << "\"" << op.op->getName() << "\"";
554 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
555 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
556 os << "\n\n";
557
558 os << "Generic Form:\n\n```mlir\n";
559
560 op.op->print(os, OpPrintingFlags()
561 .printGenericOpForm()
562 .elideLargeElementsAttrs()
563 .skipRegions());
564 os << "\n```\n";
565
566 return hover;
567}
568
569lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
570 Operation *op,
571 unsigned resultStart,
572 unsigned resultEnd,
573 SMLoc posLoc) {
574 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
575 llvm::raw_string_ostream os(hover.contents.value);
576
577 // Add the parent operation name to the hover.
578 os << "Operation: \"" << op->getName() << "\"\n\n";
579
580 // Check to see if the location points to a specific result within the
581 // group.
582 if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
583 if ((resultStart + *resultNumber) < resultEnd) {
584 resultStart += *resultNumber;
585 resultEnd = resultStart + 1;
586 }
587 }
588
589 // Add the range of results and their types to the hover info.
590 if ((resultStart + 1) == resultEnd) {
591 os << "Result #" << resultStart << "\n\n"
592 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
593 } else {
594 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
595 << "Types: ";
596 llvm::interleaveComma(
597 op->getResults().slice(resultStart, resultEnd), os,
598 [&](Value result) { os << "`" << result.getType() << "`"; });
599 }
600
601 return hover;
602}
603
604lsp::Hover
605MLIRDocument::buildHoverForBlock(SMRange hoverRange,
606 const AsmParserState::BlockDefinition &block) {
607 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
608 llvm::raw_string_ostream os(hover.contents.value);
609
610 // Print the given block to the hover output stream.
611 auto printBlockToHover = [&](Block *newBlock) {
612 if (const auto *def = asmState.getBlockDef(newBlock))
613 printDefBlockName(os, *def);
614 else
615 printDefBlockName(os, newBlock);
616 };
617
618 // Display the parent operation, block number, predecessors, and successors.
619 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
620 << "Block #" << block.block->computeBlockNumber() << "\n\n";
621 if (!block.block->hasNoPredecessors()) {
622 os << "Predecessors: ";
623 llvm::interleaveComma(block.block->getPredecessors(), os,
624 printBlockToHover);
625 os << "\n\n";
626 }
627 if (!block.block->hasNoSuccessors()) {
628 os << "Successors: ";
629 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
630 os << "\n\n";
631 }
632
633 return hover;
634}
635
636lsp::Hover MLIRDocument::buildHoverForBlockArgument(
637 SMRange hoverRange, BlockArgument arg,
638 const AsmParserState::BlockDefinition &block) {
639 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
640 llvm::raw_string_ostream os(hover.contents.value);
641
642 // Display the parent operation, block, the argument number, and the type.
643 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
644 << "Block: ";
645 printDefBlockName(os, block);
646 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
647 << "Type: `" << arg.getType() << "`\n\n";
648
649 return hover;
650}
651
652lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
653 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
654 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
655 llvm::raw_string_ostream os(hover.contents.value);
656
657 os << "Attribute Alias: \"" << attr.name << "\n\n";
658 os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
659
660 return hover;
661}
662
663lsp::Hover MLIRDocument::buildHoverForTypeAlias(
664 SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
665 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
666 llvm::raw_string_ostream os(hover.contents.value);
667
668 os << "Type Alias: \"" << type.name << "\n\n";
669 os << "Value: ```mlir\n" << type.value << "\n```\n\n";
670
671 return hover;
672}
673
674//===----------------------------------------------------------------------===//
675// MLIRDocument: Document Symbols
676//===----------------------------------------------------------------------===//
677
678void MLIRDocument::findDocumentSymbols(
679 std::vector<lsp::DocumentSymbol> &symbols) {
680 for (Operation &op : parsedIR)
681 findDocumentSymbols(&op, symbols);
682}
683
684void MLIRDocument::findDocumentSymbols(
685 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
686 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
687
688 // Check for the source information of this operation.
689 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
690 // If this operation defines a symbol, record it.
691 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
692 symbols.emplace_back(symbol.getName(),
693 isa<FunctionOpInterface>(op)
694 ? llvm::lsp::SymbolKind::Function
695 : llvm::lsp::SymbolKind::Class,
696 lsp::Range(sourceMgr, def->scopeLoc),
697 lsp::Range(sourceMgr, def->loc));
698 childSymbols = &symbols.back().children;
699
700 } else if (op->hasTrait<OpTrait::SymbolTable>()) {
701 // Otherwise, if this is a symbol table push an anonymous document symbol.
702 symbols.emplace_back("<" + op->getName().getStringRef() + ">",
703 llvm::lsp::SymbolKind::Namespace,
704 llvm::lsp::Range(sourceMgr, def->scopeLoc),
705 llvm::lsp::Range(sourceMgr, def->loc));
706 childSymbols = &symbols.back().children;
707 }
708 }
709
710 // Recurse into the regions of this operation.
711 if (!op->getNumRegions())
712 return;
713 for (Region &region : op->getRegions())
714 for (Operation &childOp : region.getOps())
715 findDocumentSymbols(&childOp, *childSymbols);
716}
717
718//===----------------------------------------------------------------------===//
719// MLIRDocument: Code Completion
720//===----------------------------------------------------------------------===//
721
722namespace {
723class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
724public:
725 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
726 MLIRContext *ctx)
727 : AsmParserCodeCompleteContext(completeLoc),
728 completionList(completionList), ctx(ctx) {}
729
730 /// Signal code completion for a dialect name, with an optional prefix.
731 void completeDialectName(StringRef prefix) final {
732 for (StringRef dialect : ctx->getAvailableDialects()) {
733 llvm::lsp::CompletionItem item(prefix + dialect,
734 llvm::lsp::CompletionItemKind::Module,
735 /*sortText=*/"3");
736 item.detail = "dialect";
737 completionList.items.emplace_back(item);
738 }
739 }
741
742 /// Signal code completion for an operation name within the given dialect.
743 void completeOperationName(StringRef dialectName) final {
744 Dialect *dialect = ctx->getOrLoadDialect(dialectName);
745 if (!dialect)
746 return;
747
748 for (const auto &op : ctx->getRegisteredOperations()) {
749 if (&op.getDialect() != dialect)
750 continue;
751
752 llvm::lsp::CompletionItem item(
753 op.getStringRef().drop_front(dialectName.size() + 1),
754 llvm::lsp::CompletionItemKind::Field,
755 /*sortText=*/"1");
756 item.detail = "operation";
757 completionList.items.emplace_back(item);
758 }
759 }
760
761 /// Append the given SSA value as a code completion result for SSA value
762 /// completions.
763 void appendSSAValueCompletion(StringRef name, std::string typeData) final {
764 // Check if we need to insert the `%` or not.
765 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
766
767 llvm::lsp::CompletionItem item(name,
768 llvm::lsp::CompletionItemKind::Variable);
769 if (stripPrefix)
770 item.insertText = name.drop_front(1).str();
771 item.detail = std::move(typeData);
772 completionList.items.emplace_back(item);
773 }
774
775 /// Append the given block as a code completion result for block name
776 /// completions.
777 void appendBlockCompletion(StringRef name) final {
778 // Check if we need to insert the `^` or not.
779 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
780
781 llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
782 if (stripPrefix)
783 item.insertText = name.drop_front(1).str();
784 completionList.items.emplace_back(item);
785 }
786
787 /// Signal a completion for the given expected token.
788 void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
789 for (StringRef token : tokens) {
790 llvm::lsp::CompletionItem item(token,
791 llvm::lsp::CompletionItemKind::Keyword,
792 /*sortText=*/"0");
793 item.detail = optional ? "optional" : "";
794 completionList.items.emplace_back(item);
795 }
796 }
797
798 /// Signal a completion for an attribute.
799 void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
800 appendSimpleCompletions({"affine_set", "affine_map", "dense",
801 "dense_resource", "false", "loc", "sparse", "true",
802 "unit"},
803 llvm::lsp::CompletionItemKind::Field,
804 /*sortText=*/"1");
805
806 completeDialectName("#");
807 completeAliases(aliases, "#");
808 }
809 void completeDialectAttributeOrAlias(
810 const llvm::StringMap<Attribute> &aliases) override {
811 completeDialectName();
812 completeAliases(aliases);
813 }
814
815 /// Signal a completion for a type.
816 void completeType(const llvm::StringMap<Type> &aliases) override {
817 // Handle the various builtin types.
818 appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
819 "bf16", "f16", "f32", "f64", "f80", "f128",
820 "index", "none"},
821 llvm::lsp::CompletionItemKind::Field,
822 /*sortText=*/"1");
823
824 // Handle the builtin integer types.
825 for (StringRef type : {"i", "si", "ui"}) {
826 llvm::lsp::CompletionItem item(type + "<N>",
827 llvm::lsp::CompletionItemKind::Field,
828 /*sortText=*/"1");
829 item.insertText = type.str();
830 completionList.items.emplace_back(item);
831 }
832
833 // Insert completions for dialect types and aliases.
834 completeDialectName("!");
835 completeAliases(aliases, "!");
836 }
837 void
838 completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
839 completeDialectName();
840 completeAliases(aliases);
841 }
842
843 /// Add completion results for the given set of aliases.
844 template <typename T>
845 void completeAliases(const llvm::StringMap<T> &aliases,
846 StringRef prefix = "") {
847 for (const auto &alias : aliases) {
848 llvm::lsp::CompletionItem item(prefix + alias.getKey(),
849 llvm::lsp::CompletionItemKind::Field,
850 /*sortText=*/"2");
851 llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
852 completionList.items.emplace_back(item);
853 }
854 }
855
856 /// Add a set of simple completions that all have the same kind.
857 void appendSimpleCompletions(ArrayRef<StringRef> completions,
858 llvm::lsp::CompletionItemKind kind,
859 StringRef sortText = "") {
860 for (StringRef completion : completions)
861 completionList.items.emplace_back(completion, kind, sortText);
862 }
863
864private:
865 lsp::CompletionList &completionList;
866 MLIRContext *ctx;
867};
868} // namespace
869
870lsp::CompletionList
871MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
872 const lsp::Position &completePos,
873 const DialectRegistry &registry) {
874 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
875 if (!posLoc.isValid())
876 return lsp::CompletionList();
877
878 // To perform code completion, we run another parse of the module with the
879 // code completion context provided.
880 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
881 tmpContext.allowUnregisteredDialects();
882 lsp::CompletionList completionList;
883 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
884 &tmpContext);
885
886 Block tmpIR;
887 AsmParserState tmpState;
888 (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
889 &lspCompleteContext);
890 return completionList;
891}
892
893//===----------------------------------------------------------------------===//
894// MLIRDocument: Code Action
895//===----------------------------------------------------------------------===//
896
897void MLIRDocument::getCodeActionForDiagnostic(
898 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
899 StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
900 // Ignore diagnostics that print the current operation. These are always
901 // enabled for the language server, but not generally during normal
902 // parsing/verification.
903 if (message.starts_with("see current operation: "))
904 return;
905
906 // Get the start of the line containing the diagnostic.
907 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
908 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
909 if (!lineStart)
910 return;
911 StringRef line(lineStart, pos.character);
912
913 // Add a text edit for adding an expected-* diagnostic check for this
914 // diagnostic.
915 llvm::lsp::TextEdit edit;
916 edit.range = lsp::Range(lsp::Position(pos.line, 0));
917
918 // Use the indent of the current line for the expected-* diagnostic.
919 size_t indent = line.find_first_not_of(' ');
920 if (indent == StringRef::npos)
921 indent = line.size();
922
923 edit.newText.append(indent, ' ');
924 llvm::raw_string_ostream(edit.newText)
925 << "// expected-" << severity << " @below {{" << message << "}}\n";
926 edits.emplace_back(std::move(edit));
927}
928
929//===----------------------------------------------------------------------===//
930// MLIRDocument: Bytecode
931//===----------------------------------------------------------------------===//
932
933llvm::Expected<lsp::MLIRConvertBytecodeResult>
934MLIRDocument::convertToBytecode() {
935 // TODO: We currently require a single top-level operation, but this could
936 // conceptually be relaxed.
937 if (!llvm::hasSingleElement(parsedIR)) {
938 if (parsedIR.empty()) {
939 return llvm::make_error<llvm::lsp::LSPError>(
940 "expected a single and valid top-level operation, please ensure "
941 "there are no errors",
942 llvm::lsp::ErrorCode::RequestFailed);
943 }
944 return llvm::make_error<llvm::lsp::LSPError>(
945 "expected a single top-level operation",
946 llvm::lsp::ErrorCode::RequestFailed);
947 }
948
949 lsp::MLIRConvertBytecodeResult result;
950 {
951 BytecodeWriterConfig writerConfig(fallbackResourceMap);
952
953 std::string rawBytecodeBuffer;
954 llvm::raw_string_ostream os(rawBytecodeBuffer);
955 // No desired bytecode version set, so no need to check for error.
956 (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
957 result.output = llvm::encodeBase64(rawBytecodeBuffer);
958 }
959 return result;
960}
961
962//===----------------------------------------------------------------------===//
963// MLIRTextFileChunk
964//===----------------------------------------------------------------------===//
965
966namespace {
967/// This class represents a single chunk of an MLIR text file.
968struct MLIRTextFileChunk {
969 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
970 const lsp::URIForFile &uri, StringRef contents,
971 std::vector<lsp::Diagnostic> &diagnostics)
972 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
973
974 /// Adjust the line number of the given range to anchor at the beginning of
975 /// the file, instead of the beginning of this chunk.
976 void adjustLocForChunkOffset(lsp::Range &range) {
977 adjustLocForChunkOffset(range.start);
978 adjustLocForChunkOffset(range.end);
979 }
980 /// Adjust the line number of the given position to anchor at the beginning of
981 /// the file, instead of the beginning of this chunk.
982 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
983
984 /// The line offset of this chunk from the beginning of the file.
985 uint64_t lineOffset;
986 /// The document referred to by this chunk.
987 MLIRDocument document;
988};
989} // namespace
990
991//===----------------------------------------------------------------------===//
992// MLIRTextFile
993//===----------------------------------------------------------------------===//
994
995namespace {
996/// This class represents a text file containing one or more MLIR documents.
997class MLIRTextFile {
998public:
999 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1000 int64_t version, lsp::DialectRegistryFn registryFn,
1001 std::vector<lsp::Diagnostic> &diagnostics);
1002
1003 /// Return the current version of this text file.
1004 int64_t getVersion() const { return version; }
1005
1006 //===--------------------------------------------------------------------===//
1007 // LSP Queries
1008 //===--------------------------------------------------------------------===//
1009
1010 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1011 std::vector<lsp::Location> &locations);
1012 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1013 std::vector<lsp::Location> &references);
1014 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1015 lsp::Position hoverPos);
1016 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1017 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1018 lsp::Position completePos);
1019 void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1020 const lsp::CodeActionContext &context,
1021 std::vector<lsp::CodeAction> &actions);
1022 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1023
1024private:
1025 /// Find the MLIR document that contains the given position, and update the
1026 /// position to be anchored at the start of the found chunk instead of the
1027 /// beginning of the file.
1028 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1029
1030 /// The context used to hold the state contained by the parsed document.
1031 MLIRContext context;
1032
1033 /// The full string contents of the file.
1034 std::string contents;
1035
1036 /// The version of this file.
1037 int64_t version;
1038
1039 /// The number of lines in the file.
1040 int64_t totalNumLines = 0;
1041
1042 /// The chunks of this file. The order of these chunks is the order in which
1043 /// they appear in the text file.
1044 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1045};
1046} // namespace
1047
1048MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1049 int64_t version, lsp::DialectRegistryFn registryFn,
1050 std::vector<lsp::Diagnostic> &diagnostics)
1051 : context(registryFn(uri), MLIRContext::Threading::DISABLED),
1052 contents(fileContents.str()), version(version) {
1053 context.allowUnregisteredDialects();
1054
1055 // Split the file into separate MLIR documents.
1056 SmallVector<StringRef, 8> subContents;
1057 StringRef(contents).split(subContents, kDefaultSplitMarker);
1058 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1059 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
1060
1061 uint64_t lineOffset = subContents.front().count('\n');
1062 for (StringRef docContents : llvm::drop_begin(subContents)) {
1063 unsigned currentNumDiags = diagnostics.size();
1064 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1065 docContents, diagnostics);
1066 lineOffset += docContents.count('\n');
1067
1068 // Adjust locations used in diagnostics to account for the offset from the
1069 // beginning of the file.
1070 for (lsp::Diagnostic &diag :
1071 llvm::drop_begin(diagnostics, currentNumDiags)) {
1072 chunk->adjustLocForChunkOffset(diag.range);
1073
1074 if (!diag.relatedInformation)
1075 continue;
1076 for (auto &it : *diag.relatedInformation)
1077 if (it.location.uri == uri)
1078 chunk->adjustLocForChunkOffset(it.location.range);
1079 }
1080 chunks.emplace_back(std::move(chunk));
1081 }
1082 totalNumLines = lineOffset;
1083}
1084
1085void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1086 lsp::Position defPos,
1087 std::vector<lsp::Location> &locations) {
1088 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1089 chunk.document.getLocationsOf(uri, defPos, locations);
1090
1091 // Adjust any locations within this file for the offset of this chunk.
1092 if (chunk.lineOffset == 0)
1093 return;
1094 for (lsp::Location &loc : locations)
1095 if (loc.uri == uri)
1096 chunk.adjustLocForChunkOffset(loc.range);
1097}
1098
1099void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1100 lsp::Position pos,
1101 std::vector<lsp::Location> &references) {
1102 MLIRTextFileChunk &chunk = getChunkFor(pos);
1103 chunk.document.findReferencesOf(uri, pos, references);
1104
1105 // Adjust any locations within this file for the offset of this chunk.
1106 if (chunk.lineOffset == 0)
1107 return;
1108 for (lsp::Location &loc : references)
1109 if (loc.uri == uri)
1110 chunk.adjustLocForChunkOffset(loc.range);
1111}
1112
1113std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1114 lsp::Position hoverPos) {
1115 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1116 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1117
1118 // Adjust any locations within this file for the offset of this chunk.
1119 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1120 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1121 return hoverInfo;
1122}
1123
1124void MLIRTextFile::findDocumentSymbols(
1125 std::vector<lsp::DocumentSymbol> &symbols) {
1126 if (chunks.size() == 1)
1127 return chunks.front()->document.findDocumentSymbols(symbols);
1128
1129 // If there are multiple chunks in this file, we create top-level symbols for
1130 // each chunk.
1131 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1132 MLIRTextFileChunk &chunk = *chunks[i];
1133 lsp::Position startPos(chunk.lineOffset);
1134 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1135 : chunks[i + 1]->lineOffset);
1136 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1137 llvm::lsp::SymbolKind::Namespace,
1138 /*range=*/lsp::Range(startPos, endPos),
1139 /*selectionRange=*/lsp::Range(startPos));
1140 chunk.document.findDocumentSymbols(symbol.children);
1141
1142 // Fixup the locations of document symbols within this chunk.
1143 if (i != 0) {
1144 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1145 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1146 symbolsToFix.push_back(&childSymbol);
1147
1148 while (!symbolsToFix.empty()) {
1149 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1150 chunk.adjustLocForChunkOffset(symbol->range);
1151 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1152
1153 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1154 symbolsToFix.push_back(&childSymbol);
1155 }
1156 }
1157
1158 // Push the symbol for this chunk.
1159 symbols.emplace_back(std::move(symbol));
1160 }
1161}
1162
1163lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1164 lsp::Position completePos) {
1165 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1166 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1167 uri, completePos, context.getDialectRegistry());
1168
1169 // Adjust any completion locations.
1170 for (llvm::lsp::CompletionItem &item : completionList.items) {
1171 if (item.textEdit)
1172 chunk.adjustLocForChunkOffset(item.textEdit->range);
1173 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1174 chunk.adjustLocForChunkOffset(edit.range);
1175 }
1176 return completionList;
1177}
1178
1179void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1180 const lsp::Range &pos,
1181 const lsp::CodeActionContext &context,
1182 std::vector<lsp::CodeAction> &actions) {
1183 // Create actions for any diagnostics in this file.
1184 for (auto &diag : context.diagnostics) {
1185 if (diag.source != "mlir")
1186 continue;
1187 lsp::Position diagPos = diag.range.start;
1188 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1189
1190 // Add a new code action that inserts a "expected" diagnostic check.
1191 lsp::CodeAction action;
1192 action.title = "Add expected-* diagnostic checks";
1193 action.kind = lsp::CodeAction::kQuickFix.str();
1194
1195 StringRef severity;
1196 switch (diag.severity) {
1197 case llvm::lsp::DiagnosticSeverity::Error:
1198 severity = "error";
1199 break;
1200 case llvm::lsp::DiagnosticSeverity::Warning:
1201 severity = "warning";
1202 break;
1203 default:
1204 continue;
1205 }
1206
1207 // Get edits for the diagnostic.
1208 std::vector<llvm::lsp::TextEdit> edits;
1209 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1210 diag.message, edits);
1211
1212 // Walk the related diagnostics, this is how we encode notes.
1213 if (diag.relatedInformation) {
1214 for (auto &noteDiag : *diag.relatedInformation) {
1215 if (noteDiag.location.uri != uri)
1216 continue;
1217 diagPos = noteDiag.location.range.start;
1218 diagPos.line -= chunk.lineOffset;
1219 chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1220 noteDiag.message, edits);
1221 }
1222 }
1223 // Fixup the locations for any edits.
1224 for (llvm::lsp::TextEdit &edit : edits)
1225 chunk.adjustLocForChunkOffset(edit.range);
1226
1227 action.edit.emplace();
1228 action.edit->changes[uri.uri().str()] = std::move(edits);
1229 action.diagnostics = {diag};
1230
1231 actions.emplace_back(std::move(action));
1232 }
1233}
1234
1235llvm::Expected<lsp::MLIRConvertBytecodeResult>
1236MLIRTextFile::convertToBytecode() {
1237 // Bail out if there is more than one chunk, bytecode wants a single module.
1238 if (chunks.size() != 1) {
1239 return llvm::make_error<llvm::lsp::LSPError>(
1240 "unexpected split file, please remove all `// -----`",
1241 llvm::lsp::ErrorCode::RequestFailed);
1242 }
1243 return chunks.front()->document.convertToBytecode();
1244}
1245
1246MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1247 if (chunks.size() == 1)
1248 return *chunks.front();
1249
1250 // Search for the first chunk with a greater line offset, the previous chunk
1251 // is the one that contains `pos`.
1252 auto it = llvm::upper_bound(
1253 chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1254 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1255 });
1256 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1257 pos.line -= chunk.lineOffset;
1258 return chunk;
1259}
1260
1261//===----------------------------------------------------------------------===//
1262// MLIRServer::Impl
1263//===----------------------------------------------------------------------===//
1264
1267
1268 /// The registry factory for containing dialects that can be recognized in
1269 /// parsed .mlir files.
1271
1272 /// The files held by the server, mapped by their URI file name.
1273 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1274};
1275
1276//===----------------------------------------------------------------------===//
1277// MLIRServer
1278//===----------------------------------------------------------------------===//
1279
1281 : impl(std::make_unique<Impl>(registryFn)) {}
1283
1285 const URIForFile &uri, StringRef contents, int64_t version,
1286 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1287 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288 uri, contents, version, impl->registryFn, diagnostics);
1289}
1290
1291std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1292 auto it = impl->files.find(uri.file());
1293 if (it == impl->files.end())
1294 return std::nullopt;
1295
1296 int64_t version = it->second->getVersion();
1297 impl->files.erase(it);
1298 return version;
1299}
1300
1302 const URIForFile &uri, const Position &defPos,
1303 std::vector<llvm::lsp::Location> &locations) {
1304 auto fileIt = impl->files.find(uri.file());
1305 if (fileIt != impl->files.end())
1306 fileIt->second->getLocationsOf(uri, defPos, locations);
1307}
1308
1310 const URIForFile &uri, const Position &pos,
1311 std::vector<llvm::lsp::Location> &references) {
1312 auto fileIt = impl->files.find(uri.file());
1313 if (fileIt != impl->files.end())
1314 fileIt->second->findReferencesOf(uri, pos, references);
1315}
1316
1317std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1318 const Position &hoverPos) {
1319 auto fileIt = impl->files.find(uri.file());
1320 if (fileIt != impl->files.end())
1321 return fileIt->second->findHover(uri, hoverPos);
1322 return std::nullopt;
1323}
1324
1326 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1327 auto fileIt = impl->files.find(uri.file());
1328 if (fileIt != impl->files.end())
1329 fileIt->second->findDocumentSymbols(symbols);
1330}
1331
1332lsp::CompletionList
1334 const Position &completePos) {
1335 auto fileIt = impl->files.find(uri.file());
1336 if (fileIt != impl->files.end())
1337 return fileIt->second->getCodeCompletion(uri, completePos);
1338 return CompletionList();
1339}
1340
1341void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1342 const CodeActionContext &context,
1343 std::vector<CodeAction> &actions) {
1344 auto fileIt = impl->files.find(uri.file());
1345 if (fileIt != impl->files.end())
1346 fileIt->second->getCodeActions(uri, pos, context, actions);
1347}
1348
1351 MLIRContext tempContext(impl->registryFn(uri));
1352 tempContext.allowUnregisteredDialects();
1353
1354 // Collect any errors during parsing.
1355 std::string errorMsg;
1356 ScopedDiagnosticHandler diagHandler(
1357 &tempContext,
1358 [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1359
1360 // Handling for external resources, which we want to propagate up to the user.
1361 FallbackAsmResourceMap fallbackResourceMap;
1362
1363 // Setup the parser config.
1364 ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1365 &fallbackResourceMap);
1366
1367 // Try to parse the given source file.
1368 Block parsedBlock;
1369 if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1370 return llvm::make_error<llvm::lsp::LSPError>(
1371 "failed to parse bytecode source file: " + errorMsg,
1372 llvm::lsp::ErrorCode::RequestFailed);
1373 }
1374
1375 // TODO: We currently expect a single top-level operation, but this could
1376 // conceptually be relaxed.
1377 if (!llvm::hasSingleElement(parsedBlock)) {
1378 return llvm::make_error<llvm::lsp::LSPError>(
1379 "expected bytecode to contain a single top-level operation",
1380 llvm::lsp::ErrorCode::RequestFailed);
1381 }
1382
1383 // Print the module to a buffer.
1385 {
1386 // Extract the top-level op so that aliases get printed.
1387 // FIXME: We should be able to enable aliases without having to do this!
1388 OwningOpRef<Operation *> topOp = &parsedBlock.front();
1389 topOp->remove();
1390
1391 AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1392 /*locationMap=*/nullptr, &fallbackResourceMap);
1393
1394 llvm::raw_string_ostream os(result.output);
1395 topOp->print(os, state);
1396 }
1397 return std::move(result);
1398}
1399
1402 auto fileIt = impl->files.find(uri.file());
1403 if (fileIt == impl->files.end()) {
1404 return llvm::make_error<llvm::lsp::LSPError>(
1405 "language server does not contain an entry for this source file",
1406 llvm::lsp::ErrorCode::RequestFailed);
1407 }
1408 return fileIt->second->convertToBytecode();
1409}
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
static std::string diag(const llvm::Value &value)
iterator_range< AttributeDefIterator > getAttributeAliasDefs() const
Return a range of the AttributeAliasDefinitions held by the current parser state.
iterator_range< BlockDefIterator > getBlockDefs() const
Return a range of the BlockDefinitions held by the current parser state.
const OperationDefinition * getOpDef(Operation *op) const
Return the definition for the given operation, or nullptr if the given operation does not have a defi...
const BlockDefinition * getBlockDef(Block *block) const
Return the definition for the given block, or nullptr if the given block does not have a definition.
iterator_range< OperationDefIterator > getOpDefs() const
Return a range of the OperationDefinitions held by the current parser state.
iterator_range< TypeDefIterator > getTypeAliasDefs() const
Return a range of the TypeAliasDefinitions held by the current parser state.
This class provides management for the lifetime of the state used when printing the IR.
Definition AsmState.h:542
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Definition Block.h:258
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
void clear()
Definition Block.h:38
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition Block.h:255
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
Definition Block.cpp:144
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition AsmState.h:421
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
unsigned getLine() const
Definition Location.cpp:173
StringAttr getFilename() const
Definition Location.cpp:169
unsigned getColumn() const
Definition Location.cpp:175
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
Definition Location.cpp:124
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
std::vector< StringRef > getAvailableDialects()
Return information about all available dialects in the registry in this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_range getResults()
Definition Operation.h:415
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This class represents a configuration for the MLIR assembly parser.
Definition AsmState.h:469
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
MLIRServer(DialectRegistryFn registry_fn)
Construct a new server with the given dialect registry function.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
llvm::function_ref< DialectRegistry &(const llvm::lsp::URIForFile &uri)> DialectRegistryFn
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition Parser.cpp:2918
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition Parser.cpp:38
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
Definition Protocol.h:48
lsp::DialectRegistryFn registryFn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
Impl(lsp::DialectRegistryFn registryFn)
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
StringRef name
The name of the attribute alias.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...