20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Base64.h"
22 #include "llvm/Support/LSP/Logging.h"
23 #include "llvm/Support/SourceMgr.h"
40 lsp::URIForFile::fromFile(loc.
getFilename(), uriScheme);
42 llvm::lsp::Logger::error(
"Failed to create URI for file `{0}`: {1}",
48 lsp::Position position;
49 position.line = loc.
getLine() - 1;
51 return lsp::Location{*sourceURI, lsp::Range(position)};
58 static std::optional<lsp::Location>
60 StringRef uriScheme,
const lsp::URIForFile *uri =
nullptr) {
61 std::optional<lsp::Location> location;
67 std::optional<lsp::Location> sourceLoc =
69 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70 location = *sourceLoc;
71 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
76 location->range.end.character += 1;
78 auto lineCol = sourceMgr.getLineAndColumn(range->End);
79 location->range.end.character =
92 std::vector<lsp::Location> &locations,
93 const lsp::URIForFile &uri) {
97 if (!fileLoc || !visitedLocs.insert(nestedLoc))
100 std::optional<lsp::Location> sourceLoc =
102 if (sourceLoc && sourceLoc->uri != uri)
103 locations.push_back(*sourceLoc);
112 return range.Start.getPointer() <= loc.getPointer() &&
113 loc.getPointer() <= range.End.getPointer();
120 SMRange *overlappedRange =
nullptr) {
124 *overlappedRange = def.
loc;
129 const auto *useIt = llvm::find_if(
130 def.
uses, [&](
const SMRange &range) { return contains(range, loc); });
131 if (useIt != def.
uses.end()) {
133 *overlappedRange = *useIt;
143 auto isIdentifierChar = [](
char c) {
144 return isalnum(c) || c ==
'%' || c ==
'$' || c ==
'.' || c ==
'_' ||
147 const char *curPtr = loc.getPointer();
148 while (isIdentifierChar(*curPtr))
157 const char *numberStart = ++curPtr;
158 while (llvm::isDigit(*curPtr))
160 StringRef numberStr(numberStart, curPtr - numberStart);
161 unsigned resultNumber = 0;
162 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
169 if (!range.isValid())
171 const char *startPtr = range.Start.getPointer();
172 return StringRef(startPtr, range.End.getPointer() - startPtr);
177 return std::distance(block->
getParent()->
begin(), block->getIterator());
185 if (text && text->starts_with(
"^")) {
201 const lsp::URIForFile &uri) {
202 lsp::Diagnostic lspDiag;
203 lspDiag.source =
"mlir";
207 lspDiag.category =
"Parse Error";
212 StringRef uriScheme = uri.scheme();
213 std::optional<lsp::Location> lspLocation =
216 lspDiag.range = lspLocation->range;
219 switch (
diag.getSeverity()) {
221 llvm_unreachable(
"expected notes to be handled separately");
223 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
229 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
232 lspDiag.message =
diag.str();
235 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
237 lsp::Location noteLoc;
238 if (std::optional<lsp::Location> loc =
243 relatedDiags.emplace_back(noteLoc, note.str());
245 if (!relatedDiags.empty())
246 lspDiag.relatedInformation = std::move(relatedDiags);
258 struct MLIRDocument {
259 MLIRDocument(
MLIRContext &context,
const lsp::URIForFile &uri,
260 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261 MLIRDocument(
const MLIRDocument &) =
delete;
262 MLIRDocument &operator=(
const MLIRDocument &) =
delete;
268 void getLocationsOf(
const lsp::URIForFile &uri,
const lsp::Position &defPos,
269 std::vector<lsp::Location> &locations);
270 void findReferencesOf(
const lsp::URIForFile &uri,
const lsp::Position &pos,
271 std::vector<lsp::Location> &references);
277 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
278 const lsp::Position &hoverPos);
279 std::optional<lsp::Hover>
280 buildHoverForOperation(SMRange hoverRange,
282 lsp::Hover buildHoverForOperationResult(SMRange hoverRange,
Operation *op,
283 unsigned resultStart,
284 unsigned resultEnd, SMLoc posLoc);
285 lsp::Hover buildHoverForBlock(SMRange hoverRange,
288 buildHoverForBlockArgument(SMRange hoverRange,
BlockArgument arg,
291 lsp::Hover buildHoverForAttributeAlias(
294 buildHoverForTypeAlias(SMRange hoverRange,
301 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
303 std::vector<lsp::DocumentSymbol> &symbols);
309 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
310 const lsp::Position &completePos,
317 void getCodeActionForDiagnostic(
const lsp::URIForFile &uri,
318 lsp::Position &pos, StringRef severity,
320 std::vector<llvm::lsp::TextEdit> &edits);
344 llvm::SourceMgr sourceMgr;
348 MLIRDocument::MLIRDocument(
MLIRContext &context,
const lsp::URIForFile &uri,
350 std::vector<lsp::Diagnostic> &diagnostics) {
356 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
358 llvm::lsp::Logger::error(
"Failed to create memory buffer for file",
364 &fallbackResourceMap);
365 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
379 void MLIRDocument::getLocationsOf(
const lsp::URIForFile &uri,
380 const lsp::Position &defPos,
381 std::vector<lsp::Location> &locations) {
382 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
388 locations.emplace_back(uri, sourceMgr, def.loc);
396 for (
const auto &result : op.resultGroups)
397 if (containsPosition(result.definition))
399 for (
const auto &symUse : op.symbolUses) {
401 locations.emplace_back(uri, sourceMgr, op.loc);
409 if (containsPosition(block.definition))
412 if (containsPosition(arg))
418 asmState.getAttributeAliasDefs()) {
419 if (containsPosition(attr.definition))
423 asmState.getTypeAliasDefs()) {
424 if (containsPosition(type.definition))
429 void MLIRDocument::findReferencesOf(
const lsp::URIForFile &uri,
430 const lsp::Position &pos,
431 std::vector<lsp::Location> &references) {
435 references.emplace_back(uri, sourceMgr, def.loc);
436 for (
const SMRange &use : def.uses)
437 references.emplace_back(uri, sourceMgr, use);
440 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
445 for (
const auto &result : op.resultGroups)
446 appendSMDef(result.definition);
447 for (
const auto &symUse : op.symbolUses)
449 references.emplace_back(uri, sourceMgr, symUse);
452 for (
const auto &result : op.resultGroups)
454 return appendSMDef(result.definition);
455 for (
const auto &symUse : op.symbolUses) {
458 for (
const auto &symUse : op.symbolUses)
459 references.emplace_back(uri, sourceMgr, symUse);
467 return appendSMDef(block.definition);
471 return appendSMDef(arg);
476 asmState.getAttributeAliasDefs()) {
478 return appendSMDef(attr.definition);
481 asmState.getTypeAliasDefs()) {
483 return appendSMDef(type.definition);
491 std::optional<lsp::Hover>
492 MLIRDocument::findHover(
const lsp::URIForFile &uri,
493 const lsp::Position &hoverPos) {
494 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
501 return buildHoverForOperation(op.loc, op);
504 for (
auto &use : op.symbolUses)
506 return buildHoverForOperation(use, op);
509 for (
unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
510 const auto &result = op.resultGroups[i];
511 if (!
isDefOrUse(result.definition, posLoc, &hoverRange))
515 unsigned resultStart = result.startIndex;
516 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
517 : op.resultGroups[i + 1].startIndex;
518 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
525 if (
isDefOrUse(block.definition, posLoc, &hoverRange))
526 return buildHoverForBlock(hoverRange, block);
529 if (!
isDefOrUse(arg.value(), posLoc, &hoverRange))
532 return buildHoverForBlockArgument(
533 hoverRange, block.block->
getArgument(arg.index()), block);
539 asmState.getAttributeAliasDefs()) {
540 if (
isDefOrUse(attr.definition, posLoc, &hoverRange))
541 return buildHoverForAttributeAlias(hoverRange, attr);
544 asmState.getTypeAliasDefs()) {
545 if (
isDefOrUse(type.definition, posLoc, &hoverRange))
546 return buildHoverForTypeAlias(hoverRange, type);
552 std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
554 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
555 llvm::raw_string_ostream os(hover.contents.value);
559 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.
op))
560 os <<
" : " << symbol.getVisibility() <<
" @" << symbol.getName() <<
"";
563 os <<
"Generic Form:\n\n```mlir\n";
566 .printGenericOpForm()
567 .elideLargeElementsAttrs()
574 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
576 unsigned resultStart,
579 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
580 llvm::raw_string_ostream os(hover.contents.value);
583 os <<
"Operation: \"" << op->
getName() <<
"\"\n\n";
588 if ((resultStart + *resultNumber) < resultEnd) {
589 resultStart += *resultNumber;
590 resultEnd = resultStart + 1;
595 if ((resultStart + 1) == resultEnd) {
596 os <<
"Result #" << resultStart <<
"\n\n"
599 os <<
"Result #[" << resultStart <<
", " << (resultEnd - 1) <<
"]\n\n"
601 llvm::interleaveComma(
602 op->
getResults().slice(resultStart, resultEnd), os,
603 [&](
Value result) { os <<
"`" << result.getType() <<
"`"; });
610 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
612 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
613 llvm::raw_string_ostream os(hover.contents.value);
616 auto printBlockToHover = [&](
Block *newBlock) {
617 if (
const auto *def = asmState.getBlockDef(newBlock))
627 os <<
"Predecessors: ";
633 os <<
"Successors: ";
641 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
644 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
645 llvm::raw_string_ostream os(hover.contents.value);
652 <<
"Type: `" << arg.
getType() <<
"`\n\n";
657 lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
659 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
660 llvm::raw_string_ostream os(hover.contents.value);
662 os <<
"Attribute Alias: \"" << attr.
name <<
"\n\n";
663 os <<
"Value: ```mlir\n" << attr.
value <<
"\n```\n\n";
668 lsp::Hover MLIRDocument::buildHoverForTypeAlias(
670 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
671 llvm::raw_string_ostream os(hover.contents.value);
673 os <<
"Type Alias: \"" << type.
name <<
"\n\n";
674 os <<
"Value: ```mlir\n" << type.
value <<
"\n```\n\n";
683 void MLIRDocument::findDocumentSymbols(
684 std::vector<lsp::DocumentSymbol> &symbols) {
686 findDocumentSymbols(&op, symbols);
689 void MLIRDocument::findDocumentSymbols(
690 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
691 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
696 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
697 symbols.emplace_back(symbol.getName(),
698 isa<FunctionOpInterface>(op)
699 ? llvm::lsp::SymbolKind::Function
700 : llvm::lsp::SymbolKind::Class,
701 lsp::Range(sourceMgr, def->scopeLoc),
702 lsp::Range(sourceMgr, def->loc));
703 childSymbols = &symbols.back().children;
708 llvm::lsp::SymbolKind::Namespace,
709 llvm::lsp::Range(sourceMgr, def->scopeLoc),
710 llvm::lsp::Range(sourceMgr, def->loc));
711 childSymbols = &symbols.back().children;
719 for (
Operation &childOp : region.getOps())
720 findDocumentSymbols(&childOp, *childSymbols);
730 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
733 completionList(completionList), ctx(ctx) {}
736 void completeDialectName(StringRef prefix)
final {
737 for (StringRef dialect : ctx->getAvailableDialects()) {
738 llvm::lsp::CompletionItem item(prefix + dialect,
739 llvm::lsp::CompletionItemKind::Module,
741 item.detail =
"dialect";
742 completionList.items.emplace_back(item);
748 void completeOperationName(StringRef dialectName)
final {
749 Dialect *dialect = ctx->getOrLoadDialect(dialectName);
753 for (
const auto &op : ctx->getRegisteredOperations()) {
757 llvm::lsp::CompletionItem item(
758 op.getStringRef().drop_front(dialectName.size() + 1),
759 llvm::lsp::CompletionItemKind::Field,
761 item.detail =
"operation";
762 completionList.items.emplace_back(item);
768 void appendSSAValueCompletion(StringRef name, std::string typeData)
final {
770 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'%';
772 llvm::lsp::CompletionItem item(name,
773 llvm::lsp::CompletionItemKind::Variable);
775 item.insertText = name.drop_front(1).str();
776 item.detail = std::move(typeData);
777 completionList.items.emplace_back(item);
782 void appendBlockCompletion(StringRef name)
final {
784 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'^';
786 llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
788 item.insertText = name.drop_front(1).str();
789 completionList.items.emplace_back(item);
794 for (StringRef token : tokens) {
795 llvm::lsp::CompletionItem item(token,
796 llvm::lsp::CompletionItemKind::Keyword,
798 item.detail = optional ?
"optional" :
"";
799 completionList.items.emplace_back(item);
804 void completeAttribute(
const llvm::StringMap<Attribute> &aliases)
override {
805 appendSimpleCompletions({
"affine_set",
"affine_map",
"dense",
806 "dense_resource",
"false",
"loc",
"sparse",
"true",
808 llvm::lsp::CompletionItemKind::Field,
811 completeDialectName(
"#");
812 completeAliases(aliases,
"#");
814 void completeDialectAttributeOrAlias(
815 const llvm::StringMap<Attribute> &aliases)
override {
816 completeDialectName();
817 completeAliases(aliases);
821 void completeType(
const llvm::StringMap<Type> &aliases)
override {
823 appendSimpleCompletions({
"memref",
"tensor",
"complex",
"tuple",
"vector",
824 "bf16",
"f16",
"f32",
"f64",
"f80",
"f128",
826 llvm::lsp::CompletionItemKind::Field,
830 for (StringRef type : {
"i",
"si",
"ui"}) {
831 llvm::lsp::CompletionItem item(type +
"<N>",
832 llvm::lsp::CompletionItemKind::Field,
834 item.insertText = type.str();
835 completionList.items.emplace_back(item);
839 completeDialectName(
"!");
840 completeAliases(aliases,
"!");
843 completeDialectTypeOrAlias(
const llvm::StringMap<Type> &aliases)
override {
844 completeDialectName();
845 completeAliases(aliases);
849 template <
typename T>
850 void completeAliases(
const llvm::StringMap<T> &aliases,
851 StringRef prefix =
"") {
852 for (
const auto &alias : aliases) {
853 llvm::lsp::CompletionItem item(prefix + alias.getKey(),
854 llvm::lsp::CompletionItemKind::Field,
856 llvm::raw_string_ostream(item.detail) <<
"alias: " << alias.getValue();
857 completionList.items.emplace_back(item);
863 llvm::lsp::CompletionItemKind
kind,
864 StringRef sortText =
"") {
865 for (StringRef completion : completions)
866 completionList.items.emplace_back(completion,
kind, sortText);
870 lsp::CompletionList &completionList;
876 MLIRDocument::getCodeCompletion(
const lsp::URIForFile &uri,
877 const lsp::Position &completePos,
879 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
880 if (!posLoc.isValid())
881 return lsp::CompletionList();
886 tmpContext.allowUnregisteredDialects();
887 lsp::CompletionList completionList;
888 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
894 &lspCompleteContext);
895 return completionList;
902 void MLIRDocument::getCodeActionForDiagnostic(
903 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
904 StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
908 if (message.starts_with(
"see current operation: "))
912 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
913 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
916 StringRef line(lineStart, pos.character);
920 llvm::lsp::TextEdit edit;
921 edit.range = lsp::Range(lsp::Position(pos.line, 0));
924 size_t indent = line.find_first_not_of(
' ');
925 if (indent == StringRef::npos)
926 indent = line.size();
928 edit.newText.append(indent,
' ');
929 llvm::raw_string_ostream(edit.newText)
930 <<
"// expected-" << severity <<
" @below {{" << message <<
"}}\n";
931 edits.emplace_back(std::move(edit));
939 MLIRDocument::convertToBytecode() {
942 if (!llvm::hasSingleElement(parsedIR)) {
943 if (parsedIR.empty()) {
944 return llvm::make_error<llvm::lsp::LSPError>(
945 "expected a single and valid top-level operation, please ensure "
946 "there are no errors",
947 llvm::lsp::ErrorCode::RequestFailed);
949 return llvm::make_error<llvm::lsp::LSPError>(
950 "expected a single top-level operation",
951 llvm::lsp::ErrorCode::RequestFailed);
954 lsp::MLIRConvertBytecodeResult result;
958 std::string rawBytecodeBuffer;
959 llvm::raw_string_ostream os(rawBytecodeBuffer);
962 result.output = llvm::encodeBase64(rawBytecodeBuffer);
973 struct MLIRTextFileChunk {
974 MLIRTextFileChunk(
MLIRContext &context, uint64_t lineOffset,
975 const lsp::URIForFile &uri, StringRef contents,
976 std::vector<lsp::Diagnostic> &diagnostics)
977 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
981 void adjustLocForChunkOffset(lsp::Range &range) {
982 adjustLocForChunkOffset(range.start);
983 adjustLocForChunkOffset(range.end);
987 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
992 MLIRDocument document;
1002 class MLIRTextFile {
1004 MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1006 std::vector<lsp::Diagnostic> &diagnostics);
1009 int64_t getVersion()
const {
return version; }
1015 void getLocationsOf(
const lsp::URIForFile &uri, lsp::Position defPos,
1016 std::vector<lsp::Location> &locations);
1017 void findReferencesOf(
const lsp::URIForFile &uri, lsp::Position pos,
1018 std::vector<lsp::Location> &references);
1019 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
1020 lsp::Position hoverPos);
1021 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1022 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
1023 lsp::Position completePos);
1024 void getCodeActions(
const lsp::URIForFile &uri,
const lsp::Range &pos,
1025 const lsp::CodeActionContext &context,
1026 std::vector<lsp::CodeAction> &actions);
1033 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1039 std::string contents;
1045 int64_t totalNumLines = 0;
1049 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1053 MLIRTextFile::MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1055 std::vector<lsp::Diagnostic> &diagnostics)
1056 : context(registry_fn(uri),
MLIRContext::Threading::DISABLED),
1057 contents(fileContents.str()), version(version) {
1063 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1064 context, 0, uri, subContents.front(), diagnostics));
1066 uint64_t lineOffset = subContents.front().count(
'\n');
1067 for (StringRef docContents : llvm::drop_begin(subContents)) {
1068 unsigned currentNumDiags = diagnostics.size();
1069 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1070 docContents, diagnostics);
1071 lineOffset += docContents.count(
'\n');
1075 for (lsp::Diagnostic &
diag :
1076 llvm::drop_begin(diagnostics, currentNumDiags)) {
1077 chunk->adjustLocForChunkOffset(
diag.range);
1079 if (!
diag.relatedInformation)
1081 for (
auto &it : *
diag.relatedInformation)
1082 if (it.location.uri == uri)
1083 chunk->adjustLocForChunkOffset(it.location.range);
1085 chunks.emplace_back(std::move(chunk));
1087 totalNumLines = lineOffset;
1090 void MLIRTextFile::getLocationsOf(
const lsp::URIForFile &uri,
1091 lsp::Position defPos,
1092 std::vector<lsp::Location> &locations) {
1093 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1094 chunk.document.getLocationsOf(uri, defPos, locations);
1097 if (chunk.lineOffset == 0)
1099 for (lsp::Location &loc : locations)
1101 chunk.adjustLocForChunkOffset(loc.range);
1104 void MLIRTextFile::findReferencesOf(
const lsp::URIForFile &uri,
1106 std::vector<lsp::Location> &references) {
1107 MLIRTextFileChunk &chunk = getChunkFor(pos);
1108 chunk.document.findReferencesOf(uri, pos, references);
1111 if (chunk.lineOffset == 0)
1113 for (lsp::Location &loc : references)
1115 chunk.adjustLocForChunkOffset(loc.range);
1118 std::optional<lsp::Hover> MLIRTextFile::findHover(
const lsp::URIForFile &uri,
1119 lsp::Position hoverPos) {
1120 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1121 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1124 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1125 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1129 void MLIRTextFile::findDocumentSymbols(
1130 std::vector<lsp::DocumentSymbol> &symbols) {
1131 if (chunks.size() == 1)
1132 return chunks.front()->document.findDocumentSymbols(symbols);
1136 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1137 MLIRTextFileChunk &chunk = *chunks[i];
1138 lsp::Position startPos(chunk.lineOffset);
1139 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1140 : chunks[i + 1]->lineOffset);
1141 lsp::DocumentSymbol symbol(
"<file-split-" + Twine(i) +
">",
1142 llvm::lsp::SymbolKind::Namespace,
1143 lsp::Range(startPos, endPos),
1144 lsp::Range(startPos));
1145 chunk.document.findDocumentSymbols(symbol.children);
1150 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1151 symbolsToFix.push_back(&childSymbol);
1153 while (!symbolsToFix.empty()) {
1154 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1155 chunk.adjustLocForChunkOffset(symbol->range);
1156 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1158 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1159 symbolsToFix.push_back(&childSymbol);
1164 symbols.emplace_back(std::move(symbol));
1168 lsp::CompletionList MLIRTextFile::getCodeCompletion(
const lsp::URIForFile &uri,
1169 lsp::Position completePos) {
1170 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1171 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1175 for (llvm::lsp::CompletionItem &item : completionList.items) {
1177 chunk.adjustLocForChunkOffset(item.textEdit->range);
1178 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1179 chunk.adjustLocForChunkOffset(edit.range);
1181 return completionList;
1184 void MLIRTextFile::getCodeActions(
const lsp::URIForFile &uri,
1185 const lsp::Range &pos,
1186 const lsp::CodeActionContext &context,
1187 std::vector<lsp::CodeAction> &actions) {
1189 for (
auto &
diag : context.diagnostics) {
1190 if (
diag.source !=
"mlir")
1192 lsp::Position diagPos =
diag.range.start;
1193 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1196 lsp::CodeAction action;
1197 action.title =
"Add expected-* diagnostic checks";
1198 action.kind = lsp::CodeAction::kQuickFix.str();
1201 switch (
diag.severity) {
1205 case llvm::lsp::DiagnosticSeverity::Warning:
1206 severity =
"warning";
1213 std::vector<llvm::lsp::TextEdit> edits;
1214 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1215 diag.message, edits);
1218 if (
diag.relatedInformation) {
1219 for (
auto ¬eDiag : *
diag.relatedInformation) {
1220 if (noteDiag.location.uri != uri)
1222 diagPos = noteDiag.location.range.start;
1223 diagPos.line -= chunk.lineOffset;
1224 chunk.document.getCodeActionForDiagnostic(uri, diagPos,
"note",
1225 noteDiag.message, edits);
1229 for (llvm::lsp::TextEdit &edit : edits)
1230 chunk.adjustLocForChunkOffset(edit.range);
1232 action.edit.emplace();
1233 action.edit->changes[uri.uri().str()] = std::move(edits);
1234 action.diagnostics = {
diag};
1236 actions.emplace_back(std::move(action));
1241 MLIRTextFile::convertToBytecode() {
1243 if (chunks.size() != 1) {
1244 return llvm::make_error<llvm::lsp::LSPError>(
1245 "unexpected split file, please remove all `// -----`",
1246 llvm::lsp::ErrorCode::RequestFailed);
1248 return chunks.front()->document.convertToBytecode();
1251 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1252 if (chunks.size() == 1)
1253 return *chunks.front();
1257 auto it = llvm::upper_bound(
1258 chunks, pos, [](
const lsp::Position &pos,
const auto &chunk) {
1259 return static_cast<uint64_t
>(pos.line) < chunk->lineOffset;
1261 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1262 pos.line -= chunk.lineOffset;
1278 llvm::StringMap<std::unique_ptr<MLIRTextFile>>
files;
1286 :
impl(std::make_unique<
Impl>(registry_fn)) {}
1290 const URIForFile &uri, StringRef contents, int64_t version,
1291 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1292 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1293 uri, contents, version,
impl->registry_fn, diagnostics);
1297 auto it =
impl->files.find(uri.file());
1298 if (it ==
impl->files.end())
1299 return std::nullopt;
1301 int64_t version = it->second->getVersion();
1302 impl->files.erase(it);
1307 const URIForFile &uri,
const Position &defPos,
1308 std::vector<llvm::lsp::Location> &locations) {
1309 auto fileIt =
impl->files.find(uri.file());
1310 if (fileIt !=
impl->files.end())
1311 fileIt->second->getLocationsOf(uri, defPos, locations);
1315 const URIForFile &uri,
const Position &pos,
1316 std::vector<llvm::lsp::Location> &references) {
1317 auto fileIt =
impl->files.find(uri.file());
1318 if (fileIt !=
impl->files.end())
1319 fileIt->second->findReferencesOf(uri, pos, references);
1323 const Position &hoverPos) {
1324 auto fileIt =
impl->files.find(uri.file());
1325 if (fileIt !=
impl->files.end())
1326 return fileIt->second->findHover(uri, hoverPos);
1327 return std::nullopt;
1331 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1332 auto fileIt =
impl->files.find(uri.file());
1333 if (fileIt !=
impl->files.end())
1334 fileIt->second->findDocumentSymbols(symbols);
1339 const Position &completePos) {
1340 auto fileIt =
impl->files.find(uri.file());
1341 if (fileIt !=
impl->files.end())
1342 return fileIt->second->getCodeCompletion(uri, completePos);
1343 return CompletionList();
1347 const CodeActionContext &context,
1348 std::vector<CodeAction> &actions) {
1349 auto fileIt =
impl->files.find(uri.file());
1350 if (fileIt !=
impl->files.end())
1351 fileIt->second->getCodeActions(uri, pos, context, actions);
1360 std::string errorMsg;
1370 &fallbackResourceMap);
1375 return llvm::make_error<llvm::lsp::LSPError>(
1376 "failed to parse bytecode source file: " + errorMsg,
1377 llvm::lsp::ErrorCode::RequestFailed);
1382 if (!llvm::hasSingleElement(parsedBlock)) {
1383 return llvm::make_error<llvm::lsp::LSPError>(
1384 "expected bytecode to contain a single top-level operation",
1385 llvm::lsp::ErrorCode::RequestFailed);
1397 nullptr, &fallbackResourceMap);
1399 llvm::raw_string_ostream os(result.
output);
1400 topOp->print(os, state);
1402 return std::move(result);
1407 auto fileIt =
impl->files.find(uri.file());
1408 if (fileIt ==
impl->files.end()) {
1409 return llvm::make_error<llvm::lsp::LSPError>(
1410 "language server does not contain an entry for this source file",
1411 llvm::lsp::ErrorCode::RequestFailed);
1413 return fileIt->second->convertToBytecode();
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides an abstract interface into the parser for hooking in code completion events.
void completeDialectName()
This class represents state from a parsed MLIR textual format string.
This class provides management for the lifetime of the state used when printing the IR.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
bool hasNoSuccessors()
Returns true if this blocks has no successors.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
SuccessorRange getSuccessors()
iterator_range< pred_iterator > getPredecessors()
bool hasNoPredecessors()
Return true if this block has no predecessors.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class contains the configuration used for the bytecode writer.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A fallback map containing external resources not explicitly handled by another parser/printer.
An instance of this location represents a tuple of file, line number, and column number.
StringAttr getFilename() const
unsigned getColumn() const
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
A trait used to provide symbol table functionalities to a region operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class represents a configuration for the MLIR assembly parser.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
std::string output
The resultant output of the conversion.
Impl(lsp::DialectRegistryFn registry_fn)
lsp::DialectRegistryFn registry_fn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
This class represents the information for an attribute alias definition within the input file.
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
This class represents the information for an operation definition within an input file.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
This class represents the information for type definition within the input file.
Type value
The value of the alias.
StringRef name
The name of the attribute alias.