/* Copyright (C) 2026 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * 0 A.D. is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with 0 A.D. If not, see . */ #include "precompiled.h" #include "NetClientSession.h" #include "lib/code_generation.h" #include "lib/debug.h" #include "network/NetClient.h" #include "network/NetMessage.h" #include "network/NetProtocol.h" #include "network/NetStats.h" #include "ps/CLogger.h" #include "ps/ProfileViewer.h" #include constexpr int NETCLIENT_POLL_TIMEOUT = 50; #include #include #include #include #include #include #include #include #include #include #include #include #include #include struct CNetClientSession::Quic { AddressStorage localAddress; std::unique_ptr credentials; std::unique_ptr session; std::unique_ptr quicConnection; ngtcp2_crypto_conn_ref connectionReference; int fd; std::optional streams; }; namespace { struct CreateSocketResult { int descriptor; AddressStorage address; }; CreateSocketResult CreateSocket(const char* host, const std::uint16_t port) { addrinfo hints{}; hints.ai_flags = AF_UNSPEC; hints.ai_socktype = SOCK_DGRAM; addrinfo* res; const int rv{getaddrinfo(host, fmt::format("{}", port).c_str(), &hints, &res)}; if (rv) throw std::runtime_error{fmt::format("getaddrinfo: {}", gai_strerror(rv))}; std::unique_ptr infoList{res, &freeaddrinfo}; addrinfo* rp; int fd{-1}; for (rp = res; rp; rp = rp->ai_next) { fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (fd != -1) break; } if (fd == -1) throw std::runtime_error{"unable to create a socket"}; return CreateSocketResult{ .descriptor{fd}, .address{.address{*rp->ai_addr}, .length{rp->ai_addrlen}} }; } AddressStorage ConnectSocket(const int fd, const AddressStorage& remoteAddress) { if (connect(fd, &remoteAddress.address.sa, remoteAddress.length)) throw std::runtime_error{fmt::format("connect: {}", strerror(errno))}; ngtcp2_sockaddr_union localAddress; ngtcp2_socklen localAddressLength{sizeof(localAddress)}; if (getsockname(fd, &localAddress.sa, &localAddressLength) == -1) throw std::runtime_error{fmt::format("getsockname: {}", strerror(errno))}; return {localAddress, localAddressLength}; } void ClientGnutlsInit(CNetClientSession::Quic* c) { gnutls_certificate_credentials_t tempCred; const int allocRet{gnutls_certificate_allocate_credentials(&tempCred)}; if (allocRet) { throw std::runtime_error{fmt::format("cred init failed: {}: {}", allocRet, gnutls_strerror(allocRet))}; } c->credentials.reset(tempCred); gnutls_session_t tempSession; if (const int initRet{gnutls_init(&tempSession, GNUTLS_CLIENT | GNUTLS_ENABLE_EARLY_DATA | GNUTLS_NO_END_OF_EARLY_DATA)}) { throw std::runtime_error{fmt::format("gnutls_init: {}", gnutls_strerror(initRet))}; } c->session.reset(tempSession); if (ngtcp2_crypto_gnutls_configure_client_session(c->session.get())) throw std::runtime_error{"ngtcp2_crypto_gnutls_configure_client_session failed"}; if (const int priorityRet{gnutls_priority_set_direct(c->session.get(), TLS_PRIORITY, nullptr)}) { throw std::runtime_error{fmt::format("gnutls_priority_set_direct: {}", gnutls_strerror(priorityRet))}; } gnutls_session_set_ptr(c->session.get(), &c->connectionReference); if (const int setRet{gnutls_credentials_set(c->session.get(), GNUTLS_CRD_CERTIFICATE, c->credentials.get())}) { throw std::runtime_error{fmt::format("gnutls_credentials_set: {}", gnutls_strerror(setRet))}; } } int OpenStream(ngtcp2_conn*, const std::int64_t streamId, void* userData) { CNetClientSession& session{*static_cast(userData)}; session.m_Connected = true; session.m_WasConnected = true; session.m_Quic->streams.emplace(streamId); session.m_IncomingMessages.push(CNetClientSession::ConnectionEstablished{}); return 0; } int OnStreamDataReceive(ngtcp2_conn* conn, const std::uint32_t /*flags*/, const std::int64_t streamId, const std::size_t /*offset*/, const std::uint8_t* data, std::size_t dataSize, void* userData, void*) { auto& session = *static_cast(userData); auto message = session.m_Quic->streams.value().Receive({data, dataSize}); if (message.has_value()) session.m_IncomingMessages.push(new std::vector{std::move(message).value()}); increaseWindow(conn, streamId, dataSize); return 0; } void ClientQuicInit(CNetClientSession& session, const AddressStorage& remote, const AddressStorage& local) { const ngtcp2_path path{ .local{ .addr{const_cast(&local.address.sa)}, .addrlen{local.length}, }, .remote{ .addr{const_cast(&remote.address.sa)}, .addrlen{remote.length}, } }; constexpr ngtcp2_callbacks callbacks{ .client_initial{&ngtcp2_crypto_client_initial_cb}, .recv_crypto_data{&ngtcp2_crypto_recv_crypto_data_cb}, .encrypt{&ngtcp2_crypto_encrypt_cb}, .decrypt{&ngtcp2_crypto_decrypt_cb}, .hp_mask{&ngtcp2_crypto_hp_mask_cb}, .recv_stream_data{&OnStreamDataReceive}, .stream_open{&OpenStream}, .recv_retry{&ngtcp2_crypto_recv_retry_cb}, .rand{&OnRandomRequest}, .get_new_connection_id{&OnNewConnectionIdRequest}, .update_key{&ngtcp2_crypto_update_key_cb}, .delete_crypto_aead_ctx{&ngtcp2_crypto_delete_crypto_aead_ctx_cb}, .delete_crypto_cipher_ctx{&ngtcp2_crypto_delete_crypto_cipher_ctx_cb}, .get_path_challenge_data{&ngtcp2_crypto_get_path_challenge_data_cb}, .version_negotiation{&ngtcp2_crypto_version_negotiation_cb} }; ngtcp2_cid dcid; dcid.datalen = NGTCP2_MIN_INITIAL_DCIDLEN; if (gnutls_rnd(GNUTLS_RND_RANDOM, dcid.data, dcid.datalen)) throw std::runtime_error{"gnutls_rnd failed"}; ngtcp2_cid scid; scid.datalen = 8; if (gnutls_rnd(GNUTLS_RND_RANDOM, scid.data, scid.datalen)) throw std::runtime_error{"gnutls_rnd failed"}; ngtcp2_settings settings; ngtcp2_settings_default(&settings); settings.initial_ts = timestamp(); ngtcp2_transport_params params; ngtcp2_transport_params_default(¶ms); params.initial_max_stream_data_bidi_remote = 128 * KiB; params.initial_max_data = 1 * MiB; params.initial_max_streams_bidi = 1; params.max_udp_payload_size = MAX_UDP_PAYLOAD_SIZE; params.grease_quic_bit = 1; ngtcp2_conn* tempConn; if (const int rv{ngtcp2_conn_client_new(&tempConn, &dcid, &scid, &path, NGTCP2_PROTO_VER_V1, &callbacks, &settings, ¶ms, nullptr, &session)}) { throw std::runtime_error{fmt::format("ngtcp2_conn_client_new: {}", ngtcp2_strerror(rv))}; } session.m_Quic->quicConnection.reset(tempConn); ngtcp2_conn_set_tls_native_handle(session.m_Quic->quicConnection.get(), session.m_Quic->session.get()); } void ClientRead(CNetClientSession::Quic* c) { std::array buf; struct sockaddr_storage addr; iovec iov{ .iov_base = buf.data(), .iov_len = buf.size(), }; msghdr msg{}; msg.msg_name = &addr; msg.msg_iov = &iov; msg.msg_iovlen = 1; ngtcp2_pkt_info pi{}; while (true) { msg.msg_namelen = sizeof(addr); const ssize_t nread{recvmsg(c->fd, &msg, MSG_DONTWAIT)}; if (nread == -1) { if (errno != EAGAIN && errno != EWOULDBLOCK) LOGERROR("recvmsg: %s", strerror(errno)); break; } ngtcp2_path path{ .local{ .addr{&c->localAddress.address.sa}, .addrlen{c->localAddress.length} }, .remote{ .addr{static_cast(msg.msg_name)}, .addrlen{msg.msg_namelen} } }; const int rv{ngtcp2_conn_read_pkt(c->quicConnection.get(), &path, &pi, buf.data(), static_cast(nread), timestamp())}; if (rv != 0) throw std::runtime_error{fmt::format("ngtcp2_conn_read_pkt: {}", ngtcp2_strerror(rv))}; } } void ClientSendDatagram(CNetClientSession::Quic* c, const std::span data) { iovec iov{ .iov_base = const_cast(data.data()), .iov_len = data.size(), }; msghdr msg{}; msg.msg_iov = &iov; msg.msg_iovlen = 1; ssize_t nwrite; do { nwrite = sendmsg(c->fd, &msg, 0); } while (nwrite == -1 && errno == EINTR); if (nwrite == -1) throw std::runtime_error{fmt::format("sendmsg: {}", strerror(errno))}; } void ClientWriteStreams(CNetClientSession::Quic* c) { const ngtcp2_tstamp ts{timestamp()}; ngtcp2_pkt_info pi; std::array buffer; ngtcp2_path_storage ps; ngtcp2_path_storage_zero(&ps); std::uint32_t flags{NGTCP2_WRITE_STREAM_FLAG_MORE}; while (true) { ngtcp2_vec datav; std::int64_t streamId; const auto bytesToSend = c->streams.has_value() ? c->streams.value().PeekData() : std::nullopt; if (c->streams.has_value() && bytesToSend.has_value()) { datav.base = const_cast(bytesToSend->data()); datav.len = bytesToSend->size(); streamId = c->streams.value().m_Id; } else { datav.base = nullptr; datav.len = 0; streamId = -1; if (c->streams.has_value()) flags &= ~NGTCP2_WRITE_STREAM_FLAG_MORE; } ngtcp2_ssize wdatalen; const ngtcp2_ssize nwrite = ngtcp2_conn_writev_stream(c->quicConnection.get(), &ps.path, &pi, buffer.data(), buffer.size(), &wdatalen, flags, streamId, &datav, 1, ts); if (nwrite < 0) { if (nwrite == NGTCP2_ERR_STREAM_DATA_BLOCKED) { LOGWARNING("blocked"); break; } if (nwrite != NGTCP2_ERR_WRITE_MORE) { throw std::runtime_error{fmt::format("ngtcp2_conn_writev_stream: {}", ngtcp2_strerror(static_cast(nwrite)))}; } if (c->streams.has_value() && wdatalen > 0) c->streams.value().MarkSent(wdatalen); continue; } if (nwrite == 0) break; if (c->streams.has_value() && wdatalen > 0) c->streams.value().MarkSent(wdatalen); ClientSendDatagram(c, {buffer.data(), static_cast(nwrite)}); } ngtcp2_conn_update_pkt_tx_time(c->quicConnection.get(), timestamp()); } void ClientHandleExpiry(CNetClientSession::Quic* c) { if (const int rv = ngtcp2_conn_handle_expiry(c->quicConnection.get(), timestamp())) throw std::runtime_error{fmt::format("ngtcp2_conn_handle_expiry: {}", ngtcp2_strerror(rv))}; } ngtcp2_conn *GetConnection(ngtcp2_crypto_conn_ref* conn_ref) { CNetClientSession::Quic* c{static_cast(conn_ref->user_data)}; return c->quicConnection.get(); } void ClientInit(CNetClientSession& session, const char* host, const std::uint16_t port) { *session.m_Quic = CNetClientSession::Quic{}; const auto [descriptor, remoteAddress] = CreateSocket(host, port); session.m_Quic->fd = descriptor; session.m_Quic->localAddress = ConnectSocket(session.m_Quic->fd, remoteAddress); ClientGnutlsInit(session.m_Quic.get()); ClientQuicInit(session, remoteAddress, session.m_Quic->localAddress); session.m_Quic->connectionReference.get_conn = &GetConnection; session.m_Quic->connectionReference.user_data = session.m_Quic.get(); } } // anonymous namespace CNetClientSession::CNetClientSession(CNetClient& client) : m_Client(client), m_FileTransferer(*this), m_Quic(std::make_unique()) { } CNetClientSession::~CNetClientSession() { ENSURE(!m_LoopRunning); constexpr ngtcp2_ccerr reason{ .type{NGTCP2_CCERR_TYPE_TRANSPORT}, .error_code{0} }; std::array buffer; const ngtcp2_ssize amount{ngtcp2_conn_write_connection_close(m_Quic->quicConnection.get(), nullptr, nullptr, buffer.data(), buffer.size(), &reason, timestamp())}; if (amount <= 0) { LOGERROR("closing connection %s", ngtcp2_strerror(amount)); } ClientSendDatagram(m_Quic.get(), {buffer.data(), static_cast(amount)}); } bool CNetClientSession::Connect(const CStr& server, const u16 port) { ENSURE(!m_LoopRunning); ClientInit(*this, server.c_str(), port); ClientWriteStreams(m_Quic.get()); m_Stats = std::make_unique(m_Quic->quicConnection.get()); if (CProfileViewer::IsInitialised()) g_ProfileViewer.AddRootTable(m_Stats.get()); return true; } void CNetClientSession::RunNetLoop(CNetClientSession* session) { ENSURE(!session->m_LoopRunning); session->m_LoopRunning = true; debug_SetThreadName("NetClientSession loop"); while (!session->m_ShouldShutdown) { // ENSURE(session->m_Host && session->m_Server); session->m_FileTransferer.Poll(); try { session->Poll(); } catch (std::runtime_error&) { // Report immediately. LOGMESSAGE("Net client: Disconnected"); session->m_Connected = false; session->m_IncomingMessages.push(Disconnect{}); return; } session->Flush(); // session->m_LastReceivedTime = timestamp() - session->m_Server->lastReceiveTime; // session->m_MeanRTT = session->m_Server->roundTripTime; } session->m_LoopRunning = false; // Deleting the session is handled in this thread as it might outlive the CNetClient. SAFE_DELETE(session); } void CNetClientSession::Shutdown() { m_ShouldShutdown = true; } void CNetClientSession::Poll() { pollfd pfd{ .fd{m_Quic->fd}, .events{POLLIN} }; const int ret{poll(&pfd, 1, NETCLIENT_POLL_TIMEOUT)}; if (ret < 0) { LOGERROR("Error while waiting for poll: %s", std::strerror(errno)); return; } if (ret == 0) { ClientHandleExpiry(m_Quic.get()); ClientWriteStreams(m_Quic.get()); return; } ClientRead(m_Quic.get()); ClientWriteStreams(m_Quic.get()); } void CNetClientSession::Flush() { std::vector* message; while (m_OutgoingMessages.pop(message)) { std::unique_ptr> data{message}; if (m_Quic->streams.has_value()) m_Quic->streams.value().PushData(std::move(*data)); else LOGERROR("no stream to send message"); } } void CNetClientSession::ProcessPolledMessages() { IncommingMessage query{}; while(m_IncomingMessages.pop(query)) { std::visit([&](Message message) { if constexpr (std::same_as) { m_Client.HandleConnect(); } else if constexpr (std::same_as) { m_Client.HandleDisconnect(NDR_UNKNOWN); } else { static_assert(std::same_as*>); std::unique_ptr> data{message}; CNetMessage* msg = CNetMessageFactory::CreateMessage(*data, m_Client.GetScriptInterface()); if (msg) { LOGMESSAGE("Net client: Received message %s of size %lu from server", msg->ToString().c_str(), (unsigned long)msg->GetSerializedLength()); m_Client.HandleMessage(msg); } } }, query); } } bool CNetClientSession::SendMessage(const CNetMessage* message) { // ENSURE(m_Host && m_Server); if (!m_OutgoingMessages.push(new std::vector{CNetHost::CreatePacket(message)})) { LOGERROR("NetClient: Failed to push message on the outgoing queue."); return false; } return true; } u32 CNetClientSession::GetLastReceivedTime() const { return m_LastReceivedTime; } u32 CNetClientSession::GetMeanRTT() const { return m_MeanRTT; }