/* Copyright (C) 2026 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * 0 A.D. is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with 0 A.D. If not, see . */ #include "precompiled.h" #include "NetServerSession.h" #include "lib/external_libraries/enet.h" #include "network/NetMessage.h" #include "network/NetMessages.h" #include "network/NetServer.h" #include "ps/CLogger.h" #include #include namespace { void SendPacket(const int socketFd, const std::span data, const ngtcp2_addr remote) { iovec iov{ .iov_base{const_cast(data.data())}, .iov_len{data.size()} }; msghdr msg{}; msg.msg_name = remote.addr; msg.msg_namelen = remote.addrlen; msg.msg_iov = &iov; msg.msg_iovlen = 1; ssize_t ret; do ret = sendmsg(socketFd, &msg, MSG_DONTWAIT); while (ret < 0 && errno == EINTR); if (ret < 0) throw std::system_error{errno, std::generic_category(), "Error sending message"}; } ngtcp2_conn* GetConnection(ngtcp2_crypto_conn_ref* connRef) { return reinterpret_cast*>(connRef->user_data)->get(); } int OnReceiveStreamData(ngtcp2_conn* conn, const std::uint32_t /*flags*/, const std::int64_t streamId, const std::uint64_t /*offset*/, const std::uint8_t* data, const std::size_t datalen, void* userData, void* /*streamUserData*/) { CNetServerSession& session{*static_cast(userData)}; const auto messageData = session.m_Connection.m_Stream.value().Receive({data, datalen}); if (messageData.has_value()) { std::unique_ptr message{CNetMessageFactory::CreateMessage(messageData.value(), session.GetServer().GetScriptInterface())}; session.GetServer().HandleMessageReceive(message.get(), &session); } increaseWindow(conn, streamId, datalen); return 0; } int OnAcknowledgedStreamData(ngtcp2_conn*, const std::int64_t, const std::uint64_t offset, const std::uint64_t dataLength, void* userData, void*) { Connection& connection{static_cast(userData)->m_Connection}; Stream& stream{connection.GetStream()}; stream.MarkAcknowledged(offset + dataLength); return 0; } int OnConnect(ngtcp2_conn*, const ngtcp2_encryption_level level, void* userData) { if (level != NGTCP2_ENCRYPTION_LEVEL_1RTT) return 0; Connection& connection{static_cast(userData)->m_Connection}; try { connection.OpenStream(); } catch (const std::runtime_error&) { return NGTCP2_ERR_CALLBACK_FAILURE; } const CSrvHandshakeMessage handshake(CreateHandshake()); connection.GetStream().PushMessage(&handshake); return 0; } constexpr ngtcp2_callbacks callbacks{ .recv_client_initial{&ngtcp2_crypto_recv_client_initial_cb}, .recv_crypto_data{&ngtcp2_crypto_recv_crypto_data_cb}, .encrypt{&ngtcp2_crypto_encrypt_cb}, .decrypt{&ngtcp2_crypto_decrypt_cb}, .hp_mask{&ngtcp2_crypto_hp_mask_cb}, .recv_stream_data{&OnReceiveStreamData}, .acked_stream_data_offset{&OnAcknowledgedStreamData}, .recv_retry{&ngtcp2_crypto_recv_retry_cb}, .rand{&OnRandomRequest}, .get_new_connection_id{&OnNewConnectionIdRequest}, .update_key{&ngtcp2_crypto_update_key_cb}, .delete_crypto_aead_ctx{&ngtcp2_crypto_delete_crypto_aead_ctx_cb}, .delete_crypto_cipher_ctx{&ngtcp2_crypto_delete_crypto_cipher_ctx_cb}, .get_path_challenge_data{&ngtcp2_crypto_get_path_challenge_data_cb}, .recv_tx_key{&OnConnect} }; void WriteToStream(const int socketFd, ngtcp2_conn* conn, Stream* stream, const ngtcp2_addr remote) { std::array buf; ngtcp2_path_storage ps; ngtcp2_path_storage_zero(&ps); ngtcp2_pkt_info pi; const std::uint64_t ts{timestamp()}; std::uint32_t flags{NGTCP2_WRITE_STREAM_FLAG_MORE}; while (true) { ngtcp2_vec datav; std::int64_t stream_id; if (stream) { auto bytesToSend = stream->PeekData(); if (bytesToSend.has_value()) { datav.base = const_cast(bytesToSend->data()); datav.len = bytesToSend->size(); stream_id = stream->m_Id; } else { /* No stream data to be sent */ datav.base = nullptr; datav.len = 0; stream_id = -1; flags &= ~NGTCP2_WRITE_STREAM_FLAG_MORE; } } else { datav.base = NULL; datav.len = 0; stream_id = -1; } ngtcp2_ssize n_read; const ngtcp2_ssize n_written{ngtcp2_conn_writev_stream(conn, &ps.path, &pi, buf.data(), buf.size(), &n_read, flags, stream_id, &datav, 1, ts)}; if (n_written < 0) { if (n_written != NGTCP2_ERR_WRITE_MORE) { throw std::runtime_error{fmt::format("ngtcp2_conn_writev_stream: {}", ngtcp2_strerror(static_cast(n_written)))}; } if (stream && n_read > 0) stream->MarkSent(n_read); continue; } if (n_written == 0) break; if (stream && n_read > 0) stream->MarkSent(n_read); try { SendPacket(socketFd, {buf.data(), static_cast(n_written)}, remote); } catch (std::system_error& e) { if (e.code().value() == EAGAIN || e.code().value() == EWOULDBLOCK) break; throw; } /* No stream data to be sent */ if (stream && datav.len == 0) break; } ngtcp2_conn_update_pkt_tx_time(conn, timestamp()); } std::unique_ptr createTlsSession( const gnutls_certificate_credentials_t cred) { gnutls_session_t tempSession; if (const int ret{gnutls_init(&tempSession, GNUTLS_SERVER | GNUTLS_ENABLE_EARLY_DATA | GNUTLS_NO_END_OF_EARLY_DATA)}) { throw std::runtime_error{fmt::format("gnutls_init: {}", gnutls_strerror(ret))}; } std::unique_ptr session{tempSession}; if (const int ret{gnutls_priority_set_direct(session.get(), TLS_PRIORITY, NULL)}) throw std::runtime_error{fmt::format("gnutls_priority_set_direct: {}", gnutls_strerror(ret))}; if (const int ret{gnutls_credentials_set(session.get(), GNUTLS_CRD_CERTIFICATE, cred)}) throw std::runtime_error{fmt::format("gnutls_credentials_set: {}", gnutls_strerror(ret))}; return session; } } Connection::Connection(CNetServerSession& session, const ngtcp2_settings& settings, gnutls_certificate_credentials_t credentials, const ngtcp2_pkt_hd& header, const ngtcp2_cid& newScid, const ngtcp2_path& path): m_TlsSession{createTlsSession(credentials)}, m_ConnRef{ .get_conn{&GetConnection}, .user_data{&m_QuicConnection} } { ngtcp2_transport_params params; ngtcp2_transport_params_default(¶ms); params.initial_max_stream_data_bidi_local = 128 * KiB; params.initial_max_data = 1 * MiB; params.max_udp_payload_size = MAX_UDP_PAYLOAD_SIZE; ngtcp2_cid_init(¶ms.original_dcid, header.dcid.data, header.dcid.datalen); params.original_dcid_present = true; params.grease_quic_bit = 1; ngtcp2_conn* tempConn; if (const int ret = ngtcp2_conn_server_new(&tempConn, &header.scid, &newScid, &path, header.version, &callbacks, &settings, ¶ms, nullptr, &session)) { throw std::runtime_error{fmt::format("ngtcp2_conn_server_new: {}", ngtcp2_strerror (ret))}; } m_QuicConnection.reset(tempConn); memcpy(&m_LocalAddress, path.local.addr, path.local.addrlen); m_LocalAddressLength = path.local.addrlen; memcpy(&m_RemoteAddress, path.remote.addr, path.remote.addrlen); m_RemoteAddressLength = path.remote.addrlen; ngtcp2_crypto_gnutls_configure_server_session(m_TlsSession.get()); ngtcp2_conn_set_tls_native_handle(m_QuicConnection.get(), m_TlsSession.get()); gnutls_session_set_ptr(m_TlsSession.get(), &m_ConnRef); m_TimerFd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK); if (m_TimerFd < 0) throw std::system_error(errno, std::generic_category(), "timerfd_create"); } void Connection::OpenStream() { std::int64_t streamId; if (ngtcp2_conn_open_bidi_stream(m_QuicConnection.get(), &streamId, nullptr)) throw std::runtime_error{""}; m_Stream.emplace(streamId); } Stream& Connection::GetStream() { return m_Stream.value(); } void Connection::Read(const ngtcp2_addr remote, const std::span data) { const ngtcp2_path path{ .local{ngtcp2_conn_get_path(m_QuicConnection.get())->local}, .remote{remote} }; ngtcp2_pkt_info pi{}; if (const int ret{ngtcp2_conn_read_pkt(m_QuicConnection.get(), &path, &pi, data.data(), data.size(), timestamp())}) { throw std::runtime_error{"Destroy connection"}; } } void Connection::Write(const int socketFd) { WriteToStream(socketFd, m_QuicConnection.get(), nullptr, {&m_RemoteAddress.sa, m_RemoteAddressLength}); if (m_Stream.has_value()) { WriteToStream(socketFd, m_QuicConnection.get(), &m_Stream.value(), {&m_RemoteAddress.sa, m_RemoteAddressLength}); } const ngtcp2_tstamp expiry{ngtcp2_conn_get_expiry(m_QuicConnection.get())}; const ngtcp2_tstamp now{timestamp()}; itimerspec it{}; if (const int ret{timerfd_settime(m_TimerFd, 0, &it, nullptr)}) throw std::system_error{errno, std::generic_category(), "timerfd_settime"}; if (expiry < now) { it.it_value.tv_sec = 0; it.it_value.tv_nsec = 1; } else { it.it_value.tv_sec = (expiry - now) / NGTCP2_SECONDS; it.it_value.tv_nsec = ((expiry - now) % NGTCP2_SECONDS) / NGTCP2_NANOSECONDS; } if (const int ret{timerfd_settime(m_TimerFd, 0, &it, nullptr)}) throw std::system_error{errno, std::generic_category(), "timerfd_settime"}; return; } CNetServerSession::CNetServerSession(CNetServerWorker& server, const int socketFd, const ngtcp2_settings& settings, gnutls_certificate_credentials_t credentials, const ngtcp2_pkt_hd& header, const ngtcp2_cid& newScid, const ngtcp2_path& path) : m_Server(server), m_Connection{*this, settings, credentials, header, newScid, path}, m_FileTransferer(*this), m_SocketFd{socketFd} { } u32 CNetServerSession::GetIPAddress() const { return m_Peer->address.host; } u32 CNetServerSession::GetLastReceivedTime() const { if (!m_Peer) return 0; return enet_time_get() - m_Peer->lastReceiveTime; } u32 CNetServerSession::GetMeanRTT() const { if (!m_Peer) return 0; return m_Peer->roundTripTime; } void CNetServerSession::Disconnect(NetDisconnectReason reason) { if (reason == NDR_UNKNOWN) LOGWARNING("Disconnecting client without communicating the disconnect reason!"); Update((uint)NMT_CONNECTION_LOST, NULL); const ngtcp2_ccerr quicReason{ .type{NGTCP2_CCERR_TYPE_APPLICATION}, .error_code{reason} }; ngtcp2_sockaddr_union local; ngtcp2_sockaddr_union remote; ngtcp2_path path{ .local{.addr{&local.sa}}, .remote{.addr{&remote.sa}} }; std::array buffer; const ngtcp2_ssize amount{ngtcp2_conn_write_connection_close(m_Connection.m_QuicConnection.get(), &path, nullptr, buffer.data(), buffer.size(), &quicReason, timestamp())}; if (amount <= 0) LOGERROR("closing connection %s", ngtcp2_strerror(static_cast(amount))); SendPacket(m_SocketFd, {buffer.data(), static_cast(amount)}, path.remote); } bool CNetServerSession::SendMessage(const CNetMessage* message) { m_Connection.m_Stream.value().PushMessage(message); return true; }