20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Base64.h"
22#include "llvm/Support/LSP/Logging.h"
23#include "llvm/Support/SourceMgr.h"
40 lsp::URIForFile::fromFile(loc.
getFilename(), uriScheme);
42 llvm::lsp::Logger::error(
"Failed to create URI for file `{0}`: {1}",
44 llvm::toString(sourceURI.takeError()));
48 lsp::Position position;
49 position.line = loc.
getLine() - 1;
51 return lsp::Location{*sourceURI, lsp::Range(position)};
58static std::optional<lsp::Location>
60 StringRef uriScheme,
const lsp::URIForFile *uri =
nullptr) {
61 std::optional<lsp::Location> location;
67 std::optional<lsp::Location> sourceLoc =
69 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70 location = *sourceLoc;
71 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
76 location->range.end.character += 1;
78 auto lineCol = sourceMgr.getLineAndColumn(range->End);
79 location->range.end.character =
80 std::max(fileLoc.
getColumn() + 1, lineCol.second - 1);
92 std::vector<lsp::Location> &locations,
93 const lsp::URIForFile &uri) {
97 if (!fileLoc || !visitedLocs.insert(nestedLoc))
100 std::optional<lsp::Location> sourceLoc =
102 if (sourceLoc && sourceLoc->uri != uri)
103 locations.push_back(*sourceLoc);
112 return range.Start.getPointer() <= loc.getPointer() &&
113 loc.getPointer() <= range.End.getPointer();
120 SMRange *overlappedRange =
nullptr) {
124 *overlappedRange = def.
loc;
129 const auto *useIt = llvm::find_if(
130 def.
uses, [&](
const SMRange &range) { return contains(range, loc); });
131 if (useIt != def.
uses.end()) {
133 *overlappedRange = *useIt;
143 auto isIdentifierChar = [](
char c) {
144 return isalnum(c) || c ==
'%' || c ==
'$' || c ==
'.' || c ==
'_' ||
147 const char *curPtr = loc.getPointer();
148 while (isIdentifierChar(*curPtr))
157 const char *numberStart = ++curPtr;
158 while (llvm::isDigit(*curPtr))
160 StringRef numberStr(numberStart, curPtr - numberStart);
161 unsigned resultNumber = 0;
162 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
169 if (!range.isValid())
171 const char *startPtr = range.Start.getPointer();
172 return StringRef(startPtr, range.End.getPointer() - startPtr);
177 return std::distance(block->
getParent()->
begin(), block->getIterator());
185 if (text && text->starts_with(
"^")) {
201 const lsp::URIForFile &uri) {
202 lsp::Diagnostic lspDiag;
203 lspDiag.source =
"mlir";
207 lspDiag.category =
"Parse Error";
212 StringRef uriScheme = uri.scheme();
213 std::optional<lsp::Location> lspLocation =
216 lspDiag.range = lspLocation->range;
219 switch (
diag.getSeverity()) {
221 llvm_unreachable(
"expected notes to be handled separately");
223 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
226 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
229 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
232 lspDiag.message =
diag.str();
235 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
237 lsp::Location noteLoc;
238 if (std::optional<lsp::Location> loc =
243 relatedDiags.emplace_back(noteLoc, note.str());
245 if (!relatedDiags.empty())
246 lspDiag.relatedInformation = std::move(relatedDiags);
259 MLIRDocument(MLIRContext &context,
const lsp::URIForFile &uri,
260 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261 MLIRDocument(
const MLIRDocument &) =
delete;
262 MLIRDocument &operator=(
const MLIRDocument &) =
delete;
268 void getLocationsOf(
const lsp::URIForFile &uri,
const lsp::Position &defPos,
269 std::vector<lsp::Location> &locations);
270 void findReferencesOf(
const lsp::URIForFile &uri,
const lsp::Position &pos,
271 std::vector<lsp::Location> &references);
277 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
278 const lsp::Position &hoverPos);
279 std::optional<lsp::Hover>
280 buildHoverForOperation(SMRange hoverRange,
281 const AsmParserState::OperationDefinition &op);
282 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
283 unsigned resultStart,
284 unsigned resultEnd, SMLoc posLoc);
285 lsp::Hover buildHoverForBlock(SMRange hoverRange,
286 const AsmParserState::BlockDefinition &block);
288 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
289 const AsmParserState::BlockDefinition &block);
291 lsp::Hover buildHoverForAttributeAlias(
292 SMRange hoverRange,
const AsmParserState::AttributeAliasDefinition &attr);
294 buildHoverForTypeAlias(SMRange hoverRange,
295 const AsmParserState::TypeAliasDefinition &type);
301 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302 void findDocumentSymbols(Operation *op,
303 std::vector<lsp::DocumentSymbol> &symbols);
309 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
310 const lsp::Position &completePos,
311 const DialectRegistry ®istry);
317 void getCodeActionForDiagnostic(
const lsp::URIForFile &uri,
318 lsp::Position &pos, StringRef severity,
320 std::vector<llvm::lsp::TextEdit> &edits);
326 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
334 AsmParserState asmState;
341 FallbackAsmResourceMap fallbackResourceMap;
344 llvm::SourceMgr sourceMgr;
348MLIRDocument::MLIRDocument(
MLIRContext &context,
const lsp::URIForFile &uri,
350 std::vector<lsp::Diagnostic> &diagnostics) {
351 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &
diag) {
356 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
358 llvm::lsp::Logger::error(
"Failed to create memory buffer for file",
363 ParserConfig
config(&context,
true,
364 &fallbackResourceMap);
365 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
369 asmState = AsmParserState();
370 fallbackResourceMap = FallbackAsmResourceMap();
379void MLIRDocument::getLocationsOf(
const lsp::URIForFile &uri,
380 const lsp::Position &defPos,
381 std::vector<lsp::Location> &locations) {
382 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
385 auto containsPosition = [&](
const AsmParserState::SMDefinition &def) {
388 locations.emplace_back(uri, sourceMgr, def.loc);
393 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
396 for (
const auto &
result : op.resultGroups)
397 if (containsPosition(
result.definition))
399 for (
const auto &symUse : op.symbolUses) {
401 locations.emplace_back(uri, sourceMgr, op.loc);
408 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
409 if (containsPosition(block.definition))
411 for (
const AsmParserState::SMDefinition &arg : block.arguments)
412 if (containsPosition(arg))
417 for (
const AsmParserState::AttributeAliasDefinition &attr :
419 if (containsPosition(attr.definition))
422 for (
const AsmParserState::TypeAliasDefinition &type :
424 if (containsPosition(type.definition))
429void MLIRDocument::findReferencesOf(
const lsp::URIForFile &uri,
430 const lsp::Position &pos,
431 std::vector<lsp::Location> &references) {
434 auto appendSMDef = [&](
const AsmParserState::SMDefinition &def) {
435 references.emplace_back(uri, sourceMgr, def.loc);
436 for (
const SMRange &use : def.uses)
437 references.emplace_back(uri, sourceMgr, use);
440 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
443 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
445 for (
const auto &
result : op.resultGroups)
446 appendSMDef(
result.definition);
447 for (
const auto &symUse : op.symbolUses)
449 references.emplace_back(uri, sourceMgr, symUse);
452 for (
const auto &
result : op.resultGroups)
454 return appendSMDef(
result.definition);
455 for (
const auto &symUse : op.symbolUses) {
458 for (
const auto &symUse : op.symbolUses)
459 references.emplace_back(uri, sourceMgr, symUse);
465 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
467 return appendSMDef(block.definition);
469 for (
const AsmParserState::SMDefinition &arg : block.arguments)
471 return appendSMDef(arg);
475 for (
const AsmParserState::AttributeAliasDefinition &attr :
478 return appendSMDef(attr.definition);
480 for (
const AsmParserState::TypeAliasDefinition &type :
483 return appendSMDef(type.definition);
491std::optional<lsp::Hover>
492MLIRDocument::findHover(
const lsp::URIForFile &uri,
493 const lsp::Position &hoverPos) {
494 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
498 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
501 return buildHoverForOperation(op.loc, op);
504 for (
auto &use : op.symbolUses)
506 return buildHoverForOperation(use, op);
509 for (
unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
510 const auto &
result = op.resultGroups[i];
515 unsigned resultStart =
result.startIndex;
516 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
517 : op.resultGroups[i + 1].startIndex;
518 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
524 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
525 if (
isDefOrUse(block.definition, posLoc, &hoverRange))
526 return buildHoverForBlock(hoverRange, block);
528 for (
const auto &arg : llvm::enumerate(block.arguments)) {
529 if (!
isDefOrUse(arg.value(), posLoc, &hoverRange))
532 return buildHoverForBlockArgument(
533 hoverRange, block.block->
getArgument(arg.index()), block);
538 for (
const AsmParserState::AttributeAliasDefinition &attr :
540 if (
isDefOrUse(attr.definition, posLoc, &hoverRange))
541 return buildHoverForAttributeAlias(hoverRange, attr);
543 for (
const AsmParserState::TypeAliasDefinition &type :
545 if (
isDefOrUse(type.definition, posLoc, &hoverRange))
546 return buildHoverForTypeAlias(hoverRange, type);
552std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
553 SMRange hoverRange,
const AsmParserState::OperationDefinition &op) {
554 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
555 llvm::raw_string_ostream os(hover.contents.value);
559 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.
op))
560 os <<
" : " << symbol.getVisibility() <<
" @" << symbol.getName() <<
"";
563 os <<
"Generic Form:\n\n```mlir\n";
565 op.
op->
print(os, OpPrintingFlags()
566 .printGenericOpForm()
567 .elideLargeElementsAttrs()
574lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
576 unsigned resultStart,
579 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
580 llvm::raw_string_ostream os(hover.contents.value);
583 os <<
"Operation: \"" << op->
getName() <<
"\"\n\n";
588 if ((resultStart + *resultNumber) < resultEnd) {
589 resultStart += *resultNumber;
590 resultEnd = resultStart + 1;
595 if ((resultStart + 1) == resultEnd) {
596 os <<
"Result #" << resultStart <<
"\n\n"
599 os <<
"Result #[" << resultStart <<
", " << (resultEnd - 1) <<
"]\n\n"
601 llvm::interleaveComma(
602 op->
getResults().slice(resultStart, resultEnd), os,
603 [&](Value
result) { os <<
"`" << result.getType() <<
"`"; });
610MLIRDocument::buildHoverForBlock(SMRange hoverRange,
611 const AsmParserState::BlockDefinition &block) {
612 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
613 llvm::raw_string_ostream os(hover.contents.value);
616 auto printBlockToHover = [&](
Block *newBlock) {
617 if (
const auto *def = asmState.
getBlockDef(newBlock))
627 os <<
"Predecessors: ";
633 os <<
"Successors: ";
641lsp::Hover MLIRDocument::buildHoverForBlockArgument(
642 SMRange hoverRange, BlockArgument arg,
643 const AsmParserState::BlockDefinition &block) {
644 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
645 llvm::raw_string_ostream os(hover.contents.value);
652 <<
"Type: `" << arg.
getType() <<
"`\n\n";
657lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
658 SMRange hoverRange,
const AsmParserState::AttributeAliasDefinition &attr) {
659 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
660 llvm::raw_string_ostream os(hover.contents.value);
662 os <<
"Attribute Alias: \"" << attr.
name <<
"\n\n";
663 os <<
"Value: ```mlir\n" << attr.
value <<
"\n```\n\n";
668lsp::Hover MLIRDocument::buildHoverForTypeAlias(
669 SMRange hoverRange,
const AsmParserState::TypeAliasDefinition &type) {
670 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
671 llvm::raw_string_ostream os(hover.contents.value);
673 os <<
"Type Alias: \"" << type.
name <<
"\n\n";
674 os <<
"Value: ```mlir\n" << type.
value <<
"\n```\n\n";
683void MLIRDocument::findDocumentSymbols(
684 std::vector<lsp::DocumentSymbol> &symbols) {
685 for (Operation &op : parsedIR)
686 findDocumentSymbols(&op, symbols);
689void MLIRDocument::findDocumentSymbols(
690 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
691 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
694 if (
const AsmParserState::OperationDefinition *def = asmState.
getOpDef(op)) {
696 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
697 symbols.emplace_back(symbol.getName(),
698 isa<FunctionOpInterface>(op)
699 ? llvm::lsp::SymbolKind::Function
700 : llvm::lsp::SymbolKind::Class,
701 lsp::Range(sourceMgr, def->scopeLoc),
702 lsp::Range(sourceMgr, def->loc));
703 childSymbols = &symbols.back().children;
705 }
else if (op->
hasTrait<OpTrait::SymbolTable>()) {
708 llvm::lsp::SymbolKind::Namespace,
709 llvm::lsp::Range(sourceMgr, def->scopeLoc),
710 llvm::lsp::Range(sourceMgr, def->loc));
711 childSymbols = &symbols.back().children;
719 for (Operation &childOp : region.getOps())
720 findDocumentSymbols(&childOp, *childSymbols);
728class LSPCodeCompleteContext :
public AsmParserCodeCompleteContext {
730 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
732 : AsmParserCodeCompleteContext(completeLoc),
733 completionList(completionList), ctx(ctx) {}
736 void completeDialectName(StringRef prefix)
final {
738 llvm::lsp::CompletionItem item(prefix + dialect,
739 llvm::lsp::CompletionItemKind::Module,
741 item.detail =
"dialect";
742 completionList.items.emplace_back(item);
748 void completeOperationName(StringRef dialectName)
final {
757 llvm::lsp::CompletionItem item(
758 op.getStringRef().drop_front(dialectName.size() + 1),
759 llvm::lsp::CompletionItemKind::Field,
761 item.detail =
"operation";
762 completionList.items.emplace_back(item);
768 void appendSSAValueCompletion(StringRef name, std::string typeData)
final {
770 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'%';
772 llvm::lsp::CompletionItem item(name,
773 llvm::lsp::CompletionItemKind::Variable);
775 item.insertText = name.drop_front(1).str();
776 item.detail = std::move(typeData);
777 completionList.items.emplace_back(item);
782 void appendBlockCompletion(StringRef name)
final {
784 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'^';
786 llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
788 item.insertText = name.drop_front(1).str();
789 completionList.items.emplace_back(item);
793 void completeExpectedTokens(ArrayRef<StringRef> tokens,
bool optional)
final {
794 for (StringRef token : tokens) {
795 llvm::lsp::CompletionItem item(token,
796 llvm::lsp::CompletionItemKind::Keyword,
798 item.detail = optional ?
"optional" :
"";
799 completionList.items.emplace_back(item);
804 void completeAttribute(
const llvm::StringMap<Attribute> &aliases)
override {
805 appendSimpleCompletions({
"affine_set",
"affine_map",
"dense",
806 "dense_resource",
"false",
"loc",
"sparse",
"true",
808 llvm::lsp::CompletionItemKind::Field,
811 completeDialectName(
"#");
812 completeAliases(aliases,
"#");
814 void completeDialectAttributeOrAlias(
815 const llvm::StringMap<Attribute> &aliases)
override {
816 completeDialectName();
817 completeAliases(aliases);
821 void completeType(
const llvm::StringMap<Type> &aliases)
override {
823 appendSimpleCompletions({
"memref",
"tensor",
"complex",
"tuple",
"vector",
824 "bf16",
"f16",
"f32",
"f64",
"f80",
"f128",
826 llvm::lsp::CompletionItemKind::Field,
830 for (StringRef type : {
"i",
"si",
"ui"}) {
831 llvm::lsp::CompletionItem item(type +
"<N>",
832 llvm::lsp::CompletionItemKind::Field,
834 item.insertText = type.str();
835 completionList.items.emplace_back(item);
839 completeDialectName(
"!");
840 completeAliases(aliases,
"!");
843 completeDialectTypeOrAlias(
const llvm::StringMap<Type> &aliases)
override {
844 completeDialectName();
845 completeAliases(aliases);
849 template <
typename T>
850 void completeAliases(
const llvm::StringMap<T> &aliases,
851 StringRef prefix =
"") {
852 for (
const auto &alias : aliases) {
853 llvm::lsp::CompletionItem item(prefix + alias.getKey(),
854 llvm::lsp::CompletionItemKind::Field,
856 llvm::raw_string_ostream(item.detail) <<
"alias: " << alias.getValue();
857 completionList.items.emplace_back(item);
862 void appendSimpleCompletions(ArrayRef<StringRef> completions,
863 llvm::lsp::CompletionItemKind kind,
864 StringRef sortText =
"") {
865 for (StringRef completion : completions)
866 completionList.items.emplace_back(completion, kind, sortText);
870 lsp::CompletionList &completionList;
876MLIRDocument::getCodeCompletion(
const lsp::URIForFile &uri,
877 const lsp::Position &completePos,
878 const DialectRegistry ®istry) {
879 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
880 if (!posLoc.isValid())
881 return lsp::CompletionList();
885 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
886 tmpContext.allowUnregisteredDialects();
887 lsp::CompletionList completionList;
888 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
892 AsmParserState tmpState;
894 &lspCompleteContext);
895 return completionList;
902void MLIRDocument::getCodeActionForDiagnostic(
903 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
904 StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
908 if (message.starts_with(
"see current operation: "))
912 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
913 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
916 StringRef line(lineStart, pos.character);
920 llvm::lsp::TextEdit edit;
921 edit.range = lsp::Range(lsp::Position(pos.line, 0));
924 size_t indent = line.find_first_not_of(
' ');
925 if (indent == StringRef::npos)
926 indent = line.size();
928 edit.newText.append(indent,
' ');
929 llvm::raw_string_ostream(edit.newText)
930 <<
"// expected-" << severity <<
" @below {{" << message <<
"}}\n";
931 edits.emplace_back(std::move(edit));
938llvm::Expected<lsp::MLIRConvertBytecodeResult>
939MLIRDocument::convertToBytecode() {
942 if (!llvm::hasSingleElement(parsedIR)) {
943 if (parsedIR.
empty()) {
944 return llvm::make_error<llvm::lsp::LSPError>(
945 "expected a single and valid top-level operation, please ensure "
946 "there are no errors",
947 llvm::lsp::ErrorCode::RequestFailed);
949 return llvm::make_error<llvm::lsp::LSPError>(
950 "expected a single top-level operation",
951 llvm::lsp::ErrorCode::RequestFailed);
954 lsp::MLIRConvertBytecodeResult
result;
956 BytecodeWriterConfig writerConfig(fallbackResourceMap);
958 std::string rawBytecodeBuffer;
959 llvm::raw_string_ostream os(rawBytecodeBuffer);
962 result.output = llvm::encodeBase64(rawBytecodeBuffer);
973struct MLIRTextFileChunk {
974 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
975 const lsp::URIForFile &uri, StringRef contents,
976 std::vector<lsp::Diagnostic> &diagnostics)
977 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
981 void adjustLocForChunkOffset(lsp::Range &range) {
982 adjustLocForChunkOffset(range.start);
983 adjustLocForChunkOffset(range.end);
987 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
992 MLIRDocument document;
1004 MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1006 std::vector<lsp::Diagnostic> &diagnostics);
1009 int64_t getVersion()
const {
return version; }
1015 void getLocationsOf(
const lsp::URIForFile &uri, lsp::Position defPos,
1016 std::vector<lsp::Location> &locations);
1017 void findReferencesOf(
const lsp::URIForFile &uri, lsp::Position pos,
1018 std::vector<lsp::Location> &references);
1019 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
1020 lsp::Position hoverPos);
1021 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1022 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
1023 lsp::Position completePos);
1024 void getCodeActions(
const lsp::URIForFile &uri,
const lsp::Range &pos,
1025 const lsp::CodeActionContext &context,
1026 std::vector<lsp::CodeAction> &actions);
1027 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1033 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1036 MLIRContext context;
1039 std::string contents;
1045 int64_t totalNumLines = 0;
1049 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1053MLIRTextFile::MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1055 std::vector<lsp::Diagnostic> &diagnostics)
1056 : context(registry_fn(uri), MLIRContext::Threading::DISABLED),
1057 contents(fileContents.str()), version(version) {
1063 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1064 context, 0, uri, subContents.front(), diagnostics));
1066 uint64_t lineOffset = subContents.front().count(
'\n');
1067 for (StringRef docContents : llvm::drop_begin(subContents)) {
1068 unsigned currentNumDiags = diagnostics.size();
1069 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1070 docContents, diagnostics);
1071 lineOffset += docContents.count(
'\n');
1075 for (lsp::Diagnostic &
diag :
1076 llvm::drop_begin(diagnostics, currentNumDiags)) {
1077 chunk->adjustLocForChunkOffset(
diag.range);
1079 if (!
diag.relatedInformation)
1081 for (
auto &it : *
diag.relatedInformation)
1082 if (it.location.uri == uri)
1083 chunk->adjustLocForChunkOffset(it.location.range);
1085 chunks.emplace_back(std::move(chunk));
1087 totalNumLines = lineOffset;
1090void MLIRTextFile::getLocationsOf(
const lsp::URIForFile &uri,
1091 lsp::Position defPos,
1092 std::vector<lsp::Location> &locations) {
1093 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1094 chunk.document.getLocationsOf(uri, defPos, locations);
1097 if (chunk.lineOffset == 0)
1099 for (lsp::Location &loc : locations)
1101 chunk.adjustLocForChunkOffset(loc.range);
1104void MLIRTextFile::findReferencesOf(
const lsp::URIForFile &uri,
1106 std::vector<lsp::Location> &references) {
1107 MLIRTextFileChunk &chunk = getChunkFor(pos);
1108 chunk.document.findReferencesOf(uri, pos, references);
1111 if (chunk.lineOffset == 0)
1113 for (lsp::Location &loc : references)
1115 chunk.adjustLocForChunkOffset(loc.range);
1118std::optional<lsp::Hover> MLIRTextFile::findHover(
const lsp::URIForFile &uri,
1119 lsp::Position hoverPos) {
1120 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1121 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1124 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1125 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1129void MLIRTextFile::findDocumentSymbols(
1130 std::vector<lsp::DocumentSymbol> &symbols) {
1131 if (chunks.size() == 1)
1132 return chunks.front()->document.findDocumentSymbols(symbols);
1136 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1137 MLIRTextFileChunk &chunk = *chunks[i];
1138 lsp::Position startPos(chunk.lineOffset);
1139 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1140 : chunks[i + 1]->lineOffset);
1141 lsp::DocumentSymbol symbol(
"<file-split-" + Twine(i) +
">",
1142 llvm::lsp::SymbolKind::Namespace,
1143 lsp::Range(startPos, endPos),
1144 lsp::Range(startPos));
1145 chunk.document.findDocumentSymbols(symbol.children);
1149 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1150 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1151 symbolsToFix.push_back(&childSymbol);
1153 while (!symbolsToFix.empty()) {
1154 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1155 chunk.adjustLocForChunkOffset(symbol->range);
1156 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1158 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1159 symbolsToFix.push_back(&childSymbol);
1164 symbols.emplace_back(std::move(symbol));
1168lsp::CompletionList MLIRTextFile::getCodeCompletion(
const lsp::URIForFile &uri,
1169 lsp::Position completePos) {
1170 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1171 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1175 for (llvm::lsp::CompletionItem &item : completionList.items) {
1177 chunk.adjustLocForChunkOffset(item.textEdit->range);
1178 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1179 chunk.adjustLocForChunkOffset(edit.range);
1181 return completionList;
1184void MLIRTextFile::getCodeActions(
const lsp::URIForFile &uri,
1185 const lsp::Range &pos,
1186 const lsp::CodeActionContext &context,
1187 std::vector<lsp::CodeAction> &actions) {
1189 for (
auto &
diag : context.diagnostics) {
1190 if (
diag.source !=
"mlir")
1192 lsp::Position diagPos =
diag.range.start;
1193 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1196 lsp::CodeAction action;
1197 action.title =
"Add expected-* diagnostic checks";
1198 action.kind = lsp::CodeAction::kQuickFix.str();
1201 switch (
diag.severity) {
1202 case llvm::lsp::DiagnosticSeverity::Error:
1205 case llvm::lsp::DiagnosticSeverity::Warning:
1206 severity =
"warning";
1213 std::vector<llvm::lsp::TextEdit> edits;
1214 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1215 diag.message, edits);
1218 if (
diag.relatedInformation) {
1219 for (
auto ¬eDiag : *
diag.relatedInformation) {
1220 if (noteDiag.location.uri != uri)
1222 diagPos = noteDiag.location.range.start;
1223 diagPos.line -= chunk.lineOffset;
1224 chunk.document.getCodeActionForDiagnostic(uri, diagPos,
"note",
1225 noteDiag.message, edits);
1229 for (llvm::lsp::TextEdit &edit : edits)
1230 chunk.adjustLocForChunkOffset(edit.range);
1232 action.edit.emplace();
1233 action.edit->changes[uri.uri().str()] = std::move(edits);
1234 action.diagnostics = {
diag};
1236 actions.emplace_back(std::move(action));
1240llvm::Expected<lsp::MLIRConvertBytecodeResult>
1241MLIRTextFile::convertToBytecode() {
1243 if (chunks.size() != 1) {
1244 return llvm::make_error<llvm::lsp::LSPError>(
1245 "unexpected split file, please remove all `// -----`",
1246 llvm::lsp::ErrorCode::RequestFailed);
1248 return chunks.front()->document.convertToBytecode();
1251MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1252 if (chunks.size() == 1)
1253 return *chunks.front();
1257 auto it = llvm::upper_bound(
1258 chunks, pos, [](
const lsp::Position &pos,
const auto &chunk) {
1259 return static_cast<uint64_t
>(pos.line) < chunk->lineOffset;
1261 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1262 pos.line -= chunk.lineOffset;
1278 llvm::StringMap<std::unique_ptr<MLIRTextFile>>
files;
1286 :
impl(std::make_unique<
Impl>(registry_fn)) {}
1290 const URIForFile &uri, StringRef contents,
int64_t version,
1291 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1292 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1293 uri, contents, version,
impl->registry_fn, diagnostics);
1297 auto it =
impl->files.find(uri.file());
1298 if (it ==
impl->files.end())
1299 return std::nullopt;
1301 int64_t version = it->second->getVersion();
1302 impl->files.erase(it);
1307 const URIForFile &uri,
const Position &defPos,
1308 std::vector<llvm::lsp::Location> &locations) {
1309 auto fileIt =
impl->files.find(uri.file());
1310 if (fileIt !=
impl->files.end())
1311 fileIt->second->getLocationsOf(uri, defPos, locations);
1315 const URIForFile &uri,
const Position &pos,
1316 std::vector<llvm::lsp::Location> &references) {
1317 auto fileIt =
impl->files.find(uri.file());
1318 if (fileIt !=
impl->files.end())
1319 fileIt->second->findReferencesOf(uri, pos, references);
1323 const Position &hoverPos) {
1324 auto fileIt =
impl->files.find(uri.file());
1325 if (fileIt !=
impl->files.end())
1326 return fileIt->second->findHover(uri, hoverPos);
1327 return std::nullopt;
1331 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1332 auto fileIt =
impl->files.find(uri.file());
1333 if (fileIt !=
impl->files.end())
1334 fileIt->second->findDocumentSymbols(symbols);
1339 const Position &completePos) {
1340 auto fileIt =
impl->files.find(uri.file());
1341 if (fileIt !=
impl->files.end())
1342 return fileIt->second->getCodeCompletion(uri, completePos);
1343 return CompletionList();
1347 const CodeActionContext &context,
1348 std::vector<CodeAction> &actions) {
1349 auto fileIt =
impl->files.find(uri.file());
1350 if (fileIt !=
impl->files.end())
1351 fileIt->second->getCodeActions(uri, pos, context, actions);
1360 std::string errorMsg;
1370 &fallbackResourceMap);
1374 if (failed(
parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1375 return llvm::make_error<llvm::lsp::LSPError>(
1376 "failed to parse bytecode source file: " + errorMsg,
1377 llvm::lsp::ErrorCode::RequestFailed);
1382 if (!llvm::hasSingleElement(parsedBlock)) {
1383 return llvm::make_error<llvm::lsp::LSPError>(
1384 "expected bytecode to contain a single top-level operation",
1385 llvm::lsp::ErrorCode::RequestFailed);
1397 nullptr, &fallbackResourceMap);
1399 llvm::raw_string_ostream os(
result.output);
1400 topOp->print(os, state);
1402 return std::move(
result);
1407 auto fileIt =
impl->files.find(uri.file());
1408 if (fileIt ==
impl->files.end()) {
1409 return llvm::make_error<llvm::lsp::LSPError>(
1410 "language server does not contain an entry for this source file",
1411 llvm::lsp::ErrorCode::RequestFailed);
1413 return fileIt->second->convertToBytecode();
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
static unsigned getBlockNumber(Block *block)
Given a block, return its position in its parent region.
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
static std::string diag(const llvm::Value &value)
void completeDialectName()
iterator_range< AttributeDefIterator > getAttributeAliasDefs() const
Return a range of the AttributeAliasDefinitions held by the current parser state.
iterator_range< BlockDefIterator > getBlockDefs() const
Return a range of the BlockDefinitions held by the current parser state.
const OperationDefinition * getOpDef(Operation *op) const
Return the definition for the given operation, or nullptr if the given operation does not have a defi...
const BlockDefinition * getBlockDef(Block *block) const
Return the definition for the given block, or nullptr if the given block does not have a definition.
iterator_range< OperationDefIterator > getOpDefs() const
Return a range of the OperationDefinitions held by the current parser state.
iterator_range< TypeDefIterator > getTypeAliasDefs() const
Return a range of the TypeAliasDefinitions held by the current parser state.
This class provides management for the lifetime of the state used when printing the IR.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
bool hasNoSuccessors()
Returns true if this blocks has no successors.
iterator_range< pred_iterator > getPredecessors()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
SuccessorRange getSuccessors()
bool hasNoPredecessors()
Return true if this block has no predecessors.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A fallback map containing external resources not explicitly handled by another parser/printer.
An instance of this location represents a tuple of file, line number, and column number.
StringAttr getFilename() const
unsigned getColumn() const
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
std::vector< StringRef > getAvailableDialects()
Return information about all available dialects in the registry in this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class represents a configuration for the MLIR assembly parser.
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
MLIRServer(DialectRegistryFn registry_fn)
Construct a new server with the given dialect registry function.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
llvm::function_ref< DialectRegistry &(const llvm::lsp::URIForFile &uri)> DialectRegistryFn
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::SetVector< T, Vector, Set, N > SetVector
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
Impl(lsp::DialectRegistryFn registry_fn)
lsp::DialectRegistryFn registry_fn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
Type value
The value of the alias.
StringRef name
The name of the attribute alias.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...