1#include "network.h"
2
3#include <base/dbg.h>
4#include <base/log.h>
5#include <base/mem.h>
6#include <base/net.h>
7#include <base/time.h>
8
9#include <algorithm>
10
11static int IndexFromNetType(int NetType)
12{
13 switch(NetType)
14 {
15 case NETTYPE_IPV6:
16 return 0;
17 case NETTYPE_IPV4:
18 return 1;
19 }
20 return -1;
21}
22
23static const char *IndexToSystem(int Index)
24{
25 switch(Index)
26 {
27 case 0:
28 return "stun/6";
29 case 1:
30 return "stun/4";
31 }
32 dbg_assert_failed("invalid index %d", Index);
33}
34
35static int RetryWaitSeconds(int NumUnsuccessfulTries)
36{
37 return (1 << std::clamp(val: NumUnsuccessfulTries, lo: 0, hi: 9));
38}
39
40CStun::CProtocol::CProtocol(int Index, NETSOCKET Socket) :
41 m_Index(Index),
42 m_Socket(Socket)
43{
44 mem_zero(block: &m_StunServer, size: sizeof(NETADDR));
45 // Initialize `m_Stun` with random data.
46 unsigned char aBuf[32];
47 StunMessagePrepare(pBuffer: aBuf, BufferSize: sizeof(aBuf), pData: &m_Stun);
48}
49
50void CStun::CProtocol::FeedStunServer(NETADDR StunServer)
51{
52 if(m_HaveStunServer && net_addr_comp(a: &m_StunServer, b: &StunServer) == 0)
53 {
54 return;
55 }
56 m_HaveStunServer = true;
57 m_StunServer = StunServer;
58 m_NumUnsuccessfulTries = 0;
59 Refresh();
60}
61
62void CStun::CProtocol::Refresh()
63{
64 m_NextTry = time_get();
65}
66
67void CStun::CProtocol::Update()
68{
69 int64_t Now = time_get();
70 if(m_NextTry == -1 || Now < m_NextTry || !m_HaveStunServer)
71 {
72 return;
73 }
74 m_NextTry = Now + RetryWaitSeconds(NumUnsuccessfulTries: m_NumUnsuccessfulTries) * time_freq();
75 m_NumUnsuccessfulTries += 1;
76 unsigned char aBuf[32];
77 int Size = StunMessagePrepare(pBuffer: aBuf, BufferSize: sizeof(aBuf), pData: &m_Stun);
78 if(net_udp_send(sock: m_Socket, addr: &m_StunServer, data: aBuf, size: Size) == -1)
79 {
80 log_debug(IndexToSystem(m_Index), "couldn't send stun request");
81 return;
82 }
83}
84
85bool CStun::CProtocol::OnPacket(NETADDR Addr, unsigned char *pData, int DataSize)
86{
87 if(m_NextTry < 0 || !m_HaveStunServer)
88 {
89 return false;
90 }
91 bool Success;
92 NETADDR StunAddr;
93 if(StunMessageParse(pMessage: pData, MessageSize: DataSize, pData: &m_Stun, pSuccess: &Success, pAddr: &StunAddr))
94 {
95 return false;
96 }
97 m_LastResponse = time_get();
98 if(!Success)
99 {
100 m_HaveAddr = false;
101 log_debug(IndexToSystem(m_Index), "got error response");
102 return true;
103 }
104 m_NextTry = -1;
105 m_NumUnsuccessfulTries = -1;
106 m_HaveAddr = true;
107 m_Addr = StunAddr;
108
109 char aStunAddr[NETADDR_MAXSTRSIZE];
110 net_addr_str(addr: &StunAddr, string: aStunAddr, max_length: sizeof(aStunAddr), add_port: true);
111 log_debug(IndexToSystem(m_Index), "got address: %s", aStunAddr);
112 return true;
113}
114
115CONNECTIVITY CStun::CProtocol::GetConnectivity(NETADDR *pGlobalAddr)
116{
117 if(!m_HaveStunServer)
118 {
119 return CONNECTIVITY::UNKNOWN;
120 }
121 int64_t Now = time_get();
122 int64_t Freq = time_freq();
123 bool HaveTriedALittle = m_NumUnsuccessfulTries >= 5 && (m_LastResponse == -1 || Now - m_LastResponse >= 30 * Freq);
124 if(m_LastResponse == -1 && !HaveTriedALittle)
125 {
126 return CONNECTIVITY::CHECKING;
127 }
128 else if(HaveTriedALittle)
129 {
130 return CONNECTIVITY::UNREACHABLE;
131 }
132 else if(!m_HaveAddr)
133 {
134 return CONNECTIVITY::REACHABLE;
135 }
136 else
137 {
138 *pGlobalAddr = m_Addr;
139 return CONNECTIVITY::ADDRESS_KNOWN;
140 }
141}
142
143CStun::CStun(NETSOCKET Socket) :
144 m_aProtocols{CProtocol(0, Socket), CProtocol(1, Socket)}
145{
146}
147
148void CStun::FeedStunServer(NETADDR StunServer)
149{
150 int Index = IndexFromNetType(NetType: StunServer.type);
151 if(Index < 0)
152 {
153 return;
154 }
155 m_aProtocols[Index].FeedStunServer(StunServer);
156}
157
158void CStun::Refresh()
159{
160 for(auto &Protocol : m_aProtocols)
161 {
162 Protocol.Refresh();
163 }
164}
165
166void CStun::Update()
167{
168 for(auto &Protocol : m_aProtocols)
169 {
170 Protocol.Update();
171 }
172}
173
174bool CStun::OnPacket(NETADDR Addr, unsigned char *pData, int DataSize)
175{
176 int Index = IndexFromNetType(NetType: Addr.type);
177 if(Index < 0)
178 {
179 return false;
180 }
181 return m_aProtocols[Index].OnPacket(Addr, pData, DataSize);
182}
183
184CONNECTIVITY CStun::GetConnectivity(int NetType, NETADDR *pGlobalAddr)
185{
186 int Index = IndexFromNetType(NetType);
187 dbg_assert(Index != -1, "invalid nettype");
188 return m_aProtocols[Index].GetConnectivity(pGlobalAddr);
189}
190