diff --git a/Source/Core/Common/Network.cpp b/Source/Core/Common/Network.cpp index 2c548347a3..faa2a33c05 100644 --- a/Source/Core/Common/Network.cpp +++ b/Source/Core/Common/Network.cpp @@ -186,4 +186,22 @@ u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value) checksum = (checksum >> 16) + (checksum & 0xFFFF); return ~static_cast(checksum); } + +NetworkErrorState SaveNetworkErrorState() +{ + return { + errno, +#ifdef _WIN32 + WSAGetLastError(), +#endif + }; +} + +void RestoreNetworkErrorState(const NetworkErrorState& state) +{ + errno = state.error; +#ifdef _WIN32 + WSASetLastError(state.wsa_error); +#endif +} } // namespace Common diff --git a/Source/Core/Common/Network.h b/Source/Core/Common/Network.h index a1fcc9ea42..0b56e09d34 100644 --- a/Source/Core/Common/Network.h +++ b/Source/Core/Common/Network.h @@ -99,8 +99,18 @@ struct UDPHeader }; static_assert(sizeof(UDPHeader) == UDPHeader::SIZE); +struct NetworkErrorState +{ + int error; +#ifdef _WIN32 + int wsa_error; +#endif +}; + MACAddress GenerateMacAddress(MACConsumer type); std::string MacAddressToString(const MACAddress& mac); std::optional StringToMacAddress(std::string_view mac_string); u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value = 0); +NetworkErrorState SaveNetworkErrorState(); +void RestoreNetworkErrorState(const NetworkErrorState& state); } // namespace Common diff --git a/Source/Core/Core/NetworkCaptureLogger.cpp b/Source/Core/Core/NetworkCaptureLogger.cpp index 2e4242fd88..064b5ed25b 100644 --- a/Source/Core/Core/NetworkCaptureLogger.cpp +++ b/Source/Core/Core/NetworkCaptureLogger.cpp @@ -15,6 +15,7 @@ #include "Common/IOFile.h" #include "Common/Network.h" #include "Common/PcapFile.h" +#include "Common/ScopeGuard.h" #include "Core/Config/MainSettings.h" #include "Core/ConfigManager.h" @@ -90,24 +91,6 @@ void PCAPSSLCaptureLogger::OnNewSocket(s32 socket) m_write_sequence_number[socket] = 0; } -PCAPSSLCaptureLogger::ErrorState PCAPSSLCaptureLogger::SaveState() const -{ - return { - errno, -#ifdef _WIN32 - WSAGetLastError(), -#endif - }; -} - -void PCAPSSLCaptureLogger::RestoreState(const PCAPSSLCaptureLogger::ErrorState& state) const -{ - errno = state.error; -#ifdef _WIN32 - WSASetLastError(state.wsa_error); -#endif -} - void PCAPSSLCaptureLogger::LogSSLRead(const void* data, std::size_t length, s32 socket) { if (!Config::Get(Config::MAIN_NETWORK_SSL_DUMP_READ)) @@ -135,7 +118,8 @@ void PCAPSSLCaptureLogger::LogWrite(const void* data, std::size_t length, s32 so void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t length, s32 socket, sockaddr* other) { - const auto state = SaveState(); + const auto state = Common::SaveNetworkErrorState(); + Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); }); sockaddr_in sock; sockaddr_in peer; sockaddr_in* from; @@ -144,16 +128,10 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l socklen_t peer_len = sizeof(sock); if (getsockname(socket, reinterpret_cast(&sock), &sock_len) != 0) - { - RestoreState(state); return; - } if (other == nullptr && getpeername(socket, reinterpret_cast(&peer), &peer_len) != 0) - { - RestoreState(state); return; - } if (log_type == LogType::Read) { @@ -168,7 +146,6 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l LogIPv4(log_type, reinterpret_cast(data), static_cast(length), socket, *from, *to); - RestoreState(state); } void PCAPSSLCaptureLogger::LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, diff --git a/Source/Core/Core/NetworkCaptureLogger.h b/Source/Core/Core/NetworkCaptureLogger.h index 032fde267c..d5067252d4 100644 --- a/Source/Core/Core/NetworkCaptureLogger.h +++ b/Source/Core/Core/NetworkCaptureLogger.h @@ -99,15 +99,6 @@ private: Read, Write, }; - struct ErrorState - { - int error; -#ifdef _WIN32 - int wsa_error; -#endif - }; - ErrorState SaveState() const; - void RestoreState(const ErrorState& state) const; void Log(LogType log_type, const void* data, std::size_t length, s32 socket, sockaddr* other); void LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, const sockaddr_in& from,