0ad/source/network/NetClientSession.cpp
2026-05-21 18:42:04 +02:00

564 lines
15 KiB
C++

/* Copyright (C) 2026 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* 0 A.D. is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
*/
#include "precompiled.h"
#include "NetClientSession.h"
#include "lib/code_generation.h"
#include "lib/debug.h"
#include "network/NetClient.h"
#include "network/NetMessage.h"
#include "network/NetProtocol.h"
#include "network/NetStats.h"
#include "ps/CLogger.h"
#include "ps/ProfileViewer.h"
#include <cstddef>
constexpr int NETCLIENT_POLL_TIMEOUT = 50;
#include <time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <poll.h>
#include <ngtcp2/ngtcp2.h>
#include <ngtcp2/ngtcp2_crypto.h>
#include <ngtcp2/ngtcp2_crypto_gnutls.h>
#include <gnutls/crypto.h>
#include <gnutls/gnutls.h>
struct CNetClientSession::Quic
{
AddressStorage localAddress;
std::unique_ptr<gnutls_certificate_credentials_st, CredentialsDeleter> credentials;
std::unique_ptr<gnutls_session_int, SessionDeleter> session;
std::unique_ptr<ngtcp2_conn, ConnectionDeleter> quicConnection;
ngtcp2_crypto_conn_ref connectionReference;
int fd;
std::optional<Stream> streams;
};
namespace
{
struct CreateSocketResult
{
int descriptor;
AddressStorage address;
};
CreateSocketResult CreateSocket(const char* host, const std::uint16_t port)
{
addrinfo hints{};
hints.ai_flags = AF_UNSPEC;
hints.ai_socktype = SOCK_DGRAM;
addrinfo* res;
const int rv{getaddrinfo(host, fmt::format("{}", port).c_str(), &hints, &res)};
if (rv)
throw std::runtime_error{fmt::format("getaddrinfo: {}", gai_strerror(rv))};
std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> infoList{res, &freeaddrinfo};
addrinfo* rp;
int fd{-1};
for (rp = res; rp; rp = rp->ai_next)
{
fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (fd != -1)
break;
}
if (fd == -1)
throw std::runtime_error{"unable to create a socket"};
return CreateSocketResult{
.descriptor{fd},
.address{.address{*rp->ai_addr}, .length{rp->ai_addrlen}}
};
}
AddressStorage ConnectSocket(const int fd, const AddressStorage& remoteAddress)
{
if (connect(fd, &remoteAddress.address.sa, remoteAddress.length))
throw std::runtime_error{fmt::format("connect: {}", strerror(errno))};
ngtcp2_sockaddr_union localAddress;
ngtcp2_socklen localAddressLength{sizeof(localAddress)};
if (getsockname(fd, &localAddress.sa, &localAddressLength) == -1)
throw std::runtime_error{fmt::format("getsockname: {}", strerror(errno))};
return {localAddress, localAddressLength};
}
void ClientGnutlsInit(CNetClientSession::Quic* c)
{
gnutls_certificate_credentials_t tempCred;
const int allocRet{gnutls_certificate_allocate_credentials(&tempCred)};
if (allocRet)
{
throw std::runtime_error{fmt::format("cred init failed: {}: {}", allocRet,
gnutls_strerror(allocRet))};
}
c->credentials.reset(tempCred);
gnutls_session_t tempSession;
if (const int initRet{gnutls_init(&tempSession, GNUTLS_CLIENT | GNUTLS_ENABLE_EARLY_DATA |
GNUTLS_NO_END_OF_EARLY_DATA)})
{
throw std::runtime_error{fmt::format("gnutls_init: {}", gnutls_strerror(initRet))};
}
c->session.reset(tempSession);
if (ngtcp2_crypto_gnutls_configure_client_session(c->session.get()))
throw std::runtime_error{"ngtcp2_crypto_gnutls_configure_client_session failed"};
if (const int priorityRet{gnutls_priority_set_direct(c->session.get(), TLS_PRIORITY, nullptr)})
{
throw std::runtime_error{fmt::format("gnutls_priority_set_direct: {}",
gnutls_strerror(priorityRet))};
}
gnutls_session_set_ptr(c->session.get(), &c->connectionReference);
if (const int setRet{gnutls_credentials_set(c->session.get(), GNUTLS_CRD_CERTIFICATE,
c->credentials.get())})
{
throw std::runtime_error{fmt::format("gnutls_credentials_set: {}", gnutls_strerror(setRet))};
}
}
int OpenStream(ngtcp2_conn*, const std::int64_t streamId, void* userData)
{
CNetClientSession& session{*static_cast<CNetClientSession*>(userData)};
session.m_Connected = true;
session.m_WasConnected = true;
session.m_Quic->streams.emplace(streamId);
session.m_IncomingMessages.push(CNetClientSession::ConnectionEstablished{});
return 0;
}
int OnStreamDataReceive(ngtcp2_conn* conn, const std::uint32_t /*flags*/, const std::int64_t streamId,
const std::size_t /*offset*/, const std::uint8_t* data, std::size_t dataSize, void* userData, void*)
{
auto& session = *static_cast<CNetClientSession*>(userData);
auto message = session.m_Quic->streams.value().Receive({data, dataSize});
if (message.has_value())
session.m_IncomingMessages.push(new std::vector<std::uint8_t>{std::move(message).value()});
increaseWindow(conn, streamId, dataSize);
return 0;
}
void ClientQuicInit(CNetClientSession& session, const AddressStorage& remote, const AddressStorage& local)
{
const ngtcp2_path path{
.local{
.addr{const_cast<ngtcp2_sockaddr*>(&local.address.sa)},
.addrlen{local.length},
},
.remote{
.addr{const_cast<ngtcp2_sockaddr*>(&remote.address.sa)},
.addrlen{remote.length},
}
};
constexpr ngtcp2_callbacks callbacks{
.client_initial{&ngtcp2_crypto_client_initial_cb},
.recv_crypto_data{&ngtcp2_crypto_recv_crypto_data_cb},
.encrypt{&ngtcp2_crypto_encrypt_cb},
.decrypt{&ngtcp2_crypto_decrypt_cb},
.hp_mask{&ngtcp2_crypto_hp_mask_cb},
.recv_stream_data{&OnStreamDataReceive},
.stream_open{&OpenStream},
.recv_retry{&ngtcp2_crypto_recv_retry_cb},
.rand{&OnRandomRequest},
.get_new_connection_id{&OnNewConnectionIdRequest},
.update_key{&ngtcp2_crypto_update_key_cb},
.delete_crypto_aead_ctx{&ngtcp2_crypto_delete_crypto_aead_ctx_cb},
.delete_crypto_cipher_ctx{&ngtcp2_crypto_delete_crypto_cipher_ctx_cb},
.get_path_challenge_data{&ngtcp2_crypto_get_path_challenge_data_cb},
.version_negotiation{&ngtcp2_crypto_version_negotiation_cb}
};
ngtcp2_cid dcid;
dcid.datalen = NGTCP2_MIN_INITIAL_DCIDLEN;
if (gnutls_rnd(GNUTLS_RND_RANDOM, dcid.data, dcid.datalen))
throw std::runtime_error{"gnutls_rnd failed"};
ngtcp2_cid scid;
scid.datalen = 8;
if (gnutls_rnd(GNUTLS_RND_RANDOM, scid.data, scid.datalen))
throw std::runtime_error{"gnutls_rnd failed"};
ngtcp2_settings settings;
ngtcp2_settings_default(&settings);
settings.initial_ts = timestamp();
ngtcp2_transport_params params;
ngtcp2_transport_params_default(&params);
params.initial_max_stream_data_bidi_remote = 128 * KiB;
params.initial_max_data = 1 * MiB;
params.initial_max_streams_bidi = 1;
params.max_udp_payload_size = MAX_UDP_PAYLOAD_SIZE;
params.grease_quic_bit = 1;
ngtcp2_conn* tempConn;
if (const int rv{ngtcp2_conn_client_new(&tempConn, &dcid, &scid, &path,
NGTCP2_PROTO_VER_V1, &callbacks, &settings, &params, nullptr, &session)})
{
throw std::runtime_error{fmt::format("ngtcp2_conn_client_new: {}", ngtcp2_strerror(rv))};
}
session.m_Quic->quicConnection.reset(tempConn);
ngtcp2_conn_set_tls_native_handle(session.m_Quic->quicConnection.get(),
session.m_Quic->session.get());
}
void ClientRead(CNetClientSession::Quic* c) {
std::array<std::uint8_t, MAX_UDP_PAYLOAD_SIZE> buf;
struct sockaddr_storage addr;
iovec iov{
.iov_base = buf.data(),
.iov_len = buf.size(),
};
msghdr msg{};
msg.msg_name = &addr;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
ngtcp2_pkt_info pi{};
while (true)
{
msg.msg_namelen = sizeof(addr);
const ssize_t nread{recvmsg(c->fd, &msg, MSG_DONTWAIT)};
if (nread == -1)
{
if (errno != EAGAIN && errno != EWOULDBLOCK)
LOGERROR("recvmsg: %s", strerror(errno));
break;
}
ngtcp2_path path{
.local{
.addr{&c->localAddress.address.sa},
.addrlen{c->localAddress.length}
},
.remote{
.addr{static_cast<ngtcp2_sockaddr*>(msg.msg_name)},
.addrlen{msg.msg_namelen}
}
};
const int rv{ngtcp2_conn_read_pkt(c->quicConnection.get(), &path, &pi, buf.data(),
static_cast<std::size_t>(nread), timestamp())};
if (rv != 0)
throw std::runtime_error{fmt::format("ngtcp2_conn_read_pkt: {}", ngtcp2_strerror(rv))};
}
}
void ClientSendDatagram(CNetClientSession::Quic* c, const std::span<const std::uint8_t> data)
{
iovec iov{
.iov_base = const_cast<std::uint8_t*>(data.data()),
.iov_len = data.size(),
};
msghdr msg{};
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
ssize_t nwrite;
do
{
nwrite = sendmsg(c->fd, &msg, 0);
} while (nwrite == -1 && errno == EINTR);
if (nwrite == -1)
throw std::runtime_error{fmt::format("sendmsg: {}", strerror(errno))};
}
void ClientWriteStreams(CNetClientSession::Quic* c)
{
const ngtcp2_tstamp ts{timestamp()};
ngtcp2_pkt_info pi;
std::array<uint8_t, MAX_UDP_PAYLOAD_SIZE> buffer;
ngtcp2_path_storage ps;
ngtcp2_path_storage_zero(&ps);
std::uint32_t flags{NGTCP2_WRITE_STREAM_FLAG_MORE};
while (true)
{
ngtcp2_vec datav;
std::int64_t streamId;
const auto bytesToSend = c->streams.has_value() ? c->streams.value().PeekData() : std::nullopt;
if (c->streams.has_value() && bytesToSend.has_value())
{
datav.base = const_cast<uint8_t*>(bytesToSend->data());
datav.len = bytesToSend->size();
streamId = c->streams.value().m_Id;
}
else
{
datav.base = nullptr;
datav.len = 0;
streamId = -1;
if (c->streams.has_value())
flags &= ~NGTCP2_WRITE_STREAM_FLAG_MORE;
}
ngtcp2_ssize wdatalen;
const ngtcp2_ssize nwrite = ngtcp2_conn_writev_stream(c->quicConnection.get(), &ps.path, &pi,
buffer.data(), buffer.size(), &wdatalen, flags, streamId, &datav, 1, ts);
if (nwrite < 0)
{
if (nwrite == NGTCP2_ERR_STREAM_DATA_BLOCKED)
{
LOGWARNING("blocked");
break;
}
if (nwrite != NGTCP2_ERR_WRITE_MORE)
{
throw std::runtime_error{fmt::format("ngtcp2_conn_writev_stream: {}",
ngtcp2_strerror(static_cast<int>(nwrite)))};
}
if (c->streams.has_value() && wdatalen > 0)
c->streams.value().MarkSent(wdatalen);
continue;
}
if (nwrite == 0)
break;
if (c->streams.has_value() && wdatalen > 0)
c->streams.value().MarkSent(wdatalen);
ClientSendDatagram(c, {buffer.data(), static_cast<std::size_t>(nwrite)});
}
ngtcp2_conn_update_pkt_tx_time(c->quicConnection.get(), timestamp());
}
void ClientHandleExpiry(CNetClientSession::Quic* c)
{
if (const int rv = ngtcp2_conn_handle_expiry(c->quicConnection.get(), timestamp()))
throw std::runtime_error{fmt::format("ngtcp2_conn_handle_expiry: {}", ngtcp2_strerror(rv))};
}
ngtcp2_conn *GetConnection(ngtcp2_crypto_conn_ref* conn_ref)
{
CNetClientSession::Quic* c{static_cast<CNetClientSession::Quic*>(conn_ref->user_data)};
return c->quicConnection.get();
}
void ClientInit(CNetClientSession& session, const char* host, const std::uint16_t port)
{
*session.m_Quic = CNetClientSession::Quic{};
const auto [descriptor, remoteAddress] = CreateSocket(host, port);
session.m_Quic->fd = descriptor;
session.m_Quic->localAddress = ConnectSocket(session.m_Quic->fd, remoteAddress);
ClientGnutlsInit(session.m_Quic.get());
ClientQuicInit(session, remoteAddress, session.m_Quic->localAddress);
session.m_Quic->connectionReference.get_conn = &GetConnection;
session.m_Quic->connectionReference.user_data = session.m_Quic.get();
}
} // anonymous namespace
CNetClientSession::CNetClientSession(CNetClient& client) :
m_Client(client), m_FileTransferer(*this),
m_Quic(std::make_unique<Quic>())
{
}
CNetClientSession::~CNetClientSession()
{
ENSURE(!m_LoopRunning);
constexpr ngtcp2_ccerr reason{
.type{NGTCP2_CCERR_TYPE_TRANSPORT},
.error_code{0}
};
std::array<std::uint8_t, MAX_UDP_PAYLOAD_SIZE> buffer;
const ngtcp2_ssize amount{ngtcp2_conn_write_connection_close(m_Quic->quicConnection.get(), nullptr, nullptr, buffer.data(),
buffer.size(), &reason, timestamp())};
if (amount <= 0)
{
LOGERROR("closing connection %s", ngtcp2_strerror(amount));
}
ClientSendDatagram(m_Quic.get(), {buffer.data(), static_cast<std::size_t>(amount)});
}
bool CNetClientSession::Connect(const CStr& server, const u16 port)
{
ENSURE(!m_LoopRunning);
ClientInit(*this, server.c_str(), port);
ClientWriteStreams(m_Quic.get());
m_Stats = std::make_unique<CNetStatsTable>(m_Quic->quicConnection.get());
if (CProfileViewer::IsInitialised())
g_ProfileViewer.AddRootTable(m_Stats.get());
return true;
}
void CNetClientSession::RunNetLoop(CNetClientSession* session)
{
ENSURE(!session->m_LoopRunning);
session->m_LoopRunning = true;
debug_SetThreadName("NetClientSession loop");
while (!session->m_ShouldShutdown)
{
// ENSURE(session->m_Host && session->m_Server);
session->m_FileTransferer.Poll();
try {
session->Poll();
}
catch (std::runtime_error&)
{
// Report immediately.
LOGMESSAGE("Net client: Disconnected");
session->m_Connected = false;
session->m_IncomingMessages.push(Disconnect{});
return;
}
session->Flush();
// session->m_LastReceivedTime = timestamp() - session->m_Server->lastReceiveTime;
// session->m_MeanRTT = session->m_Server->roundTripTime;
}
session->m_LoopRunning = false;
// Deleting the session is handled in this thread as it might outlive the CNetClient.
SAFE_DELETE(session);
}
void CNetClientSession::Shutdown()
{
m_ShouldShutdown = true;
}
void CNetClientSession::Poll()
{
pollfd pfd{
.fd{m_Quic->fd},
.events{POLLIN}
};
const int ret{poll(&pfd, 1, NETCLIENT_POLL_TIMEOUT)};
if (ret < 0)
{
LOGERROR("Error while waiting for poll: %s", std::strerror(errno));
return;
}
if (ret == 0)
{
ClientHandleExpiry(m_Quic.get());
ClientWriteStreams(m_Quic.get());
return;
}
ClientRead(m_Quic.get());
ClientWriteStreams(m_Quic.get());
}
void CNetClientSession::Flush()
{
std::vector<std::uint8_t>* message;
while (m_OutgoingMessages.pop(message))
{
std::unique_ptr<std::vector<std::uint8_t>> data{message};
if (m_Quic->streams.has_value())
m_Quic->streams.value().PushData(std::move(*data));
else
LOGERROR("no stream to send message");
}
}
void CNetClientSession::ProcessPolledMessages()
{
IncommingMessage query{};
while(m_IncomingMessages.pop(query))
{
std::visit([&]<typename Message>(Message message)
{
if constexpr (std::same_as<Message, ConnectionEstablished>)
{
m_Client.HandleConnect();
}
else if constexpr (std::same_as<Message, Disconnect>)
{
m_Client.HandleDisconnect(NDR_UNKNOWN);
}
else
{
static_assert(std::same_as<Message, std::vector<std::uint8_t>*>);
std::unique_ptr<std::vector<std::uint8_t>> data{message};
CNetMessage* msg = CNetMessageFactory::CreateMessage(*data, m_Client.GetScriptInterface());
if (msg)
{
LOGMESSAGE("Net client: Received message %s of size %lu from server", msg->ToString().c_str(), (unsigned long)msg->GetSerializedLength());
m_Client.HandleMessage(msg);
}
}
}, query);
}
}
bool CNetClientSession::SendMessage(const CNetMessage* message)
{
// ENSURE(m_Host && m_Server);
if (!m_OutgoingMessages.push(new std::vector{CNetHost::CreatePacket(message)}))
{
LOGERROR("NetClient: Failed to push message on the outgoing queue.");
return false;
}
return true;
}
u32 CNetClientSession::GetLastReceivedTime() const
{
return m_LastReceivedTime;
}
u32 CNetClientSession::GetMeanRTT() const
{
return m_MeanRTT;
}