1#if defined(CONF_WEBSOCKETS)
2
3#include "websockets.h"
4
5#include <base/log.h>
6#include <base/system.h>
7
8#include <engine/shared/network.h>
9#include <engine/shared/protocol.h>
10#include <engine/shared/ringbuffer.h>
11
12#if defined(CONF_FAMILY_UNIX)
13#include <arpa/inet.h>
14#elif defined(CONF_FAMILY_WINDOWS)
15#include <ws2tcpip.h>
16#endif
17#include <libwebsockets.h>
18
19#include <cstdlib>
20#include <map>
21#include <string>
22
23struct websocket_chunk
24{
25 size_t size;
26 size_t read;
27 NETADDR addr;
28 unsigned char data[0];
29};
30
31// Client opens two connections for whatever reason
32typedef CStaticRingBuffer<websocket_chunk, (MAX_CLIENTS * 2) * NET_CONN_BUFFERSIZE,
33 CRingBufferBase::FLAG_RECYCLE>
34 TRecvBuffer;
35typedef CStaticRingBuffer<websocket_chunk, NET_CONN_BUFFERSIZE,
36 CRingBufferBase::FLAG_RECYCLE>
37 TSendBuffer;
38
39struct per_session_data
40{
41 lws *wsi;
42 NETADDR addr;
43 TSendBuffer send_buffer;
44};
45
46struct context_data
47{
48 char bindaddr_str[NETADDR_MAXSTRSIZE];
49 lws_context_creation_info creation_info;
50 lws_context *context;
51 std::map<NETADDR, per_session_data *> port_map;
52 TRecvBuffer recv_buffer;
53};
54
55// Client has main, dummy and contact connections with IPv4 and IPv6
56static context_data contexts[3 * 2];
57static std::map<lws_context *, context_data *> contexts_map;
58
59static lws_context *websocket_context(int socket)
60{
61 dbg_assert(socket >= 0 && socket < (int)std::size(contexts), "socket index invalid: %d", socket);
62 lws_context *context = contexts[socket].context;
63 dbg_assert(context != nullptr, "socket context not initialized: %d", socket);
64 return context;
65}
66
67static int receive_chunk(context_data *ctx_data, per_session_data *pss, const void *in, size_t len)
68{
69 websocket_chunk *chunk = ctx_data->recv_buffer.Allocate(Size: len + sizeof(websocket_chunk));
70 if(chunk == nullptr)
71 {
72 return 1;
73 }
74
75 chunk->size = len;
76 chunk->read = 0;
77 chunk->addr = pss->addr;
78 mem_copy(dest: &chunk->data[0], source: in, size: len);
79 return 0;
80}
81
82static void sockaddr_to_netaddr_websocket(const sockaddr *src, socklen_t src_len, NETADDR *dst)
83{
84 *dst = NETADDR_ZEROED;
85 if(src->sa_family == AF_INET && src_len >= (socklen_t)sizeof(sockaddr_in))
86 {
87 const sockaddr_in *src_in = (const sockaddr_in *)src;
88 dst->type = NETTYPE_WEBSOCKET_IPV4;
89 dst->port = htons(hostshort: src_in->sin_port);
90 static_assert(sizeof(dst->ip) >= sizeof(src_in->sin_addr.s_addr));
91 mem_copy(dest: dst->ip, source: &src_in->sin_addr.s_addr, size: sizeof(src_in->sin_addr.s_addr));
92 }
93 else if(src->sa_family == AF_INET6 && src_len >= (socklen_t)sizeof(sockaddr_in6))
94 {
95 const sockaddr_in6 *src_in6 = (const sockaddr_in6 *)src;
96 dst->type = NETTYPE_WEBSOCKET_IPV6;
97 dst->port = htons(hostshort: src_in6->sin6_port);
98 static_assert(sizeof(dst->ip) >= sizeof(src_in6->sin6_addr.s6_addr));
99 mem_copy(dest: dst->ip, source: &src_in6->sin6_addr.s6_addr, size: sizeof(src_in6->sin6_addr.s6_addr));
100 }
101 else
102 {
103 log_warn("websockets", "Cannot convert sockaddr of family %d", src->sa_family);
104 }
105}
106
107static int websocket_protocol_callback(lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len)
108{
109 per_session_data *pss = (per_session_data *)user;
110 lws_context *context = lws_get_context(wsi);
111 context_data *ctx_data = contexts_map[context];
112 switch(reason)
113 {
114 case LWS_CALLBACK_WSI_CREATE:
115 if(pss == nullptr)
116 {
117 return 0;
118 }
119 [[fallthrough]];
120 case LWS_CALLBACK_ESTABLISHED:
121 {
122 sockaddr_storage peersockaddr;
123 socklen_t peersockaddr_size = sizeof(peersockaddr);
124 getpeername(fd: lws_get_socket_fd(wsi), addr: (sockaddr *)&peersockaddr, len: &peersockaddr_size);
125 NETADDR addr;
126 sockaddr_to_netaddr_websocket(src: (sockaddr *)&peersockaddr, src_len: peersockaddr_size, dst: &addr);
127 if(addr.type == NETTYPE_INVALID)
128 {
129 return 0;
130 }
131
132 pss->wsi = wsi;
133 pss->addr = addr;
134 pss->send_buffer.Init();
135 ctx_data->port_map[addr] = pss;
136
137 char addr_str[NETADDR_MAXSTRSIZE];
138 net_addr_str(addr: &addr, string: addr_str, max_length: sizeof(addr_str), add_port: true);
139 log_trace("websockets", "Connection established with '%s'", addr_str);
140 return 0;
141 }
142
143 case LWS_CALLBACK_CLOSED:
144 {
145 char addr_str[NETADDR_MAXSTRSIZE];
146 net_addr_str(addr: &pss->addr, string: addr_str, max_length: sizeof(addr_str), add_port: true);
147 log_trace("websockets", "Connection closed with '%s'", addr_str);
148
149 static const unsigned char CLOSE_PACKET[] = {0x10, 0x0e, 0x00, 0x04};
150 receive_chunk(ctx_data, pss, in: &CLOSE_PACKET, len: sizeof(CLOSE_PACKET));
151 pss->wsi = nullptr;
152 ctx_data->port_map.erase(x: pss->addr);
153 return 0;
154 }
155
156 case LWS_CALLBACK_CLIENT_WRITEABLE:
157 [[fallthrough]];
158 case LWS_CALLBACK_SERVER_WRITEABLE:
159 {
160 websocket_chunk *chunk = pss->send_buffer.First();
161 if(chunk == nullptr)
162 {
163 return 0;
164 }
165
166 int chunk_len = chunk->size - chunk->read;
167 int n = lws_write(wsi, buf: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING + chunk->read], len: chunk->size - chunk->read, protocol: LWS_WRITE_BINARY);
168 if(n < 0)
169 {
170 return 1;
171 }
172
173 if(n < chunk_len)
174 {
175 chunk->read += n;
176 lws_callback_on_writable(wsi);
177 return 0;
178 }
179
180 pss->send_buffer.PopFirst();
181 lws_callback_on_writable(wsi);
182 return 0;
183 }
184
185 case LWS_CALLBACK_CLIENT_RECEIVE:
186 [[fallthrough]];
187 case LWS_CALLBACK_RECEIVE:
188 return receive_chunk(ctx_data, pss, in, len);
189
190 default:
191 return 0;
192 }
193}
194
195static const lws_protocols protocols[] = {
196 {.name: "binary", .callback: websocket_protocol_callback, .per_session_data_size: sizeof(per_session_data)},
197 {.name: "base64", .callback: websocket_protocol_callback, .per_session_data_size: sizeof(per_session_data)},
198 {.name: nullptr, .callback: nullptr, .per_session_data_size: 0}};
199
200static LEVEL websocket_level_to_loglevel(int level)
201{
202 switch(level)
203 {
204 case LLL_ERR:
205 return LEVEL_ERROR;
206 case LLL_WARN:
207 return LEVEL_WARN;
208 case LLL_NOTICE:
209 case LLL_INFO:
210 return LEVEL_DEBUG;
211 default:
212 dbg_assert_failed("invalid log level: %d", level);
213 }
214}
215
216static void websocket_log_callback(int level, const char *line)
217{
218 // Truncate duplicate timestamp from beginning and newline from end
219 char line_truncated[4096]; // Longest log line length
220 const char *line_time_end = str_find(haystack: line, needle: "] ");
221 dbg_assert(line_time_end != nullptr, "unexpected log format");
222 str_copy(dst&: line_truncated, src: line_time_end + 2);
223 const int length = str_length(str: line_truncated);
224 if(line_truncated[length - 1] == '\n')
225 {
226 line_truncated[length - 1] = '\0';
227 }
228 if(line_truncated[length - 2] == '\r')
229 {
230 line_truncated[length - 2] = '\0';
231 }
232 log_log(level: websocket_level_to_loglevel(level), sys: "websockets", fmt: "%s", line_truncated);
233}
234
235void websocket_init()
236{
237 lws_set_log_level(LLL_ERR | LLL_WARN | LLL_NOTICE | LLL_INFO, log_emit_function: websocket_log_callback);
238}
239
240int websocket_create(const NETADDR *bindaddr)
241{
242 // find free context
243 int first_free = -1;
244 for(int i = 0; i < (int)std::size(contexts); i++)
245 {
246 if(contexts[i].context == nullptr)
247 {
248 first_free = i;
249 break;
250 }
251 }
252 if(first_free == -1)
253 {
254 log_error("websockets", "Failed to create websocket: no free contexts available");
255 return -1;
256 }
257
258 context_data *ctx_data = &contexts[first_free];
259 mem_zero(block: &ctx_data->creation_info, size: sizeof(ctx_data->creation_info));
260 ctx_data->creation_info.options = LWS_SERVER_OPTION_FAIL_UPON_UNABLE_TO_BIND;
261 if(bindaddr->type == NETTYPE_WEBSOCKET_IPV6)
262 {
263 // Set IPv6-only mode and socket option for IPv6 Websockets.
264 ctx_data->creation_info.options |= LWS_SERVER_OPTION_IPV6_V6ONLY_VALUE | LWS_SERVER_OPTION_IPV6_V6ONLY_MODIFY;
265 }
266 net_addr_str(addr: bindaddr, string: ctx_data->bindaddr_str, max_length: sizeof(ctx_data->bindaddr_str), add_port: false);
267 if(ctx_data->bindaddr_str[0] == '[' && ctx_data->bindaddr_str[str_length(str: ctx_data->bindaddr_str) - 1] == ']')
268 {
269 // Bindaddr must not be enclosed in brackets for IPv6 Websockets.
270 ctx_data->bindaddr_str[str_length(str: ctx_data->bindaddr_str) - 1] = '\0';
271 mem_move(dest: &ctx_data->bindaddr_str[0], source: &ctx_data->bindaddr_str[1], size: str_length(str: ctx_data->bindaddr_str) + 1);
272 }
273 ctx_data->creation_info.iface = ctx_data->bindaddr_str;
274 ctx_data->creation_info.port = bindaddr->port;
275 ctx_data->creation_info.protocols = protocols;
276 ctx_data->creation_info.gid = -1;
277 ctx_data->creation_info.uid = -1;
278 ctx_data->creation_info.user = ctx_data;
279
280 ctx_data->context = lws_create_context(info: &ctx_data->creation_info);
281 if(ctx_data->context == nullptr)
282 {
283 return -1;
284 }
285 contexts_map[ctx_data->context] = ctx_data;
286 ctx_data->recv_buffer.Init();
287 return first_free;
288}
289
290void websocket_destroy(int socket)
291{
292 lws_context *context = websocket_context(socket);
293 lws_context_destroy(context);
294 contexts_map.erase(x: context);
295 contexts[socket].context = nullptr;
296}
297
298int websocket_recv(int socket, unsigned char *data, size_t maxsize, NETADDR *addr)
299{
300 lws_context *context = websocket_context(socket);
301 const int service_result = lws_service(context, timeout_ms: -1);
302 if(service_result < 0)
303 {
304 return service_result;
305 }
306
307 context_data *ctx_data = contexts_map[context];
308 websocket_chunk *chunk = ctx_data->recv_buffer.First();
309 if(chunk == nullptr)
310 {
311 return 0;
312 }
313
314 if(maxsize >= chunk->size - chunk->read)
315 {
316 const int len = chunk->size - chunk->read;
317 mem_copy(dest: data, source: &chunk->data[chunk->read], size: len);
318 *addr = chunk->addr;
319 ctx_data->recv_buffer.PopFirst();
320 return len;
321 }
322 else
323 {
324 mem_copy(dest: data, source: &chunk->data[chunk->read], size: maxsize);
325 *addr = chunk->addr;
326 chunk->read += maxsize;
327 return maxsize;
328 }
329}
330
331int websocket_send(int socket, const unsigned char *data, size_t size, const NETADDR *addr)
332{
333 lws_context *context = websocket_context(socket);
334 context_data *ctx_data = contexts_map[context];
335 per_session_data *pss = ctx_data->port_map[*addr];
336 if(pss == nullptr)
337 {
338 char addr_str[NETADDR_MAXSTRSIZE];
339 net_addr_str(addr, string: addr_str, max_length: sizeof(addr_str), add_port: false);
340 lws_client_connect_info ccinfo = {.context: 0};
341 ccinfo.context = context;
342 ccinfo.address = addr_str;
343 ccinfo.port = addr->port;
344 ccinfo.protocol = protocols[0].name;
345 lws *wsi = lws_client_connect_via_info(ccinfo: &ccinfo);
346 if(wsi == nullptr)
347 {
348 return -1;
349 }
350 lws_service(context, timeout_ms: -1);
351 pss = ctx_data->port_map[*addr];
352 if(pss == nullptr)
353 {
354 return -1;
355 }
356 }
357
358 const size_t chunk_size = size + sizeof(websocket_chunk) + LWS_SEND_BUFFER_PRE_PADDING + LWS_SEND_BUFFER_POST_PADDING;
359 websocket_chunk *chunk = pss->send_buffer.Allocate(Size: chunk_size);
360 mem_zero(block: chunk, size: chunk_size);
361 if(chunk == nullptr)
362 {
363 return -1;
364 }
365
366 chunk->size = size;
367 chunk->read = 0;
368 chunk->addr = pss->addr;
369 mem_copy(dest: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING], source: data, size);
370 lws_callback_on_writable(wsi: pss->wsi);
371 lws_service(context, timeout_ms: -1);
372 return size;
373}
374
375int websocket_fd_set(int socket, fd_set *set)
376{
377 lws_context *context = websocket_context(socket);
378 lws_service(context, timeout_ms: -1);
379
380 context_data *ctx_data = contexts_map[context];
381 int max = 0;
382 for(const auto &[_, pss] : ctx_data->port_map)
383 {
384 if(pss == nullptr)
385 {
386 continue;
387 }
388 int fd = lws_get_socket_fd(wsi: pss->wsi);
389 max = std::max(a: fd, b: max);
390 FD_SET(fd, set);
391 }
392 return max;
393}
394
395int websocket_fd_get(int socket, fd_set *set)
396{
397 lws_context *context = websocket_context(socket);
398 lws_service(context, timeout_ms: -1);
399
400 context_data *ctx_data = contexts_map[context];
401 for(const auto &[_, pss] : ctx_data->port_map)
402 {
403 if(pss == nullptr)
404 {
405 continue;
406 }
407 if(FD_ISSET(lws_get_socket_fd(pss->wsi), set))
408 {
409 return 1;
410 }
411 }
412 return 0;
413}
414
415#endif
416