20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Base64.h"
22#include "llvm/Support/LSP/Logging.h"
23#include "llvm/Support/SourceMgr.h"
40 lsp::URIForFile::fromFile(loc.
getFilename(), uriScheme);
42 llvm::lsp::Logger::error(
"Failed to create URI for file `{0}`: {1}",
44 llvm::toString(sourceURI.takeError()));
48 lsp::Position position;
49 position.line = loc.
getLine() - 1;
51 return lsp::Location{*sourceURI, lsp::Range(position)};
58static std::optional<lsp::Location>
60 StringRef uriScheme,
const lsp::URIForFile *uri =
nullptr) {
61 std::optional<lsp::Location> location;
67 std::optional<lsp::Location> sourceLoc =
69 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70 location = *sourceLoc;
71 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
76 location->range.end.character += 1;
78 auto lineCol = sourceMgr.getLineAndColumn(range->End);
79 location->range.end.character =
80 std::max(fileLoc.
getColumn() + 1, lineCol.second - 1);
92 std::vector<lsp::Location> &locations,
93 const lsp::URIForFile &uri) {
97 if (!fileLoc || !visitedLocs.insert(nestedLoc))
100 std::optional<lsp::Location> sourceLoc =
102 if (sourceLoc && sourceLoc->uri != uri)
103 locations.push_back(*sourceLoc);
112 return range.Start.getPointer() <= loc.getPointer() &&
113 loc.getPointer() <= range.End.getPointer();
120 SMRange *overlappedRange =
nullptr) {
124 *overlappedRange = def.
loc;
129 const auto *useIt = llvm::find_if(
130 def.
uses, [&](
const SMRange &range) { return contains(range, loc); });
131 if (useIt != def.
uses.end()) {
133 *overlappedRange = *useIt;
143 auto isIdentifierChar = [](
char c) {
144 return isalnum(c) || c ==
'%' || c ==
'$' || c ==
'.' || c ==
'_' ||
147 const char *curPtr = loc.getPointer();
148 while (isIdentifierChar(*curPtr))
157 const char *numberStart = ++curPtr;
158 while (llvm::isDigit(*curPtr))
160 StringRef numberStr(numberStart, curPtr - numberStart);
161 unsigned resultNumber = 0;
162 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
169 if (!range.isValid())
171 const char *startPtr = range.Start.getPointer();
172 return StringRef(startPtr, range.End.getPointer() - startPtr);
180 if (text && text->starts_with(
"^")) {
196 const lsp::URIForFile &uri) {
197 lsp::Diagnostic lspDiag;
198 lspDiag.source =
"mlir";
202 lspDiag.category =
"Parse Error";
207 StringRef uriScheme = uri.scheme();
208 std::optional<lsp::Location> lspLocation =
211 lspDiag.range = lspLocation->range;
214 switch (
diag.getSeverity()) {
216 llvm_unreachable(
"expected notes to be handled separately");
218 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
221 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
224 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
227 lspDiag.message =
diag.str();
230 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
232 lsp::Location noteLoc;
233 if (std::optional<lsp::Location> loc =
238 relatedDiags.emplace_back(noteLoc, note.str());
240 if (!relatedDiags.empty())
241 lspDiag.relatedInformation = std::move(relatedDiags);
254 MLIRDocument(MLIRContext &context,
const lsp::URIForFile &uri,
255 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
256 MLIRDocument(
const MLIRDocument &) =
delete;
257 MLIRDocument &operator=(
const MLIRDocument &) =
delete;
263 void getLocationsOf(
const lsp::URIForFile &uri,
const lsp::Position &defPos,
264 std::vector<lsp::Location> &locations);
265 void findReferencesOf(
const lsp::URIForFile &uri,
const lsp::Position &pos,
266 std::vector<lsp::Location> &references);
272 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
273 const lsp::Position &hoverPos);
274 std::optional<lsp::Hover>
275 buildHoverForOperation(SMRange hoverRange,
276 const AsmParserState::OperationDefinition &op);
277 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
278 unsigned resultStart,
279 unsigned resultEnd, SMLoc posLoc);
280 lsp::Hover buildHoverForBlock(SMRange hoverRange,
281 const AsmParserState::BlockDefinition &block);
283 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
284 const AsmParserState::BlockDefinition &block);
286 lsp::Hover buildHoverForAttributeAlias(
287 SMRange hoverRange,
const AsmParserState::AttributeAliasDefinition &attr);
289 buildHoverForTypeAlias(SMRange hoverRange,
290 const AsmParserState::TypeAliasDefinition &type);
296 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
297 void findDocumentSymbols(Operation *op,
298 std::vector<lsp::DocumentSymbol> &symbols);
304 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
305 const lsp::Position &completePos,
306 const DialectRegistry ®istry);
312 void getCodeActionForDiagnostic(
const lsp::URIForFile &uri,
313 lsp::Position &pos, StringRef severity,
315 std::vector<llvm::lsp::TextEdit> &edits);
321 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
329 AsmParserState asmState;
336 FallbackAsmResourceMap fallbackResourceMap;
339 llvm::SourceMgr sourceMgr;
343MLIRDocument::MLIRDocument(
MLIRContext &context,
const lsp::URIForFile &uri,
345 std::vector<lsp::Diagnostic> &diagnostics) {
346 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &
diag) {
351 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
353 llvm::lsp::Logger::error(
"Failed to create memory buffer for file",
358 ParserConfig
config(&context,
true,
359 &fallbackResourceMap);
360 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
364 asmState = AsmParserState();
365 fallbackResourceMap = FallbackAsmResourceMap();
374void MLIRDocument::getLocationsOf(
const lsp::URIForFile &uri,
375 const lsp::Position &defPos,
376 std::vector<lsp::Location> &locations) {
377 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
380 auto containsPosition = [&](
const AsmParserState::SMDefinition &def) {
383 locations.emplace_back(uri, sourceMgr, def.loc);
388 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
391 for (
const auto &
result : op.resultGroups)
392 if (containsPosition(
result.definition))
394 for (
const auto &symUse : op.symbolUses) {
396 locations.emplace_back(uri, sourceMgr, op.loc);
403 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
404 if (containsPosition(block.definition))
406 for (
const AsmParserState::SMDefinition &arg : block.arguments)
407 if (containsPosition(arg))
412 for (
const AsmParserState::AttributeAliasDefinition &attr :
414 if (containsPosition(attr.definition))
417 for (
const AsmParserState::TypeAliasDefinition &type :
419 if (containsPosition(type.definition))
424void MLIRDocument::findReferencesOf(
const lsp::URIForFile &uri,
425 const lsp::Position &pos,
426 std::vector<lsp::Location> &references) {
429 auto appendSMDef = [&](
const AsmParserState::SMDefinition &def) {
430 references.emplace_back(uri, sourceMgr, def.loc);
431 for (
const SMRange &use : def.uses)
432 references.emplace_back(uri, sourceMgr, use);
435 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
438 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
440 for (
const auto &
result : op.resultGroups)
441 appendSMDef(
result.definition);
442 for (
const auto &symUse : op.symbolUses)
444 references.emplace_back(uri, sourceMgr, symUse);
447 for (
const auto &
result : op.resultGroups)
449 return appendSMDef(
result.definition);
450 for (
const auto &symUse : op.symbolUses) {
453 for (
const auto &symUse : op.symbolUses)
454 references.emplace_back(uri, sourceMgr, symUse);
460 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
462 return appendSMDef(block.definition);
464 for (
const AsmParserState::SMDefinition &arg : block.arguments)
466 return appendSMDef(arg);
470 for (
const AsmParserState::AttributeAliasDefinition &attr :
473 return appendSMDef(attr.definition);
475 for (
const AsmParserState::TypeAliasDefinition &type :
478 return appendSMDef(type.definition);
486std::optional<lsp::Hover>
487MLIRDocument::findHover(
const lsp::URIForFile &uri,
488 const lsp::Position &hoverPos) {
489 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
493 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
496 return buildHoverForOperation(op.loc, op);
499 for (
auto &use : op.symbolUses)
501 return buildHoverForOperation(use, op);
504 for (
unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
505 const auto &
result = op.resultGroups[i];
510 unsigned resultStart =
result.startIndex;
511 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
512 : op.resultGroups[i + 1].startIndex;
513 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
519 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
520 if (
isDefOrUse(block.definition, posLoc, &hoverRange))
521 return buildHoverForBlock(hoverRange, block);
523 for (
const auto &arg : llvm::enumerate(block.arguments)) {
524 if (!
isDefOrUse(arg.value(), posLoc, &hoverRange))
527 return buildHoverForBlockArgument(
528 hoverRange, block.block->
getArgument(arg.index()), block);
533 for (
const AsmParserState::AttributeAliasDefinition &attr :
535 if (
isDefOrUse(attr.definition, posLoc, &hoverRange))
536 return buildHoverForAttributeAlias(hoverRange, attr);
538 for (
const AsmParserState::TypeAliasDefinition &type :
540 if (
isDefOrUse(type.definition, posLoc, &hoverRange))
541 return buildHoverForTypeAlias(hoverRange, type);
547std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
548 SMRange hoverRange,
const AsmParserState::OperationDefinition &op) {
549 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
550 llvm::raw_string_ostream os(hover.contents.value);
554 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.
op))
555 os <<
" : " << symbol.getVisibility() <<
" @" << symbol.getName() <<
"";
558 os <<
"Generic Form:\n\n```mlir\n";
560 op.
op->
print(os, OpPrintingFlags()
561 .printGenericOpForm()
562 .elideLargeElementsAttrs()
569lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
571 unsigned resultStart,
574 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
575 llvm::raw_string_ostream os(hover.contents.value);
578 os <<
"Operation: \"" << op->
getName() <<
"\"\n\n";
583 if ((resultStart + *resultNumber) < resultEnd) {
584 resultStart += *resultNumber;
585 resultEnd = resultStart + 1;
590 if ((resultStart + 1) == resultEnd) {
591 os <<
"Result #" << resultStart <<
"\n\n"
594 os <<
"Result #[" << resultStart <<
", " << (resultEnd - 1) <<
"]\n\n"
596 llvm::interleaveComma(
597 op->
getResults().slice(resultStart, resultEnd), os,
598 [&](Value
result) { os <<
"`" << result.getType() <<
"`"; });
605MLIRDocument::buildHoverForBlock(SMRange hoverRange,
606 const AsmParserState::BlockDefinition &block) {
607 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
608 llvm::raw_string_ostream os(hover.contents.value);
611 auto printBlockToHover = [&](
Block *newBlock) {
612 if (
const auto *def = asmState.
getBlockDef(newBlock))
622 os <<
"Predecessors: ";
628 os <<
"Successors: ";
636lsp::Hover MLIRDocument::buildHoverForBlockArgument(
637 SMRange hoverRange, BlockArgument arg,
638 const AsmParserState::BlockDefinition &block) {
639 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
640 llvm::raw_string_ostream os(hover.contents.value);
647 <<
"Type: `" << arg.
getType() <<
"`\n\n";
652lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
653 SMRange hoverRange,
const AsmParserState::AttributeAliasDefinition &attr) {
654 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
655 llvm::raw_string_ostream os(hover.contents.value);
657 os <<
"Attribute Alias: \"" << attr.
name <<
"\n\n";
658 os <<
"Value: ```mlir\n" << attr.
value <<
"\n```\n\n";
663lsp::Hover MLIRDocument::buildHoverForTypeAlias(
664 SMRange hoverRange,
const AsmParserState::TypeAliasDefinition &type) {
665 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
666 llvm::raw_string_ostream os(hover.contents.value);
668 os <<
"Type Alias: \"" << type.
name <<
"\n\n";
669 os <<
"Value: ```mlir\n" << type.
value <<
"\n```\n\n";
678void MLIRDocument::findDocumentSymbols(
679 std::vector<lsp::DocumentSymbol> &symbols) {
680 for (Operation &op : parsedIR)
681 findDocumentSymbols(&op, symbols);
684void MLIRDocument::findDocumentSymbols(
685 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
686 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
689 if (
const AsmParserState::OperationDefinition *def = asmState.
getOpDef(op)) {
691 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
692 symbols.emplace_back(symbol.getName(),
693 isa<FunctionOpInterface>(op)
694 ? llvm::lsp::SymbolKind::Function
695 : llvm::lsp::SymbolKind::Class,
696 lsp::Range(sourceMgr, def->scopeLoc),
697 lsp::Range(sourceMgr, def->loc));
698 childSymbols = &symbols.back().children;
700 }
else if (op->
hasTrait<OpTrait::SymbolTable>()) {
703 llvm::lsp::SymbolKind::Namespace,
704 llvm::lsp::Range(sourceMgr, def->scopeLoc),
705 llvm::lsp::Range(sourceMgr, def->loc));
706 childSymbols = &symbols.back().children;
714 for (Operation &childOp : region.getOps())
715 findDocumentSymbols(&childOp, *childSymbols);
723class LSPCodeCompleteContext :
public AsmParserCodeCompleteContext {
725 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
727 : AsmParserCodeCompleteContext(completeLoc),
728 completionList(completionList), ctx(ctx) {}
731 void completeDialectName(StringRef prefix)
final {
733 llvm::lsp::CompletionItem item(prefix + dialect,
734 llvm::lsp::CompletionItemKind::Module,
736 item.detail =
"dialect";
737 completionList.items.emplace_back(item);
743 void completeOperationName(StringRef dialectName)
final {
752 llvm::lsp::CompletionItem item(
753 op.getStringRef().drop_front(dialectName.size() + 1),
754 llvm::lsp::CompletionItemKind::Field,
756 item.detail =
"operation";
757 completionList.items.emplace_back(item);
763 void appendSSAValueCompletion(StringRef name, std::string typeData)
final {
765 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'%';
767 llvm::lsp::CompletionItem item(name,
768 llvm::lsp::CompletionItemKind::Variable);
770 item.insertText = name.drop_front(1).str();
771 item.detail = std::move(typeData);
772 completionList.items.emplace_back(item);
777 void appendBlockCompletion(StringRef name)
final {
779 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'^';
781 llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
783 item.insertText = name.drop_front(1).str();
784 completionList.items.emplace_back(item);
788 void completeExpectedTokens(ArrayRef<StringRef> tokens,
bool optional)
final {
789 for (StringRef token : tokens) {
790 llvm::lsp::CompletionItem item(token,
791 llvm::lsp::CompletionItemKind::Keyword,
793 item.detail = optional ?
"optional" :
"";
794 completionList.items.emplace_back(item);
799 void completeAttribute(
const llvm::StringMap<Attribute> &aliases)
override {
800 appendSimpleCompletions({
"affine_set",
"affine_map",
"dense",
801 "dense_resource",
"false",
"loc",
"sparse",
"true",
803 llvm::lsp::CompletionItemKind::Field,
806 completeDialectName(
"#");
807 completeAliases(aliases,
"#");
809 void completeDialectAttributeOrAlias(
810 const llvm::StringMap<Attribute> &aliases)
override {
811 completeDialectName();
812 completeAliases(aliases);
816 void completeType(
const llvm::StringMap<Type> &aliases)
override {
818 appendSimpleCompletions({
"memref",
"tensor",
"complex",
"tuple",
"vector",
819 "bf16",
"f16",
"f32",
"f64",
"f80",
"f128",
821 llvm::lsp::CompletionItemKind::Field,
825 for (StringRef type : {
"i",
"si",
"ui"}) {
826 llvm::lsp::CompletionItem item(type +
"<N>",
827 llvm::lsp::CompletionItemKind::Field,
829 item.insertText = type.str();
830 completionList.items.emplace_back(item);
834 completeDialectName(
"!");
835 completeAliases(aliases,
"!");
838 completeDialectTypeOrAlias(
const llvm::StringMap<Type> &aliases)
override {
839 completeDialectName();
840 completeAliases(aliases);
844 template <
typename T>
845 void completeAliases(
const llvm::StringMap<T> &aliases,
846 StringRef prefix =
"") {
847 for (
const auto &alias : aliases) {
848 llvm::lsp::CompletionItem item(prefix + alias.getKey(),
849 llvm::lsp::CompletionItemKind::Field,
851 llvm::raw_string_ostream(item.detail) <<
"alias: " << alias.getValue();
852 completionList.items.emplace_back(item);
857 void appendSimpleCompletions(ArrayRef<StringRef> completions,
858 llvm::lsp::CompletionItemKind kind,
859 StringRef sortText =
"") {
860 for (StringRef completion : completions)
861 completionList.items.emplace_back(completion, kind, sortText);
865 lsp::CompletionList &completionList;
871MLIRDocument::getCodeCompletion(
const lsp::URIForFile &uri,
872 const lsp::Position &completePos,
873 const DialectRegistry ®istry) {
874 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
875 if (!posLoc.isValid())
876 return lsp::CompletionList();
880 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
881 tmpContext.allowUnregisteredDialects();
882 lsp::CompletionList completionList;
883 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
887 AsmParserState tmpState;
889 &lspCompleteContext);
890 return completionList;
897void MLIRDocument::getCodeActionForDiagnostic(
898 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
899 StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
903 if (message.starts_with(
"see current operation: "))
907 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
908 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
911 StringRef line(lineStart, pos.character);
915 llvm::lsp::TextEdit edit;
916 edit.range = lsp::Range(lsp::Position(pos.line, 0));
919 size_t indent = line.find_first_not_of(
' ');
920 if (indent == StringRef::npos)
921 indent = line.size();
923 edit.newText.append(indent,
' ');
924 llvm::raw_string_ostream(edit.newText)
925 <<
"// expected-" << severity <<
" @below {{" << message <<
"}}\n";
926 edits.emplace_back(std::move(edit));
933llvm::Expected<lsp::MLIRConvertBytecodeResult>
934MLIRDocument::convertToBytecode() {
937 if (!llvm::hasSingleElement(parsedIR)) {
938 if (parsedIR.
empty()) {
939 return llvm::make_error<llvm::lsp::LSPError>(
940 "expected a single and valid top-level operation, please ensure "
941 "there are no errors",
942 llvm::lsp::ErrorCode::RequestFailed);
944 return llvm::make_error<llvm::lsp::LSPError>(
945 "expected a single top-level operation",
946 llvm::lsp::ErrorCode::RequestFailed);
949 lsp::MLIRConvertBytecodeResult
result;
951 BytecodeWriterConfig writerConfig(fallbackResourceMap);
953 std::string rawBytecodeBuffer;
954 llvm::raw_string_ostream os(rawBytecodeBuffer);
957 result.output = llvm::encodeBase64(rawBytecodeBuffer);
968struct MLIRTextFileChunk {
969 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
970 const lsp::URIForFile &uri, StringRef contents,
971 std::vector<lsp::Diagnostic> &diagnostics)
972 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
976 void adjustLocForChunkOffset(lsp::Range &range) {
977 adjustLocForChunkOffset(range.start);
978 adjustLocForChunkOffset(range.end);
982 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
987 MLIRDocument document;
999 MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1001 std::vector<lsp::Diagnostic> &diagnostics);
1004 int64_t getVersion()
const {
return version; }
1010 void getLocationsOf(
const lsp::URIForFile &uri, lsp::Position defPos,
1011 std::vector<lsp::Location> &locations);
1012 void findReferencesOf(
const lsp::URIForFile &uri, lsp::Position pos,
1013 std::vector<lsp::Location> &references);
1014 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
1015 lsp::Position hoverPos);
1016 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1017 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
1018 lsp::Position completePos);
1019 void getCodeActions(
const lsp::URIForFile &uri,
const lsp::Range &pos,
1020 const lsp::CodeActionContext &context,
1021 std::vector<lsp::CodeAction> &actions);
1022 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1028 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1031 MLIRContext context;
1034 std::string contents;
1040 int64_t totalNumLines = 0;
1044 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1048MLIRTextFile::MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1050 std::vector<lsp::Diagnostic> &diagnostics)
1051 : context(registryFn(uri), MLIRContext::Threading::
DISABLED),
1052 contents(fileContents.str()), version(version) {
1058 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1059 context, 0, uri, subContents.front(), diagnostics));
1061 uint64_t lineOffset = subContents.front().count(
'\n');
1062 for (StringRef docContents : llvm::drop_begin(subContents)) {
1063 unsigned currentNumDiags = diagnostics.size();
1064 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1065 docContents, diagnostics);
1066 lineOffset += docContents.count(
'\n');
1070 for (lsp::Diagnostic &
diag :
1071 llvm::drop_begin(diagnostics, currentNumDiags)) {
1072 chunk->adjustLocForChunkOffset(
diag.range);
1074 if (!
diag.relatedInformation)
1076 for (
auto &it : *
diag.relatedInformation)
1077 if (it.location.uri == uri)
1078 chunk->adjustLocForChunkOffset(it.location.range);
1080 chunks.emplace_back(std::move(chunk));
1082 totalNumLines = lineOffset;
1085void MLIRTextFile::getLocationsOf(
const lsp::URIForFile &uri,
1086 lsp::Position defPos,
1087 std::vector<lsp::Location> &locations) {
1088 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1089 chunk.document.getLocationsOf(uri, defPos, locations);
1092 if (chunk.lineOffset == 0)
1094 for (lsp::Location &loc : locations)
1096 chunk.adjustLocForChunkOffset(loc.range);
1099void MLIRTextFile::findReferencesOf(
const lsp::URIForFile &uri,
1101 std::vector<lsp::Location> &references) {
1102 MLIRTextFileChunk &chunk = getChunkFor(pos);
1103 chunk.document.findReferencesOf(uri, pos, references);
1106 if (chunk.lineOffset == 0)
1108 for (lsp::Location &loc : references)
1110 chunk.adjustLocForChunkOffset(loc.range);
1113std::optional<lsp::Hover> MLIRTextFile::findHover(
const lsp::URIForFile &uri,
1114 lsp::Position hoverPos) {
1115 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1116 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1119 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1120 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1124void MLIRTextFile::findDocumentSymbols(
1125 std::vector<lsp::DocumentSymbol> &symbols) {
1126 if (chunks.size() == 1)
1127 return chunks.front()->document.findDocumentSymbols(symbols);
1131 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1132 MLIRTextFileChunk &chunk = *chunks[i];
1133 lsp::Position startPos(chunk.lineOffset);
1134 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1135 : chunks[i + 1]->lineOffset);
1136 lsp::DocumentSymbol symbol(
"<file-split-" + Twine(i) +
">",
1137 llvm::lsp::SymbolKind::Namespace,
1138 lsp::Range(startPos, endPos),
1139 lsp::Range(startPos));
1140 chunk.document.findDocumentSymbols(symbol.children);
1144 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1145 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1146 symbolsToFix.push_back(&childSymbol);
1148 while (!symbolsToFix.empty()) {
1149 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1150 chunk.adjustLocForChunkOffset(symbol->range);
1151 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1153 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1154 symbolsToFix.push_back(&childSymbol);
1159 symbols.emplace_back(std::move(symbol));
1163lsp::CompletionList MLIRTextFile::getCodeCompletion(
const lsp::URIForFile &uri,
1164 lsp::Position completePos) {
1165 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1166 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1170 for (llvm::lsp::CompletionItem &item : completionList.items) {
1172 chunk.adjustLocForChunkOffset(item.textEdit->range);
1173 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1174 chunk.adjustLocForChunkOffset(edit.range);
1176 return completionList;
1179void MLIRTextFile::getCodeActions(
const lsp::URIForFile &uri,
1180 const lsp::Range &pos,
1181 const lsp::CodeActionContext &context,
1182 std::vector<lsp::CodeAction> &actions) {
1184 for (
auto &
diag : context.diagnostics) {
1185 if (
diag.source !=
"mlir")
1187 lsp::Position diagPos =
diag.range.start;
1188 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1191 lsp::CodeAction action;
1192 action.title =
"Add expected-* diagnostic checks";
1193 action.kind = lsp::CodeAction::kQuickFix.str();
1196 switch (
diag.severity) {
1197 case llvm::lsp::DiagnosticSeverity::Error:
1200 case llvm::lsp::DiagnosticSeverity::Warning:
1201 severity =
"warning";
1208 std::vector<llvm::lsp::TextEdit> edits;
1209 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1210 diag.message, edits);
1213 if (
diag.relatedInformation) {
1214 for (
auto ¬eDiag : *
diag.relatedInformation) {
1215 if (noteDiag.location.uri != uri)
1217 diagPos = noteDiag.location.range.start;
1218 diagPos.line -= chunk.lineOffset;
1219 chunk.document.getCodeActionForDiagnostic(uri, diagPos,
"note",
1220 noteDiag.message, edits);
1224 for (llvm::lsp::TextEdit &edit : edits)
1225 chunk.adjustLocForChunkOffset(edit.range);
1227 action.edit.emplace();
1228 action.edit->changes[uri.uri().str()] = std::move(edits);
1229 action.diagnostics = {
diag};
1231 actions.emplace_back(std::move(action));
1235llvm::Expected<lsp::MLIRConvertBytecodeResult>
1236MLIRTextFile::convertToBytecode() {
1238 if (chunks.size() != 1) {
1239 return llvm::make_error<llvm::lsp::LSPError>(
1240 "unexpected split file, please remove all `// -----`",
1241 llvm::lsp::ErrorCode::RequestFailed);
1243 return chunks.front()->document.convertToBytecode();
1246MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1247 if (chunks.size() == 1)
1248 return *chunks.front();
1252 auto it = llvm::upper_bound(
1253 chunks, pos, [](
const lsp::Position &pos,
const auto &chunk) {
1254 return static_cast<uint64_t
>(pos.line) < chunk->lineOffset;
1256 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1257 pos.line -= chunk.lineOffset;
1273 llvm::StringMap<std::unique_ptr<MLIRTextFile>>
files;
1281 :
impl(std::make_unique<
Impl>(registryFn)) {}
1285 const URIForFile &uri, StringRef contents,
int64_t version,
1286 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1287 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288 uri, contents, version,
impl->registryFn, diagnostics);
1292 auto it =
impl->files.find(uri.file());
1293 if (it ==
impl->files.end())
1294 return std::nullopt;
1296 int64_t version = it->second->getVersion();
1297 impl->files.erase(it);
1302 const URIForFile &uri,
const Position &defPos,
1303 std::vector<llvm::lsp::Location> &locations) {
1304 auto fileIt =
impl->files.find(uri.file());
1305 if (fileIt !=
impl->files.end())
1306 fileIt->second->getLocationsOf(uri, defPos, locations);
1310 const URIForFile &uri,
const Position &pos,
1311 std::vector<llvm::lsp::Location> &references) {
1312 auto fileIt =
impl->files.find(uri.file());
1313 if (fileIt !=
impl->files.end())
1314 fileIt->second->findReferencesOf(uri, pos, references);
1318 const Position &hoverPos) {
1319 auto fileIt =
impl->files.find(uri.file());
1320 if (fileIt !=
impl->files.end())
1321 return fileIt->second->findHover(uri, hoverPos);
1322 return std::nullopt;
1326 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1327 auto fileIt =
impl->files.find(uri.file());
1328 if (fileIt !=
impl->files.end())
1329 fileIt->second->findDocumentSymbols(symbols);
1334 const Position &completePos) {
1335 auto fileIt =
impl->files.find(uri.file());
1336 if (fileIt !=
impl->files.end())
1337 return fileIt->second->getCodeCompletion(uri, completePos);
1338 return CompletionList();
1342 const CodeActionContext &context,
1343 std::vector<CodeAction> &actions) {
1344 auto fileIt =
impl->files.find(uri.file());
1345 if (fileIt !=
impl->files.end())
1346 fileIt->second->getCodeActions(uri, pos, context, actions);
1355 std::string errorMsg;
1365 &fallbackResourceMap);
1369 if (failed(
parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1370 return llvm::make_error<llvm::lsp::LSPError>(
1371 "failed to parse bytecode source file: " + errorMsg,
1372 llvm::lsp::ErrorCode::RequestFailed);
1377 if (!llvm::hasSingleElement(parsedBlock)) {
1378 return llvm::make_error<llvm::lsp::LSPError>(
1379 "expected bytecode to contain a single top-level operation",
1380 llvm::lsp::ErrorCode::RequestFailed);
1392 nullptr, &fallbackResourceMap);
1394 llvm::raw_string_ostream os(
result.output);
1395 topOp->print(os, state);
1397 return std::move(
result);
1402 auto fileIt =
impl->files.find(uri.file());
1403 if (fileIt ==
impl->files.end()) {
1404 return llvm::make_error<llvm::lsp::LSPError>(
1405 "language server does not contain an entry for this source file",
1406 llvm::lsp::ErrorCode::RequestFailed);
1408 return fileIt->second->convertToBytecode();
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri)
Convert the given MLIR diagnostic to the LSP form.
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc)
Returns a language server location from the given MLIR file location.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
static std::string diag(const llvm::Value &value)
void completeDialectName()
iterator_range< AttributeDefIterator > getAttributeAliasDefs() const
Return a range of the AttributeAliasDefinitions held by the current parser state.
iterator_range< BlockDefIterator > getBlockDefs() const
Return a range of the BlockDefinitions held by the current parser state.
const OperationDefinition * getOpDef(Operation *op) const
Return the definition for the given operation, or nullptr if the given operation does not have a defi...
const BlockDefinition * getBlockDef(Block *block) const
Return the definition for the given block, or nullptr if the given block does not have a definition.
iterator_range< OperationDefIterator > getOpDefs() const
Return a range of the OperationDefinitions held by the current parser state.
iterator_range< TypeDefIterator > getTypeAliasDefs() const
Return a range of the TypeAliasDefinitions held by the current parser state.
This class provides management for the lifetime of the state used when printing the IR.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
bool hasNoSuccessors()
Returns true if this blocks has no successors.
iterator_range< pred_iterator > getPredecessors()
SuccessorRange getSuccessors()
bool hasNoPredecessors()
Return true if this block has no predecessors.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A fallback map containing external resources not explicitly handled by another parser/printer.
An instance of this location represents a tuple of file, line number, and column number.
StringAttr getFilename() const
unsigned getColumn() const
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
std::vector< StringRef > getAvailableDialects()
Return information about all available dialects in the registry in this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class represents a configuration for the MLIR assembly parser.
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
MLIRServer(DialectRegistryFn registry_fn)
Construct a new server with the given dialect registry function.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
llvm::function_ref< DialectRegistry &(const llvm::lsp::URIForFile &uri)> DialectRegistryFn
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::SetVector< T, Vector, Set, N > SetVector
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
lsp::DialectRegistryFn registryFn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
Impl(lsp::DialectRegistryFn registryFn)
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
Type value
The value of the alias.
StringRef name
The name of the attribute alias.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...