dolphin/Source/Core/Common/TraversalClient.cpp
Pierre Bourdon f4e34703c0
licensing: convert "public domain" to CC0 1.0
Public domain does not have an internationally agreed upon definition,
As such it's generally preferred to use an extremely liberal license,
which can explicitly list the rights granted by the copyright holder.
The CC0 license is the usual choice here.

This "relicensing" is done without hunting down copyright holders, since
it is presumed that their release of this work into the public domain
authorizes us to redistribute this code under any other license of our
choosing.
2021-07-05 04:43:55 +02:00

343 lines
8.5 KiB
C++

// SPDX-License-Identifier: CC0-1.0
#include "Common/TraversalClient.h"
#include <cstddef>
#include <cstring>
#include <string>
#include "Common/CommonTypes.h"
#include "Common/Logging/Log.h"
#include "Common/MsgHandler.h"
#include "Common/Random.h"
#include "Core/NetPlayProto.h"
TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port)
: m_NetHost(netHost), m_Server(server), m_port(port)
{
netHost->intercept = TraversalClient::InterceptCallback;
Reset();
ReconnectToServer();
}
TraversalClient::~TraversalClient() = default;
TraversalHostId TraversalClient::GetHostID() const
{
return m_HostId;
}
TraversalClient::State TraversalClient::GetState() const
{
return m_State;
}
TraversalClient::FailureReason TraversalClient::GetFailureReason() const
{
return m_FailureReason;
}
void TraversalClient::ReconnectToServer()
{
if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
{
OnFailure(FailureReason::BadHost);
return;
}
m_ServerAddress.port = m_port;
m_State = State::Connecting;
TraversalPacket hello = {};
hello.type = TraversalPacketType::HelloFromClient;
hello.helloFromClient.protoVersion = TraversalProtoVersion;
SendTraversalPacket(hello);
if (m_Client)
m_Client->OnTraversalStateChanged();
}
static ENetAddress MakeENetAddress(const TraversalInetAddress& address)
{
ENetAddress eaddr{};
if (address.isIPV6)
{
eaddr.port = 0; // no support yet :(
}
else
{
eaddr.host = address.address[0];
eaddr.port = ntohs(address.port);
}
return eaddr;
}
void TraversalClient::ConnectToClient(std::string_view host)
{
if (host.size() > sizeof(TraversalHostId))
{
PanicAlertFmt("Host too long");
return;
}
TraversalPacket packet = {};
packet.type = TraversalPacketType::ConnectPlease;
memcpy(packet.connectPlease.hostId.data(), host.data(), host.size());
m_ConnectRequestId = SendTraversalPacket(packet);
m_PendingConnect = true;
}
bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
{
if (from->host == m_ServerAddress.host && from->port == m_ServerAddress.port)
{
if (size < sizeof(TraversalPacket))
{
ERROR_LOG_FMT(NETPLAY, "Received too-short traversal packet.");
}
else
{
HandleServerPacket((TraversalPacket*)data);
return true;
}
}
return false;
}
//--Temporary until more of the old netplay branch is moved over
void TraversalClient::Update()
{
ENetEvent netEvent;
if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
{
switch (netEvent.type)
{
case ENET_EVENT_TYPE_RECEIVE:
TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
enet_packet_destroy(netEvent.packet);
break;
default:
break;
}
}
HandleResends();
}
void TraversalClient::HandleServerPacket(TraversalPacket* packet)
{
u8 ok = 1;
switch (packet->type)
{
case TraversalPacketType::Ack:
if (!packet->ack.ok)
{
OnFailure(FailureReason::ServerForgotAboutUs);
break;
}
for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
{
if (it->packet.requestId == packet->requestId)
{
m_OutgoingTraversalPackets.erase(it);
break;
}
}
break;
case TraversalPacketType::HelloFromServer:
if (!IsConnecting())
break;
if (!packet->helloFromServer.ok)
{
OnFailure(FailureReason::VersionTooOld);
break;
}
m_HostId = packet->helloFromServer.yourHostId;
m_State = State::Connected;
if (m_Client)
m_Client->OnTraversalStateChanged();
break;
case TraversalPacketType::PleaseSendPacket:
{
// security is overrated.
ENetAddress addr = MakeENetAddress(packet->pleaseSendPacket.address);
if (addr.port != 0)
{
char message[] = "Hello from Dolphin Netplay...";
ENetBuffer buf;
buf.data = message;
buf.dataLength = sizeof(message) - 1;
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
}
else
{
// invalid IPV6
ok = 0;
}
break;
}
case TraversalPacketType::ConnectReady:
case TraversalPacketType::ConnectFailed:
{
if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
break;
m_PendingConnect = false;
if (!m_Client)
break;
if (packet->type == TraversalPacketType::ConnectReady)
m_Client->OnConnectReady(MakeENetAddress(packet->connectReady.address));
else
m_Client->OnConnectFailed(packet->connectFailed.reason);
break;
}
default:
WARN_LOG_FMT(NETPLAY, "Received unknown packet with type {}", packet->type);
break;
}
if (packet->type != TraversalPacketType::Ack)
{
TraversalPacket ack = {};
ack.type = TraversalPacketType::Ack;
ack.requestId = packet->requestId;
ack.ack.ok = ok;
ENetBuffer buf;
buf.data = &ack;
buf.dataLength = sizeof(ack);
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
OnFailure(FailureReason::SocketSendError);
}
}
void TraversalClient::OnFailure(FailureReason reason)
{
m_State = State::Failure;
m_FailureReason = reason;
if (m_Client)
m_Client->OnTraversalStateChanged();
}
void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
{
info->sendTime = enet_time_get();
info->tries++;
ENetBuffer buf;
buf.data = &info->packet;
buf.dataLength = sizeof(info->packet);
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
OnFailure(FailureReason::SocketSendError);
}
void TraversalClient::HandleResends()
{
const u32 now = enet_time_get();
for (auto& tpi : m_OutgoingTraversalPackets)
{
if (now - tpi.sendTime >= (u32)(300 * tpi.tries))
{
if (tpi.tries >= 5)
{
OnFailure(FailureReason::ResendTimeout);
m_OutgoingTraversalPackets.clear();
break;
}
else
{
ResendPacket(&tpi);
}
}
}
HandlePing();
}
void TraversalClient::HandlePing()
{
const u32 now = enet_time_get();
if (IsConnected() && now - m_PingTime >= 500)
{
TraversalPacket ping = {};
ping.type = TraversalPacketType::Ping;
ping.ping.hostId = m_HostId;
SendTraversalPacket(ping);
m_PingTime = now;
}
}
TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
{
OutgoingTraversalPacketInfo info;
info.packet = packet;
info.packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
info.tries = 0;
m_OutgoingTraversalPackets.push_back(info);
ResendPacket(&m_OutgoingTraversalPackets.back());
return info.packet.requestId;
}
void TraversalClient::Reset()
{
m_PendingConnect = false;
m_Client = nullptr;
}
int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
{
auto traversalClient = g_TraversalClient.get();
if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength,
&host->receivedAddress) ||
(host->receivedDataLength == 1 && host->receivedData[0] == 0))
{
event->type = (ENetEventType)42;
return 1;
}
return 0;
}
std::unique_ptr<TraversalClient> g_TraversalClient;
std::unique_ptr<ENetHost> g_MainNetHost;
// The settings at the previous TraversalClient reset - notably, we
// need to know not just what port it's on, but whether it was
// explicitly requested.
static std::string g_OldServer;
static u16 g_OldServerPort;
static u16 g_OldListenPort;
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 listen_port)
{
if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer ||
server_port != g_OldServerPort || listen_port != g_OldListenPort)
{
g_OldServer = server;
g_OldServerPort = server_port;
g_OldListenPort = listen_port;
ENetAddress addr = {ENET_HOST_ANY, listen_port};
ENetHost* host = enet_host_create(&addr, // address
50, // peerCount
NetPlay::CHANNEL_COUNT, // channelLimit
0, // incomingBandwidth
0); // outgoingBandwidth
if (!host)
{
g_MainNetHost.reset();
return false;
}
g_MainNetHost.reset(host);
g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server, server_port));
}
return true;
}
void ReleaseTraversalClient()
{
if (!g_TraversalClient)
return;
g_TraversalClient.reset();
g_MainNetHost.reset();
}