0ad/source/network/NetServerSession.cpp
2026-05-21 18:42:04 +02:00

391 lines
11 KiB
C++

/* Copyright (C) 2026 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* 0 A.D. is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
*/
#include "precompiled.h"
#include "NetServerSession.h"
#include "lib/external_libraries/enet.h"
#include "network/NetMessage.h"
#include "network/NetMessages.h"
#include "network/NetServer.h"
#include "ps/CLogger.h"
#include <gnutls/crypto.h>
#include <gnutls/gnutls.h>
namespace
{
void SendPacket(const int socketFd, const std::span<const std::uint8_t> data, const ngtcp2_addr remote)
{
iovec iov{
.iov_base{const_cast<std::uint8_t*>(data.data())},
.iov_len{data.size()}
};
msghdr msg{};
msg.msg_name = remote.addr;
msg.msg_namelen = remote.addrlen;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
ssize_t ret;
do
ret = sendmsg(socketFd, &msg, MSG_DONTWAIT);
while (ret < 0 && errno == EINTR);
if (ret < 0)
throw std::system_error{errno, std::generic_category(), "Error sending message"};
}
ngtcp2_conn* GetConnection(ngtcp2_crypto_conn_ref* connRef)
{
return reinterpret_cast<std::unique_ptr<ngtcp2_conn, ConnectionDeleter>*>(connRef->user_data)->get();
}
int OnReceiveStreamData(ngtcp2_conn* conn, const std::uint32_t /*flags*/, const std::int64_t streamId,
const std::uint64_t /*offset*/, const std::uint8_t* data, const std::size_t datalen, void* userData,
void* /*streamUserData*/)
{
CNetServerSession& session{*static_cast<CNetServerSession*>(userData)};
const auto messageData = session.m_Connection.m_Stream.value().Receive({data, datalen});
if (messageData.has_value())
{
std::unique_ptr<CNetMessage> message{CNetMessageFactory::CreateMessage(messageData.value(),
session.GetServer().GetScriptInterface())};
session.GetServer().HandleMessageReceive(message.get(), &session);
}
increaseWindow(conn, streamId, datalen);
return 0;
}
int OnAcknowledgedStreamData(ngtcp2_conn*, const std::int64_t, const std::uint64_t offset,
const std::uint64_t dataLength, void* userData, void*)
{
Connection& connection{static_cast<CNetServerSession*>(userData)->m_Connection};
Stream& stream{connection.GetStream()};
stream.MarkAcknowledged(offset + dataLength);
return 0;
}
int OnConnect(ngtcp2_conn*, const ngtcp2_encryption_level level, void* userData)
{
if (level != NGTCP2_ENCRYPTION_LEVEL_1RTT)
return 0;
Connection& connection{static_cast<CNetServerSession*>(userData)->m_Connection};
try
{
connection.OpenStream();
}
catch (const std::runtime_error&)
{
return NGTCP2_ERR_CALLBACK_FAILURE;
}
const CSrvHandshakeMessage handshake(CreateHandshake<CSrvHandshakeMessage>());
connection.GetStream().PushMessage(&handshake);
return 0;
}
constexpr ngtcp2_callbacks callbacks{
.recv_client_initial{&ngtcp2_crypto_recv_client_initial_cb},
.recv_crypto_data{&ngtcp2_crypto_recv_crypto_data_cb},
.encrypt{&ngtcp2_crypto_encrypt_cb},
.decrypt{&ngtcp2_crypto_decrypt_cb},
.hp_mask{&ngtcp2_crypto_hp_mask_cb},
.recv_stream_data{&OnReceiveStreamData},
.acked_stream_data_offset{&OnAcknowledgedStreamData},
.recv_retry{&ngtcp2_crypto_recv_retry_cb},
.rand{&OnRandomRequest},
.get_new_connection_id{&OnNewConnectionIdRequest},
.update_key{&ngtcp2_crypto_update_key_cb},
.delete_crypto_aead_ctx{&ngtcp2_crypto_delete_crypto_aead_ctx_cb},
.delete_crypto_cipher_ctx{&ngtcp2_crypto_delete_crypto_cipher_ctx_cb},
.get_path_challenge_data{&ngtcp2_crypto_get_path_challenge_data_cb},
.recv_tx_key{&OnConnect}
};
void WriteToStream(const int socketFd, ngtcp2_conn* conn, Stream* stream, const ngtcp2_addr remote)
{
std::array<std::uint8_t, MAX_UDP_PAYLOAD_SIZE> buf;
ngtcp2_path_storage ps;
ngtcp2_path_storage_zero(&ps);
ngtcp2_pkt_info pi;
const std::uint64_t ts{timestamp()};
std::uint32_t flags{NGTCP2_WRITE_STREAM_FLAG_MORE};
while (true)
{
ngtcp2_vec datav;
std::int64_t stream_id;
if (stream)
{
auto bytesToSend = stream->PeekData();
if (bytesToSend.has_value())
{
datav.base = const_cast<uint8_t*>(bytesToSend->data());
datav.len = bytesToSend->size();
stream_id = stream->m_Id;
}
else
{
/* No stream data to be sent */
datav.base = nullptr;
datav.len = 0;
stream_id = -1;
flags &= ~NGTCP2_WRITE_STREAM_FLAG_MORE;
}
}
else
{
datav.base = NULL;
datav.len = 0;
stream_id = -1;
}
ngtcp2_ssize n_read;
const ngtcp2_ssize n_written{ngtcp2_conn_writev_stream(conn, &ps.path, &pi, buf.data(),
buf.size(), &n_read, flags, stream_id, &datav, 1, ts)};
if (n_written < 0)
{
if (n_written != NGTCP2_ERR_WRITE_MORE)
{
throw std::runtime_error{fmt::format("ngtcp2_conn_writev_stream: {}",
ngtcp2_strerror(static_cast<int>(n_written)))};
}
if (stream && n_read > 0)
stream->MarkSent(n_read);
continue;
}
if (n_written == 0)
break;
if (stream && n_read > 0)
stream->MarkSent(n_read);
try
{
SendPacket(socketFd, {buf.data(), static_cast<std::size_t>(n_written)}, remote);
}
catch (std::system_error& e)
{
if (e.code().value() == EAGAIN || e.code().value() == EWOULDBLOCK)
break;
throw;
}
/* No stream data to be sent */
if (stream && datav.len == 0)
break;
}
ngtcp2_conn_update_pkt_tx_time(conn, timestamp());
}
std::unique_ptr<gnutls_session_int, SessionDeleter> createTlsSession(
const gnutls_certificate_credentials_t cred)
{
gnutls_session_t tempSession;
if (const int ret{gnutls_init(&tempSession, GNUTLS_SERVER | GNUTLS_ENABLE_EARLY_DATA |
GNUTLS_NO_END_OF_EARLY_DATA)})
{
throw std::runtime_error{fmt::format("gnutls_init: {}", gnutls_strerror(ret))};
}
std::unique_ptr<gnutls_session_int, SessionDeleter> session{tempSession};
if (const int ret{gnutls_priority_set_direct(session.get(), TLS_PRIORITY, NULL)})
throw std::runtime_error{fmt::format("gnutls_priority_set_direct: {}", gnutls_strerror(ret))};
if (const int ret{gnutls_credentials_set(session.get(), GNUTLS_CRD_CERTIFICATE, cred)})
throw std::runtime_error{fmt::format("gnutls_credentials_set: {}", gnutls_strerror(ret))};
return session;
}
}
Connection::Connection(CNetServerSession& session, const ngtcp2_settings& settings,
gnutls_certificate_credentials_t credentials, const ngtcp2_pkt_hd& header,
const ngtcp2_cid& newScid, const ngtcp2_path& path):
m_TlsSession{createTlsSession(credentials)},
m_ConnRef{
.get_conn{&GetConnection},
.user_data{&m_QuicConnection}
}
{
ngtcp2_transport_params params;
ngtcp2_transport_params_default(&params);
params.initial_max_stream_data_bidi_local = 128 * KiB;
params.initial_max_data = 1 * MiB;
params.max_udp_payload_size = MAX_UDP_PAYLOAD_SIZE;
ngtcp2_cid_init(&params.original_dcid, header.dcid.data, header.dcid.datalen);
params.original_dcid_present = true;
params.grease_quic_bit = 1;
ngtcp2_conn* tempConn;
if (const int ret = ngtcp2_conn_server_new(&tempConn, &header.scid, &newScid, &path, header.version,
&callbacks, &settings, &params, nullptr, &session))
{
throw std::runtime_error{fmt::format("ngtcp2_conn_server_new: {}",
ngtcp2_strerror (ret))};
}
m_QuicConnection.reset(tempConn);
memcpy(&m_LocalAddress, path.local.addr, path.local.addrlen);
m_LocalAddressLength = path.local.addrlen;
memcpy(&m_RemoteAddress, path.remote.addr, path.remote.addrlen);
m_RemoteAddressLength = path.remote.addrlen;
ngtcp2_crypto_gnutls_configure_server_session(m_TlsSession.get());
ngtcp2_conn_set_tls_native_handle(m_QuicConnection.get(), m_TlsSession.get());
gnutls_session_set_ptr(m_TlsSession.get(), &m_ConnRef);
m_TimerFd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK);
if (m_TimerFd < 0)
throw std::system_error(errno, std::generic_category(), "timerfd_create");
}
void Connection::OpenStream()
{
std::int64_t streamId;
if (ngtcp2_conn_open_bidi_stream(m_QuicConnection.get(), &streamId, nullptr))
throw std::runtime_error{""};
m_Stream.emplace(streamId);
}
Stream& Connection::GetStream()
{
return m_Stream.value();
}
void Connection::Read(const ngtcp2_addr remote, const std::span<std::uint8_t> data)
{
const ngtcp2_path path{
.local{ngtcp2_conn_get_path(m_QuicConnection.get())->local},
.remote{remote}
};
ngtcp2_pkt_info pi{};
if (const int ret{ngtcp2_conn_read_pkt(m_QuicConnection.get(), &path, &pi, data.data(), data.size(),
timestamp())})
{
throw std::runtime_error{"Destroy connection"};
}
}
void Connection::Write(const int socketFd)
{
WriteToStream(socketFd, m_QuicConnection.get(), nullptr,
{&m_RemoteAddress.sa, m_RemoteAddressLength});
if (m_Stream.has_value())
{
WriteToStream(socketFd, m_QuicConnection.get(), &m_Stream.value(),
{&m_RemoteAddress.sa, m_RemoteAddressLength});
}
const ngtcp2_tstamp expiry{ngtcp2_conn_get_expiry(m_QuicConnection.get())};
const ngtcp2_tstamp now{timestamp()};
itimerspec it{};
if (const int ret{timerfd_settime(m_TimerFd, 0, &it, nullptr)})
throw std::system_error{errno, std::generic_category(), "timerfd_settime"};
if (expiry < now)
{
it.it_value.tv_sec = 0;
it.it_value.tv_nsec = 1;
}
else
{
it.it_value.tv_sec = (expiry - now) / NGTCP2_SECONDS;
it.it_value.tv_nsec = ((expiry - now) % NGTCP2_SECONDS) / NGTCP2_NANOSECONDS;
}
if (const int ret{timerfd_settime(m_TimerFd, 0, &it, nullptr)})
throw std::system_error{errno, std::generic_category(), "timerfd_settime"};
return;
}
CNetServerSession::CNetServerSession(CNetServerWorker& server, const int socketFd,
const ngtcp2_settings& settings, gnutls_certificate_credentials_t credentials,
const ngtcp2_pkt_hd& header, const ngtcp2_cid& newScid, const ngtcp2_path& path) :
m_Server(server),
m_Connection{*this, settings, credentials, header, newScid, path},
m_FileTransferer(*this),
m_SocketFd{socketFd}
{
}
u32 CNetServerSession::GetIPAddress() const
{
return m_Peer->address.host;
}
u32 CNetServerSession::GetLastReceivedTime() const
{
if (!m_Peer)
return 0;
return enet_time_get() - m_Peer->lastReceiveTime;
}
u32 CNetServerSession::GetMeanRTT() const
{
if (!m_Peer)
return 0;
return m_Peer->roundTripTime;
}
void CNetServerSession::Disconnect(NetDisconnectReason reason)
{
if (reason == NDR_UNKNOWN)
LOGWARNING("Disconnecting client without communicating the disconnect reason!");
Update((uint)NMT_CONNECTION_LOST, NULL);
const ngtcp2_ccerr quicReason{
.type{NGTCP2_CCERR_TYPE_APPLICATION},
.error_code{reason}
};
ngtcp2_sockaddr_union local;
ngtcp2_sockaddr_union remote;
ngtcp2_path path{
.local{.addr{&local.sa}},
.remote{.addr{&remote.sa}}
};
std::array<std::uint8_t, MAX_UDP_PAYLOAD_SIZE> buffer;
const ngtcp2_ssize amount{ngtcp2_conn_write_connection_close(m_Connection.m_QuicConnection.get(), &path, nullptr, buffer.data(),
buffer.size(), &quicReason, timestamp())};
if (amount <= 0)
LOGERROR("closing connection %s", ngtcp2_strerror(static_cast<int>(amount)));
SendPacket(m_SocketFd, {buffer.data(), static_cast<std::size_t>(amount)}, path.remote);
}
bool CNetServerSession::SendMessage(const CNetMessage* message)
{
m_Connection.m_Stream.value().PushMessage(message);
return true;
}