1#if defined(CONF_WEBSOCKETS)
2
3#include "websockets.h"
4
5#include <base/dbg.h>
6#include <base/log.h>
7#include <base/mem.h>
8#include <base/str.h>
9
10#include <engine/shared/config.h>
11#include <engine/shared/network.h>
12#include <engine/shared/protocol.h>
13#include <engine/shared/ringbuffer.h>
14
15#if defined(CONF_FAMILY_UNIX)
16#include <arpa/inet.h>
17#elif defined(CONF_FAMILY_WINDOWS)
18#include <ws2tcpip.h>
19#endif
20#include <libwebsockets.h>
21
22#include <cstdlib>
23#include <map>
24#include <string>
25
26struct websocket_chunk
27{
28 size_t size;
29 size_t read;
30 NETADDR addr;
31 unsigned char data[0];
32};
33
34// Client opens two connections for whatever reason
35typedef CStaticRingBuffer<websocket_chunk, (MAX_CLIENTS * 2) * NET_CONN_BUFFERSIZE,
36 CRingBufferBase::FLAG_RECYCLE>
37 TRecvBuffer;
38typedef CStaticRingBuffer<websocket_chunk, NET_CONN_BUFFERSIZE,
39 CRingBufferBase::FLAG_RECYCLE>
40 TSendBuffer;
41
42struct per_session_data
43{
44 lws *wsi;
45 NETADDR addr;
46 TSendBuffer send_buffer;
47};
48
49struct context_data
50{
51 char bindaddr_str[NETADDR_MAXSTRSIZE];
52 lws_context_creation_info creation_info;
53 lws_context *context;
54 std::map<NETADDR, per_session_data *> port_map;
55 TRecvBuffer recv_buffer;
56};
57
58// Client has main, dummy and contact connections with IPv4 and IPv6
59static context_data contexts[3 * 2];
60static std::map<lws_context *, context_data *> contexts_map;
61
62static lws_context *websocket_context(int socket)
63{
64 dbg_assert(socket >= 0 && socket < (int)std::size(contexts), "socket index invalid: %d", socket);
65 lws_context *context = contexts[socket].context;
66 dbg_assert(context != nullptr, "socket context not initialized: %d", socket);
67 return context;
68}
69
70static void receive_chunk(context_data *ctx_data, per_session_data *pss, const void *in, size_t len)
71{
72 websocket_chunk *chunk = ctx_data->recv_buffer.Allocate(Size: len + sizeof(websocket_chunk));
73 dbg_assert(chunk != nullptr, "failed to allocate websocket receive buffer chunk of size %" PRIzu, len);
74 chunk->size = len;
75 chunk->read = 0;
76 chunk->addr = pss->addr;
77 mem_copy(dest: &chunk->data[0], source: in, size: len);
78}
79
80static void sockaddr_to_netaddr_websocket(const sockaddr *src, socklen_t src_len, NETADDR *dst)
81{
82 *dst = NETADDR_ZEROED;
83 if(src->sa_family == AF_INET && src_len >= (socklen_t)sizeof(sockaddr_in))
84 {
85 const sockaddr_in *src_in = (const sockaddr_in *)src;
86 dst->type = NETTYPE_WEBSOCKET_IPV4;
87 dst->port = htons(hostshort: src_in->sin_port);
88 static_assert(sizeof(dst->ip) >= sizeof(src_in->sin_addr.s_addr));
89 mem_copy(dest: dst->ip, source: &src_in->sin_addr.s_addr, size: sizeof(src_in->sin_addr.s_addr));
90 }
91 else if(src->sa_family == AF_INET6 && src_len >= (socklen_t)sizeof(sockaddr_in6))
92 {
93 const sockaddr_in6 *src_in6 = (const sockaddr_in6 *)src;
94 dst->type = NETTYPE_WEBSOCKET_IPV6;
95 dst->port = htons(hostshort: src_in6->sin6_port);
96 static_assert(sizeof(dst->ip) >= sizeof(src_in6->sin6_addr.s6_addr));
97 mem_copy(dest: dst->ip, source: &src_in6->sin6_addr.s6_addr, size: sizeof(src_in6->sin6_addr.s6_addr));
98 }
99 else
100 {
101 log_warn("websockets", "Cannot convert sockaddr of family %d", src->sa_family);
102 }
103}
104
105static int websocket_protocol_callback(lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len)
106{
107 per_session_data *pss = (per_session_data *)user;
108 lws_context *context = lws_get_context(wsi);
109 context_data *ctx_data = contexts_map[context];
110 switch(reason)
111 {
112 case LWS_CALLBACK_WSI_CREATE:
113 if(pss == nullptr)
114 {
115 return 0;
116 }
117 [[fallthrough]];
118 case LWS_CALLBACK_ESTABLISHED:
119 {
120 sockaddr_storage peersockaddr;
121 socklen_t peersockaddr_size = sizeof(peersockaddr);
122 getpeername(fd: lws_get_socket_fd(wsi), addr: (sockaddr *)&peersockaddr, len: &peersockaddr_size);
123 NETADDR addr;
124 sockaddr_to_netaddr_websocket(src: (sockaddr *)&peersockaddr, src_len: peersockaddr_size, dst: &addr);
125 if(addr.type == NETTYPE_INVALID)
126 {
127 return 0;
128 }
129
130 pss->wsi = wsi;
131 pss->addr = addr;
132 pss->send_buffer.Init();
133 ctx_data->port_map[addr] = pss;
134
135 char addr_str[NETADDR_MAXSTRSIZE];
136 net_addr_str(addr: &addr, string: addr_str, max_length: sizeof(addr_str), add_port: true);
137 log_trace("websockets", "Connection established with '%s'", addr_str);
138 return 0;
139 }
140
141 case LWS_CALLBACK_CLIENT_CLOSED:
142 [[fallthrough]];
143 case LWS_CALLBACK_CLOSED:
144 {
145 char addr_str[NETADDR_MAXSTRSIZE];
146 net_addr_str(addr: &pss->addr, string: addr_str, max_length: sizeof(addr_str), add_port: true);
147 log_trace("websockets", "Connection closed with '%s'", addr_str);
148
149 static const unsigned char CLOSE_PACKET[] = {0x10, 0x0e, 0x00, 0x04};
150 receive_chunk(ctx_data, pss, in: &CLOSE_PACKET, len: sizeof(CLOSE_PACKET));
151 return 0;
152 }
153
154 case LWS_CALLBACK_WSI_DESTROY:
155 {
156 if(pss == nullptr)
157 {
158 return 0;
159 }
160 pss->wsi = nullptr;
161 ctx_data->port_map.erase(x: pss->addr);
162 return 0;
163 }
164
165 case LWS_CALLBACK_CLIENT_WRITEABLE:
166 [[fallthrough]];
167 case LWS_CALLBACK_SERVER_WRITEABLE:
168 {
169 websocket_chunk *chunk = pss->send_buffer.First();
170 if(chunk == nullptr)
171 {
172 return 0;
173 }
174
175 int chunk_len = chunk->size - chunk->read;
176 int n = lws_write(wsi, buf: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING + chunk->read], len: chunk->size - chunk->read, protocol: LWS_WRITE_BINARY);
177 if(n < 0)
178 {
179 return 1;
180 }
181
182 if(n < chunk_len)
183 {
184 chunk->read += n;
185 lws_callback_on_writable(wsi);
186 return 0;
187 }
188
189 pss->send_buffer.PopFirst();
190 lws_callback_on_writable(wsi);
191 return 0;
192 }
193
194 case LWS_CALLBACK_CLIENT_RECEIVE:
195 [[fallthrough]];
196 case LWS_CALLBACK_RECEIVE:
197 receive_chunk(ctx_data, pss, in, len);
198 return 0;
199
200 default:
201 return 0;
202 }
203}
204
205static const lws_protocols protocols[] = {
206 {.name: "binary", .callback: websocket_protocol_callback, .per_session_data_size: sizeof(per_session_data)},
207 {.name: "base64", .callback: websocket_protocol_callback, .per_session_data_size: sizeof(per_session_data)},
208 {.name: nullptr, .callback: nullptr, .per_session_data_size: 0}};
209
210static LEVEL websocket_level_to_loglevel(int level)
211{
212 switch(level)
213 {
214 case LLL_ERR:
215 return LEVEL_ERROR;
216 case LLL_WARN:
217 return LEVEL_WARN;
218 case LLL_NOTICE:
219 case LLL_INFO:
220 return LEVEL_DEBUG;
221 default:
222 dbg_assert_failed("invalid log level: %d", level);
223 }
224}
225
226static void websocket_log_callback(int level, const char *line)
227{
228 if((level == LLL_NOTICE || level == LLL_INFO) && !g_Config.m_DbgWebsockets)
229 {
230 return;
231 }
232
233 // Truncate duplicate timestamp from beginning and newline from end
234 char line_truncated[4096]; // Longest log line length
235 const char *line_time_end = str_find(haystack: line, needle: "] ");
236 dbg_assert(line_time_end != nullptr, "unexpected log format");
237 str_copy(dst&: line_truncated, src: line_time_end + 2);
238 const int length = str_length(str: line_truncated);
239 if(line_truncated[length - 1] == '\n')
240 {
241 line_truncated[length - 1] = '\0';
242 }
243 if(line_truncated[length - 2] == '\r')
244 {
245 line_truncated[length - 2] = '\0';
246 }
247 log_log(level: websocket_level_to_loglevel(level), sys: "websockets", fmt: "%s", line_truncated);
248}
249
250void websocket_init()
251{
252 lws_set_log_level(LLL_ERR | LLL_WARN | LLL_NOTICE | LLL_INFO, log_emit_function: websocket_log_callback);
253}
254
255int websocket_create(const NETADDR *bindaddr)
256{
257 // find free context
258 int first_free = -1;
259 for(int i = 0; i < (int)std::size(contexts); i++)
260 {
261 if(contexts[i].context == nullptr)
262 {
263 first_free = i;
264 break;
265 }
266 }
267 if(first_free == -1)
268 {
269 log_error("websockets", "Failed to create websocket: no free contexts available");
270 return -1;
271 }
272
273 context_data *ctx_data = &contexts[first_free];
274 mem_zero(block: &ctx_data->creation_info, size: sizeof(ctx_data->creation_info));
275 ctx_data->creation_info.options = LWS_SERVER_OPTION_FAIL_UPON_UNABLE_TO_BIND;
276 if(bindaddr->type == NETTYPE_WEBSOCKET_IPV6)
277 {
278 // Set IPv6-only mode and socket option for IPv6 Websockets.
279 ctx_data->creation_info.options |= LWS_SERVER_OPTION_IPV6_V6ONLY_VALUE | LWS_SERVER_OPTION_IPV6_V6ONLY_MODIFY;
280 }
281 net_addr_str(addr: bindaddr, string: ctx_data->bindaddr_str, max_length: sizeof(ctx_data->bindaddr_str), add_port: false);
282 if(ctx_data->bindaddr_str[0] == '[' && ctx_data->bindaddr_str[str_length(str: ctx_data->bindaddr_str) - 1] == ']')
283 {
284 // Bindaddr must not be enclosed in brackets for IPv6 Websockets.
285 ctx_data->bindaddr_str[str_length(str: ctx_data->bindaddr_str) - 1] = '\0';
286 mem_move(dest: &ctx_data->bindaddr_str[0], source: &ctx_data->bindaddr_str[1], size: str_length(str: ctx_data->bindaddr_str) + 1);
287 }
288 ctx_data->creation_info.iface = ctx_data->bindaddr_str;
289 ctx_data->creation_info.port = bindaddr->port;
290 ctx_data->creation_info.protocols = protocols;
291 ctx_data->creation_info.gid = -1;
292 ctx_data->creation_info.uid = -1;
293 ctx_data->creation_info.user = ctx_data;
294
295 ctx_data->context = lws_create_context(info: &ctx_data->creation_info);
296 if(ctx_data->context == nullptr)
297 {
298 return -1;
299 }
300 contexts_map[ctx_data->context] = ctx_data;
301 ctx_data->recv_buffer.Init();
302 return first_free;
303}
304
305void websocket_destroy(int socket)
306{
307 lws_context *context = websocket_context(socket);
308 lws_context_destroy(context);
309 contexts_map.erase(x: context);
310 contexts[socket].context = nullptr;
311}
312
313int websocket_recv(int socket, unsigned char *data, size_t maxsize, NETADDR *addr)
314{
315 lws_context *context = websocket_context(socket);
316 const int service_result = lws_service(context, timeout_ms: -1);
317 if(service_result < 0)
318 {
319 return service_result;
320 }
321
322 context_data *ctx_data = contexts_map[context];
323 websocket_chunk *chunk = ctx_data->recv_buffer.First();
324 if(chunk == nullptr)
325 {
326 return 0;
327 }
328
329 if(maxsize >= chunk->size - chunk->read)
330 {
331 const int len = chunk->size - chunk->read;
332 mem_copy(dest: data, source: &chunk->data[chunk->read], size: len);
333 *addr = chunk->addr;
334 ctx_data->recv_buffer.PopFirst();
335 return len;
336 }
337 else
338 {
339 mem_copy(dest: data, source: &chunk->data[chunk->read], size: maxsize);
340 *addr = chunk->addr;
341 chunk->read += maxsize;
342 return maxsize;
343 }
344}
345
346int websocket_send(int socket, const unsigned char *data, size_t size, const NETADDR *addr)
347{
348 lws_context *context = websocket_context(socket);
349 context_data *ctx_data = contexts_map[context];
350 per_session_data *pss = ctx_data->port_map[*addr];
351 if(pss == nullptr)
352 {
353 char addr_str[NETADDR_MAXSTRSIZE];
354 net_addr_str(addr, string: addr_str, max_length: sizeof(addr_str), add_port: false);
355 lws_client_connect_info ccinfo = {.context: 0};
356 ccinfo.context = context;
357 ccinfo.address = addr_str;
358 ccinfo.port = addr->port;
359 ccinfo.protocol = protocols[0].name;
360 lws *wsi = lws_client_connect_via_info(ccinfo: &ccinfo);
361 if(wsi == nullptr)
362 {
363 return -1;
364 }
365 lws_service(context, timeout_ms: -1);
366 pss = ctx_data->port_map[*addr];
367 if(pss == nullptr)
368 {
369 return -1;
370 }
371 }
372
373 const size_t chunk_size = size + sizeof(websocket_chunk) + LWS_SEND_BUFFER_PRE_PADDING + LWS_SEND_BUFFER_POST_PADDING;
374 websocket_chunk *chunk = pss->send_buffer.Allocate(Size: chunk_size);
375 dbg_assert(chunk != nullptr, "failed to allocate websocket send buffer chunk of size %" PRIzu, size);
376 mem_zero(block: chunk, size: chunk_size);
377 chunk->size = size;
378 chunk->read = 0;
379 chunk->addr = pss->addr;
380 mem_copy(dest: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING], source: data, size);
381 lws_callback_on_writable(wsi: pss->wsi);
382 lws_service(context, timeout_ms: -1);
383 return size;
384}
385
386int websocket_fd_set(int socket, fd_set *set)
387{
388 lws_context *context = websocket_context(socket);
389 lws_service(context, timeout_ms: -1);
390
391 context_data *ctx_data = contexts_map[context];
392 int max = 0;
393 for(const auto &[_, pss] : ctx_data->port_map)
394 {
395 if(pss == nullptr)
396 {
397 continue;
398 }
399 int fd = lws_get_socket_fd(wsi: pss->wsi);
400 max = std::max(a: fd, b: max);
401 FD_SET(fd, set);
402 }
403 return max;
404}
405
406int websocket_fd_get(int socket, fd_set *set)
407{
408 lws_context *context = websocket_context(socket);
409 lws_service(context, timeout_ms: -1);
410
411 context_data *ctx_data = contexts_map[context];
412 for(const auto &[_, pss] : ctx_data->port_map)
413 {
414 if(pss == nullptr)
415 {
416 continue;
417 }
418 if(FD_ISSET(lws_get_socket_fd(pss->wsi), set))
419 {
420 return 1;
421 }
422 }
423 return 0;
424}
425
426#endif
427