1#include "network.h"
2
3#include <base/log.h>
4#include <base/system.h>
5
6#include <algorithm>
7
8static int IndexFromNetType(int NetType)
9{
10 switch(NetType)
11 {
12 case NETTYPE_IPV6:
13 return 0;
14 case NETTYPE_IPV4:
15 return 1;
16 }
17 return -1;
18}
19
20static const char *IndexToSystem(int Index)
21{
22 switch(Index)
23 {
24 case 0:
25 return "stun/6";
26 case 1:
27 return "stun/4";
28 }
29 dbg_assert_failed("invalid index %d", Index);
30}
31
32static int RetryWaitSeconds(int NumUnsuccessfulTries)
33{
34 return (1 << std::clamp(val: NumUnsuccessfulTries, lo: 0, hi: 9));
35}
36
37CStun::CProtocol::CProtocol(int Index, NETSOCKET Socket) :
38 m_Index(Index),
39 m_Socket(Socket)
40{
41 mem_zero(block: &m_StunServer, size: sizeof(NETADDR));
42 // Initialize `m_Stun` with random data.
43 unsigned char aBuf[32];
44 StunMessagePrepare(pBuffer: aBuf, BufferSize: sizeof(aBuf), pData: &m_Stun);
45}
46
47void CStun::CProtocol::FeedStunServer(NETADDR StunServer)
48{
49 if(m_HaveStunServer && net_addr_comp(a: &m_StunServer, b: &StunServer) == 0)
50 {
51 return;
52 }
53 m_HaveStunServer = true;
54 m_StunServer = StunServer;
55 m_NumUnsuccessfulTries = 0;
56 Refresh();
57}
58
59void CStun::CProtocol::Refresh()
60{
61 m_NextTry = time_get();
62}
63
64void CStun::CProtocol::Update()
65{
66 int64_t Now = time_get();
67 if(m_NextTry == -1 || Now < m_NextTry || !m_HaveStunServer)
68 {
69 return;
70 }
71 m_NextTry = Now + RetryWaitSeconds(NumUnsuccessfulTries: m_NumUnsuccessfulTries) * time_freq();
72 m_NumUnsuccessfulTries += 1;
73 unsigned char aBuf[32];
74 int Size = StunMessagePrepare(pBuffer: aBuf, BufferSize: sizeof(aBuf), pData: &m_Stun);
75 if(net_udp_send(sock: m_Socket, addr: &m_StunServer, data: aBuf, size: Size) == -1)
76 {
77 log_debug(IndexToSystem(m_Index), "couldn't send stun request");
78 return;
79 }
80}
81
82bool CStun::CProtocol::OnPacket(NETADDR Addr, unsigned char *pData, int DataSize)
83{
84 if(m_NextTry < 0 || !m_HaveStunServer)
85 {
86 return false;
87 }
88 bool Success;
89 NETADDR StunAddr;
90 if(StunMessageParse(pMessage: pData, MessageSize: DataSize, pData: &m_Stun, pSuccess: &Success, pAddr: &StunAddr))
91 {
92 return false;
93 }
94 m_LastResponse = time_get();
95 if(!Success)
96 {
97 m_HaveAddr = false;
98 log_debug(IndexToSystem(m_Index), "got error response");
99 return true;
100 }
101 m_NextTry = -1;
102 m_NumUnsuccessfulTries = -1;
103 m_HaveAddr = true;
104 m_Addr = StunAddr;
105
106 char aStunAddr[NETADDR_MAXSTRSIZE];
107 net_addr_str(addr: &StunAddr, string: aStunAddr, max_length: sizeof(aStunAddr), add_port: true);
108 log_debug(IndexToSystem(m_Index), "got address: %s", aStunAddr);
109 return true;
110}
111
112CONNECTIVITY CStun::CProtocol::GetConnectivity(NETADDR *pGlobalAddr)
113{
114 if(!m_HaveStunServer)
115 {
116 return CONNECTIVITY::UNKNOWN;
117 }
118 int64_t Now = time_get();
119 int64_t Freq = time_freq();
120 bool HaveTriedALittle = m_NumUnsuccessfulTries >= 5 && (m_LastResponse == -1 || Now - m_LastResponse >= 30 * Freq);
121 if(m_LastResponse == -1 && !HaveTriedALittle)
122 {
123 return CONNECTIVITY::CHECKING;
124 }
125 else if(HaveTriedALittle)
126 {
127 return CONNECTIVITY::UNREACHABLE;
128 }
129 else if(!m_HaveAddr)
130 {
131 return CONNECTIVITY::REACHABLE;
132 }
133 else
134 {
135 *pGlobalAddr = m_Addr;
136 return CONNECTIVITY::ADDRESS_KNOWN;
137 }
138}
139
140CStun::CStun(NETSOCKET Socket) :
141 m_aProtocols{CProtocol(0, Socket), CProtocol(1, Socket)}
142{
143}
144
145void CStun::FeedStunServer(NETADDR StunServer)
146{
147 int Index = IndexFromNetType(NetType: StunServer.type);
148 if(Index < 0)
149 {
150 return;
151 }
152 m_aProtocols[Index].FeedStunServer(StunServer);
153}
154
155void CStun::Refresh()
156{
157 for(auto &Protocol : m_aProtocols)
158 {
159 Protocol.Refresh();
160 }
161}
162
163void CStun::Update()
164{
165 for(auto &Protocol : m_aProtocols)
166 {
167 Protocol.Update();
168 }
169}
170
171bool CStun::OnPacket(NETADDR Addr, unsigned char *pData, int DataSize)
172{
173 int Index = IndexFromNetType(NetType: Addr.type);
174 if(Index < 0)
175 {
176 return false;
177 }
178 return m_aProtocols[Index].OnPacket(Addr, pData, DataSize);
179}
180
181CONNECTIVITY CStun::GetConnectivity(int NetType, NETADDR *pGlobalAddr)
182{
183 int Index = IndexFromNetType(NetType);
184 dbg_assert(Index != -1, "invalid nettype");
185 return m_aProtocols[Index].GetConnectivity(pGlobalAddr);
186}
187