0ad/source/network/StunClient.cpp
josue 05afbf1805 Deduplicate the hole punching pacing between client and server
The synchronous client-side SendHolePunchingMessages and the server
worker each implemented the fw_punch config reading, the target
resolution and the message pacing. Move them to a
StunClient::HolePuncher class which sends due messages without
blocking, and express the blocking client-side variant as a loop over
it. The single-message StunClient::SendHolePunchingMessage is no
longer needed.

As a side effect, the client no longer sleeps for fw_punch.delay after
the last message, and skips sending if the server address cannot be
resolved instead of punching an unresolvable address.

Requested by Phosit in the review of #8977.
2026-06-12 18:56:32 +02:00

450 lines
12 KiB
C++

/* Copyright (C) 2026 Wildfire Games.
* Copyright (C) 2013-2016 SuperTuxKart-Team.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* 0 A.D. is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
*/
#include "precompiled.h"
#include "StunClient.h"
#include "lib/byte_order.h"
#include "lib/code_annotation.h"
#include "lib/external_libraries/enet.h"
#include "ps/CLogger.h"
#include "ps/CStr.h"
#include "ps/ConfigDB.h"
#include <algorithm>
#include <bit>
#include <cerrno>
#include <chrono>
#include <concepts>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <optional>
#include <thread>
#include <type_traits>
#include <vector>
namespace StunClient
{
/**
* These constants are defined in Section 6 of RFC 5389.
*/
const u32 m_MagicCookie = 0x2112A442;
const u16 m_MethodTypeBinding = 0x01;
const u32 m_BindingSuccessResponse = 0x0101;
/**
* Bit determining whether comprehension of an attribute is optional.
* Described in Section 15 of RFC 5389.
*/
const u16 m_ComprehensionOptional = 0x1 << 15;
/**
* Bit determining whether the bit was assigned by IETF Review.
* Described in section 18.1. of RFC 5389.
*/
const u16 m_IETFReview = 0x1 << 14;
/**
* These constants are defined in Section 15.1 of RFC 5389.
*/
const u8 m_IPAddressFamilyIPv4 = 0x01;
/**
* These constants are defined in Section 18.2 of RFC 5389.
*/
const u16 m_AttrTypeMappedAddress = 0x001;
const u16 m_AttrTypeXORMappedAddress = 0x0020;
/**
* Described in section 3 of RFC 5389.
*/
u8 m_TransactionID[12];
ENetAddress m_StunServer;
/**
* Public IP + port discovered via the STUN transaction.
*/
ENetAddress m_PublicAddress;
/**
* Push integral type to a network-byte-order buffer.
*/
template<std::integral T, size_t n = sizeof(T)>
void AddToBuffer(std::vector<u8>& buffer, const T value)
{
buffer.reserve(buffer.size() + n);
// std::byte* can alias anything so this is legal.
const std::byte* ptr = reinterpret_cast<const std::byte*>(&value);
for (size_t a = 0; a < n; ++a)
{
if constexpr (std::endian::native == std::endian::little)
buffer.push_back(static_cast<u8>(*(ptr + n - 1 - a)));
else
buffer.push_back(static_cast<u8>(*(ptr + a)));
}
}
/**
* Read integral type from a network-byte-order buffer.
*/
template<std::integral T, size_t n = sizeof(T)>
bool GetFromBuffer(const std::vector<u8>& buffer, u32& offset, T& result)
{
if (offset + n > buffer.size())
return false;
// std::byte* can alias anything so this is legal.
std::byte* ptr = reinterpret_cast<std::byte*>(&result);
for (size_t a = 0; a < n; ++a)
{
if constexpr (std::endian::native == std::endian::little)
*ptr++ = static_cast<std::byte>(buffer[offset + n - 1 - a]);
else
*ptr++ = static_cast<std::byte>(buffer[offset + a]);
}
offset += n;
return true;
}
void SendStunRequest(ENetHost& transactionHost, ENetAddress addr)
{
std::vector<u8> buffer;
AddToBuffer<u16>(buffer, m_MethodTypeBinding);
AddToBuffer<u16>(buffer, 0); // length
AddToBuffer<u32>(buffer, m_MagicCookie);
for (std::size_t i = 0; i < sizeof(m_TransactionID); ++i)
{
u8 random_byte = rand() % 256;
buffer.push_back(random_byte);
m_TransactionID[i] = random_byte;
}
ENetBuffer enetBuffer;
enetBuffer.data = buffer.data();
enetBuffer.dataLength = buffer.size();
enet_socket_send(transactionHost.socket, &addr, &enetBuffer, 1);
}
/**
* Creates a STUN request and sends it to a STUN server.
* The request is sent through transactionHost, from which the answer
* will be retrieved by ReceiveStunResponse and interpreted by ParseStunResponse.
*/
bool CreateStunRequest(ENetHost& transactionHost)
{
const std::string server_name{g_ConfigDB.Get("lobby.stun.server", std::string{})};
const int port{g_ConfigDB.Get("lobby.stun.port", 0)};
LOGMESSAGE("StunClient: Using STUN server %s:%d\n", server_name.c_str(), port);
ENetAddress addr;
addr.port = port;
if (enet_address_set_host(&addr, server_name.c_str()) == -1)
return false;
m_StunServer = addr;
StunClient::SendStunRequest(transactionHost, addr);
return true;
}
/**
* Gets the response from the STUN server and checks it for its validity.
*/
bool ReceiveStunResponse(ENetHost& transactionHost, std::vector<u8>& buffer)
{
// TransportAddress sender;
const int LEN = 2048;
char input_buffer[LEN];
memset(input_buffer, 0, LEN);
ENetBuffer enetBuffer;
enetBuffer.data = input_buffer;
enetBuffer.dataLength = LEN;
ENetAddress sender = m_StunServer;
int len = enet_socket_receive(transactionHost.socket, &sender, &enetBuffer, 1);
const int delay{g_ConfigDB.Get("lobby.stun.delay", 10)};
const int maxTries{g_ConfigDB.Get("lobby.stun.max_tries", 100)};
// Wait to receive the message because enet sockets are non-blocking
for (int count = 0; len <= 0 && (count < maxTries || maxTries == -1); ++count)
{
std::this_thread::sleep_for(std::chrono::milliseconds(delay));
len = enet_socket_receive(transactionHost.socket, &sender, &enetBuffer, 1);
}
if (len <= 0)
{
LOGERROR("ReceiveStunResponse: recvfrom error (%d): %s", errno, strerror(errno));
return false;
}
if (memcmp(&sender, &m_StunServer, sizeof(m_StunServer)) != 0)
LOGERROR("ReceiveStunResponse: Received stun response from different address: %d.%d.%d.%d:%d %s",
(sender.host >> 24) & 0xff,
(sender.host >> 16) & 0xff,
(sender.host >> 8) & 0xff,
(sender.host >> 0) & 0xff,
sender.port,
input_buffer);
// Convert to network string.
buffer.resize(len);
memcpy(buffer.data(), reinterpret_cast<u8*>(input_buffer), len);
return true;
}
bool ParseStunResponse(const std::vector<u8>& buffer)
{
u32 offset = 0;
u16 responseType = 0;
if (!GetFromBuffer(buffer, offset, responseType) || responseType != m_BindingSuccessResponse)
{
LOGERROR("STUN response isn't a binding success response");
return false;
}
// Ignore message size
offset += 2;
u32 cookie = 0;
if (!GetFromBuffer(buffer, offset, cookie) || cookie != m_MagicCookie)
{
LOGERROR("STUN response doesn't contain the magic cookie");
return false;
}
for (std::size_t i = 0; i < sizeof(m_TransactionID); ++i)
{
u8 transactionChar = 0;
if (!GetFromBuffer(buffer, offset, transactionChar) || transactionChar != m_TransactionID[i])
{
LOGERROR("STUN response doesn't contain the transaction ID");
return false;
}
}
while (offset < buffer.size())
{
u16 type = 0;
u16 size = 0;
if (!GetFromBuffer(buffer, offset, type) ||
!GetFromBuffer(buffer, offset, size))
{
LOGERROR("STUN response contains invalid attribute");
return false;
}
// The first two bits are irrelevant to the type
type &= ~(m_ComprehensionOptional | m_IETFReview);
switch (type)
{
case m_AttrTypeMappedAddress:
case m_AttrTypeXORMappedAddress:
{
if (size != 8)
{
LOGERROR("Invalid STUN Mapped Address length");
return false;
}
// Ignore the first byte as mentioned in Section 15.1 of RFC 5389.
++offset;
u8 ipFamily = 0;
if (!GetFromBuffer(buffer, offset, ipFamily) || ipFamily != m_IPAddressFamilyIPv4)
{
LOGERROR("Unsupported address family, IPv4 is expected");
return false;
}
u16 port = 0;
u32 ip = 0;
if (!GetFromBuffer(buffer, offset, port) ||
!GetFromBuffer(buffer, offset, ip))
{
LOGERROR("Mapped address doesn't contain IP and port");
return false;
}
// Obfuscation is described in Section 15.2 of RFC 5389.
if (type == m_AttrTypeXORMappedAddress)
{
port ^= m_MagicCookie >> 16;
ip ^= m_MagicCookie;
}
// ENetAddress takes a host byte-order port and network byte-order IP.
// Network byte order is big endian, so convert appropriately.
m_PublicAddress.host = to_be32(ip);
m_PublicAddress.port = port;
break;
}
default:
{
// We don't care about other attributes at all
// Skip attribute
offset += size;
// Skip padding
int padding = size % 4;
if (padding)
offset += 4 - padding;
break;
}
}
}
return true;
}
bool STUNRequestAndResponse(ENetHost& transactionHost)
{
if (!CreateStunRequest(transactionHost))
return false;
std::vector<u8> buffer;
return ReceiveStunResponse(transactionHost, buffer) &&
ParseStunResponse(buffer);
}
bool FindPublicIP(ENetHost& transactionHost, CStr& ip, u16& port)
{
if (!STUNRequestAndResponse(transactionHost))
return false;
// Convert m_IP to string
char ipStr[256] = "(error)";
enet_address_get_host_ip(&m_PublicAddress, ipStr, ARRAY_SIZE(ipStr));
ip = ipStr;
port = m_PublicAddress.port;
LOGMESSAGE("StunClient: external IP address is %s:%i", ip.c_str(), port);
return true;
}
void SendHolePunchingMessages(ENetHost& enetClient, const std::string& serverAddress, u16 serverPort)
{
HolePuncher puncher;
if (!puncher.AddTarget(serverAddress, serverPort))
return;
while (const std::optional<std::chrono::steady_clock::time_point> nextSendTime = puncher.Tick(enetClient))
std::this_thread::sleep_until(*nextSendTime);
}
bool HolePuncher::AddTarget(const std::string& address, u16 port)
{
const int numMessages{g_ConfigDB.Get("lobby.fw_punch.num_msg", 3)};
if (numMessages == 0)
return false;
ENetAddress addr;
addr.port = port;
if (enet_address_set_host(&addr, address.c_str()) != 0)
{
LOGWARNING("StunClient: failed to resolve hole punching target %s", address.c_str());
return false;
}
m_Targets.push_back({addr, numMessages, std::chrono::steady_clock::now()});
return true;
}
void HolePuncher::RemoveTarget(const ENetAddress& address)
{
std::erase_if(m_Targets, [&](const Target& target)
{
return target.address.host == address.host && target.address.port == address.port;
});
}
std::optional<std::chrono::steady_clock::time_point> HolePuncher::Tick(ENetHost& enetClient)
{
const std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now();
const std::chrono::milliseconds delay{g_ConfigDB.Get("lobby.fw_punch.delay", 200)};
for (Target& target : m_Targets)
{
if (now < target.nextSendTime)
continue;
SendStunRequest(enetClient, target.address);
if (target.remainingMessages > 0)
--target.remainingMessages;
target.nextSendTime = now + delay;
}
std::erase_if(m_Targets, [](const Target& target) { return target.remainingMessages == 0; });
if (m_Targets.empty())
return std::nullopt;
return std::min_element(m_Targets.begin(), m_Targets.end(),
[](const Target& lhs, const Target& rhs) { return lhs.nextSendTime < rhs.nextSendTime; })->nextSendTime;
}
bool FindLocalIP(CStr& ip)
{
// Open an UDP socket.
ENetSocket socket = enet_socket_create(ENET_SOCKET_TYPE_DATAGRAM);
ENetAddress addr;
addr.port = 9; // Use the debug port (which we pick does not matter).
// Connect to a random address. It does not need to be valid, only to not be the loopback address.
if (enet_address_set_host(&addr, "100.0.100.0") == -1)
return false;
// Connect the socket. Being UDP, there is no actual outgoing traffic, this just binds it
// to a valid port locally, allowing us to get the local IP of the machine.
if (enet_socket_connect(socket, &addr) == -1)
return false;
// Fetch the local port & IP.
if (enet_socket_get_address(socket, &addr) == -1)
return false;
enet_socket_destroy(socket);
// Convert to a human readable string.
char buf[50];
if (enet_address_get_host_ip(&addr, buf, ARRAY_SIZE(buf)) == -1)
return false;
ip = buf;
return true;
}
}