1#include "network.h"
2
3#include <base/log.h>
4#include <base/system.h>
5
6static int IndexFromNetType(int NetType)
7{
8 switch(NetType)
9 {
10 case NETTYPE_IPV6:
11 return 0;
12 case NETTYPE_IPV4:
13 return 1;
14 }
15 return -1;
16}
17
18static const char *IndexToSystem(int Index)
19{
20 switch(Index)
21 {
22 case 0:
23 return "stun/6";
24 case 1:
25 return "stun/4";
26 }
27 dbg_break();
28 return nullptr;
29}
30
31static int RetryWaitSeconds(int NumUnsuccessfulTries)
32{
33 return (1 << clamp(val: NumUnsuccessfulTries, lo: 0, hi: 9));
34}
35
36CStun::CProtocol::CProtocol(int Index, NETSOCKET Socket) :
37 m_Index(Index),
38 m_Socket(Socket)
39{
40 mem_zero(block: &m_StunServer, size: sizeof(NETADDR));
41 // Initialize `m_Stun` with random data.
42 unsigned char aBuf[32];
43 StunMessagePrepare(pBuffer: aBuf, BufferSize: sizeof(aBuf), pData: &m_Stun);
44}
45
46void CStun::CProtocol::FeedStunServer(NETADDR StunServer)
47{
48 if(m_HaveStunServer && net_addr_comp(a: &m_StunServer, b: &StunServer) == 0)
49 {
50 return;
51 }
52 m_HaveStunServer = true;
53 m_StunServer = StunServer;
54 m_NumUnsuccessfulTries = 0;
55 Refresh();
56}
57
58void CStun::CProtocol::Refresh()
59{
60 m_NextTry = time_get();
61}
62
63void CStun::CProtocol::Update()
64{
65 int64_t Now = time_get();
66 if(m_NextTry == -1 || Now < m_NextTry || !m_HaveStunServer)
67 {
68 return;
69 }
70 m_NextTry = Now + RetryWaitSeconds(NumUnsuccessfulTries: m_NumUnsuccessfulTries) * time_freq();
71 m_NumUnsuccessfulTries += 1;
72 unsigned char aBuf[32];
73 int Size = StunMessagePrepare(pBuffer: aBuf, BufferSize: sizeof(aBuf), pData: &m_Stun);
74 if(net_udp_send(sock: m_Socket, addr: &m_StunServer, data: aBuf, size: Size) == -1)
75 {
76 log_debug(IndexToSystem(m_Index), "couldn't send stun request");
77 return;
78 }
79}
80
81bool CStun::CProtocol::OnPacket(NETADDR Addr, unsigned char *pData, int DataSize)
82{
83 if(m_NextTry < 0 || !m_HaveStunServer)
84 {
85 return false;
86 }
87 bool Success;
88 NETADDR StunAddr;
89 if(StunMessageParse(pMessage: pData, MessageSize: DataSize, pData: &m_Stun, pSuccess: &Success, pAddr: &StunAddr))
90 {
91 return false;
92 }
93 m_LastResponse = time_get();
94 if(!Success)
95 {
96 m_HaveAddr = false;
97 log_debug(IndexToSystem(m_Index), "got error response");
98 return true;
99 }
100 m_NextTry = -1;
101 m_NumUnsuccessfulTries = -1;
102 m_HaveAddr = true;
103 m_Addr = StunAddr;
104
105 char aStunAddr[NETADDR_MAXSTRSIZE];
106 net_addr_str(addr: &StunAddr, string: aStunAddr, max_length: sizeof(aStunAddr), add_port: true);
107 log_debug(IndexToSystem(m_Index), "got address: %s", aStunAddr);
108 return true;
109}
110
111CONNECTIVITY CStun::CProtocol::GetConnectivity(NETADDR *pGlobalAddr)
112{
113 if(!m_HaveStunServer)
114 {
115 return CONNECTIVITY::UNKNOWN;
116 }
117 int64_t Now = time_get();
118 int64_t Freq = time_freq();
119 bool HaveTriedALittle = m_NumUnsuccessfulTries >= 5 && (m_LastResponse == -1 || Now - m_LastResponse >= 30 * Freq);
120 if(m_LastResponse == -1 && !HaveTriedALittle)
121 {
122 return CONNECTIVITY::CHECKING;
123 }
124 else if(HaveTriedALittle)
125 {
126 return CONNECTIVITY::UNREACHABLE;
127 }
128 else if(!m_HaveAddr)
129 {
130 return CONNECTIVITY::REACHABLE;
131 }
132 else
133 {
134 *pGlobalAddr = m_Addr;
135 return CONNECTIVITY::ADDRESS_KNOWN;
136 }
137}
138
139CStun::CStun(NETSOCKET Socket) :
140 m_aProtocols{CProtocol(0, Socket), CProtocol(1, Socket)}
141{
142}
143
144void CStun::FeedStunServer(NETADDR StunServer)
145{
146 int Index = IndexFromNetType(NetType: StunServer.type);
147 if(Index < 0)
148 {
149 return;
150 }
151 m_aProtocols[Index].FeedStunServer(StunServer);
152}
153
154void CStun::Refresh()
155{
156 for(auto &Protocol : m_aProtocols)
157 {
158 Protocol.Refresh();
159 }
160}
161
162void CStun::Update()
163{
164 for(auto &Protocol : m_aProtocols)
165 {
166 Protocol.Update();
167 }
168}
169
170bool CStun::OnPacket(NETADDR Addr, unsigned char *pData, int DataSize)
171{
172 int Index = IndexFromNetType(NetType: Addr.type);
173 if(Index < 0)
174 {
175 return false;
176 }
177 return m_aProtocols[Index].OnPacket(Addr, pData, DataSize);
178}
179
180CONNECTIVITY CStun::GetConnectivity(int NetType, NETADDR *pGlobalAddr)
181{
182 int Index = IndexFromNetType(NetType);
183 dbg_assert(Index != -1, "invalid nettype");
184 return m_aProtocols[Index].GetConnectivity(pGlobalAddr);
185}
186