20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Base64.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/LSP/Logging.h"
24#include "llvm/Support/Path.h"
25#include "llvm/Support/SourceMgr.h"
39static std::optional<lsp::Location>
41 StringRef workspaceRoot) {
46 if (!llvm::sys::path::is_absolute(filename) && !filename.starts_with(
"/") &&
47 !filename.starts_with(
"\\")) {
48 if (!workspaceRoot.empty())
49 llvm::sys::path::append(absPath, workspaceRoot, filename);
52 llvm::sys::fs::make_absolute(absPath);
57 lsp::URIForFile::fromFile(filename, uriScheme);
59 llvm::lsp::Logger::error(
"Failed to create URI for file `{0}`: {1}",
60 filename, llvm::toString(sourceURI.takeError()));
64 lsp::Position position;
65 position.line = loc.
getLine() - 1;
67 return lsp::Location{*sourceURI, lsp::Range(position)};
74static std::optional<lsp::Location>
76 StringRef uriScheme, StringRef workspaceRoot,
77 const lsp::URIForFile *uri =
nullptr) {
78 std::optional<lsp::Location> location;
80 auto fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
84 std::optional<lsp::Location> sourceLoc =
86 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
87 location = *sourceLoc;
88 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
89 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
93 location->range.end.character += 1;
95 auto lineCol = sourceMgr.getLineAndColumn(range->End);
96 location->range.end.character =
97 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
109 std::vector<lsp::Location> &locations,
110 const lsp::URIForFile &uri,
111 StringRef workspaceRoot) {
115 if (!fileLoc || !visitedLocs.insert(nestedLoc))
118 std::optional<lsp::Location> sourceLoc =
120 if (sourceLoc && sourceLoc->uri != uri)
121 locations.push_back(*sourceLoc);
130 return range.Start.getPointer() <= loc.getPointer() &&
131 loc.getPointer() <= range.End.getPointer();
138 SMRange *overlappedRange =
nullptr) {
142 *overlappedRange = def.
loc;
147 const auto *useIt = llvm::find_if(
148 def.
uses, [&](
const SMRange &range) { return contains(range, loc); });
149 if (useIt != def.
uses.end()) {
151 *overlappedRange = *useIt;
161 auto isIdentifierChar = [](
char c) {
162 return isalnum(c) || c ==
'%' || c ==
'$' || c ==
'.' || c ==
'_' ||
165 const char *curPtr = loc.getPointer();
166 while (isIdentifierChar(*curPtr))
175 const char *numberStart = ++curPtr;
176 while (llvm::isDigit(*curPtr))
178 StringRef numberStr(numberStart, curPtr - numberStart);
179 unsigned resultNumber = 0;
180 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
187 if (!range.isValid())
189 const char *startPtr = range.Start.getPointer();
190 return StringRef(startPtr, range.End.getPointer() - startPtr);
198 if (text && text->starts_with(
"^")) {
214 const lsp::URIForFile &uri,
215 StringRef workspaceRoot) {
216 lsp::Diagnostic lspDiag;
217 lspDiag.source =
"mlir";
221 lspDiag.category =
"Parse Error";
226 StringRef uriScheme = uri.scheme();
228 sourceMgr,
diag.getLocation(), uriScheme, workspaceRoot, &uri);
230 lspDiag.range = lspLocation->range;
233 switch (
diag.getSeverity()) {
235 llvm_unreachable(
"expected notes to be handled separately");
237 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
240 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
243 lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
246 lspDiag.message =
diag.str();
249 std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
251 lsp::Location noteLoc;
253 sourceMgr, note.getLocation(), uriScheme, workspaceRoot))
257 relatedDiags.emplace_back(noteLoc, note.str());
259 if (!relatedDiags.empty())
260 lspDiag.relatedInformation = std::move(relatedDiags);
273 MLIRDocument(MLIRContext &context,
const lsp::URIForFile &uri,
274 StringRef contents, StringRef workspaceRoot,
275 std::vector<lsp::Diagnostic> &diagnostics);
276 MLIRDocument(
const MLIRDocument &) =
delete;
277 MLIRDocument &operator=(
const MLIRDocument &) =
delete;
283 void getLocationsOf(
const lsp::URIForFile &uri,
const lsp::Position &defPos,
284 std::vector<lsp::Location> &locations);
285 void findReferencesOf(
const lsp::URIForFile &uri,
const lsp::Position &pos,
286 std::vector<lsp::Location> &references);
292 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
293 const lsp::Position &hoverPos);
294 std::optional<lsp::Hover>
295 buildHoverForOperation(SMRange hoverRange,
296 const AsmParserState::OperationDefinition &op);
297 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
298 unsigned resultStart,
299 unsigned resultEnd, SMLoc posLoc);
300 lsp::Hover buildHoverForBlock(SMRange hoverRange,
301 const AsmParserState::BlockDefinition &block);
303 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
304 const AsmParserState::BlockDefinition &block);
306 lsp::Hover buildHoverForAttributeAlias(
307 SMRange hoverRange,
const AsmParserState::AttributeAliasDefinition &attr);
309 buildHoverForTypeAlias(SMRange hoverRange,
310 const AsmParserState::TypeAliasDefinition &type);
316 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
317 void findDocumentSymbols(Operation *op,
318 std::vector<lsp::DocumentSymbol> &symbols);
324 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
325 const lsp::Position &completePos,
326 const DialectRegistry ®istry);
332 void getCodeActionForDiagnostic(
const lsp::URIForFile &uri,
333 lsp::Position &pos, StringRef severity,
335 std::vector<llvm::lsp::TextEdit> &edits);
341 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
349 AsmParserState asmState;
356 FallbackAsmResourceMap fallbackResourceMap;
359 llvm::SourceMgr sourceMgr;
362 std::string workspaceRoot;
366MLIRDocument::MLIRDocument(
MLIRContext &context,
const lsp::URIForFile &uri,
367 StringRef contents, StringRef workspaceRoot,
368 std::vector<lsp::Diagnostic> &diagnostics)
369 : workspaceRoot(workspaceRoot.str()) {
371 diagnostics.push_back(
376 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
378 llvm::lsp::Logger::error(
"Failed to create memory buffer for file",
384 &fallbackResourceMap);
385 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
399void MLIRDocument::getLocationsOf(
const lsp::URIForFile &uri,
400 const lsp::Position &defPos,
401 std::vector<lsp::Location> &locations) {
402 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
405 auto containsPosition = [&](
const AsmParserState::SMDefinition &def) {
408 locations.emplace_back(uri, sourceMgr, def.loc);
413 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
417 for (
const auto &
result : op.resultGroups)
418 if (containsPosition(
result.definition))
421 for (
const auto &symUse : op.symbolUses) {
423 locations.emplace_back(uri, sourceMgr, op.loc);
431 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
432 if (containsPosition(block.definition))
434 for (
const AsmParserState::SMDefinition &arg : block.arguments)
435 if (containsPosition(arg))
440 for (
const AsmParserState::AttributeAliasDefinition &attr :
442 if (containsPosition(attr.definition))
445 for (
const AsmParserState::TypeAliasDefinition &type :
447 if (containsPosition(type.definition))
452void MLIRDocument::findReferencesOf(
const lsp::URIForFile &uri,
453 const lsp::Position &pos,
454 std::vector<lsp::Location> &references) {
457 auto appendSMDef = [&](
const AsmParserState::SMDefinition &def) {
458 references.emplace_back(uri, sourceMgr, def.loc);
459 for (
const SMRange &use : def.uses)
460 references.emplace_back(uri, sourceMgr, use);
463 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
466 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
468 for (
const auto &
result : op.resultGroups)
469 appendSMDef(
result.definition);
470 for (
const auto &symUse : op.symbolUses)
472 references.emplace_back(uri, sourceMgr, symUse);
475 for (
const auto &
result : op.resultGroups)
477 return appendSMDef(
result.definition);
478 for (
const auto &symUse : op.symbolUses) {
481 for (
const auto &symUse : op.symbolUses)
482 references.emplace_back(uri, sourceMgr, symUse);
488 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
490 return appendSMDef(block.definition);
492 for (
const AsmParserState::SMDefinition &arg : block.arguments)
494 return appendSMDef(arg);
498 for (
const AsmParserState::AttributeAliasDefinition &attr :
501 return appendSMDef(attr.definition);
503 for (
const AsmParserState::TypeAliasDefinition &type :
506 return appendSMDef(type.definition);
514std::optional<lsp::Hover>
515MLIRDocument::findHover(
const lsp::URIForFile &uri,
516 const lsp::Position &hoverPos) {
517 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
521 for (
const AsmParserState::OperationDefinition &op : asmState.
getOpDefs()) {
524 return buildHoverForOperation(op.loc, op);
527 for (
auto &use : op.symbolUses)
529 return buildHoverForOperation(use, op);
532 for (
unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
533 const auto &
result = op.resultGroups[i];
538 unsigned resultStart =
result.startIndex;
539 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
540 : op.resultGroups[i + 1].startIndex;
541 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
547 for (
const AsmParserState::BlockDefinition &block : asmState.
getBlockDefs()) {
548 if (
isDefOrUse(block.definition, posLoc, &hoverRange))
549 return buildHoverForBlock(hoverRange, block);
551 for (
const auto &arg : llvm::enumerate(block.arguments)) {
552 if (!
isDefOrUse(arg.value(), posLoc, &hoverRange))
555 return buildHoverForBlockArgument(
556 hoverRange, block.block->
getArgument(arg.index()), block);
561 for (
const AsmParserState::AttributeAliasDefinition &attr :
563 if (
isDefOrUse(attr.definition, posLoc, &hoverRange))
564 return buildHoverForAttributeAlias(hoverRange, attr);
566 for (
const AsmParserState::TypeAliasDefinition &type :
568 if (
isDefOrUse(type.definition, posLoc, &hoverRange))
569 return buildHoverForTypeAlias(hoverRange, type);
575std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
576 SMRange hoverRange,
const AsmParserState::OperationDefinition &op) {
577 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
578 llvm::raw_string_ostream os(hover.contents.value);
582 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.
op))
583 os <<
" : " << symbol.getVisibility() <<
" @" << symbol.getName() <<
"";
586 os <<
"Generic Form:\n\n```mlir\n";
588 op.
op->
print(os, OpPrintingFlags()
589 .printGenericOpForm()
590 .elideLargeElementsAttrs()
597lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
599 unsigned resultStart,
602 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
603 llvm::raw_string_ostream os(hover.contents.value);
606 os <<
"Operation: \"" << op->
getName() <<
"\"\n\n";
611 if ((resultStart + *resultNumber) < resultEnd) {
612 resultStart += *resultNumber;
613 resultEnd = resultStart + 1;
618 if ((resultStart + 1) == resultEnd) {
619 os <<
"Result #" << resultStart <<
"\n\n"
622 os <<
"Result #[" << resultStart <<
", " << (resultEnd - 1) <<
"]\n\n"
624 llvm::interleaveComma(
625 op->
getResults().slice(resultStart, resultEnd), os,
626 [&](Value
result) { os <<
"`" << result.getType() <<
"`"; });
633MLIRDocument::buildHoverForBlock(SMRange hoverRange,
634 const AsmParserState::BlockDefinition &block) {
635 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
636 llvm::raw_string_ostream os(hover.contents.value);
639 auto printBlockToHover = [&](
Block *newBlock) {
640 if (
const auto *def = asmState.
getBlockDef(newBlock))
650 os <<
"Predecessors: ";
656 os <<
"Successors: ";
664lsp::Hover MLIRDocument::buildHoverForBlockArgument(
665 SMRange hoverRange, BlockArgument arg,
666 const AsmParserState::BlockDefinition &block) {
667 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
668 llvm::raw_string_ostream os(hover.contents.value);
675 <<
"Type: `" << arg.
getType() <<
"`\n\n";
680lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
681 SMRange hoverRange,
const AsmParserState::AttributeAliasDefinition &attr) {
682 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
683 llvm::raw_string_ostream os(hover.contents.value);
685 os <<
"Attribute Alias: \"" << attr.
name <<
"\n\n";
686 os <<
"Value: ```mlir\n" << attr.
value <<
"\n```\n\n";
691lsp::Hover MLIRDocument::buildHoverForTypeAlias(
692 SMRange hoverRange,
const AsmParserState::TypeAliasDefinition &type) {
693 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
694 llvm::raw_string_ostream os(hover.contents.value);
696 os <<
"Type Alias: \"" << type.
name <<
"\n\n";
697 os <<
"Value: ```mlir\n" << type.
value <<
"\n```\n\n";
706void MLIRDocument::findDocumentSymbols(
707 std::vector<lsp::DocumentSymbol> &symbols) {
708 for (Operation &op : parsedIR)
709 findDocumentSymbols(&op, symbols);
712void MLIRDocument::findDocumentSymbols(
713 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
714 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
717 if (
const AsmParserState::OperationDefinition *def = asmState.
getOpDef(op)) {
719 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
720 symbols.emplace_back(symbol.getName(),
721 isa<FunctionOpInterface>(op)
722 ? llvm::lsp::SymbolKind::Function
723 : llvm::lsp::SymbolKind::Class,
724 lsp::Range(sourceMgr, def->scopeLoc),
725 lsp::Range(sourceMgr, def->loc));
726 childSymbols = &symbols.back().children;
728 }
else if (op->
hasTrait<OpTrait::SymbolTable>()) {
731 llvm::lsp::SymbolKind::Namespace,
732 llvm::lsp::Range(sourceMgr, def->scopeLoc),
733 llvm::lsp::Range(sourceMgr, def->loc));
734 childSymbols = &symbols.back().children;
742 for (Operation &childOp : region.getOps())
743 findDocumentSymbols(&childOp, *childSymbols);
751class LSPCodeCompleteContext :
public AsmParserCodeCompleteContext {
753 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
755 : AsmParserCodeCompleteContext(completeLoc),
756 completionList(completionList), ctx(ctx) {}
759 void completeDialectName(StringRef prefix)
final {
761 llvm::lsp::CompletionItem item(prefix + dialect,
762 llvm::lsp::CompletionItemKind::Module,
764 item.detail =
"dialect";
765 completionList.items.emplace_back(item);
771 void completeOperationName(StringRef dialectName)
final {
780 llvm::lsp::CompletionItem item(
781 op.getStringRef().drop_front(dialectName.size() + 1),
782 llvm::lsp::CompletionItemKind::Field,
784 item.detail =
"operation";
785 completionList.items.emplace_back(item);
791 void appendSSAValueCompletion(StringRef name, std::string typeData)
final {
793 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'%';
795 llvm::lsp::CompletionItem item(name,
796 llvm::lsp::CompletionItemKind::Variable);
798 item.insertText = name.drop_front(1).str();
799 item.detail = std::move(typeData);
800 completionList.items.emplace_back(item);
805 void appendBlockCompletion(StringRef name)
final {
807 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] ==
'^';
809 llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
811 item.insertText = name.drop_front(1).str();
812 completionList.items.emplace_back(item);
816 void completeExpectedTokens(ArrayRef<StringRef> tokens,
bool optional)
final {
817 for (StringRef token : tokens) {
818 llvm::lsp::CompletionItem item(token,
819 llvm::lsp::CompletionItemKind::Keyword,
821 item.detail = optional ?
"optional" :
"";
822 completionList.items.emplace_back(item);
827 void completeAttribute(
const llvm::StringMap<Attribute> &aliases)
override {
828 appendSimpleCompletions({
"affine_set",
"affine_map",
"dense",
829 "dense_resource",
"false",
"loc",
"sparse",
"true",
831 llvm::lsp::CompletionItemKind::Field,
834 completeDialectName(
"#");
835 completeAliases(aliases,
"#");
837 void completeDialectAttributeOrAlias(
838 const llvm::StringMap<Attribute> &aliases)
override {
839 completeDialectName();
840 completeAliases(aliases);
844 void completeType(
const llvm::StringMap<Type> &aliases)
override {
846 appendSimpleCompletions({
"memref",
"tensor",
"complex",
"tuple",
"vector",
847 "bf16",
"f16",
"f32",
"f64",
"f80",
"f128",
849 llvm::lsp::CompletionItemKind::Field,
853 for (StringRef type : {
"i",
"si",
"ui"}) {
854 llvm::lsp::CompletionItem item(type +
"<N>",
855 llvm::lsp::CompletionItemKind::Field,
857 item.insertText = type.str();
858 completionList.items.emplace_back(item);
862 completeDialectName(
"!");
863 completeAliases(aliases,
"!");
866 completeDialectTypeOrAlias(
const llvm::StringMap<Type> &aliases)
override {
867 completeDialectName();
868 completeAliases(aliases);
872 template <
typename T>
873 void completeAliases(
const llvm::StringMap<T> &aliases,
874 StringRef prefix =
"") {
875 for (
const auto &alias : aliases) {
876 llvm::lsp::CompletionItem item(prefix + alias.getKey(),
877 llvm::lsp::CompletionItemKind::Field,
879 llvm::raw_string_ostream(item.detail) <<
"alias: " << alias.getValue();
880 completionList.items.emplace_back(item);
885 void appendSimpleCompletions(ArrayRef<StringRef> completions,
886 llvm::lsp::CompletionItemKind kind,
887 StringRef sortText =
"") {
888 for (StringRef completion : completions)
889 completionList.items.emplace_back(completion, kind, sortText);
893 lsp::CompletionList &completionList;
899MLIRDocument::getCodeCompletion(
const lsp::URIForFile &uri,
900 const lsp::Position &completePos,
901 const DialectRegistry ®istry) {
902 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
903 if (!posLoc.isValid())
904 return lsp::CompletionList();
908 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
909 tmpContext.allowUnregisteredDialects();
910 lsp::CompletionList completionList;
911 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
915 AsmParserState tmpState;
917 &lspCompleteContext);
918 return completionList;
925void MLIRDocument::getCodeActionForDiagnostic(
926 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
927 StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
931 if (message.starts_with(
"see current operation: "))
935 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
936 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
939 StringRef line(lineStart, pos.character);
943 llvm::lsp::TextEdit edit;
944 edit.range = lsp::Range(lsp::Position(pos.line, 0));
947 size_t indent = line.find_first_not_of(
' ');
948 if (indent == StringRef::npos)
949 indent = line.size();
951 edit.newText.append(indent,
' ');
952 llvm::raw_string_ostream(edit.newText)
953 <<
"// expected-" << severity <<
" @below {{" << message <<
"}}\n";
954 edits.emplace_back(std::move(edit));
961llvm::Expected<lsp::MLIRConvertBytecodeResult>
962MLIRDocument::convertToBytecode() {
965 if (!llvm::hasSingleElement(parsedIR)) {
966 if (parsedIR.
empty()) {
967 return llvm::make_error<llvm::lsp::LSPError>(
968 "expected a single and valid top-level operation, please ensure "
969 "there are no errors",
970 llvm::lsp::ErrorCode::RequestFailed);
972 return llvm::make_error<llvm::lsp::LSPError>(
973 "expected a single top-level operation",
974 llvm::lsp::ErrorCode::RequestFailed);
977 lsp::MLIRConvertBytecodeResult
result;
979 BytecodeWriterConfig writerConfig(fallbackResourceMap);
981 std::string rawBytecodeBuffer;
982 llvm::raw_string_ostream os(rawBytecodeBuffer);
985 result.output = llvm::encodeBase64(rawBytecodeBuffer);
996struct MLIRTextFileChunk {
997 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
998 const lsp::URIForFile &uri, StringRef contents,
999 StringRef workspaceRoot,
1000 std::vector<lsp::Diagnostic> &diagnostics)
1001 : lineOffset(lineOffset),
1002 document(context, uri, contents, workspaceRoot, diagnostics) {}
1006 void adjustLocForChunkOffset(lsp::Range &range) {
1007 adjustLocForChunkOffset(range.start);
1008 adjustLocForChunkOffset(range.end);
1012 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1015 uint64_t lineOffset;
1017 MLIRDocument document;
1029 MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1031 StringRef workspaceRoot,
1032 std::vector<lsp::Diagnostic> &diagnostics);
1035 int64_t getVersion()
const {
return version; }
1041 void getLocationsOf(
const lsp::URIForFile &uri, lsp::Position defPos,
1042 std::vector<lsp::Location> &locations);
1043 void findReferencesOf(
const lsp::URIForFile &uri, lsp::Position pos,
1044 std::vector<lsp::Location> &references);
1045 std::optional<lsp::Hover> findHover(
const lsp::URIForFile &uri,
1046 lsp::Position hoverPos);
1047 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1048 lsp::CompletionList getCodeCompletion(
const lsp::URIForFile &uri,
1049 lsp::Position completePos);
1050 void getCodeActions(
const lsp::URIForFile &uri,
const lsp::Range &pos,
1051 const lsp::CodeActionContext &context,
1052 std::vector<lsp::CodeAction> &actions);
1053 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1059 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1062 MLIRContext context;
1065 std::string contents;
1071 int64_t totalNumLines = 0;
1075 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1079MLIRTextFile::MLIRTextFile(
const lsp::URIForFile &uri, StringRef fileContents,
1081 StringRef workspaceRoot,
1082 std::vector<lsp::Diagnostic> &diagnostics)
1083 : context(registryFn(uri), MLIRContext::Threading::
DISABLED),
1084 contents(fileContents.str()), version(version) {
1090 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1091 context, 0, uri, subContents.front(), workspaceRoot,
1094 uint64_t lineOffset = subContents.front().count(
'\n');
1095 for (StringRef docContents : llvm::drop_begin(subContents)) {
1096 unsigned currentNumDiags = diagnostics.size();
1097 auto chunk = std::make_unique<MLIRTextFileChunk>(
1098 context, lineOffset, uri, docContents, workspaceRoot, diagnostics);
1099 lineOffset += docContents.count(
'\n');
1103 for (lsp::Diagnostic &
diag :
1104 llvm::drop_begin(diagnostics, currentNumDiags)) {
1105 chunk->adjustLocForChunkOffset(
diag.range);
1107 if (!
diag.relatedInformation)
1109 for (
auto &it : *
diag.relatedInformation)
1110 if (it.location.uri == uri)
1111 chunk->adjustLocForChunkOffset(it.location.range);
1113 chunks.emplace_back(std::move(chunk));
1115 totalNumLines = lineOffset;
1118void MLIRTextFile::getLocationsOf(
const lsp::URIForFile &uri,
1119 lsp::Position defPos,
1120 std::vector<lsp::Location> &locations) {
1121 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1122 chunk.document.getLocationsOf(uri, defPos, locations);
1125 if (chunk.lineOffset == 0)
1127 for (lsp::Location &loc : locations)
1129 chunk.adjustLocForChunkOffset(loc.range);
1132void MLIRTextFile::findReferencesOf(
const lsp::URIForFile &uri,
1134 std::vector<lsp::Location> &references) {
1135 MLIRTextFileChunk &chunk = getChunkFor(pos);
1136 chunk.document.findReferencesOf(uri, pos, references);
1139 if (chunk.lineOffset == 0)
1141 for (lsp::Location &loc : references)
1143 chunk.adjustLocForChunkOffset(loc.range);
1146std::optional<lsp::Hover> MLIRTextFile::findHover(
const lsp::URIForFile &uri,
1147 lsp::Position hoverPos) {
1148 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1149 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1152 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1153 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1157void MLIRTextFile::findDocumentSymbols(
1158 std::vector<lsp::DocumentSymbol> &symbols) {
1159 if (chunks.size() == 1)
1160 return chunks.front()->document.findDocumentSymbols(symbols);
1164 for (
unsigned i = 0, e = chunks.size(); i < e; ++i) {
1165 MLIRTextFileChunk &chunk = *chunks[i];
1166 lsp::Position startPos(chunk.lineOffset);
1167 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1168 : chunks[i + 1]->lineOffset);
1169 lsp::DocumentSymbol symbol(
"<file-split-" + Twine(i) +
">",
1170 llvm::lsp::SymbolKind::Namespace,
1171 lsp::Range(startPos, endPos),
1172 lsp::Range(startPos));
1173 chunk.document.findDocumentSymbols(symbol.children);
1177 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1178 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1179 symbolsToFix.push_back(&childSymbol);
1181 while (!symbolsToFix.empty()) {
1182 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1183 chunk.adjustLocForChunkOffset(symbol->range);
1184 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1186 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1187 symbolsToFix.push_back(&childSymbol);
1192 symbols.emplace_back(std::move(symbol));
1196lsp::CompletionList MLIRTextFile::getCodeCompletion(
const lsp::URIForFile &uri,
1197 lsp::Position completePos) {
1198 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1199 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1203 for (llvm::lsp::CompletionItem &item : completionList.items) {
1205 chunk.adjustLocForChunkOffset(item.textEdit->range);
1206 for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
1207 chunk.adjustLocForChunkOffset(edit.range);
1209 return completionList;
1212void MLIRTextFile::getCodeActions(
const lsp::URIForFile &uri,
1213 const lsp::Range &pos,
1214 const lsp::CodeActionContext &context,
1215 std::vector<lsp::CodeAction> &actions) {
1217 for (
auto &
diag : context.diagnostics) {
1218 if (
diag.source !=
"mlir")
1220 lsp::Position diagPos =
diag.range.start;
1221 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1224 lsp::CodeAction action;
1225 action.title =
"Add expected-* diagnostic checks";
1226 action.kind = lsp::CodeAction::kQuickFix.str();
1229 switch (
diag.severity) {
1230 case llvm::lsp::DiagnosticSeverity::Error:
1233 case llvm::lsp::DiagnosticSeverity::Warning:
1234 severity =
"warning";
1241 std::vector<llvm::lsp::TextEdit> edits;
1242 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1243 diag.message, edits);
1246 if (
diag.relatedInformation) {
1247 for (
auto ¬eDiag : *
diag.relatedInformation) {
1248 if (noteDiag.location.uri != uri)
1250 diagPos = noteDiag.location.range.start;
1251 diagPos.line -= chunk.lineOffset;
1252 chunk.document.getCodeActionForDiagnostic(uri, diagPos,
"note",
1253 noteDiag.message, edits);
1257 for (llvm::lsp::TextEdit &edit : edits)
1258 chunk.adjustLocForChunkOffset(edit.range);
1260 action.edit.emplace();
1261 action.edit->changes[uri.uri().str()] = std::move(edits);
1262 action.diagnostics = {
diag};
1264 actions.emplace_back(std::move(action));
1268llvm::Expected<lsp::MLIRConvertBytecodeResult>
1269MLIRTextFile::convertToBytecode() {
1271 if (chunks.size() != 1) {
1272 return llvm::make_error<llvm::lsp::LSPError>(
1273 "unexpected split file, please remove all `// -----`",
1274 llvm::lsp::ErrorCode::RequestFailed);
1276 return chunks.front()->document.convertToBytecode();
1279MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1280 if (chunks.size() == 1)
1281 return *chunks.front();
1285 auto it = llvm::upper_bound(
1286 chunks, pos, [](
const lsp::Position &pos,
const auto &chunk) {
1287 return static_cast<uint64_t
>(pos.line) < chunk->lineOffset;
1289 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1290 pos.line -= chunk.lineOffset;
1306 llvm::StringMap<std::unique_ptr<MLIRTextFile>>
files;
1317 :
impl(std::make_unique<
Impl>(registryFn)) {}
1321 const URIForFile &uri, StringRef contents,
int64_t version,
1322 std::vector<llvm::lsp::Diagnostic> &diagnostics) {
1323 impl->files[uri.file()] =
1324 std::make_unique<MLIRTextFile>(uri, contents, version,
impl->registryFn,
1325 impl->workspaceRoot, diagnostics);
1329 auto it =
impl->files.find(uri.file());
1330 if (it ==
impl->files.end())
1331 return std::nullopt;
1333 int64_t version = it->second->getVersion();
1334 impl->files.erase(it);
1339 const URIForFile &uri,
const Position &defPos,
1340 std::vector<llvm::lsp::Location> &locations) {
1341 auto fileIt =
impl->files.find(uri.file());
1342 if (fileIt !=
impl->files.end())
1343 fileIt->second->getLocationsOf(uri, defPos, locations);
1347 const URIForFile &uri,
const Position &pos,
1348 std::vector<llvm::lsp::Location> &references) {
1349 auto fileIt =
impl->files.find(uri.file());
1350 if (fileIt !=
impl->files.end())
1351 fileIt->second->findReferencesOf(uri, pos, references);
1355 const Position &hoverPos) {
1356 auto fileIt =
impl->files.find(uri.file());
1357 if (fileIt !=
impl->files.end())
1358 return fileIt->second->findHover(uri, hoverPos);
1359 return std::nullopt;
1363 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1364 auto fileIt =
impl->files.find(uri.file());
1365 if (fileIt !=
impl->files.end())
1366 fileIt->second->findDocumentSymbols(symbols);
1371 const Position &completePos) {
1372 auto fileIt =
impl->files.find(uri.file());
1373 if (fileIt !=
impl->files.end())
1374 return fileIt->second->getCodeCompletion(uri, completePos);
1375 return CompletionList();
1379 const CodeActionContext &context,
1380 std::vector<CodeAction> &actions) {
1381 auto fileIt =
impl->files.find(uri.file());
1382 if (fileIt !=
impl->files.end())
1383 fileIt->second->getCodeActions(uri, pos, context, actions);
1392 std::string errorMsg;
1402 &fallbackResourceMap);
1406 if (failed(
parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1407 return llvm::make_error<llvm::lsp::LSPError>(
1408 "failed to parse bytecode source file: " + errorMsg,
1409 llvm::lsp::ErrorCode::RequestFailed);
1414 if (!llvm::hasSingleElement(parsedBlock)) {
1415 return llvm::make_error<llvm::lsp::LSPError>(
1416 "expected bytecode to contain a single top-level operation",
1417 llvm::lsp::ErrorCode::RequestFailed);
1429 nullptr, &fallbackResourceMap);
1431 llvm::raw_string_ostream os(
result.output);
1432 topOp->print(os, state);
1434 return std::move(
result);
1439 auto fileIt =
impl->files.find(uri.file());
1440 if (fileIt ==
impl->files.end()) {
1441 return llvm::make_error<llvm::lsp::LSPError>(
1442 "language server does not contain an entry for this source file",
1443 llvm::lsp::ErrorCode::RequestFailed);
1445 return fileIt->second->convertToBytecode();
1449 impl->workspaceRoot = root.str();
static std::optional< unsigned > getResultNumberFromLoc(SMLoc loc)
Given a location pointing to a result, return the result number it refers to or std::nullopt if it re...
static std::optional< StringRef > getTextFromRange(SMRange range)
Given a source location range, return the text covered by the given range.
static std::optional< lsp::Location > getLocationFromLoc(StringRef uriScheme, FileLineColLoc loc, StringRef workspaceRoot)
Returns a language server location from the given MLIR file location.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, SMRange *overlappedRange=nullptr)
Returns true if the given location is contained by the definition or one of the uses of the given SMD...
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
static void collectLocationsFromLoc(Location loc, std::vector< lsp::Location > &locations, const lsp::URIForFile &uri, StringRef workspaceRoot)
Collect all of the locations from the given MLIR location that are not contained within the given URI...
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, Diagnostic &diag, const lsp::URIForFile &uri, StringRef workspaceRoot)
Convert the given MLIR diagnostic to the LSP form.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc={})
Given a block and source location, print the source name of the block to the given output stream.
static SMRange convertTokenLocToRange(SMLoc loc)
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
static std::string diag(const llvm::Value &value)
void completeDialectName()
This class represents state from a parsed MLIR textual format string.
iterator_range< AttributeDefIterator > getAttributeAliasDefs() const
Return a range of the AttributeAliasDefinitions held by the current parser state.
iterator_range< BlockDefIterator > getBlockDefs() const
Return a range of the BlockDefinitions held by the current parser state.
const OperationDefinition * getOpDef(Operation *op) const
Return the definition for the given operation, or nullptr if the given operation does not have a defi...
const BlockDefinition * getBlockDef(Block *block) const
Return the definition for the given block, or nullptr if the given block does not have a definition.
iterator_range< OperationDefIterator > getOpDefs() const
Return a range of the OperationDefinitions held by the current parser state.
iterator_range< TypeDefIterator > getTypeAliasDefs() const
Return a range of the TypeAliasDefinitions held by the current parser state.
This class provides management for the lifetime of the state used when printing the IR.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
bool hasNoSuccessors()
Returns true if this blocks has no successors.
iterator_range< pred_iterator > getPredecessors()
SuccessorRange getSuccessors()
bool hasNoPredecessors()
Return true if this block has no predecessors.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A fallback map containing external resources not explicitly handled by another parser/printer.
An instance of this location represents a tuple of file, line number, and column number.
StringAttr getFilename() const
unsigned getColumn() const
WalkResult walk(function_ref< WalkResult(Location)> walkFn)
Walk all of the locations nested directly under, and including, the current.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
std::vector< StringRef > getAvailableDialects()
Return information about all available dialects in the registry in this context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
Set of flags used to control the behavior of the various IR print methods (e.g.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class represents a configuration for the MLIR assembly parser.
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
void addOrUpdateDocument(const URIForFile &uri, StringRef contents, int64_t version, std::vector< Diagnostic > &diagnostics)
Add or update the document, with the provided version, at the given URI.
std::optional< int64_t > removeDocument(const URIForFile &uri)
Remove the document with the given uri.
void findReferencesOf(const URIForFile &uri, const Position &pos, std::vector< Location > &references)
Find all references of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector< Location > &locations)
Return the locations of the object pointed at by the given position.
std::optional< Hover > findHover(const URIForFile &uri, const Position &hoverPos)
Find a hover description for the given hover position, or std::nullopt if one couldn't be found.
llvm::Expected< MLIRConvertBytecodeResult > convertFromBytecode(const URIForFile &uri)
Convert the given bytecode file to the textual format.
llvm::Expected< MLIRConvertBytecodeResult > convertToBytecode(const URIForFile &uri)
Convert the given textual file to the bytecode format.
void setWorkspaceRoot(StringRef root)
Set the workspace root for the server.
CompletionList getCodeCompletion(const URIForFile &uri, const Position &completePos)
Get the code completion list for the position within the given file.
void findDocumentSymbols(const URIForFile &uri, std::vector< DocumentSymbol > &symbols)
Find all of the document symbols within the given file.
MLIRServer(DialectRegistryFn registry_fn)
Construct a new server with the given dialect registry function.
void getCodeActions(const URIForFile &uri, const Range &pos, const CodeActionContext &context, std::vector< CodeAction > &actions)
Get the set of code actions within the file.
llvm::function_ref< DialectRegistry &(const llvm::lsp::URIForFile &uri)> DialectRegistryFn
SMRange convertTokenLocToRange(SMLoc loc, StringRef identifierChars="")
Returns the range of a lexical token given a SMLoc corresponding to the start of an token location.
Include the generated interface declarations.
const char *const kDefaultSplitMarker
LogicalResult parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, AsmParserState *asmState=nullptr, AsmParserCodeCompleteContext *codeCompleteContext=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
llvm::SetVector< T, Vector, Set, N > SetVector
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
This class represents the result of converting between MLIR's bytecode and textual format.
std::string workspaceRoot
The workspace root of the server.
lsp::DialectRegistryFn registryFn
The registry factory for containing dialects that can be recognized in parsed .mlir files.
llvm::StringMap< std::unique_ptr< MLIRTextFile > > files
The files held by the server, mapped by their URI file name.
Impl(lsp::DialectRegistryFn registryFn)
StringRef name
The name of the attribute alias.
Attribute value
The value of the alias.
This class represents the information for a block definition within the input file.
Block * block
The block representing this definition.
SMDefinition definition
The source location for the block, i.e.
Operation * op
The operation representing this definition.
This class represents a definition within the source manager, containing it's defining location and l...
SmallVector< SMRange > uses
The source location of all uses of the definition.
SMRange loc
The source location of the definition.
Type value
The value of the alias.
StringRef name
The name of the attribute alias.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...