1#if defined(CONF_WEBSOCKETS)
2
3#include <cstdlib>
4#include <map>
5#include <string>
6
7#include "protocol.h"
8#include "ringbuffer.h"
9#include <base/system.h>
10#if defined(CONF_FAMILY_UNIX)
11#include <arpa/inet.h>
12#elif defined(CONF_FAMILY_WINDOWS)
13#include <ws2tcpip.h>
14#endif
15#include <libwebsockets.h>
16
17#include "websockets.h"
18
19// not sure why would anyone need more than one but well...
20#define WS_CONTEXTS 4
21// ddnet client opens two connections for whatever reason
22#define WS_CLIENTS (MAX_CLIENTS * 2)
23
24typedef CStaticRingBuffer<unsigned char, WS_CLIENTS * 4 * 1024,
25 CRingBufferBase::FLAG_RECYCLE>
26 TRecvBuffer;
27typedef CStaticRingBuffer<unsigned char, 4 * 1024,
28 CRingBufferBase::FLAG_RECYCLE>
29 TSendBuffer;
30
31struct websocket_chunk
32{
33 size_t size;
34 size_t read;
35 sockaddr_in addr;
36 unsigned char data[0];
37};
38
39struct per_session_data
40{
41 struct lws *wsi;
42 std::string addr_str;
43 sockaddr_in addr;
44 TSendBuffer send_buffer;
45};
46
47struct context_data
48{
49 lws_context *context;
50 std::map<std::string, per_session_data *> port_map;
51 TRecvBuffer recv_buffer;
52};
53
54static int receive_chunk(context_data *ctx_data, struct per_session_data *pss,
55 void *in, size_t len)
56{
57 websocket_chunk *chunk = (websocket_chunk *)ctx_data->recv_buffer.Allocate(
58 Size: len + sizeof(websocket_chunk));
59 if(chunk == 0)
60 return 1;
61 chunk->size = len;
62 chunk->read = 0;
63 mem_copy(dest: &chunk->addr, source: &pss->addr, size: sizeof(sockaddr_in));
64 mem_copy(dest: &chunk->data[0], source: in, size: len);
65 return 0;
66}
67
68static int websocket_callback(struct lws *wsi, enum lws_callback_reasons reason,
69 void *user, void *in, size_t len)
70{
71 struct per_session_data *pss = (struct per_session_data *)user;
72 lws_context *context = lws_get_context(wsi);
73 context_data *ctx_data = (context_data *)lws_context_user(context);
74 switch(reason)
75 {
76 case LWS_CALLBACK_WSI_CREATE:
77 if(pss == NULL)
78 {
79 return 0;
80 }
81 [[fallthrough]];
82 case LWS_CALLBACK_ESTABLISHED:
83 {
84 pss->wsi = wsi;
85 int fd = lws_get_socket_fd(wsi);
86 socklen_t addr_size = sizeof(pss->addr);
87 getpeername(fd: fd, addr: (struct sockaddr *)&pss->addr, len: &addr_size);
88 int orig_port = ntohs(netshort: pss->addr.sin_port);
89 pss->send_buffer.Init();
90
91 char addr_str[NETADDR_MAXSTRSIZE];
92 int ip_uint32 = pss->addr.sin_addr.s_addr;
93 str_format(buffer: addr_str, buffer_size: sizeof(addr_str), format: "%d.%d.%d.%d:%d", (ip_uint32)&0xff, (ip_uint32 >> 8) & 0xff, (ip_uint32 >> 16) & 0xff, (ip_uint32 >> 24) & 0xff, orig_port);
94
95 dbg_msg(sys: "websockets", fmt: "connection established with %s", addr_str);
96
97 std::string addr_str_final;
98 addr_str_final.append(s: addr_str);
99
100 pss->addr_str = addr_str_final;
101 ctx_data->port_map[pss->addr_str] = pss;
102 }
103 break;
104
105 case LWS_CALLBACK_CLOSED:
106 {
107 dbg_msg(sys: "websockets", fmt: "connection with addr string %s closed", pss->addr_str.c_str());
108 if(!pss->addr_str.empty())
109 {
110 unsigned char close_packet[] = {0x10, 0x0e, 0x00, 0x04};
111 receive_chunk(ctx_data, pss, in: &close_packet, len: sizeof(close_packet));
112 pss->wsi = 0;
113 ctx_data->port_map.erase(x: pss->addr_str);
114 }
115 }
116 break;
117
118 case LWS_CALLBACK_CLIENT_WRITEABLE:
119 [[fallthrough]];
120 case LWS_CALLBACK_SERVER_WRITEABLE:
121 {
122 websocket_chunk *chunk = (websocket_chunk *)pss->send_buffer.First();
123 if(chunk == NULL)
124 break;
125 int chunk_len = chunk->size - chunk->read;
126 int n =
127 lws_write(wsi, buf: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING + chunk->read],
128 len: chunk->size - chunk->read, protocol: LWS_WRITE_BINARY);
129 if(n < 0)
130 return 1;
131 if(n < chunk_len)
132 {
133 chunk->read += n;
134 lws_callback_on_writable(wsi);
135 break;
136 }
137 pss->send_buffer.PopFirst();
138 lws_callback_on_writable(wsi);
139 }
140 break;
141
142 case LWS_CALLBACK_CLIENT_RECEIVE:
143 [[fallthrough]];
144 case LWS_CALLBACK_RECEIVE:
145 if(pss->addr_str.empty())
146 return -1;
147 if(receive_chunk(ctx_data, pss, in, len))
148 return 1;
149 break;
150
151 default:
152 break;
153 }
154
155 return 0;
156}
157
158static struct lws_protocols protocols[] = {
159 {
160 .name: "binary", /* name */
161 .callback: websocket_callback, /* callback */
162 .per_session_data_size: sizeof(struct per_session_data) /* per_session_data_size */
163 },
164 {
165 NULL, NULL, .per_session_data_size: 0 /* End of list */
166 }};
167
168static context_data contexts[WS_CONTEXTS];
169
170int websocket_create(const char *addr, int port)
171{
172 struct lws_context_creation_info info;
173 mem_zero(block: &info, size: sizeof(info));
174 info.port = port;
175 info.iface = addr;
176 info.protocols = protocols;
177 info.gid = -1;
178 info.uid = -1;
179
180 // find free context
181 int first_free = -1;
182 for(int i = 0; i < WS_CONTEXTS; i++)
183 {
184 if(contexts[i].context == NULL)
185 {
186 first_free = i;
187 break;
188 }
189 }
190 if(first_free == -1)
191 return -1;
192
193 context_data *ctx_data = &contexts[first_free];
194 info.user = (void *)ctx_data;
195
196 ctx_data->context = lws_create_context(info: &info);
197 if(ctx_data->context == NULL)
198 {
199 return -1;
200 }
201 ctx_data->recv_buffer.Init();
202 return first_free;
203}
204
205int websocket_destroy(int socket)
206{
207 lws_context *context = contexts[socket].context;
208 if(context == NULL)
209 return -1;
210 lws_context_destroy(context);
211 contexts[socket].context = NULL;
212 return 0;
213}
214
215int websocket_recv(int socket, unsigned char *data, size_t maxsize,
216 struct sockaddr_in *sockaddrbuf, size_t fromLen)
217{
218 lws_context *context = contexts[socket].context;
219 if(context == NULL)
220 return -1;
221 int n = lws_service(context, timeout_ms: -1);
222 if(n < 0)
223 return n;
224 context_data *ctx_data = (context_data *)lws_context_user(context);
225 websocket_chunk *chunk = (websocket_chunk *)ctx_data->recv_buffer.First();
226 if(chunk == 0)
227 return 0;
228 if(maxsize >= chunk->size - chunk->read)
229 {
230 int len = chunk->size - chunk->read;
231 mem_copy(dest: data, source: &chunk->data[chunk->read], size: len);
232 mem_copy(dest: sockaddrbuf, source: &chunk->addr, size: fromLen);
233 ctx_data->recv_buffer.PopFirst();
234 return len;
235 }
236 else
237 {
238 mem_copy(dest: data, source: &chunk->data[chunk->read], size: maxsize);
239 mem_copy(dest: sockaddrbuf, source: &chunk->addr, size: fromLen);
240 chunk->read += maxsize;
241 return maxsize;
242 }
243}
244
245int websocket_send(int socket, const unsigned char *data, size_t size,
246 const char *addr_str, int port)
247{
248 lws_context *context = contexts[socket].context;
249 if(context == NULL)
250 {
251 return -1;
252 }
253 context_data *ctx_data = (context_data *)lws_context_user(context);
254 char aBuf[100];
255 snprintf(s: aBuf, maxlen: sizeof(aBuf), format: "%s:%d", addr_str, port);
256 std::string addr_str_with_port = std::string(aBuf);
257 struct per_session_data *pss = ctx_data->port_map[addr_str_with_port];
258 if(pss == NULL)
259 {
260 struct lws_client_connect_info ccinfo = {.context: 0};
261 ccinfo.context = context;
262 ccinfo.address = addr_str;
263 ccinfo.port = port;
264 ccinfo.protocol = protocols[0].name;
265 lws *wsi = lws_client_connect_via_info(ccinfo: &ccinfo);
266 if(wsi == NULL)
267 {
268 return -1;
269 }
270 lws_service(context, timeout_ms: -1);
271 pss = ctx_data->port_map[addr_str_with_port];
272 if(pss == NULL)
273 {
274 return -1;
275 }
276 }
277 websocket_chunk *chunk = (websocket_chunk *)pss->send_buffer.Allocate(
278 Size: size + sizeof(websocket_chunk) + LWS_SEND_BUFFER_PRE_PADDING +
279 LWS_SEND_BUFFER_POST_PADDING);
280 if(chunk == NULL)
281 return -1;
282 chunk->size = size;
283 chunk->read = 0;
284 mem_copy(dest: &chunk->addr, source: &pss->addr, size: sizeof(sockaddr_in));
285 mem_copy(dest: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING], source: data, size);
286 lws_callback_on_writable(wsi: pss->wsi);
287 lws_service(context, timeout_ms: -1);
288 return size;
289}
290
291int websocket_fd_set(int socket, fd_set *set)
292{
293 lws_context *context = contexts[socket].context;
294 if(context == NULL)
295 return -1;
296 lws_service(context, timeout_ms: -1);
297 context_data *ctx_data = (context_data *)lws_context_user(context);
298 int max = 0;
299 for(auto const &x : ctx_data->port_map)
300 {
301 if(x.second == NULL)
302 continue;
303 int fd = lws_get_socket_fd(wsi: x.second->wsi);
304 if(fd > max)
305 max = fd;
306 FD_SET(fd, set);
307 }
308 return max;
309}
310
311#endif
312