mirror of
https://gitea.wildfiregames.com/0ad/0ad
synced 2026-06-16 05:13:58 -07:00
564 lines
15 KiB
C++
564 lines
15 KiB
C++
/* Copyright (C) 2026 Wildfire Games.
|
|
* This file is part of 0 A.D.
|
|
*
|
|
* 0 A.D. is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 2 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* 0 A.D. is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
#include "precompiled.h"
|
|
|
|
#include "NetClientSession.h"
|
|
|
|
#include "lib/code_generation.h"
|
|
#include "lib/debug.h"
|
|
#include "network/NetClient.h"
|
|
#include "network/NetMessage.h"
|
|
#include "network/NetProtocol.h"
|
|
#include "network/NetStats.h"
|
|
#include "ps/CLogger.h"
|
|
#include "ps/ProfileViewer.h"
|
|
|
|
#include <cstddef>
|
|
|
|
constexpr int NETCLIENT_POLL_TIMEOUT = 50;
|
|
|
|
#include <time.h>
|
|
#include <sys/types.h>
|
|
#include <sys/socket.h>
|
|
#include <netdb.h>
|
|
#include <arpa/inet.h>
|
|
#include <string.h>
|
|
#include <stdio.h>
|
|
#include <errno.h>
|
|
#include <poll.h>
|
|
|
|
#include <ngtcp2/ngtcp2.h>
|
|
#include <ngtcp2/ngtcp2_crypto.h>
|
|
#include <ngtcp2/ngtcp2_crypto_gnutls.h>
|
|
|
|
#include <gnutls/crypto.h>
|
|
#include <gnutls/gnutls.h>
|
|
|
|
struct CNetClientSession::Quic
|
|
{
|
|
AddressStorage localAddress;
|
|
std::unique_ptr<gnutls_certificate_credentials_st, CredentialsDeleter> credentials;
|
|
std::unique_ptr<gnutls_session_int, SessionDeleter> session;
|
|
std::unique_ptr<ngtcp2_conn, ConnectionDeleter> quicConnection;
|
|
ngtcp2_crypto_conn_ref connectionReference;
|
|
int fd;
|
|
|
|
std::optional<Stream> streams;
|
|
};
|
|
|
|
namespace
|
|
{
|
|
|
|
struct CreateSocketResult
|
|
{
|
|
int descriptor;
|
|
AddressStorage address;
|
|
};
|
|
CreateSocketResult CreateSocket(const char* host, const std::uint16_t port)
|
|
{
|
|
addrinfo hints{};
|
|
hints.ai_flags = AF_UNSPEC;
|
|
hints.ai_socktype = SOCK_DGRAM;
|
|
|
|
addrinfo* res;
|
|
const int rv{getaddrinfo(host, fmt::format("{}", port).c_str(), &hints, &res)};
|
|
if (rv)
|
|
throw std::runtime_error{fmt::format("getaddrinfo: {}", gai_strerror(rv))};
|
|
std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> infoList{res, &freeaddrinfo};
|
|
|
|
addrinfo* rp;
|
|
int fd{-1};
|
|
for (rp = res; rp; rp = rp->ai_next)
|
|
{
|
|
fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
|
|
if (fd != -1)
|
|
break;
|
|
}
|
|
|
|
if (fd == -1)
|
|
throw std::runtime_error{"unable to create a socket"};
|
|
|
|
return CreateSocketResult{
|
|
.descriptor{fd},
|
|
.address{.address{*rp->ai_addr}, .length{rp->ai_addrlen}}
|
|
};
|
|
}
|
|
|
|
AddressStorage ConnectSocket(const int fd, const AddressStorage& remoteAddress)
|
|
{
|
|
if (connect(fd, &remoteAddress.address.sa, remoteAddress.length))
|
|
throw std::runtime_error{fmt::format("connect: {}", strerror(errno))};
|
|
|
|
ngtcp2_sockaddr_union localAddress;
|
|
ngtcp2_socklen localAddressLength{sizeof(localAddress)};
|
|
if (getsockname(fd, &localAddress.sa, &localAddressLength) == -1)
|
|
throw std::runtime_error{fmt::format("getsockname: {}", strerror(errno))};
|
|
|
|
return {localAddress, localAddressLength};
|
|
}
|
|
|
|
void ClientGnutlsInit(CNetClientSession::Quic* c)
|
|
{
|
|
gnutls_certificate_credentials_t tempCred;
|
|
const int allocRet{gnutls_certificate_allocate_credentials(&tempCred)};
|
|
if (allocRet)
|
|
{
|
|
throw std::runtime_error{fmt::format("cred init failed: {}: {}", allocRet,
|
|
gnutls_strerror(allocRet))};
|
|
}
|
|
c->credentials.reset(tempCred);
|
|
|
|
gnutls_session_t tempSession;
|
|
if (const int initRet{gnutls_init(&tempSession, GNUTLS_CLIENT | GNUTLS_ENABLE_EARLY_DATA |
|
|
GNUTLS_NO_END_OF_EARLY_DATA)})
|
|
{
|
|
throw std::runtime_error{fmt::format("gnutls_init: {}", gnutls_strerror(initRet))};
|
|
}
|
|
c->session.reset(tempSession);
|
|
|
|
if (ngtcp2_crypto_gnutls_configure_client_session(c->session.get()))
|
|
throw std::runtime_error{"ngtcp2_crypto_gnutls_configure_client_session failed"};
|
|
|
|
if (const int priorityRet{gnutls_priority_set_direct(c->session.get(), TLS_PRIORITY, nullptr)})
|
|
{
|
|
throw std::runtime_error{fmt::format("gnutls_priority_set_direct: {}",
|
|
gnutls_strerror(priorityRet))};
|
|
}
|
|
|
|
gnutls_session_set_ptr(c->session.get(), &c->connectionReference);
|
|
|
|
if (const int setRet{gnutls_credentials_set(c->session.get(), GNUTLS_CRD_CERTIFICATE,
|
|
c->credentials.get())})
|
|
{
|
|
throw std::runtime_error{fmt::format("gnutls_credentials_set: {}", gnutls_strerror(setRet))};
|
|
}
|
|
}
|
|
|
|
int OpenStream(ngtcp2_conn*, const std::int64_t streamId, void* userData)
|
|
{
|
|
CNetClientSession& session{*static_cast<CNetClientSession*>(userData)};
|
|
|
|
session.m_Connected = true;
|
|
session.m_WasConnected = true;
|
|
session.m_Quic->streams.emplace(streamId);
|
|
session.m_IncomingMessages.push(CNetClientSession::ConnectionEstablished{});
|
|
return 0;
|
|
}
|
|
|
|
int OnStreamDataReceive(ngtcp2_conn* conn, const std::uint32_t /*flags*/, const std::int64_t streamId,
|
|
const std::size_t /*offset*/, const std::uint8_t* data, std::size_t dataSize, void* userData, void*)
|
|
{
|
|
auto& session = *static_cast<CNetClientSession*>(userData);
|
|
auto message = session.m_Quic->streams.value().Receive({data, dataSize});
|
|
if (message.has_value())
|
|
session.m_IncomingMessages.push(new std::vector<std::uint8_t>{std::move(message).value()});
|
|
increaseWindow(conn, streamId, dataSize);
|
|
return 0;
|
|
}
|
|
|
|
void ClientQuicInit(CNetClientSession& session, const AddressStorage& remote, const AddressStorage& local)
|
|
{
|
|
const ngtcp2_path path{
|
|
.local{
|
|
.addr{const_cast<ngtcp2_sockaddr*>(&local.address.sa)},
|
|
.addrlen{local.length},
|
|
},
|
|
.remote{
|
|
.addr{const_cast<ngtcp2_sockaddr*>(&remote.address.sa)},
|
|
.addrlen{remote.length},
|
|
}
|
|
};
|
|
constexpr ngtcp2_callbacks callbacks{
|
|
.client_initial{&ngtcp2_crypto_client_initial_cb},
|
|
.recv_crypto_data{&ngtcp2_crypto_recv_crypto_data_cb},
|
|
.encrypt{&ngtcp2_crypto_encrypt_cb},
|
|
.decrypt{&ngtcp2_crypto_decrypt_cb},
|
|
.hp_mask{&ngtcp2_crypto_hp_mask_cb},
|
|
.recv_stream_data{&OnStreamDataReceive},
|
|
.stream_open{&OpenStream},
|
|
.recv_retry{&ngtcp2_crypto_recv_retry_cb},
|
|
.rand{&OnRandomRequest},
|
|
.get_new_connection_id{&OnNewConnectionIdRequest},
|
|
.update_key{&ngtcp2_crypto_update_key_cb},
|
|
.delete_crypto_aead_ctx{&ngtcp2_crypto_delete_crypto_aead_ctx_cb},
|
|
.delete_crypto_cipher_ctx{&ngtcp2_crypto_delete_crypto_cipher_ctx_cb},
|
|
.get_path_challenge_data{&ngtcp2_crypto_get_path_challenge_data_cb},
|
|
.version_negotiation{&ngtcp2_crypto_version_negotiation_cb}
|
|
};
|
|
ngtcp2_cid dcid;
|
|
dcid.datalen = NGTCP2_MIN_INITIAL_DCIDLEN;
|
|
if (gnutls_rnd(GNUTLS_RND_RANDOM, dcid.data, dcid.datalen))
|
|
throw std::runtime_error{"gnutls_rnd failed"};
|
|
|
|
ngtcp2_cid scid;
|
|
scid.datalen = 8;
|
|
if (gnutls_rnd(GNUTLS_RND_RANDOM, scid.data, scid.datalen))
|
|
throw std::runtime_error{"gnutls_rnd failed"};
|
|
|
|
ngtcp2_settings settings;
|
|
ngtcp2_settings_default(&settings);
|
|
settings.initial_ts = timestamp();
|
|
|
|
ngtcp2_transport_params params;
|
|
ngtcp2_transport_params_default(¶ms);
|
|
params.initial_max_stream_data_bidi_remote = 128 * KiB;
|
|
params.initial_max_data = 1 * MiB;
|
|
params.initial_max_streams_bidi = 1;
|
|
params.max_udp_payload_size = MAX_UDP_PAYLOAD_SIZE;
|
|
params.grease_quic_bit = 1;
|
|
|
|
ngtcp2_conn* tempConn;
|
|
if (const int rv{ngtcp2_conn_client_new(&tempConn, &dcid, &scid, &path,
|
|
NGTCP2_PROTO_VER_V1, &callbacks, &settings, ¶ms, nullptr, &session)})
|
|
{
|
|
throw std::runtime_error{fmt::format("ngtcp2_conn_client_new: {}", ngtcp2_strerror(rv))};
|
|
}
|
|
session.m_Quic->quicConnection.reset(tempConn);
|
|
|
|
ngtcp2_conn_set_tls_native_handle(session.m_Quic->quicConnection.get(),
|
|
session.m_Quic->session.get());
|
|
}
|
|
|
|
void ClientRead(CNetClientSession::Quic* c) {
|
|
std::array<std::uint8_t, MAX_UDP_PAYLOAD_SIZE> buf;
|
|
struct sockaddr_storage addr;
|
|
iovec iov{
|
|
.iov_base = buf.data(),
|
|
.iov_len = buf.size(),
|
|
};
|
|
msghdr msg{};
|
|
msg.msg_name = &addr;
|
|
msg.msg_iov = &iov;
|
|
msg.msg_iovlen = 1;
|
|
|
|
ngtcp2_pkt_info pi{};
|
|
while (true)
|
|
{
|
|
msg.msg_namelen = sizeof(addr);
|
|
|
|
const ssize_t nread{recvmsg(c->fd, &msg, MSG_DONTWAIT)};
|
|
|
|
if (nread == -1)
|
|
{
|
|
if (errno != EAGAIN && errno != EWOULDBLOCK)
|
|
LOGERROR("recvmsg: %s", strerror(errno));
|
|
|
|
break;
|
|
}
|
|
|
|
ngtcp2_path path{
|
|
.local{
|
|
.addr{&c->localAddress.address.sa},
|
|
.addrlen{c->localAddress.length}
|
|
},
|
|
.remote{
|
|
.addr{static_cast<ngtcp2_sockaddr*>(msg.msg_name)},
|
|
.addrlen{msg.msg_namelen}
|
|
}
|
|
};
|
|
|
|
const int rv{ngtcp2_conn_read_pkt(c->quicConnection.get(), &path, &pi, buf.data(),
|
|
static_cast<std::size_t>(nread), timestamp())};
|
|
if (rv != 0)
|
|
throw std::runtime_error{fmt::format("ngtcp2_conn_read_pkt: {}", ngtcp2_strerror(rv))};
|
|
}
|
|
}
|
|
|
|
void ClientSendDatagram(CNetClientSession::Quic* c, const std::span<const std::uint8_t> data)
|
|
{
|
|
iovec iov{
|
|
.iov_base = const_cast<std::uint8_t*>(data.data()),
|
|
.iov_len = data.size(),
|
|
};
|
|
msghdr msg{};
|
|
msg.msg_iov = &iov;
|
|
msg.msg_iovlen = 1;
|
|
|
|
ssize_t nwrite;
|
|
do
|
|
{
|
|
nwrite = sendmsg(c->fd, &msg, 0);
|
|
} while (nwrite == -1 && errno == EINTR);
|
|
|
|
if (nwrite == -1)
|
|
throw std::runtime_error{fmt::format("sendmsg: {}", strerror(errno))};
|
|
}
|
|
|
|
void ClientWriteStreams(CNetClientSession::Quic* c)
|
|
{
|
|
const ngtcp2_tstamp ts{timestamp()};
|
|
ngtcp2_pkt_info pi;
|
|
std::array<uint8_t, MAX_UDP_PAYLOAD_SIZE> buffer;
|
|
|
|
ngtcp2_path_storage ps;
|
|
ngtcp2_path_storage_zero(&ps);
|
|
std::uint32_t flags{NGTCP2_WRITE_STREAM_FLAG_MORE};
|
|
|
|
while (true)
|
|
{
|
|
ngtcp2_vec datav;
|
|
std::int64_t streamId;
|
|
|
|
const auto bytesToSend = c->streams.has_value() ? c->streams.value().PeekData() : std::nullopt;
|
|
if (c->streams.has_value() && bytesToSend.has_value())
|
|
{
|
|
datav.base = const_cast<uint8_t*>(bytesToSend->data());
|
|
datav.len = bytesToSend->size();
|
|
streamId = c->streams.value().m_Id;
|
|
}
|
|
else
|
|
{
|
|
datav.base = nullptr;
|
|
datav.len = 0;
|
|
streamId = -1;
|
|
if (c->streams.has_value())
|
|
flags &= ~NGTCP2_WRITE_STREAM_FLAG_MORE;
|
|
}
|
|
|
|
ngtcp2_ssize wdatalen;
|
|
const ngtcp2_ssize nwrite = ngtcp2_conn_writev_stream(c->quicConnection.get(), &ps.path, &pi,
|
|
buffer.data(), buffer.size(), &wdatalen, flags, streamId, &datav, 1, ts);
|
|
if (nwrite < 0)
|
|
{
|
|
if (nwrite == NGTCP2_ERR_STREAM_DATA_BLOCKED)
|
|
{
|
|
LOGWARNING("blocked");
|
|
break;
|
|
}
|
|
if (nwrite != NGTCP2_ERR_WRITE_MORE)
|
|
{
|
|
throw std::runtime_error{fmt::format("ngtcp2_conn_writev_stream: {}",
|
|
ngtcp2_strerror(static_cast<int>(nwrite)))};
|
|
}
|
|
if (c->streams.has_value() && wdatalen > 0)
|
|
c->streams.value().MarkSent(wdatalen);
|
|
continue;
|
|
}
|
|
|
|
if (nwrite == 0)
|
|
break;
|
|
|
|
if (c->streams.has_value() && wdatalen > 0)
|
|
c->streams.value().MarkSent(wdatalen);
|
|
|
|
ClientSendDatagram(c, {buffer.data(), static_cast<std::size_t>(nwrite)});
|
|
}
|
|
ngtcp2_conn_update_pkt_tx_time(c->quicConnection.get(), timestamp());
|
|
}
|
|
|
|
void ClientHandleExpiry(CNetClientSession::Quic* c)
|
|
{
|
|
if (const int rv = ngtcp2_conn_handle_expiry(c->quicConnection.get(), timestamp()))
|
|
throw std::runtime_error{fmt::format("ngtcp2_conn_handle_expiry: {}", ngtcp2_strerror(rv))};
|
|
}
|
|
|
|
ngtcp2_conn *GetConnection(ngtcp2_crypto_conn_ref* conn_ref)
|
|
{
|
|
CNetClientSession::Quic* c{static_cast<CNetClientSession::Quic*>(conn_ref->user_data)};
|
|
return c->quicConnection.get();
|
|
}
|
|
|
|
void ClientInit(CNetClientSession& session, const char* host, const std::uint16_t port)
|
|
{
|
|
*session.m_Quic = CNetClientSession::Quic{};
|
|
|
|
const auto [descriptor, remoteAddress] = CreateSocket(host, port);
|
|
session.m_Quic->fd = descriptor;
|
|
|
|
session.m_Quic->localAddress = ConnectSocket(session.m_Quic->fd, remoteAddress);
|
|
|
|
ClientGnutlsInit(session.m_Quic.get());
|
|
|
|
ClientQuicInit(session, remoteAddress, session.m_Quic->localAddress);
|
|
|
|
session.m_Quic->connectionReference.get_conn = &GetConnection;
|
|
session.m_Quic->connectionReference.user_data = session.m_Quic.get();
|
|
}
|
|
} // anonymous namespace
|
|
|
|
CNetClientSession::CNetClientSession(CNetClient& client) :
|
|
m_Client(client), m_FileTransferer(*this),
|
|
m_Quic(std::make_unique<Quic>())
|
|
{
|
|
}
|
|
|
|
CNetClientSession::~CNetClientSession()
|
|
{
|
|
ENSURE(!m_LoopRunning);
|
|
|
|
constexpr ngtcp2_ccerr reason{
|
|
.type{NGTCP2_CCERR_TYPE_TRANSPORT},
|
|
.error_code{0}
|
|
};
|
|
|
|
std::array<std::uint8_t, MAX_UDP_PAYLOAD_SIZE> buffer;
|
|
const ngtcp2_ssize amount{ngtcp2_conn_write_connection_close(m_Quic->quicConnection.get(), nullptr, nullptr, buffer.data(),
|
|
buffer.size(), &reason, timestamp())};
|
|
if (amount <= 0)
|
|
{
|
|
LOGERROR("closing connection %s", ngtcp2_strerror(amount));
|
|
}
|
|
|
|
ClientSendDatagram(m_Quic.get(), {buffer.data(), static_cast<std::size_t>(amount)});
|
|
}
|
|
|
|
bool CNetClientSession::Connect(const CStr& server, const u16 port)
|
|
{
|
|
ENSURE(!m_LoopRunning);
|
|
|
|
ClientInit(*this, server.c_str(), port);
|
|
ClientWriteStreams(m_Quic.get());
|
|
|
|
m_Stats = std::make_unique<CNetStatsTable>(m_Quic->quicConnection.get());
|
|
if (CProfileViewer::IsInitialised())
|
|
g_ProfileViewer.AddRootTable(m_Stats.get());
|
|
|
|
return true;
|
|
}
|
|
|
|
void CNetClientSession::RunNetLoop(CNetClientSession* session)
|
|
{
|
|
ENSURE(!session->m_LoopRunning);
|
|
session->m_LoopRunning = true;
|
|
|
|
debug_SetThreadName("NetClientSession loop");
|
|
|
|
while (!session->m_ShouldShutdown)
|
|
{
|
|
// ENSURE(session->m_Host && session->m_Server);
|
|
|
|
session->m_FileTransferer.Poll();
|
|
try {
|
|
session->Poll();
|
|
}
|
|
catch (std::runtime_error&)
|
|
{
|
|
// Report immediately.
|
|
LOGMESSAGE("Net client: Disconnected");
|
|
session->m_Connected = false;
|
|
session->m_IncomingMessages.push(Disconnect{});
|
|
return;
|
|
}
|
|
session->Flush();
|
|
|
|
// session->m_LastReceivedTime = timestamp() - session->m_Server->lastReceiveTime;
|
|
// session->m_MeanRTT = session->m_Server->roundTripTime;
|
|
}
|
|
|
|
session->m_LoopRunning = false;
|
|
|
|
// Deleting the session is handled in this thread as it might outlive the CNetClient.
|
|
SAFE_DELETE(session);
|
|
}
|
|
|
|
void CNetClientSession::Shutdown()
|
|
{
|
|
m_ShouldShutdown = true;
|
|
}
|
|
|
|
void CNetClientSession::Poll()
|
|
{
|
|
pollfd pfd{
|
|
.fd{m_Quic->fd},
|
|
.events{POLLIN}
|
|
};
|
|
|
|
const int ret{poll(&pfd, 1, NETCLIENT_POLL_TIMEOUT)};
|
|
if (ret < 0)
|
|
{
|
|
LOGERROR("Error while waiting for poll: %s", std::strerror(errno));
|
|
return;
|
|
}
|
|
if (ret == 0)
|
|
{
|
|
ClientHandleExpiry(m_Quic.get());
|
|
ClientWriteStreams(m_Quic.get());
|
|
return;
|
|
}
|
|
|
|
ClientRead(m_Quic.get());
|
|
ClientWriteStreams(m_Quic.get());
|
|
}
|
|
|
|
void CNetClientSession::Flush()
|
|
{
|
|
std::vector<std::uint8_t>* message;
|
|
while (m_OutgoingMessages.pop(message))
|
|
{
|
|
std::unique_ptr<std::vector<std::uint8_t>> data{message};
|
|
if (m_Quic->streams.has_value())
|
|
m_Quic->streams.value().PushData(std::move(*data));
|
|
else
|
|
LOGERROR("no stream to send message");
|
|
}
|
|
}
|
|
|
|
void CNetClientSession::ProcessPolledMessages()
|
|
{
|
|
IncommingMessage query{};
|
|
while(m_IncomingMessages.pop(query))
|
|
{
|
|
std::visit([&]<typename Message>(Message message)
|
|
{
|
|
if constexpr (std::same_as<Message, ConnectionEstablished>)
|
|
{
|
|
m_Client.HandleConnect();
|
|
}
|
|
else if constexpr (std::same_as<Message, Disconnect>)
|
|
{
|
|
m_Client.HandleDisconnect(NDR_UNKNOWN);
|
|
}
|
|
else
|
|
{
|
|
static_assert(std::same_as<Message, std::vector<std::uint8_t>*>);
|
|
std::unique_ptr<std::vector<std::uint8_t>> data{message};
|
|
CNetMessage* msg = CNetMessageFactory::CreateMessage(*data, m_Client.GetScriptInterface());
|
|
if (msg)
|
|
{
|
|
LOGMESSAGE("Net client: Received message %s of size %lu from server", msg->ToString().c_str(), (unsigned long)msg->GetSerializedLength());
|
|
|
|
m_Client.HandleMessage(msg);
|
|
}
|
|
}
|
|
}, query);
|
|
}
|
|
}
|
|
|
|
bool CNetClientSession::SendMessage(const CNetMessage* message)
|
|
{
|
|
// ENSURE(m_Host && m_Server);
|
|
|
|
if (!m_OutgoingMessages.push(new std::vector{CNetHost::CreatePacket(message)}))
|
|
{
|
|
LOGERROR("NetClient: Failed to push message on the outgoing queue.");
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
u32 CNetClientSession::GetLastReceivedTime() const
|
|
{
|
|
return m_LastReceivedTime;
|
|
}
|
|
|
|
u32 CNetClientSession::GetMeanRTT() const
|
|
{
|
|
return m_MeanRTT;
|
|
}
|
|
|