MLIR  18.0.0git
Transport.cpp
Go to the documentation of this file.
1 //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "llvm/ADT/SmallString.h"
13 #include "llvm/Support/Errno.h"
14 #include "llvm/Support/Error.h"
15 #include <optional>
16 #include <system_error>
17 #include <utility>
18 
19 using namespace mlir;
20 using namespace mlir::lsp;
21 
22 //===----------------------------------------------------------------------===//
23 // Reply
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 /// Function object to reply to an LSP call.
28 /// Each instance must be called exactly once, otherwise:
29 /// - if there was no reply, an error reply is sent
30 /// - if there were multiple replies, only the first is sent
31 class Reply {
32 public:
33  Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport,
34  std::mutex &transportOutputMutex);
35  Reply(Reply &&other);
36  Reply &operator=(Reply &&) = delete;
37  Reply(const Reply &) = delete;
38  Reply &operator=(const Reply &) = delete;
39 
40  void operator()(llvm::Expected<llvm::json::Value> reply);
41 
42 private:
43  StringRef method;
44  std::atomic<bool> replied = {false};
45  llvm::json::Value id;
46  JSONTransport *transport;
47  std::mutex &transportOutputMutex;
48 };
49 } // namespace
50 
51 Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
52  JSONTransport &transport, std::mutex &transportOutputMutex)
53  : id(id), transport(&transport),
54  transportOutputMutex(transportOutputMutex) {}
55 
56 Reply::Reply(Reply &&other)
57  : replied(other.replied.load()), id(std::move(other.id)),
58  transport(other.transport),
59  transportOutputMutex(other.transportOutputMutex) {
60  other.transport = nullptr;
61 }
62 
63 void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
64  if (replied.exchange(true)) {
65  Logger::error("Replied twice to message {0}({1})", method, id);
66  assert(false && "must reply to each call only once!");
67  return;
68  }
69  assert(transport && "expected valid transport to reply to");
70 
71  std::lock_guard<std::mutex> transportLock(transportOutputMutex);
72  if (reply) {
73  Logger::info("--> reply:{0}({1})", method, id);
74  transport->reply(std::move(id), std::move(reply));
75  } else {
76  llvm::Error error = reply.takeError();
77  Logger::info("--> reply:{0}({1})", method, id, error);
78  transport->reply(std::move(id), std::move(error));
79  }
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // MessageHandler
84 //===----------------------------------------------------------------------===//
85 
86 bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
87  Logger::info("--> {0}", method);
88 
89  if (method == "exit")
90  return false;
91  if (method == "$cancel") {
92  // TODO: Add support for cancelling requests.
93  } else {
94  auto it = notificationHandlers.find(method);
95  if (it != notificationHandlers.end())
96  it->second(std::move(value));
97  }
98  return true;
99 }
100 
101 bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
102  llvm::json::Value id) {
103  Logger::info("--> {0}({1})", method, id);
104 
105  Reply reply(id, method, transport, transportOutputMutex);
106 
107  auto it = methodHandlers.find(method);
108  if (it != methodHandlers.end()) {
109  it->second(std::move(params), std::move(reply));
110  } else {
111  reply(llvm::make_error<LSPError>("method not found: " + method.str(),
113  }
114  return true;
115 }
116 
117 bool MessageHandler::onReply(llvm::json::Value id,
119  // TODO: Add support for reply callbacks when support for outgoing messages is
120  // added. For now, we just log an error on any replies received.
121  Callback<llvm::json::Value> replyHandler =
122  [&id](llvm::Expected<llvm::json::Value> result) {
124  "received a reply with ID {0}, but there was no such call", id);
125  if (!result)
126  llvm::consumeError(result.takeError());
127  };
128 
129  // Log and run the reply handler.
130  if (result)
131  replyHandler(std::move(result));
132  else
133  replyHandler(result.takeError());
134  return true;
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // JSONTransport
139 //===----------------------------------------------------------------------===//
140 
141 /// Encode the given error as a JSON object.
142 static llvm::json::Object encodeError(llvm::Error error) {
143  std::string message;
145  auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
146  message = lspError.message;
147  code = lspError.code;
148  return llvm::Error::success();
149  };
150  if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
151  message = llvm::toString(std::move(unhandled));
152 
153  return llvm::json::Object{
154  {"message", std::move(message)},
155  {"code", int64_t(code)},
156  };
157 }
158 
159 /// Decode the given JSON object into an error.
160 llvm::Error decodeError(const llvm::json::Object &o) {
161  StringRef msg = o.getString("message").value_or("Unspecified error");
162  if (std::optional<int64_t> code = o.getInteger("code"))
163  return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
164  return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
165  msg.str());
166 }
167 
168 void JSONTransport::notify(StringRef method, llvm::json::Value params) {
169  sendMessage(llvm::json::Object{
170  {"jsonrpc", "2.0"},
171  {"method", method},
172  {"params", std::move(params)},
173  });
174 }
175 void JSONTransport::call(StringRef method, llvm::json::Value params,
176  llvm::json::Value id) {
177  sendMessage(llvm::json::Object{
178  {"jsonrpc", "2.0"},
179  {"id", std::move(id)},
180  {"method", method},
181  {"params", std::move(params)},
182  });
183 }
184 void JSONTransport::reply(llvm::json::Value id,
186  if (result) {
187  return sendMessage(llvm::json::Object{
188  {"jsonrpc", "2.0"},
189  {"id", std::move(id)},
190  {"result", std::move(*result)},
191  });
192  }
193 
194  sendMessage(llvm::json::Object{
195  {"jsonrpc", "2.0"},
196  {"id", std::move(id)},
197  {"error", encodeError(result.takeError())},
198  });
199 }
200 
201 llvm::Error JSONTransport::run(MessageHandler &handler) {
202  std::string json;
203  while (!feof(in)) {
204  if (ferror(in)) {
205  return llvm::errorCodeToError(
206  std::error_code(errno, std::system_category()));
207  }
208 
209  if (succeeded(readMessage(json))) {
210  if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
211  if (!handleMessage(std::move(*doc), handler))
212  return llvm::Error::success();
213  } else {
214  Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
215  }
216  }
217  }
218  return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
219 }
220 
221 void JSONTransport::sendMessage(llvm::json::Value msg) {
222  outputBuffer.clear();
223  llvm::raw_svector_ostream os(outputBuffer);
224  os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
225  out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
226  << outputBuffer;
227  out.flush();
228  Logger::debug(">>> {0}\n", outputBuffer);
229 }
230 
231 bool JSONTransport::handleMessage(llvm::json::Value msg,
232  MessageHandler &handler) {
233  // Message must be an object with "jsonrpc":"2.0".
234  llvm::json::Object *object = msg.getAsObject();
235  if (!object ||
236  object->getString("jsonrpc") != std::optional<StringRef>("2.0"))
237  return false;
238 
239  // `id` may be any JSON value. If absent, this is a notification.
240  std::optional<llvm::json::Value> id;
241  if (llvm::json::Value *i = object->get("id"))
242  id = std::move(*i);
243  std::optional<StringRef> method = object->getString("method");
244 
245  // This is a response.
246  if (!method) {
247  if (!id)
248  return false;
249  if (auto *err = object->getObject("error"))
250  return handler.onReply(std::move(*id), decodeError(*err));
251  // result should be given, use null if not.
252  llvm::json::Value result = nullptr;
253  if (llvm::json::Value *r = object->get("result"))
254  result = std::move(*r);
255  return handler.onReply(std::move(*id), std::move(result));
256  }
257 
258  // Params should be given, use null if not.
259  llvm::json::Value params = nullptr;
260  if (llvm::json::Value *p = object->get("params"))
261  params = std::move(*p);
262 
263  if (id)
264  return handler.onCall(*method, std::move(params), std::move(*id));
265  return handler.onNotify(*method, std::move(params));
266 }
267 
268 /// Tries to read a line up to and including \n.
269 /// If failing, feof(), ferror(), or shutdownRequested() will be set.
271  // Big enough to hold any reasonable header line. May not fit content lines
272  // in delimited mode, but performance doesn't matter for that mode.
273  static constexpr int bufSize = 128;
274  size_t size = 0;
275  out.clear();
276  for (;;) {
277  out.resize_for_overwrite(size + bufSize);
278  if (!std::fgets(&out[size], bufSize, in))
279  return failure();
280 
281  clearerr(in);
282 
283  // If the line contained null bytes, anything after it (including \n) will
284  // be ignored. Fortunately this is not a legal header or JSON.
285  size_t read = std::strlen(&out[size]);
286  if (read > 0 && out[size + read - 1] == '\n') {
287  out.resize(size + read);
288  return success();
289  }
290  size += read;
291  }
292 }
293 
294 // Returns std::nullopt when:
295 // - ferror(), feof(), or shutdownRequested() are set.
296 // - Content-Length is missing or empty (protocol error)
297 LogicalResult JSONTransport::readStandardMessage(std::string &json) {
298  // A Language Server Protocol message starts with a set of HTTP headers,
299  // delimited by \r\n, and terminated by an empty line (\r\n).
300  unsigned long long contentLength = 0;
302  while (true) {
303  if (feof(in) || ferror(in) || failed(readLine(in, line)))
304  return failure();
305 
306  // Content-Length is a mandatory header, and the only one we handle.
307  StringRef lineRef = line;
308  if (lineRef.consume_front("Content-Length: ")) {
309  llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
310  } else if (!lineRef.trim().empty()) {
311  // It's another header, ignore it.
312  continue;
313  } else {
314  // An empty line indicates the end of headers. Go ahead and read the JSON.
315  break;
316  }
317  }
318 
319  // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
320  if (contentLength == 0 || contentLength > 1 << 30)
321  return failure();
322 
323  json.resize(contentLength);
324  for (size_t pos = 0, read; pos < contentLength; pos += read) {
325  read = std::fread(&json[pos], 1, contentLength - pos, in);
326  if (read == 0)
327  return failure();
328 
329  // If we're done, the error was transient. If we're not done, either it was
330  // transient or we'll see it again on retry.
331  clearerr(in);
332  pos += read;
333  }
334  return success();
335 }
336 
337 /// For lit tests we support a simplified syntax:
338 /// - messages are delimited by '// -----' on a line by itself
339 /// - lines starting with // are ignored.
340 /// This is a testing path, so favor simplicity over performance here.
341 /// When returning failure: feof(), ferror(), or shutdownRequested() will be
342 /// set.
343 LogicalResult JSONTransport::readDelimitedMessage(std::string &json) {
344  json.clear();
346  while (succeeded(readLine(in, line))) {
347  StringRef lineRef = line.str().trim();
348  if (lineRef.startswith("//")) {
349  // Found a delimiter for the message.
350  if (lineRef == "// -----")
351  break;
352  continue;
353  }
354 
355  json += line;
356  }
357 
358  return failure(ferror(in));
359 }
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static llvm::json::Object encodeError(llvm::Error error)
Encode the given error as a JSON object.
Definition: Transport.cpp:142
llvm::Error decodeError(const llvm::json::Object &o)
Decode the given JSON object into an error.
Definition: Transport.cpp:160
LogicalResult readLine(std::FILE *in, SmallVectorImpl< char > &out)
Tries to read a line up to and including .
Definition: Transport.cpp:270
A transport class that performs the JSON-RPC communication with the LSP client.
Definition: Transport.h:48
void notify(StringRef method, llvm::json::Value params)
The following methods are used to send a message to the LSP client.
Definition: Transport.cpp:168
void call(StringRef method, llvm::json::Value params, llvm::json::Value id)
Definition: Transport.cpp:175
llvm::Error run(MessageHandler &handler)
Start executing the JSON-RPC transport.
Definition: Transport.cpp:201
void reply(llvm::json::Value id, llvm::Expected< llvm::json::Value > result)
Definition: Transport.cpp:184
This class models an LSP error as an llvm::Error.
Definition: Protocol.h:78
static void debug(const char *fmt, Ts &&...vals)
Initiate a log message at various severity levels.
Definition: Logging.h:34
static void info(const char *fmt, Ts &&...vals)
Definition: Logging.h:38
static void error(const char *fmt, Ts &&...vals)
Definition: Logging.h:42
A handler used to process the incoming transport messages.
Definition: Transport.h:104
bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id)
Definition: Transport.cpp:101
void method(llvm::StringLiteral method, ThisT *thisPtr, void(ThisT::*handler)(const Param &, Callback< Result >))
Definition: Transport.h:133
bool onReply(llvm::json::Value id, llvm::Expected< llvm::json::Value > result)
Definition: Transport.cpp:117
bool onNotify(StringRef method, llvm::json::Value value)
Definition: Transport.cpp:86
llvm::unique_function< void(llvm::Expected< T >)> Callback
A Callback<T> is a void function that accepts Expected<T>.
Definition: Transport.h:96
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26