1 | #if defined(CONF_WEBSOCKETS) |
2 | |
3 | #include <cstdlib> |
4 | #include <map> |
5 | #include <string> |
6 | |
7 | #include "protocol.h" |
8 | #include "ringbuffer.h" |
9 | #include <base/system.h> |
10 | #if defined(CONF_FAMILY_UNIX) |
11 | #include <arpa/inet.h> |
12 | #elif defined(CONF_FAMILY_WINDOWS) |
13 | #include <ws2tcpip.h> |
14 | #endif |
15 | #include <libwebsockets.h> |
16 | |
17 | #include "websockets.h" |
18 | |
19 | // not sure why would anyone need more than one but well... |
20 | #define WS_CONTEXTS 4 |
21 | // ddnet client opens two connections for whatever reason |
22 | #define WS_CLIENTS (MAX_CLIENTS * 2) |
23 | |
24 | typedef CStaticRingBuffer<unsigned char, WS_CLIENTS * 4 * 1024, |
25 | CRingBufferBase::FLAG_RECYCLE> |
26 | TRecvBuffer; |
27 | typedef CStaticRingBuffer<unsigned char, 4 * 1024, |
28 | CRingBufferBase::FLAG_RECYCLE> |
29 | TSendBuffer; |
30 | |
31 | struct websocket_chunk |
32 | { |
33 | size_t size; |
34 | size_t read; |
35 | sockaddr_in addr; |
36 | unsigned char data[0]; |
37 | }; |
38 | |
39 | struct per_session_data |
40 | { |
41 | struct lws *wsi; |
42 | std::string addr_str; |
43 | sockaddr_in addr; |
44 | TSendBuffer send_buffer; |
45 | }; |
46 | |
47 | struct context_data |
48 | { |
49 | lws_context *context; |
50 | std::map<std::string, per_session_data *> port_map; |
51 | TRecvBuffer recv_buffer; |
52 | }; |
53 | |
54 | static int receive_chunk(context_data *ctx_data, struct per_session_data *pss, |
55 | void *in, size_t len) |
56 | { |
57 | websocket_chunk *chunk = (websocket_chunk *)ctx_data->recv_buffer.Allocate( |
58 | Size: len + sizeof(websocket_chunk)); |
59 | if(chunk == 0) |
60 | return 1; |
61 | chunk->size = len; |
62 | chunk->read = 0; |
63 | mem_copy(dest: &chunk->addr, source: &pss->addr, size: sizeof(sockaddr_in)); |
64 | mem_copy(dest: &chunk->data[0], source: in, size: len); |
65 | return 0; |
66 | } |
67 | |
68 | static int websocket_callback(struct lws *wsi, enum lws_callback_reasons reason, |
69 | void *user, void *in, size_t len) |
70 | { |
71 | struct per_session_data *pss = (struct per_session_data *)user; |
72 | lws_context *context = lws_get_context(wsi); |
73 | context_data *ctx_data = (context_data *)lws_context_user(context); |
74 | switch(reason) |
75 | { |
76 | case LWS_CALLBACK_WSI_CREATE: |
77 | if(pss == NULL) |
78 | { |
79 | return 0; |
80 | } |
81 | [[fallthrough]]; |
82 | case LWS_CALLBACK_ESTABLISHED: |
83 | { |
84 | pss->wsi = wsi; |
85 | int fd = lws_get_socket_fd(wsi); |
86 | socklen_t addr_size = sizeof(pss->addr); |
87 | getpeername(fd: fd, addr: (struct sockaddr *)&pss->addr, len: &addr_size); |
88 | int orig_port = ntohs(netshort: pss->addr.sin_port); |
89 | pss->send_buffer.Init(); |
90 | |
91 | char addr_str[NETADDR_MAXSTRSIZE]; |
92 | int ip_uint32 = pss->addr.sin_addr.s_addr; |
93 | str_format(buffer: addr_str, buffer_size: sizeof(addr_str), format: "%d.%d.%d.%d:%d" , (ip_uint32)&0xff, (ip_uint32 >> 8) & 0xff, (ip_uint32 >> 16) & 0xff, (ip_uint32 >> 24) & 0xff, orig_port); |
94 | |
95 | dbg_msg(sys: "websockets" , fmt: "connection established with %s" , addr_str); |
96 | |
97 | std::string addr_str_final; |
98 | addr_str_final.append(s: addr_str); |
99 | |
100 | pss->addr_str = addr_str_final; |
101 | ctx_data->port_map[pss->addr_str] = pss; |
102 | } |
103 | break; |
104 | |
105 | case LWS_CALLBACK_CLOSED: |
106 | { |
107 | dbg_msg(sys: "websockets" , fmt: "connection with addr string %s closed" , pss->addr_str.c_str()); |
108 | if(!pss->addr_str.empty()) |
109 | { |
110 | unsigned char close_packet[] = {0x10, 0x0e, 0x00, 0x04}; |
111 | receive_chunk(ctx_data, pss, in: &close_packet, len: sizeof(close_packet)); |
112 | pss->wsi = 0; |
113 | ctx_data->port_map.erase(x: pss->addr_str); |
114 | } |
115 | } |
116 | break; |
117 | |
118 | case LWS_CALLBACK_CLIENT_WRITEABLE: |
119 | [[fallthrough]]; |
120 | case LWS_CALLBACK_SERVER_WRITEABLE: |
121 | { |
122 | websocket_chunk *chunk = (websocket_chunk *)pss->send_buffer.First(); |
123 | if(chunk == NULL) |
124 | break; |
125 | int chunk_len = chunk->size - chunk->read; |
126 | int n = |
127 | lws_write(wsi, buf: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING + chunk->read], |
128 | len: chunk->size - chunk->read, protocol: LWS_WRITE_BINARY); |
129 | if(n < 0) |
130 | return 1; |
131 | if(n < chunk_len) |
132 | { |
133 | chunk->read += n; |
134 | lws_callback_on_writable(wsi); |
135 | break; |
136 | } |
137 | pss->send_buffer.PopFirst(); |
138 | lws_callback_on_writable(wsi); |
139 | } |
140 | break; |
141 | |
142 | case LWS_CALLBACK_CLIENT_RECEIVE: |
143 | [[fallthrough]]; |
144 | case LWS_CALLBACK_RECEIVE: |
145 | if(pss->addr_str.empty()) |
146 | return -1; |
147 | if(receive_chunk(ctx_data, pss, in, len)) |
148 | return 1; |
149 | break; |
150 | |
151 | default: |
152 | break; |
153 | } |
154 | |
155 | return 0; |
156 | } |
157 | |
158 | static struct lws_protocols protocols[] = { |
159 | { |
160 | .name: "binary" , /* name */ |
161 | .callback: websocket_callback, /* callback */ |
162 | .per_session_data_size: sizeof(struct per_session_data) /* per_session_data_size */ |
163 | }, |
164 | { |
165 | NULL, NULL, .per_session_data_size: 0 /* End of list */ |
166 | }}; |
167 | |
168 | static context_data contexts[WS_CONTEXTS]; |
169 | |
170 | int websocket_create(const char *addr, int port) |
171 | { |
172 | struct lws_context_creation_info info; |
173 | mem_zero(block: &info, size: sizeof(info)); |
174 | info.port = port; |
175 | info.iface = addr; |
176 | info.protocols = protocols; |
177 | info.gid = -1; |
178 | info.uid = -1; |
179 | |
180 | // find free context |
181 | int first_free = -1; |
182 | for(int i = 0; i < WS_CONTEXTS; i++) |
183 | { |
184 | if(contexts[i].context == NULL) |
185 | { |
186 | first_free = i; |
187 | break; |
188 | } |
189 | } |
190 | if(first_free == -1) |
191 | return -1; |
192 | |
193 | context_data *ctx_data = &contexts[first_free]; |
194 | info.user = (void *)ctx_data; |
195 | |
196 | ctx_data->context = lws_create_context(info: &info); |
197 | if(ctx_data->context == NULL) |
198 | { |
199 | return -1; |
200 | } |
201 | ctx_data->recv_buffer.Init(); |
202 | return first_free; |
203 | } |
204 | |
205 | int websocket_destroy(int socket) |
206 | { |
207 | lws_context *context = contexts[socket].context; |
208 | if(context == NULL) |
209 | return -1; |
210 | lws_context_destroy(context); |
211 | contexts[socket].context = NULL; |
212 | return 0; |
213 | } |
214 | |
215 | int websocket_recv(int socket, unsigned char *data, size_t maxsize, |
216 | struct sockaddr_in *sockaddrbuf, size_t fromLen) |
217 | { |
218 | lws_context *context = contexts[socket].context; |
219 | if(context == NULL) |
220 | return -1; |
221 | int n = lws_service(context, timeout_ms: -1); |
222 | if(n < 0) |
223 | return n; |
224 | context_data *ctx_data = (context_data *)lws_context_user(context); |
225 | websocket_chunk *chunk = (websocket_chunk *)ctx_data->recv_buffer.First(); |
226 | if(chunk == 0) |
227 | return 0; |
228 | if(maxsize >= chunk->size - chunk->read) |
229 | { |
230 | int len = chunk->size - chunk->read; |
231 | mem_copy(dest: data, source: &chunk->data[chunk->read], size: len); |
232 | mem_copy(dest: sockaddrbuf, source: &chunk->addr, size: fromLen); |
233 | ctx_data->recv_buffer.PopFirst(); |
234 | return len; |
235 | } |
236 | else |
237 | { |
238 | mem_copy(dest: data, source: &chunk->data[chunk->read], size: maxsize); |
239 | mem_copy(dest: sockaddrbuf, source: &chunk->addr, size: fromLen); |
240 | chunk->read += maxsize; |
241 | return maxsize; |
242 | } |
243 | } |
244 | |
245 | int websocket_send(int socket, const unsigned char *data, size_t size, |
246 | const char *addr_str, int port) |
247 | { |
248 | lws_context *context = contexts[socket].context; |
249 | if(context == NULL) |
250 | { |
251 | return -1; |
252 | } |
253 | context_data *ctx_data = (context_data *)lws_context_user(context); |
254 | char aBuf[100]; |
255 | snprintf(s: aBuf, maxlen: sizeof(aBuf), format: "%s:%d" , addr_str, port); |
256 | std::string addr_str_with_port = std::string(aBuf); |
257 | struct per_session_data *pss = ctx_data->port_map[addr_str_with_port]; |
258 | if(pss == NULL) |
259 | { |
260 | struct lws_client_connect_info ccinfo = {.context: 0}; |
261 | ccinfo.context = context; |
262 | ccinfo.address = addr_str; |
263 | ccinfo.port = port; |
264 | ccinfo.protocol = protocols[0].name; |
265 | lws *wsi = lws_client_connect_via_info(ccinfo: &ccinfo); |
266 | if(wsi == NULL) |
267 | { |
268 | return -1; |
269 | } |
270 | lws_service(context, timeout_ms: -1); |
271 | pss = ctx_data->port_map[addr_str_with_port]; |
272 | if(pss == NULL) |
273 | { |
274 | return -1; |
275 | } |
276 | } |
277 | websocket_chunk *chunk = (websocket_chunk *)pss->send_buffer.Allocate( |
278 | Size: size + sizeof(websocket_chunk) + LWS_SEND_BUFFER_PRE_PADDING + |
279 | LWS_SEND_BUFFER_POST_PADDING); |
280 | if(chunk == NULL) |
281 | return -1; |
282 | chunk->size = size; |
283 | chunk->read = 0; |
284 | mem_copy(dest: &chunk->addr, source: &pss->addr, size: sizeof(sockaddr_in)); |
285 | mem_copy(dest: &chunk->data[LWS_SEND_BUFFER_PRE_PADDING], source: data, size); |
286 | lws_callback_on_writable(wsi: pss->wsi); |
287 | lws_service(context, timeout_ms: -1); |
288 | return size; |
289 | } |
290 | |
291 | int websocket_fd_set(int socket, fd_set *set) |
292 | { |
293 | lws_context *context = contexts[socket].context; |
294 | if(context == NULL) |
295 | return -1; |
296 | lws_service(context, timeout_ms: -1); |
297 | context_data *ctx_data = (context_data *)lws_context_user(context); |
298 | int max = 0; |
299 | for(auto const &x : ctx_data->port_map) |
300 | { |
301 | if(x.second == NULL) |
302 | continue; |
303 | int fd = lws_get_socket_fd(wsi: x.second->wsi); |
304 | if(fd > max) |
305 | max = fd; |
306 | FD_SET(fd, set); |
307 | } |
308 | return max; |
309 | } |
310 | |
311 | #endif |
312 | |