MLIR  19.0.0git
Transport.cpp
Go to the documentation of this file.
1 //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "llvm/ADT/SmallString.h"
14 #include "llvm/Support/Errno.h"
15 #include "llvm/Support/Error.h"
16 #include <optional>
17 #include <system_error>
18 #include <utility>
19 
20 using namespace mlir;
21 using namespace mlir::lsp;
22 
23 //===----------------------------------------------------------------------===//
24 // Reply
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 /// Function object to reply to an LSP call.
29 /// Each instance must be called exactly once, otherwise:
30 /// - if there was no reply, an error reply is sent
31 /// - if there were multiple replies, only the first is sent
32 class Reply {
33 public:
34  Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport,
35  std::mutex &transportOutputMutex);
36  Reply(Reply &&other);
37  Reply &operator=(Reply &&) = delete;
38  Reply(const Reply &) = delete;
39  Reply &operator=(const Reply &) = delete;
40 
41  void operator()(llvm::Expected<llvm::json::Value> reply);
42 
43 private:
44  StringRef method;
45  std::atomic<bool> replied = {false};
46  llvm::json::Value id;
47  JSONTransport *transport;
48  std::mutex &transportOutputMutex;
49 };
50 } // namespace
51 
52 Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
53  JSONTransport &transport, std::mutex &transportOutputMutex)
54  : id(id), transport(&transport),
55  transportOutputMutex(transportOutputMutex) {}
56 
57 Reply::Reply(Reply &&other)
58  : replied(other.replied.load()), id(std::move(other.id)),
59  transport(other.transport),
60  transportOutputMutex(other.transportOutputMutex) {
61  other.transport = nullptr;
62 }
63 
64 void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
65  if (replied.exchange(true)) {
66  Logger::error("Replied twice to message {0}({1})", method, id);
67  assert(false && "must reply to each call only once!");
68  return;
69  }
70  assert(transport && "expected valid transport to reply to");
71 
72  std::lock_guard<std::mutex> transportLock(transportOutputMutex);
73  if (reply) {
74  Logger::info("--> reply:{0}({1})", method, id);
75  transport->reply(std::move(id), std::move(reply));
76  } else {
77  llvm::Error error = reply.takeError();
78  Logger::info("--> reply:{0}({1})", method, id, error);
79  transport->reply(std::move(id), std::move(error));
80  }
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // MessageHandler
85 //===----------------------------------------------------------------------===//
86 
87 bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
88  Logger::info("--> {0}", method);
89 
90  if (method == "exit")
91  return false;
92  if (method == "$cancel") {
93  // TODO: Add support for cancelling requests.
94  } else {
95  auto it = notificationHandlers.find(method);
96  if (it != notificationHandlers.end())
97  it->second(std::move(value));
98  }
99  return true;
100 }
101 
102 bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
103  llvm::json::Value id) {
104  Logger::info("--> {0}({1})", method, id);
105 
106  Reply reply(id, method, transport, transportOutputMutex);
107 
108  auto it = methodHandlers.find(method);
109  if (it != methodHandlers.end()) {
110  it->second(std::move(params), std::move(reply));
111  } else {
112  reply(llvm::make_error<LSPError>("method not found: " + method.str(),
114  }
115  return true;
116 }
117 
118 bool MessageHandler::onReply(llvm::json::Value id,
120  // TODO: Add support for reply callbacks when support for outgoing messages is
121  // added. For now, we just log an error on any replies received.
122  Callback<llvm::json::Value> replyHandler =
123  [&id](llvm::Expected<llvm::json::Value> result) {
125  "received a reply with ID {0}, but there was no such call", id);
126  if (!result)
127  llvm::consumeError(result.takeError());
128  };
129 
130  // Log and run the reply handler.
131  if (result)
132  replyHandler(std::move(result));
133  else
134  replyHandler(result.takeError());
135  return true;
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // JSONTransport
140 //===----------------------------------------------------------------------===//
141 
142 /// Encode the given error as a JSON object.
143 static llvm::json::Object encodeError(llvm::Error error) {
144  std::string message;
146  auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
147  message = lspError.message;
148  code = lspError.code;
149  return llvm::Error::success();
150  };
151  if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
152  message = llvm::toString(std::move(unhandled));
153 
154  return llvm::json::Object{
155  {"message", std::move(message)},
156  {"code", int64_t(code)},
157  };
158 }
159 
160 /// Decode the given JSON object into an error.
161 llvm::Error decodeError(const llvm::json::Object &o) {
162  StringRef msg = o.getString("message").value_or("Unspecified error");
163  if (std::optional<int64_t> code = o.getInteger("code"))
164  return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
165  return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
166  msg.str());
167 }
168 
169 void JSONTransport::notify(StringRef method, llvm::json::Value params) {
170  sendMessage(llvm::json::Object{
171  {"jsonrpc", "2.0"},
172  {"method", method},
173  {"params", std::move(params)},
174  });
175 }
176 void JSONTransport::call(StringRef method, llvm::json::Value params,
177  llvm::json::Value id) {
178  sendMessage(llvm::json::Object{
179  {"jsonrpc", "2.0"},
180  {"id", std::move(id)},
181  {"method", method},
182  {"params", std::move(params)},
183  });
184 }
185 void JSONTransport::reply(llvm::json::Value id,
187  if (result) {
188  return sendMessage(llvm::json::Object{
189  {"jsonrpc", "2.0"},
190  {"id", std::move(id)},
191  {"result", std::move(*result)},
192  });
193  }
194 
195  sendMessage(llvm::json::Object{
196  {"jsonrpc", "2.0"},
197  {"id", std::move(id)},
198  {"error", encodeError(result.takeError())},
199  });
200 }
201 
202 llvm::Error JSONTransport::run(MessageHandler &handler) {
203  std::string json;
204  while (!feof(in)) {
205  if (ferror(in)) {
206  return llvm::errorCodeToError(
207  std::error_code(errno, std::system_category()));
208  }
209 
210  if (succeeded(readMessage(json))) {
212  if (!handleMessage(std::move(*doc), handler))
213  return llvm::Error::success();
214  } else {
215  Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
216  }
217  }
218  }
219  return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
220 }
221 
222 void JSONTransport::sendMessage(llvm::json::Value msg) {
223  outputBuffer.clear();
224  llvm::raw_svector_ostream os(outputBuffer);
225  os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
226  out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
227  << outputBuffer;
228  out.flush();
229  Logger::debug(">>> {0}\n", outputBuffer);
230 }
231 
232 bool JSONTransport::handleMessage(llvm::json::Value msg,
233  MessageHandler &handler) {
234  // Message must be an object with "jsonrpc":"2.0".
235  llvm::json::Object *object = msg.getAsObject();
236  if (!object ||
237  object->getString("jsonrpc") != std::optional<StringRef>("2.0"))
238  return false;
239 
240  // `id` may be any JSON value. If absent, this is a notification.
241  std::optional<llvm::json::Value> id;
242  if (llvm::json::Value *i = object->get("id"))
243  id = std::move(*i);
244  std::optional<StringRef> method = object->getString("method");
245 
246  // This is a response.
247  if (!method) {
248  if (!id)
249  return false;
250  if (auto *err = object->getObject("error"))
251  return handler.onReply(std::move(*id), decodeError(*err));
252  // result should be given, use null if not.
253  llvm::json::Value result = nullptr;
254  if (llvm::json::Value *r = object->get("result"))
255  result = std::move(*r);
256  return handler.onReply(std::move(*id), std::move(result));
257  }
258 
259  // Params should be given, use null if not.
260  llvm::json::Value params = nullptr;
261  if (llvm::json::Value *p = object->get("params"))
262  params = std::move(*p);
263 
264  if (id)
265  return handler.onCall(*method, std::move(params), std::move(*id));
266  return handler.onNotify(*method, std::move(params));
267 }
268 
269 /// Tries to read a line up to and including \n.
270 /// If failing, feof(), ferror(), or shutdownRequested() will be set.
272  // Big enough to hold any reasonable header line. May not fit content lines
273  // in delimited mode, but performance doesn't matter for that mode.
274  static constexpr int bufSize = 128;
275  size_t size = 0;
276  out.clear();
277  for (;;) {
278  out.resize_for_overwrite(size + bufSize);
279  if (!std::fgets(&out[size], bufSize, in))
280  return failure();
281 
282  clearerr(in);
283 
284  // If the line contained null bytes, anything after it (including \n) will
285  // be ignored. Fortunately this is not a legal header or JSON.
286  size_t read = std::strlen(&out[size]);
287  if (read > 0 && out[size + read - 1] == '\n') {
288  out.resize(size + read);
289  return success();
290  }
291  size += read;
292  }
293 }
294 
295 // Returns std::nullopt when:
296 // - ferror(), feof(), or shutdownRequested() are set.
297 // - Content-Length is missing or empty (protocol error)
298 LogicalResult JSONTransport::readStandardMessage(std::string &json) {
299  // A Language Server Protocol message starts with a set of HTTP headers,
300  // delimited by \r\n, and terminated by an empty line (\r\n).
301  unsigned long long contentLength = 0;
303  while (true) {
304  if (feof(in) || ferror(in) || failed(readLine(in, line)))
305  return failure();
306 
307  // Content-Length is a mandatory header, and the only one we handle.
308  StringRef lineRef = line;
309  if (lineRef.consume_front("Content-Length: ")) {
310  llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
311  } else if (!lineRef.trim().empty()) {
312  // It's another header, ignore it.
313  continue;
314  } else {
315  // An empty line indicates the end of headers. Go ahead and read the JSON.
316  break;
317  }
318  }
319 
320  // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
321  if (contentLength == 0 || contentLength > 1 << 30)
322  return failure();
323 
324  json.resize(contentLength);
325  for (size_t pos = 0, read; pos < contentLength; pos += read) {
326  read = std::fread(&json[pos], 1, contentLength - pos, in);
327  if (read == 0)
328  return failure();
329 
330  // If we're done, the error was transient. If we're not done, either it was
331  // transient or we'll see it again on retry.
332  clearerr(in);
333  pos += read;
334  }
335  return success();
336 }
337 
338 /// For lit tests we support a simplified syntax:
339 /// - messages are delimited by '// -----' on a line by itself
340 /// - lines starting with // are ignored.
341 /// This is a testing path, so favor simplicity over performance here.
342 /// When returning failure: feof(), ferror(), or shutdownRequested() will be
343 /// set.
344 LogicalResult JSONTransport::readDelimitedMessage(std::string &json) {
345  json.clear();
347  while (succeeded(readLine(in, line))) {
348  StringRef lineRef = line.str().trim();
349  if (lineRef.starts_with("//")) {
350  // Found a delimiter for the message.
351  if (lineRef == kDefaultSplitMarker)
352  break;
353  continue;
354  }
355 
356  json += line;
357  }
358 
359  return failure(ferror(in));
360 }
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static llvm::json::Object encodeError(llvm::Error error)
Encode the given error as a JSON object.
Definition: Transport.cpp:143
llvm::Error decodeError(const llvm::json::Object &o)
Decode the given JSON object into an error.
Definition: Transport.cpp:161
LogicalResult readLine(std::FILE *in, SmallVectorImpl< char > &out)
Tries to read a line up to and including .
Definition: Transport.cpp:271
A transport class that performs the JSON-RPC communication with the LSP client.
Definition: Transport.h:48
void notify(StringRef method, llvm::json::Value params)
The following methods are used to send a message to the LSP client.
Definition: Transport.cpp:169
void call(StringRef method, llvm::json::Value params, llvm::json::Value id)
Definition: Transport.cpp:176
llvm::Error run(MessageHandler &handler)
Start executing the JSON-RPC transport.
Definition: Transport.cpp:202
void reply(llvm::json::Value id, llvm::Expected< llvm::json::Value > result)
Definition: Transport.cpp:185
This class models an LSP error as an llvm::Error.
Definition: Protocol.h:78
static void debug(const char *fmt, Ts &&...vals)
Initiate a log message at various severity levels.
Definition: Logging.h:34
static void info(const char *fmt, Ts &&...vals)
Definition: Logging.h:38
static void error(const char *fmt, Ts &&...vals)
Definition: Logging.h:42
A handler used to process the incoming transport messages.
Definition: Transport.h:104
bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id)
Definition: Transport.cpp:102
void method(llvm::StringLiteral method, ThisT *thisPtr, void(ThisT::*handler)(const Param &, Callback< Result >))
Definition: Transport.h:133
bool onReply(llvm::json::Value id, llvm::Expected< llvm::json::Value > result)
Definition: Transport.cpp:118
bool onNotify(StringRef method, llvm::json::Value value)
Definition: Transport.cpp:87
llvm::unique_function< void(llvm::Expected< T >)> Callback
A Callback<T> is a void function that accepts Expected<T>.
Definition: Transport.h:96
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
const char *const kDefaultSplitMarker
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26